Skip to content

Commit

Permalink
add interface DB.CheckLocalTypeForField for package gdb (#2059)
Browse files Browse the repository at this point in the history
  • Loading branch information
gqcn authored Aug 11, 2022
1 parent 95888e0 commit e4c8cfc
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 195 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cli.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ github.ref }}
release_name: GoFrame CLI Release ${{ github.ref }}
release_name: GoFrame Release ${{ github.ref }}
draft: false
prerelease: false

Expand Down
1 change: 1 addition & 0 deletions cmd/gf/internal/cmd/gendao/gendao_do.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func generateDo(ctx context.Context, db gdb.DB, tableNames, newTableNames []stri
doFilePath = gfile.Join(doDirPath, gstr.CaseSnake(newTableName)+".go")
structDefinition = generateStructDefinition(ctx, generateStructDefinitionInput{
CGenDaoInternalInput: in,
DB: db,
StructName: gstr.CaseCamel(newTableName),
FieldMap: fieldMap,
IsDo: true,
Expand Down
1 change: 1 addition & 0 deletions cmd/gf/internal/cmd/gendao/gendao_entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func generateEntity(ctx context.Context, db gdb.DB, tableNames, newTableNames []
gstr.CaseCamel(newTableName),
generateStructDefinition(ctx, generateStructDefinitionInput{
CGenDaoInternalInput: in,
DB: db,
StructName: gstr.CaseCamel(newTableName),
FieldMap: fieldMap,
IsDo: false,
Expand Down
20 changes: 6 additions & 14 deletions cmd/gf/internal/cmd/gendao/gendao_structure.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,12 @@ import (

type generateStructDefinitionInput struct {
CGenDaoInternalInput
DB gdb.DB // Current DB.
StructName string // Struct name.
FieldMap map[string]*gdb.TableField // Table field map.
IsDo bool // Is generating DTO struct.
}

const (
typeDate = "date"
typeDatetime = "datetime"
typeInt64Bytes = "int64-bytes"
typeUint64Bytes = "uint64-bytes"
typeJson = "json"
typeJsonb = "jsonb"
)

func generateStructDefinition(ctx context.Context, in generateStructDefinitionInput) string {
buffer := bytes.NewBuffer(nil)
array := make([][]string, len(in.FieldMap))
Expand Down Expand Up @@ -67,26 +59,26 @@ func generateStructFieldDefinition(
typeName string
jsonTag = getJsonTagFromCase(field.Name, in.JsonCase)
)
typeName, err = gdb.CheckValueForLocalType(ctx, field.Type, nil)
typeName, err = in.DB.CheckLocalTypeForField(ctx, field.Type, nil)
if err != nil {
panic(err)
}
switch typeName {
case typeDate, typeDatetime:
case gdb.LocalTypeDate, gdb.LocalTypeDatetime:
if in.StdTime {
typeName = "time.Time"
} else {
typeName = "*gtime.Time"
}

case typeInt64Bytes:
case gdb.LocalTypeInt64Bytes:
typeName = "int64"

case typeUint64Bytes:
case gdb.LocalTypeUint64Bytes:
typeName = "uint64"

// Special type handle.
case typeJson, typeJsonb:
case gdb.LocalTypeJson, gdb.LocalTypeJsonb:
if in.GJsonSupport {
typeName = "*gjson.Json"
} else {
Expand Down
37 changes: 37 additions & 0 deletions contrib/drivers/pgsql/pgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,43 @@ func (d *Driver) GetChars() (charLeft string, charRight string) {
return `"`, `"`
}

// CheckLocalTypeForValue checks and returns corresponding local golang type for given db type.
func (d *Driver) CheckLocalTypeForValue(ctx context.Context, fieldType string, fieldValue interface{}) (string, error) {
var typeName string
match, _ := gregex.MatchString(`(.+?)\((.+)\)`, fieldType)
if len(match) == 3 {
typeName = gstr.Trim(match[1])
} else {
typeName = fieldType
}
typeName = strings.ToLower(typeName)
switch typeName {
case
// For pgsql, int2 = smallint.
"int2",
// For pgsql, int4 = integer
"int4":
return gdb.LocalTypeInt, nil

case
// For pgsql, int8 = bigint
"int8":
return gdb.LocalTypeInt64, nil

case
"_int2",
"_int4":
return gdb.LocalTypeIntSlice, nil

case
"_int8":
return gdb.LocalTypeInt64Slice, nil

default:
return d.Core.CheckLocalTypeForField(ctx, fieldType, fieldValue)
}
}

// ConvertValueForLocal converts value to local Golang type of value according field type name from database.
// The parameter `fieldType` is in lower case, like:
// `float(5,2)`, `unsigned double(5,2)`, `decimal(10,2)`, `char(45)`, `varchar(100)`, etc.
Expand Down
22 changes: 22 additions & 0 deletions database/gdb/gdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ type DB interface {
TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields.
ConvertDataForRecord(ctx context.Context, data interface{}) (map[string]interface{}, error) // See Core.ConvertDataForRecord
ConvertValueForLocal(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForLocal
CheckLocalTypeForField(ctx context.Context, fieldType string, fieldValue interface{}) (string, error) // See Core.CheckLocalTypeForField
FilteredLink() string // FilteredLink is used for filtering sensitive information in `Link` configuration before output it to tracing server.
}

Expand Down Expand Up @@ -312,6 +313,27 @@ const (
SqlTypeStmtQueryRowContext = "DB.Statement.QueryRowContext"
)

const (
LocalTypeString = "string"
LocalTypeDate = "date"
LocalTypeDatetime = "datetime"
LocalTypeInt = "int"
LocalTypeUint = "uint"
LocalTypeInt64 = "int64"
LocalTypeUint64 = "uint64"
LocalTypeIntSlice = "[]int"
LocalTypeInt64Slice = "[]int64"
LocalTypeUint64Slice = "[]uint64"
LocalTypeInt64Bytes = "int64-bytes"
LocalTypeUint64Bytes = "uint64-bytes"
LocalTypeFloat32 = "float32"
LocalTypeFloat64 = "float64"
LocalTypeBytes = "[]byte"
LocalTypeBool = "bool"
LocalTypeJson = "json"
LocalTypeJsonb = "jsonb"
)

var (
// instances is the management map for instances.
instances = gmap.NewStrAnyMap(true)
Expand Down
155 changes: 142 additions & 13 deletions database/gdb/gdb_core_structure.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/internal/json"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
)
Expand Down Expand Up @@ -119,6 +121,133 @@ func (c *Core) ConvertDataForRecordValue(ctx context.Context, value interface{})
return convertedValue, nil
}

// CheckLocalTypeForField checks and returns corresponding type for given db type.
func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, fieldValue interface{}) (string, error) {
var (
typeName string
typePattern string
)
match, _ := gregex.MatchString(`(.+?)\((.+)\)`, fieldType)
if len(match) == 3 {
typeName = gstr.Trim(match[1])
typePattern = gstr.Trim(match[2])
} else {
typeName = fieldType
}
typeName = strings.ToLower(typeName)
switch typeName {
case
"binary",
"varbinary",
"blob",
"tinyblob",
"mediumblob",
"longblob":
return LocalTypeBytes, nil

case
"int",
"tinyint",
"small_int",
"smallint",
"medium_int",
"mediumint",
"serial":
if gstr.ContainsI(fieldType, "unsigned") {
return LocalTypeUint, nil
}
return LocalTypeInt, nil

case
"big_int",
"bigint",
"bigserial":
if gstr.ContainsI(fieldType, "unsigned") {
return LocalTypeUint64, nil
}
return LocalTypeInt64, nil

case
"real":
return LocalTypeFloat32, nil

case
"float",
"double",
"decimal",
"money",
"numeric",
"smallmoney":
return LocalTypeFloat64, nil

case
"bit":
// It is suggested using bit(1) as boolean.
if typePattern == "1" {
return LocalTypeBool, nil
}
s := gconv.String(fieldValue)
// mssql is true|false string.
if strings.EqualFold(s, "true") || strings.EqualFold(s, "false") {
return LocalTypeBool, nil
}
if gstr.ContainsI(fieldType, "unsigned") {
return LocalTypeUint64Bytes, nil
}
return LocalTypeInt64Bytes, nil

case
"bool":
return LocalTypeBool, nil

case
"date":
return LocalTypeDate, nil

case
"datetime",
"timestamp",
"timestamptz":
return LocalTypeDatetime, nil

case
"json":
return LocalTypeJson, nil

case
"jsonb":
return LocalTypeJsonb, nil

default:
// Auto-detect field type, using key match.
switch {
case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || strings.Contains(typeName, "character"):
return LocalTypeString, nil

case strings.Contains(typeName, "float") || strings.Contains(typeName, "double") || strings.Contains(typeName, "numeric"):
return LocalTypeFloat64, nil

case strings.Contains(typeName, "bool"):
return LocalTypeBool, nil

case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob"):
return LocalTypeBytes, nil

case strings.Contains(typeName, "int"):
return LocalTypeInt, nil

case strings.Contains(typeName, "time"):
return LocalTypeDatetime, nil

case strings.Contains(typeName, "date"):
return LocalTypeDatetime, nil

default:
return LocalTypeString, nil
}
}
}

// ConvertValueForLocal converts value to local Golang type of value according field type name from database.
// The parameter `fieldType` is in lower case, like:
// `float(5,2)`, `unsigned double(5,2)`, `decimal(10,2)`, `char(45)`, `varchar(100)`, etc.
Expand All @@ -128,42 +257,42 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field
if fieldType == "" {
return fieldValue, nil
}
typeName, err := CheckValueForLocalType(ctx, fieldType, fieldValue)
typeName, err := c.db.CheckLocalTypeForField(ctx, fieldType, fieldValue)
if err != nil {
return nil, err
}
switch typeName {
case typeBytes:
case LocalTypeBytes:
if strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob") {
return fieldValue, nil
}
return gconv.Bytes(fieldValue), nil

case typeInt:
case LocalTypeInt:
return gconv.Int(gconv.String(fieldValue)), nil

case typeUint:
case LocalTypeUint:
return gconv.Uint(gconv.String(fieldValue)), nil

case typeInt64:
case LocalTypeInt64:
return gconv.Int64(gconv.String(fieldValue)), nil

case typeUint64:
case LocalTypeUint64:
return gconv.Uint64(gconv.String(fieldValue)), nil

case typeInt64Bytes:
case LocalTypeInt64Bytes:
return gbinary.BeDecodeToInt64(gconv.Bytes(fieldValue)), nil

case typeUint64Bytes:
case LocalTypeUint64Bytes:
return gbinary.BeDecodeToUint64(gconv.Bytes(fieldValue)), nil

case typeFloat32:
case LocalTypeFloat32:
return gconv.Float32(gconv.String(fieldValue)), nil

case typeFloat64:
case LocalTypeFloat64:
return gconv.Float64(gconv.String(fieldValue)), nil

case typeBool:
case LocalTypeBool:
s := gconv.String(fieldValue)
// mssql is true|false string.
if strings.EqualFold(s, "true") {
Expand All @@ -174,15 +303,15 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field
}
return gconv.Bool(fieldValue), nil

case typeDate:
case LocalTypeDate:
// Date without time.
if t, ok := fieldValue.(time.Time); ok {
return gtime.NewFromTime(t).Format("Y-m-d"), nil
}
t, _ := gtime.StrToTime(gconv.String(fieldValue))
return t.Format("Y-m-d"), nil

case typeDatetime:
case LocalTypeDatetime:
if t, ok := fieldValue.(time.Time); ok {
return gtime.NewFromTime(t), nil
}
Expand Down
Loading

0 comments on commit e4c8cfc

Please sign in to comment.