From 18dfd6971d5e842c5b792a56a7e97829387aa81d Mon Sep 17 00:00:00 2001 From: Brandur Date: Mon, 29 Jul 2024 17:35:39 -0700 Subject: [PATCH] `database/sql` driver: Quote named parameter strings Fixes a problem reported in #478 in which during a job list, strings aren't properly quoted before being sent to Postgres. This is a bit of an unfortunate problem that stems from the driver being unable to take advantage of Pgx's named parameter system, nor `database/sql`'s `sql.Named` system (because neither Pgx nor `lib/pq` implement it), which required a rough implementation of a custom named parameter system that uses string find/replace. Here, handle a number of possible argument types that `JobList` might support and make sure that the driver sends them to Postgres in an appropriate format, making sure to quote strings and escape subquotes they may have contained. This is all still a little rough and something more robust would definitely be nice, but this should all be input from River's job module, so it doesn't need to have the most robust implementation ever. Fixes #481. --- CHANGELOG.md | 3 +- .../riverdrivertest/riverdrivertest.go | 107 +++++++++++------- .../river_database_sql_driver.go | 106 +++++++++++++++-- .../river_database_sql_driver_test.go | 47 ++++++++ 4 files changed, 213 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c6f58ee..5d4a5ea6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Include `pending` state in `JobListParams` by default so pending jobs are included in `JobList` / `JobListTx` results. +- Include `pending` state in `JobListParams` by default so pending jobs are included in `JobList` / `JobListTx` results. [PR #477](https://github.com/riverqueue/river/pull/477). +- Quote strings when using `Client.JobList` functions with the `database/sql` driver. [PR #481](https://github.com/riverqueue/river/pull/481). ## [0.10.1] - 2024-07-23 diff --git a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go index 848b0686..5d8c9a54 100644 --- a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go +++ b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go @@ -1206,47 +1206,76 @@ func Exercise[TTx any](ctx context.Context, t *testing.T, t.Run("JobList", func(t *testing.T) { t.Parallel() - exec, _ := setup(ctx, t) + t.Run("ListsJobs", func(t *testing.T) { + exec, _ := setup(ctx, t) - now := time.Now().UTC() + now := time.Now().UTC() - job := testfactory.Job(ctx, t, exec, &testfactory.JobOpts{ - Attempt: ptrutil.Ptr(3), - AttemptedAt: &now, - CreatedAt: &now, - EncodedArgs: []byte(`{"encoded": "args"}`), - Errors: [][]byte{[]byte(`{"error": "message1"}`), []byte(`{"error": "message2"}`)}, - FinalizedAt: &now, - Metadata: []byte(`{"meta": "data"}`), - ScheduledAt: &now, - State: ptrutil.Ptr(rivertype.JobStateCompleted), - Tags: []string{"tag"}, - }) - - fetchedJobs, err := exec.JobList( - ctx, - fmt.Sprintf("SELECT %s FROM river_job WHERE id = @job_id_123", exec.JobListFields()), - map[string]any{"job_id_123": job.ID}, - ) - require.NoError(t, err) - require.Len(t, fetchedJobs, 1) - - fetchedJob := fetchedJobs[0] - require.Equal(t, job.Attempt, fetchedJob.Attempt) - require.Equal(t, job.AttemptedAt, fetchedJob.AttemptedAt) - require.Equal(t, job.CreatedAt, fetchedJob.CreatedAt) - require.Equal(t, job.EncodedArgs, fetchedJob.EncodedArgs) - require.Equal(t, "message1", fetchedJob.Errors[0].Error) - require.Equal(t, "message2", fetchedJob.Errors[1].Error) - require.Equal(t, job.FinalizedAt, fetchedJob.FinalizedAt) - require.Equal(t, job.Kind, fetchedJob.Kind) - require.Equal(t, job.MaxAttempts, fetchedJob.MaxAttempts) - require.Equal(t, job.Metadata, fetchedJob.Metadata) - require.Equal(t, job.Priority, fetchedJob.Priority) - require.Equal(t, job.Queue, fetchedJob.Queue) - require.Equal(t, job.ScheduledAt, fetchedJob.ScheduledAt) - require.Equal(t, job.State, fetchedJob.State) - require.Equal(t, job.Tags, fetchedJob.Tags) + job := testfactory.Job(ctx, t, exec, &testfactory.JobOpts{ + Attempt: ptrutil.Ptr(3), + AttemptedAt: &now, + CreatedAt: &now, + EncodedArgs: []byte(`{"encoded": "args"}`), + Errors: [][]byte{[]byte(`{"error": "message1"}`), []byte(`{"error": "message2"}`)}, + FinalizedAt: &now, + Metadata: []byte(`{"meta": "data"}`), + ScheduledAt: &now, + State: ptrutil.Ptr(rivertype.JobStateCompleted), + Tags: []string{"tag"}, + }) + + fetchedJobs, err := exec.JobList( + ctx, + fmt.Sprintf("SELECT %s FROM river_job WHERE id = @job_id_123", exec.JobListFields()), + map[string]any{"job_id_123": job.ID}, + ) + require.NoError(t, err) + require.Len(t, fetchedJobs, 1) + + fetchedJob := fetchedJobs[0] + require.Equal(t, job.Attempt, fetchedJob.Attempt) + require.Equal(t, job.AttemptedAt, fetchedJob.AttemptedAt) + require.Equal(t, job.CreatedAt, fetchedJob.CreatedAt) + require.Equal(t, job.EncodedArgs, fetchedJob.EncodedArgs) + require.Equal(t, "message1", fetchedJob.Errors[0].Error) + require.Equal(t, "message2", fetchedJob.Errors[1].Error) + require.Equal(t, job.FinalizedAt, fetchedJob.FinalizedAt) + require.Equal(t, job.Kind, fetchedJob.Kind) + require.Equal(t, job.MaxAttempts, fetchedJob.MaxAttempts) + require.Equal(t, job.Metadata, fetchedJob.Metadata) + require.Equal(t, job.Priority, fetchedJob.Priority) + require.Equal(t, job.Queue, fetchedJob.Queue) + require.Equal(t, job.ScheduledAt, fetchedJob.ScheduledAt) + require.Equal(t, job.State, fetchedJob.State) + require.Equal(t, job.Tags, fetchedJob.Tags) + }) + + t.Run("HandlesRequiredArgumentTypes", func(t *testing.T) { + exec, _ := setup(ctx, t) + + job1 := testfactory.Job(ctx, t, exec, &testfactory.JobOpts{Kind: ptrutil.Ptr("test_kind1")}) + job2 := testfactory.Job(ctx, t, exec, &testfactory.JobOpts{Kind: ptrutil.Ptr("test_kind2")}) + + { + fetchedJobs, err := exec.JobList( + ctx, + fmt.Sprintf("SELECT %s FROM river_job WHERE kind = @kind", exec.JobListFields()), + map[string]any{"kind": job1.Kind}, + ) + require.NoError(t, err) + require.Len(t, fetchedJobs, 1) + } + + { + fetchedJobs, err := exec.JobList( + ctx, + fmt.Sprintf("SELECT %s FROM river_job WHERE kind = any(@kind::text[])", exec.JobListFields()), + map[string]any{"kind": []string{job1.Kind, job2.Kind}}, + ) + require.NoError(t, err) + require.Len(t, fetchedJobs, 2) + } + }) }) t.Run("JobListFields", func(t *testing.T) { diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver.go b/riverdriver/riverdatabasesql/river_database_sql_driver.go index 4d3e1a8b..48625d4e 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver.go @@ -14,6 +14,7 @@ import ( "fmt" "io/fs" "math" + "strconv" "strings" "time" @@ -314,16 +315,9 @@ func (e *Executor) JobInsertUnique(ctx context.Context, params *riverdriver.JobI } func (e *Executor) JobList(ctx context.Context, query string, namedArgs map[string]any) ([]*rivertype.JobRow, error) { - // `database/sql` has an `sql.Named` system that should theoretically work - // for named parameters, but neither Pgx or lib/pq implement it, so just use - // dumb string replacement given we're only injecting a very basic value - // anyway. - for name, value := range namedArgs { - newQuery := strings.Replace(query, "@"+name, fmt.Sprintf("%v", value), 1) - if newQuery == query { - return nil, fmt.Errorf("named query parameter @%s not found in query", name) - } - query = newQuery + query, err := replaceNamed(query, namedArgs) + if err != nil { + return nil, err } rows, err := e.dbtx.QueryContext(ctx, query) @@ -364,6 +358,98 @@ func (e *Executor) JobList(ctx context.Context, query string, namedArgs map[stri return mapSliceError(items, jobRowFromInternal) } +func escapeSinglePostgresValue(value any) string { + switch typedValue := value.(type) { + case bool: + return strconv.FormatBool(typedValue) + case float32: + // The `-1` arg tells Go to represent the number with as few digits as + // possible. i.e. No unnecessary trailing zeroes. + return strconv.FormatFloat(float64(typedValue), 'f', -1, 32) + case float64: + // The `-1` arg tells Go to represent the number with as few digits as + // possible. i.e. No unnecessary trailing zeroes. + return strconv.FormatFloat(typedValue, 'f', -1, 64) + case int, int16, int32, int64, uint, uint16, uint32, uint64: + return fmt.Sprintf("%d", value) + case string: + return "'" + strings.ReplaceAll(typedValue, "'", "''") + "'" + default: + // unreachable as long as new types aren't added to the switch in `replacedNamed` below + panic("type not supported") + } +} + +func makePostgresArray[T any](values []T) string { + var sb strings.Builder + sb.WriteString("ARRAY[") + + for i, value := range values { + sb.WriteString(escapeSinglePostgresValue(value)) + + if i < len(values)-1 { + sb.WriteString(",") + } + } + + sb.WriteString("]") + return sb.String() +} + +// `database/sql` has an `sql.Named` system that should theoretically work for +// named parameters, but neither Pgx or lib/pq implement it, so just use dumb +// string replacement given we're only injecting a very basic value anyway. +func replaceNamed(query string, namedArgs map[string]any) (string, error) { + for name, value := range namedArgs { + var escapedValue string + + switch typedValue := value.(type) { + case bool, float32, float64, int, int16, int32, int64, string, uint, uint16, uint32, uint64: + escapedValue = escapeSinglePostgresValue(value) + + // This is pretty awkward, but typedValue reverts back to `any` if + // any of these conditions are combined together, and that prevents + // us from ranging over the slice. Technically only `[]string` is + // needed right now, but I included other slice types just so there + // isn't a surprise later on. + case []bool: + escapedValue = makePostgresArray(typedValue) + case []float32: + escapedValue = makePostgresArray(typedValue) + case []float64: + escapedValue = makePostgresArray(typedValue) + case []int: + escapedValue = makePostgresArray(typedValue) + case []int16: + escapedValue = makePostgresArray(typedValue) + case []int32: + escapedValue = makePostgresArray(typedValue) + case []int64: + escapedValue = makePostgresArray(typedValue) + case []string: + escapedValue = makePostgresArray(typedValue) + case []uint: + escapedValue = makePostgresArray(typedValue) + case []uint16: + escapedValue = makePostgresArray(typedValue) + case []uint32: + escapedValue = makePostgresArray(typedValue) + case []uint64: + escapedValue = makePostgresArray(typedValue) + default: + return "", fmt.Errorf("named query parameter @%s is not a supported type", name) + } + + newQuery := strings.Replace(query, "@"+name, escapedValue, 1) + if newQuery == query { + return "", fmt.Errorf("named query parameter @%s not found in query", name) + } + query = newQuery + } + + return query, nil +} + func (e *Executor) JobListFields() string { return "id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags" } diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver_test.go b/riverdriver/riverdatabasesql/river_database_sql_driver_test.go index 49d69a38..d6daf359 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver_test.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver_test.go @@ -40,3 +40,50 @@ func TestInterpretError(t *testing.T) { require.ErrorIs(t, interpretError(sql.ErrNoRows), rivertype.ErrNotFound) require.NoError(t, interpretError(nil)) } + +func TestReplaceNamed(t *testing.T) { + t.Parallel() + + testCases := []struct { + Desc string + ExpectedSQL string + InputSQL string + InputArgs map[string]any + }{ + {Desc: "Boolean", ExpectedSQL: "SELECT true", InputSQL: "SELECT @bool", InputArgs: map[string]any{"bool": true}}, + {Desc: "Float32", ExpectedSQL: "SELECT 1.23", InputSQL: "SELECT @float32", InputArgs: map[string]any{"float32": float32(1.23)}}, + {Desc: "Float64", ExpectedSQL: "SELECT 1.23", InputSQL: "SELECT @float64", InputArgs: map[string]any{"float64": float64(1.23)}}, + {Desc: "Int", ExpectedSQL: "SELECT 123", InputSQL: "SELECT @int", InputArgs: map[string]any{"int": 123}}, + {Desc: "Int16", ExpectedSQL: "SELECT 123", InputSQL: "SELECT @int16", InputArgs: map[string]any{"int16": int16(123)}}, + {Desc: "Int32", ExpectedSQL: "SELECT 123", InputSQL: "SELECT @int32", InputArgs: map[string]any{"int32": int32(123)}}, + {Desc: "Int64", ExpectedSQL: "SELECT 123", InputSQL: "SELECT @int64", InputArgs: map[string]any{"int64": int64(123)}}, + {Desc: "String", ExpectedSQL: "SELECT 'string value'", InputSQL: "SELECT @string", InputArgs: map[string]any{"string": "string value"}}, + {Desc: "StringWithQuote", ExpectedSQL: "SELECT 'string value with '' quote'", InputSQL: "SELECT @string", InputArgs: map[string]any{"string": "string value with ' quote"}}, + {Desc: "Uint", ExpectedSQL: "SELECT 123", InputSQL: "SELECT @uint", InputArgs: map[string]any{"uint": uint(123)}}, + {Desc: "Uint16", ExpectedSQL: "SELECT 123", InputSQL: "SELECT @uint16", InputArgs: map[string]any{"uint16": uint16(123)}}, + {Desc: "Uint32", ExpectedSQL: "SELECT 123", InputSQL: "SELECT @uint32", InputArgs: map[string]any{"uint32": uint32(123)}}, + {Desc: "Uint64", ExpectedSQL: "SELECT 123", InputSQL: "SELECT @uint64", InputArgs: map[string]any{"uint64": uint64(123)}}, + + {Desc: "SliceBoolean", ExpectedSQL: "SELECT ARRAY[false,true]", InputSQL: "SELECT @slice_bool", InputArgs: map[string]any{"slice_bool": []bool{false, true}}}, + {Desc: "SliceFloat32", ExpectedSQL: "SELECT ARRAY[1.23,1.24]", InputSQL: "SELECT @slice_float32", InputArgs: map[string]any{"slice_float32": []float32{1.23, 1.24}}}, + {Desc: "SliceFloat64", ExpectedSQL: "SELECT ARRAY[1.23,1.24]", InputSQL: "SELECT @slice_float64", InputArgs: map[string]any{"slice_float64": []float64{1.23, 1.24}}}, + {Desc: "SliceInt", ExpectedSQL: "SELECT ARRAY[123,124]", InputSQL: "SELECT @slice_int", InputArgs: map[string]any{"slice_int": []int{123, 124}}}, + {Desc: "SliceInt16", ExpectedSQL: "SELECT ARRAY[123,124]", InputSQL: "SELECT @slice_int16", InputArgs: map[string]any{"slice_int16": []int16{123, 124}}}, + {Desc: "SliceInt32", ExpectedSQL: "SELECT ARRAY[123,124]", InputSQL: "SELECT @slice_int32", InputArgs: map[string]any{"slice_int32": []int32{123, 124}}}, + {Desc: "SliceInt64", ExpectedSQL: "SELECT ARRAY[123,124]", InputSQL: "SELECT @slice_int64", InputArgs: map[string]any{"slice_int64": []int64{123, 124}}}, + {Desc: "SliceString", ExpectedSQL: "SELECT ARRAY['string 1','string 2']", InputSQL: "SELECT @slice_string", InputArgs: map[string]any{"slice_string": []string{"string 1", "string 2"}}}, + {Desc: "SliceUint", ExpectedSQL: "SELECT ARRAY[123,124]", InputSQL: "SELECT @slice_uint", InputArgs: map[string]any{"slice_uint": []uint{123, 124}}}, + {Desc: "SliceUint16", ExpectedSQL: "SELECT ARRAY[123,124]", InputSQL: "SELECT @slice_uint16", InputArgs: map[string]any{"slice_uint16": []uint16{123, 124}}}, + {Desc: "SliceUint32", ExpectedSQL: "SELECT ARRAY[123,124]", InputSQL: "SELECT @slice_uint32", InputArgs: map[string]any{"slice_uint32": []uint32{123, 124}}}, + {Desc: "SliceUint64", ExpectedSQL: "SELECT ARRAY[123,124]", InputSQL: "SELECT @slice_uint64", InputArgs: map[string]any{"slice_uint64": []uint64{123, 124}}}, + } + for _, tt := range testCases { + t.Run(tt.Desc, func(t *testing.T) { + t.Parallel() + + actualSQL, err := replaceNamed(tt.InputSQL, tt.InputArgs) + require.NoError(t, err) + require.Equal(t, tt.ExpectedSQL, actualSQL) + }) + } +}