diff --git a/all_test.go b/all_test.go index 25222f9..a6cc2b7 100644 --- a/all_test.go +++ b/all_test.go @@ -3504,69 +3504,117 @@ COMMIT; } func TestSelectDummy(t *testing.T) { - RegisterMemDriver() - db, err := sql.Open("ql-mem", "") + db, err := OpenMem() if err != nil { t.Fatal(err) } defer db.Close() - //int - var i int - err = db.QueryRow("select 1").Scan(&i) - if err != nil { - t.Fatal(err) - } - if i != 1 { - t.Fatalf("expected 1 got %d", i) - } - - i = 0 - err = db.QueryRow("select $1", 1).Scan(&i) - if err != nil { - t.Fatal(err) - } - if i != 1 { - t.Fatalf("expected 1 got %d", i) - } - - //float - var f float64 - err = db.QueryRow("select 1.5").Scan(&f) - if err != nil { - t.Fatal(err) + sample := []struct { + src string + exp []interface{} + }{ + {"select 10", []interface{}{10}}, + {"select 10,20", []interface{}{10, 20}}, } - if f != 1.5 { - t.Fatalf("expected 1.0 got %f", f) + for _, s := range sample { + rst, _, err := db.run(nil, s.src) + if err != nil { + t.Fatal(err) + } + for _, rs := range rst { + d, err := rs.FirstRow() + if err != nil { + t.Fatal(err) + } + for k, val := range d { + if int(val.(idealInt)) != s.exp[k].(int) { + t.Errorf("expected %v got %v", s.exp[k], val) + } + } + } } - f = 0.0 - err = db.QueryRow("select $1", 1.5).Scan(&f) - if err != nil { - t.Fatal(err) + // //float + sample = []struct { + src string + exp []interface{} + }{ + {"select 1.5", []interface{}{1.5}}, + {"select 1.5,2.5", []interface{}{1.5, 2.5}}, } - if f != 1.5 { - t.Fatalf("expected 1.5 got %f", f) + for _, s := range sample { + rst, _, err := db.run(nil, s.src) + if err != nil { + t.Fatal(err) + } + for _, rs := range rst { + d, err := rs.FirstRow() + if err != nil { + t.Fatal(err) + } + for k, val := range d { + if float64(val.(idealFloat)) != s.exp[k].(float64) { + t.Errorf("expected %v got %v", s.exp[k], val) + } + } + } } - //string - var s string - msg := "foo" - err = db.QueryRow(`select "foo"`).Scan(&s) - if err != nil { - t.Fatal(err) + // //string + sample = []struct { + src string + exp []interface{} + }{ + {`select "foo"`, []interface{}{"foo"}}, + {`select "foo","bar"`, []interface{}{"foo", "bar"}}, } - if s != msg { - t.Fatalf("expected %s got %s", msg, s) + for _, s := range sample { + rst, _, err := db.run(nil, s.src) + if err != nil { + t.Fatal(err) + } + for _, rs := range rst { + d, err := rs.FirstRow() + if err != nil { + t.Fatal(err) + } + for k, val := range d { + if val.(string) != s.exp[k].(string) { + t.Errorf("expected %v got %v", s.exp[k], val) + } + } + } } - s = "" - err = db.QueryRow("select $1", msg).Scan(&s) - if err != nil { - t.Fatal(err) + sample = []struct { + src string + exp []interface{} + }{ + {`select "foo",now()`, []interface{}{"foo"}}, } - if s != msg { - t.Fatalf("expected %s got %s", msg, s) + for _, s := range sample { + rst, _, err := db.run(nil, s.src) + if err != nil { + t.Fatal(err) + } + for _, rs := range rst { + d, err := rs.FirstRow() + if err != nil { + t.Fatal(err) + } + for k, val := range d { + if k == 1 { + if _, ok := val.(time.Time); !ok { + t.Fatal("expected time object") + } + continue + } + if val.(string) != s.exp[k].(string) { + t.Errorf("expected %v got %v", s.exp[k], val) + } + } + } } } diff --git a/plan.go b/plan.go index 2388b25..4b435fd 100644 --- a/plan.go +++ b/plan.go @@ -2801,7 +2801,7 @@ func (r *fullJoinDefaultPlan) do(ctx *execCtx, f func(id interface{}, data []int } type selectDummyPlan struct { - fields []interface{} + flds []*fld } func (r *selectDummyPlan) hasID() bool { return true } @@ -2810,13 +2810,22 @@ func (r *selectDummyPlan) explain(w strutil.Formatter) { w.Format("┌Selects values from dummy table\n└Output field names %v\n", qnames(r.fieldNames())) } -func (r *selectDummyPlan) fieldNames() []string { return make([]string, len(r.fields)) } +func (r *selectDummyPlan) fieldNames() []string { return make([]string, len(r.flds)) } func (r *selectDummyPlan) filter(expr expression) (plan, []string, error) { return nil, nil, nil } func (r *selectDummyPlan) do(ctx *execCtx, f func(id interface{}, data []interface{}) (bool, error)) (err error) { - _, err = f(nil, r.fields) + m := map[interface{}]interface{}{} + data := []interface{}{} + for _, v := range r.flds { + rst, err := v.expr.eval(ctx, m) + if err != nil { + return err + } + data = append(data, rst) + } + _, err = f(nil, data) return } diff --git a/stmt.go b/stmt.go index 44ed6a2..044e43e 100644 --- a/stmt.go +++ b/stmt.go @@ -803,13 +803,7 @@ func (s *selectStmt) plan(ctx *execCtx) (plan, error) { //LATER overlapping goro } } if r == nil { - var fds []interface{} - for _, v := range s.flds { - if val, ok := v.expr.(value); ok { - fds = append(fds, val) - } - } - r = &selectDummyPlan{fields: fds} + r = &selectDummyPlan{flds: s.flds} } if w := s.where; w != nil { if r, err = (&whereRset{expr: w.expr, src: r, sel: w.sel, exists: w.exists}).plan(ctx); err != nil {