From fc095356ebc94d7ed473683ce8ff4c0bcdfb7a9d Mon Sep 17 00:00:00 2001 From: shixuan Date: Mon, 2 Dec 2024 16:47:05 +0800 Subject: [PATCH 1/2] feat(contrib/drivers/pgsql): scan and gf gen dao support pg array (varchar[] and text[]) to Go []string --- contrib/drivers/pgsql/pgsql_convert.go | 13 +++++++++++++ database/gdb/gdb.go | 1 + 2 files changed, 14 insertions(+) diff --git a/contrib/drivers/pgsql/pgsql_convert.go b/contrib/drivers/pgsql/pgsql_convert.go index 71657669b24..8ccd6e54fe0 100644 --- a/contrib/drivers/pgsql/pgsql_convert.go +++ b/contrib/drivers/pgsql/pgsql_convert.go @@ -8,6 +8,7 @@ package pgsql import ( "context" + "github.com/lib/pq" "reflect" "strings" @@ -72,6 +73,10 @@ func (d *Driver) CheckLocalTypeForField(ctx context.Context, fieldType string, f "_int8": return gdb.LocalTypeInt64Slice, nil + case + "_varchar", "_text": + return gdb.LocalTypeStringSlice, nil + default: return d.Core.CheckLocalTypeForField(ctx, fieldType, fieldValue) } @@ -116,6 +121,14 @@ func (d *Driver) ConvertValueForLocal(ctx context.Context, fieldType string, fie ), ), nil + // String slice. + case "_varchar", "_text": + var result pq.StringArray + if err := result.Scan(fieldValue); err != nil { + return nil, err + } + return result, nil + default: return d.Core.ConvertValueForLocal(ctx, fieldType, fieldValue) } diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 74d76e5c8a1..f9851dbaca4 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -458,6 +458,7 @@ const ( LocalTypeIntSlice LocalType = "[]int" LocalTypeInt64Slice LocalType = "[]int64" LocalTypeUint64Slice LocalType = "[]uint64" + LocalTypeStringSlice LocalType = "[]string" LocalTypeInt64Bytes LocalType = "int64-bytes" LocalTypeUint64Bytes LocalType = "uint64-bytes" LocalTypeFloat32 LocalType = "float32" From baad29c1cc1774762a4a0e8ecc94a18f7175efb3 Mon Sep 17 00:00:00 2001 From: shixuan Date: Thu, 5 Dec 2024 09:56:57 +0800 Subject: [PATCH 2/2] fixed: add unit test --- .../drivers/pgsql/pgsql_z_unit_init_test.go | 2 + .../drivers/pgsql/pgsql_z_unit_model_test.go | 59 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/contrib/drivers/pgsql/pgsql_z_unit_init_test.go b/contrib/drivers/pgsql/pgsql_z_unit_init_test.go index c2033e30125..023efdee1c4 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_init_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_init_test.go @@ -83,6 +83,8 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) { password varchar(32) NOT NULL, nickname varchar(45) NOT NULL, create_time timestamp NOT NULL, + favorite_movie varchar[], + favorite_music text[], PRIMARY KEY (id) ) ;`, name, )); err != nil { diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index 634de6dc821..d7748f07f17 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -611,3 +611,62 @@ func Test_OrderRandom(t *testing.T) { t.Assert(len(result), TableSize) }) } + +func Test_ConvertSliceString(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime *gtime.Time + FavoriteMovie []string + FavoriteMusic []string + } + + var ( + user User + user2 User + err error + ) + + // slice string not null + _, err = db.Model(table).Data(g.Map{ + "id": 1, + "passport": "p1", + "password": "pw1", + "nickname": "n1", + "create_time": CreateTime, + "favorite_movie": g.Slice{"Iron-Man", "Spider-Man"}, + "favorite_music": g.Slice{"Hey jude", "Let it be"}, + }).Insert() + t.AssertNil(err) + + err = db.Model(table).Where("id", 1).Scan(&user) + t.AssertNil(err) + t.Assert(len(user.FavoriteMusic), 2) + t.Assert(user.FavoriteMusic[0], "Hey jude") + t.Assert(user.FavoriteMusic[1], "Let it be") + t.Assert(len(user.FavoriteMovie), 2) + t.Assert(user.FavoriteMovie[0], "Iron-Man") + t.Assert(user.FavoriteMovie[1], "Spider-Man") + + // slice string null + _, err = db.Model(table).Data(g.Map{ + "id": 2, + "passport": "p1", + "password": "pw1", + "nickname": "n1", + "create_time": CreateTime, + }).Insert() + t.AssertNil(err) + + err = db.Model(table).Where("id", 2).Scan(&user2) + t.AssertNil(err) + t.Assert(user2.FavoriteMusic, nil) + t.Assert(len(user2.FavoriteMovie), 0) + }) +}