Skip to content

Commit

Permalink
Add sqlc.slice() to support IN clauses in MySQL (sqlc-dev#695)
Browse files Browse the repository at this point in the history
This feature (currently MySQL-specific) allows passing in a slice to an
IN clause. Adding the new function sqlc.slice() as opposed to overloading
the parsing of "IN (?)" was chosen to guarantee backwards compatibility.

   SELECT * FROM tab WHERE col IN (sqlc.slice("go_param_name"))

This commit is based on sqlc-dev#1312 by
Paul Cameron. I just rebased and did some cleanup.

Co-authored-by: Paul Cameron <[email protected]>
  • Loading branch information
Jille and cameronpm committed Aug 23, 2022
1 parent 32010e9 commit 8ca1ff1
Show file tree
Hide file tree
Showing 37 changed files with 1,074 additions and 171 deletions.
110 changes: 110 additions & 0 deletions docs/howto/select.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ func (q *Queries) GetInfoForAuthor(ctx context.Context, id int) (GetInfoForAutho

## Passing a slice as a parameter to a query

### PostgreSQL

In PostgreSQL,
[ANY](https://www.postgresql.org/docs/current/functions-comparisons.html#id-1.5.8.28.16)
allows you to check if a value exists in an array expression. Queries using ANY
Expand Down Expand Up @@ -262,3 +264,111 @@ func (q *Queries) ListAuthorsByIDs(ctx context.Context, ids []int) ([]Author, er
return items, nil
}
```

### MySQL

MySQL differs from PostgreSQL in that placeholders must be generated based on
the number of elements in the slice you pass in. Though trivial it is still
something of a nuisance. The passed in slice must not be nil or empty or an
error will be returned (ie not a panic). The placeholder insertion location is
marked by the meta-function `sqlc.slice()` (which is similar to `sqlc.arg()`
that you see documented under [Naming parameters](named_parameters.md)).

To rephrase, the `sqlc.slice('param')` behaves identically to `sqlc.arg()` it
terms of how it maps the explicit argument to the function signature, eg:

* `sqlc.slice('ids')` maps to `ids []GoType` in the function signature
* `sqlc.slice(cust_ids)` maps to `custIds []GoType` in the function signature
(like `sqlc.arg()`, the parameter does not have to be quoted)

This feature is not compatible with `emit_prepared_queries` statement found in the
[Configuration file](../reference/config.md).

```sql
CREATE TABLE authors (
id SERIAL PRIMARY KEY,
bio text NOT NULL,
birth_year int NOT NULL
);

-- name: ListAuthorsByIDs :many
SELECT * FROM authors
WHERE id IN (sqlc.slice('ids'));
```

The above SQL will generate the following code:

```go
package db

import (
"context"
"database/sql"
"fmt"
"strings"
)

type Author struct {
ID int
Bio string
BirthYear int
}

type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}

func New(db DBTX) *Queries {
return &Queries{db: db}
}

type Queries struct {
db DBTX
}

func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}
}

const listAuthorsByIDs = `-- name: ListAuthorsByIDs :many
SELECT id, bio, birth_year FROM authors
WHERE id IN (/*SLICE:ids*/?)
`

func (q *Queries) ListAuthorsByIDs(ctx context.Context, ids []int64) ([]Author, error) {
sql := listAuthorsByIDs
var queryParams []interface{}
if len(ids) == 0 {
return nil, fmt.Errorf("slice ids must have at least one element")
}
for _, v := range ids {
queryParams = append(queryParams, v)
}
sql = strings.Replace(sql, "/*SLICE:ids*/?", strings.Repeat(",?", len(ids))[1:], 1)
rows, err := q.db.QueryContext(ctx, sql, queryParams...)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Author
for rows.Next() {
var i Author
if err := rows.Scan(&i.ID, &i.Bio, &i.BirthYear); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
```
1 change: 1 addition & 0 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
Length: int32(l),
IsNamedParam: c.IsNamedParam,
IsFuncCall: c.IsFuncCall,
IsSqlcSlice: c.IsSqlcSlice,
}

if c.Type != nil {
Expand Down
5 changes: 5 additions & 0 deletions internal/codegen/golang/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Field struct {
Type string
Tags map[string]string
Comment string
Column *plugin.Column
}

func (gf Field) Tag() string {
Expand All @@ -28,6 +29,10 @@ func (gf Field) Tag() string {
return strings.Join(tags, " ")
}

func (gf Field) HasSqlcSlice() bool {
return gf.Column.IsSqlcSlice
}

func JSONTagName(name string, settings *plugin.Settings) string {
style := settings.Go.JsonTagsCaseStyle
if style == "" || style == "none" {
Expand Down
100 changes: 82 additions & 18 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,63 @@ func (t *tmplCtx) OutputQuery(sourceName string) bool {
return t.SourceName == sourceName
}

func (t *tmplCtx) codegenDbarg() string {
if t.EmitMethodsWithDBArgument {
return "db DBTX, "
}
return ""
}

// Called as a global method since subtemplate queryCodeStdExec does not have
// access to the toplevel tmplCtx
func (t *tmplCtx) codegenEmitPreparedQueries() bool {
return t.EmitPreparedQueries
}

func (t *tmplCtx) codegenQueryMethod(q Query) string {
db := "q.db"
if t.EmitMethodsWithDBArgument {
db = "db"
}

switch q.Cmd {
case ":one":
if t.EmitPreparedQueries {
return "q.queryRow"
}
return db + ".QueryRowContext"

case ":many":
if t.EmitPreparedQueries {
return "q.query"
}
return db + ".QueryContext"

default:
if t.EmitPreparedQueries {
return "q.exec"
}
return db + ".ExecContext"
}
}

func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
switch q.Cmd {
case ":one":
return "row :=", nil
case ":many":
return "rows, err :=", nil
case ":exec":
return "_, err :=", nil
case ":execrows":
return "result, err :=", nil
case ":execresult":
return "return", nil
default:
return "", fmt.Errorf("unhandled q.Cmd case %q", q.Cmd)
}
}

func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
enums := buildEnums(req)
structs := buildStructs(req)
Expand All @@ -61,24 +118,6 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
Structs: structs,
}

funcMap := template.FuncMap{
"lowerTitle": sdk.LowerTitle,
"comment": sdk.DoubleSlashComment,
"escape": sdk.EscapeBacktick,
"imports": i.Imports,
"hasPrefix": strings.HasPrefix,
}

tmpl := template.Must(
template.New("table").
Funcs(funcMap).
ParseFS(
templates,
"templates/*.tmpl",
"templates/*/*.tmpl",
),
)

golang := req.Settings.Go
tctx := tmplCtx{
EmitInterface: golang.EmitInterface,
Expand Down Expand Up @@ -108,6 +147,31 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
return nil, errors.New(":batch* commands are only supported by pgx")
}

funcMap := template.FuncMap{
"lowerTitle": sdk.LowerTitle,
"comment": sdk.DoubleSlashComment,
"escape": sdk.EscapeBacktick,
"imports": i.Imports,
"hasPrefix": strings.HasPrefix,

// These methods are Go specific, they do not belong in the codegen package
// (as that is language independent)
"dbarg": tctx.codegenDbarg,
"emitPreparedQueries": tctx.codegenEmitPreparedQueries,
"queryMethod": tctx.codegenQueryMethod,
"queryRetval": tctx.codegenQueryRetval,
}

tmpl := template.Must(
template.New("table").
Funcs(funcMap).
ParseFS(
templates,
"templates/*.tmpl",
"templates/*/*.tmpl",
),
)

output := map[string]string{}

execute := func(name, templateName string) error {
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func goType(req *plugin.CodeGenRequest, col *plugin.Column) string {
}
}
typ := goInnerType(req, col)
if col.IsArray {
if col.IsArray || col.IsSqlcSlice {
return "[]" + typ
}
return typ
Expand Down
14 changes: 13 additions & 1 deletion internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,24 @@ func (i *importer) queryImports(filename string) fileImports {
return false
}

// Search for sqlc.slice() calls
sqlcSliceScan := func() bool {
for _, q := range gq {
if q.Arg.HasSqlcSlices() {
return true
}
}
return false
}

if anyNonCopyFrom {
std["context"] = struct{}{}
}

sqlpkg := SQLPackageFromString(i.Settings.Go.SqlPackage)
if sliceScan() && sqlpkg != SQLPackagePGX {
if sqlcSliceScan() {
std["strings"] = struct{}{}
} else if sliceScan() && sqlpkg != SQLPackagePGX {
pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{}
}

Expand Down
22 changes: 20 additions & 2 deletions internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ type QueryValue struct {
Struct *Struct
Typ string
SQLPackage SQLPackage

// Column is kept so late in the generation process around to differentiate
// between mysql slices and pg arrays
Column *plugin.Column
}

func (v QueryValue) EmitStruct() bool {
Expand Down Expand Up @@ -93,14 +97,14 @@ func (v QueryValue) Params() string {
}
var out []string
if v.Struct == nil {
if strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && v.SQLPackage != SQLPackagePGX {
if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && v.SQLPackage != SQLPackagePGX {
out = append(out, "pq.Array("+v.Name+")")
} else {
out = append(out, v.Name)
}
} else {
for _, f := range v.Struct.Fields {
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && v.SQLPackage != SQLPackagePGX {
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && v.SQLPackage != SQLPackagePGX {
out = append(out, "pq.Array("+v.Name+"."+f.Name+")")
} else {
out = append(out, v.Name+"."+f.Name)
Expand All @@ -125,6 +129,20 @@ func (v QueryValue) ColumnNames() string {
return "[]string{" + strings.Join(escapedNames, ", ") + "}"
}

// When true, we have to build the arguments to q.db.QueryContext in addition to
// munging the SQL
func (v QueryValue) HasSqlcSlices() bool {
if v.Struct == nil {
return v.Column != nil && v.Column.IsSqlcSlice
}
for _, v := range v.Struct.Fields {
if v.Column.IsSqlcSlice {
return true
}
}
return false
}

func (v QueryValue) Scan() string {
var out []string
if v.Struct == nil {
Expand Down
2 changes: 2 additions & 0 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
Name: paramName(p),
Typ: goType(req, p.Column),
SQLPackage: sqlpkg,
Column: p.Column,
}
} else if len(query.Params) > 1 {
var cols []goColumn
Expand Down Expand Up @@ -294,6 +295,7 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
DBName: colName,
Type: goType(req, c.Column),
Tags: tags,
Column: c.Column,
})
if _, found := seen[baseFieldName]; !found {
seen[baseFieldName] = []int{i}
Expand Down
Loading

0 comments on commit 8ca1ff1

Please sign in to comment.