Skip to content

Commit d726827

Browse files
committedNov 30, 2020
feature: inbuilt support for scanner and valuer in pq.Array for int32/float32/[]byte slices
1 parent 11a44e2 commit d726827

File tree

3 files changed

+482
-0
lines changed

3 files changed

+482
-0
lines changed
 

‎.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
*.test
33
*~
44
*.swp
5+
.idea
6+
.vscode

‎array.go

+139
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,31 @@ func Array(a interface{}) interface {
3535
return (*BoolArray)(&a)
3636
case []float64:
3737
return (*Float64Array)(&a)
38+
case []float32:
39+
return (*Float32Array)(&a)
3840
case []int64:
3941
return (*Int64Array)(&a)
42+
case []int32:
43+
return (*Int32Array)(&a)
4044
case []string:
4145
return (*StringArray)(&a)
46+
case [][]byte:
47+
return (*ByteaArray)(&a)
4248

4349
case *[]bool:
4450
return (*BoolArray)(a)
4551
case *[]float64:
4652
return (*Float64Array)(a)
53+
case *[]float32:
54+
return (*Float32Array)(a)
4755
case *[]int64:
4856
return (*Int64Array)(a)
57+
case *[]int32:
58+
return (*Int32Array)(a)
4959
case *[]string:
5060
return (*StringArray)(a)
61+
case *[][]byte:
62+
return (*ByteaArray)(a)
5163
}
5264

5365
return GenericArray{a}
@@ -267,6 +279,70 @@ func (a Float64Array) Value() (driver.Value, error) {
267279
return "{}", nil
268280
}
269281

282+
// Float32Array represents a one-dimensional array of the PostgreSQL double
283+
// precision type.
284+
type Float32Array []float32
285+
286+
// Scan implements the sql.Scanner interface.
287+
func (a *Float32Array) Scan(src interface{}) error {
288+
switch src := src.(type) {
289+
case []byte:
290+
return a.scanBytes(src)
291+
case string:
292+
return a.scanBytes([]byte(src))
293+
case nil:
294+
*a = nil
295+
return nil
296+
}
297+
298+
return fmt.Errorf("pq: cannot convert %T to Float32Array", src)
299+
}
300+
301+
func (a *Float32Array) scanBytes(src []byte) error {
302+
elems, err := scanLinearArray(src, []byte{','}, "Float32Array")
303+
if err != nil {
304+
return err
305+
}
306+
if *a != nil && len(elems) == 0 {
307+
*a = (*a)[:0]
308+
} else {
309+
b := make(Float32Array, len(elems))
310+
for i, v := range elems {
311+
var x float64
312+
if x, err = strconv.ParseFloat(string(v), 32); err != nil {
313+
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
314+
}
315+
b[i] = float32(x)
316+
}
317+
*a = b
318+
}
319+
return nil
320+
}
321+
322+
// Value implements the driver.Valuer interface.
323+
func (a Float32Array) Value() (driver.Value, error) {
324+
if a == nil {
325+
return nil, nil
326+
}
327+
328+
if n := len(a); n > 0 {
329+
// There will be at least two curly brackets, N bytes of values,
330+
// and N-1 bytes of delimiters.
331+
b := make([]byte, 1, 1+2*n)
332+
b[0] = '{'
333+
334+
b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32)
335+
for i := 1; i < n; i++ {
336+
b = append(b, ',')
337+
b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32)
338+
}
339+
340+
return string(append(b, '}')), nil
341+
}
342+
343+
return "{}", nil
344+
}
345+
270346
// GenericArray implements the driver.Valuer and sql.Scanner interfaces for
271347
// an array or slice of any dimension.
272348
type GenericArray struct{ A interface{} }
@@ -483,6 +559,69 @@ func (a Int64Array) Value() (driver.Value, error) {
483559
return "{}", nil
484560
}
485561

562+
// Int32Array represents a one-dimensional array of the PostgreSQL integer types.
563+
type Int32Array []int32
564+
565+
// Scan implements the sql.Scanner interface.
566+
func (a *Int32Array) Scan(src interface{}) error {
567+
switch src := src.(type) {
568+
case []byte:
569+
return a.scanBytes(src)
570+
case string:
571+
return a.scanBytes([]byte(src))
572+
case nil:
573+
*a = nil
574+
return nil
575+
}
576+
577+
return fmt.Errorf("pq: cannot convert %T to Int32Array", src)
578+
}
579+
580+
func (a *Int32Array) scanBytes(src []byte) error {
581+
elems, err := scanLinearArray(src, []byte{','}, "Int32Array")
582+
if err != nil {
583+
return err
584+
}
585+
if *a != nil && len(elems) == 0 {
586+
*a = (*a)[:0]
587+
} else {
588+
b := make(Int32Array, len(elems))
589+
for i, v := range elems {
590+
var x int
591+
if x, err = strconv.Atoi(string(v)); err != nil {
592+
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
593+
}
594+
b[i] = int32(x)
595+
}
596+
*a = b
597+
}
598+
return nil
599+
}
600+
601+
// Value implements the driver.Valuer interface.
602+
func (a Int32Array) Value() (driver.Value, error) {
603+
if a == nil {
604+
return nil, nil
605+
}
606+
607+
if n := len(a); n > 0 {
608+
// There will be at least two curly brackets, N bytes of values,
609+
// and N-1 bytes of delimiters.
610+
b := make([]byte, 1, 1+2*n)
611+
b[0] = '{'
612+
613+
b = strconv.AppendInt(b, int64(a[0]), 10)
614+
for i := 1; i < n; i++ {
615+
b = append(b, ',')
616+
b = strconv.AppendInt(b, int64(a[i]), 10)
617+
}
618+
619+
return string(append(b, '}')), nil
620+
}
621+
622+
return "{}", nil
623+
}
624+
486625
// StringArray represents a one-dimensional array of the PostgreSQL character types.
487626
type StringArray []string
488627

‎array_test.go

+341
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,33 @@ func TestArrayScanner(t *testing.T) {
104104
t.Errorf("Expected *Int64Array, got %T", s)
105105
}
106106

107+
s = Array(&[]float32{})
108+
if _, ok := s.(*Float32Array); !ok {
109+
t.Errorf("Expected *Float32Array, got %T", s)
110+
}
111+
112+
s = Array(&[]int32{})
113+
if _, ok := s.(*Int32Array); !ok {
114+
t.Errorf("Expected *Int32Array, got %T", s)
115+
}
116+
107117
s = Array(&[]string{})
108118
if _, ok := s.(*StringArray); !ok {
109119
t.Errorf("Expected *StringArray, got %T", s)
110120
}
111121

122+
s = Array(&[][]byte{})
123+
if _, ok := s.(*ByteaArray); !ok {
124+
t.Errorf("Expected *ByteaArray, got %T", s)
125+
}
126+
112127
for _, tt := range []interface{}{
113128
&[]sql.Scanner{},
114129
&[][]bool{},
115130
&[][]float64{},
116131
&[][]int64{},
132+
&[][]float32{},
133+
&[][]int32{},
117134
&[][]string{},
118135
} {
119136
s = Array(tt)
@@ -139,17 +156,34 @@ func TestArrayValuer(t *testing.T) {
139156
t.Errorf("Expected *Int64Array, got %T", v)
140157
}
141158

159+
v = Array([]float32{})
160+
if _, ok := v.(*Float32Array); !ok {
161+
t.Errorf("Expected *Float32Array, got %T", v)
162+
}
163+
164+
v = Array([]int32{})
165+
if _, ok := v.(*Int32Array); !ok {
166+
t.Errorf("Expected *Int32Array, got %T", v)
167+
}
168+
142169
v = Array([]string{})
143170
if _, ok := v.(*StringArray); !ok {
144171
t.Errorf("Expected *StringArray, got %T", v)
145172
}
146173

174+
v = Array([][]byte{})
175+
if _, ok := v.(*ByteaArray); !ok {
176+
t.Errorf("Expected *ByteaArray, got %T", v)
177+
}
178+
147179
for _, tt := range []interface{}{
148180
nil,
149181
[]driver.Value{},
150182
[][]bool{},
151183
[][]float64{},
152184
[][]int64{},
185+
[][]float32{},
186+
[][]int32{},
153187
[][]string{},
154188
} {
155189
v = Array(tt)
@@ -773,6 +807,313 @@ func BenchmarkInt64ArrayValue(b *testing.B) {
773807
}
774808
}
775809

810+
func TestFloat32ArrayScanUnsupported(t *testing.T) {
811+
var arr Float32Array
812+
err := arr.Scan(true)
813+
814+
if err == nil {
815+
t.Fatal("Expected error when scanning from bool")
816+
}
817+
if !strings.Contains(err.Error(), "bool to Float32Array") {
818+
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
819+
}
820+
}
821+
822+
func TestFloat32ArrayScanEmpty(t *testing.T) {
823+
var arr Float32Array
824+
err := arr.Scan(`{}`)
825+
826+
if err != nil {
827+
t.Fatalf("Expected no error, got %v", err)
828+
}
829+
if arr == nil || len(arr) != 0 {
830+
t.Errorf("Expected empty, got %#v", arr)
831+
}
832+
}
833+
834+
func TestFloat32ArrayScanNil(t *testing.T) {
835+
arr := Float32Array{5, 5, 5}
836+
err := arr.Scan(nil)
837+
838+
if err != nil {
839+
t.Fatalf("Expected no error, got %v", err)
840+
}
841+
if arr != nil {
842+
t.Errorf("Expected nil, got %+v", arr)
843+
}
844+
}
845+
846+
var Float32ArrayStringTests = []struct {
847+
str string
848+
arr Float32Array
849+
}{
850+
{`{}`, Float32Array{}},
851+
{`{1.2}`, Float32Array{1.2}},
852+
{`{3.456,7.89}`, Float32Array{3.456, 7.89}},
853+
{`{3,1,2}`, Float32Array{3, 1, 2}},
854+
}
855+
856+
func TestFloat32ArrayScanBytes(t *testing.T) {
857+
for _, tt := range Float32ArrayStringTests {
858+
bytes := []byte(tt.str)
859+
arr := Float32Array{5, 5, 5}
860+
err := arr.Scan(bytes)
861+
862+
if err != nil {
863+
t.Fatalf("Expected no error for %q, got %v", bytes, err)
864+
}
865+
if !reflect.DeepEqual(arr, tt.arr) {
866+
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
867+
}
868+
}
869+
}
870+
871+
func BenchmarkFloat32ArrayScanBytes(b *testing.B) {
872+
var a Float32Array
873+
var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`)
874+
875+
for i := 0; i < b.N; i++ {
876+
a = Float32Array{}
877+
a.Scan(x)
878+
}
879+
}
880+
881+
func TestFloat32ArrayScanString(t *testing.T) {
882+
for _, tt := range Float32ArrayStringTests {
883+
arr := Float32Array{5, 5, 5}
884+
err := arr.Scan(tt.str)
885+
886+
if err != nil {
887+
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
888+
}
889+
if !reflect.DeepEqual(arr, tt.arr) {
890+
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
891+
}
892+
}
893+
}
894+
895+
func TestFloat32ArrayScanError(t *testing.T) {
896+
for _, tt := range []struct {
897+
input, err string
898+
}{
899+
{``, "unable to parse array"},
900+
{`{`, "unable to parse array"},
901+
{`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float32Array"},
902+
{`{NULL}`, "parsing array element index 0:"},
903+
{`{a}`, "parsing array element index 0:"},
904+
{`{5.6,a}`, "parsing array element index 1:"},
905+
{`{5.6,7.8,a}`, "parsing array element index 2:"},
906+
} {
907+
arr := Float32Array{5, 5, 5}
908+
err := arr.Scan(tt.input)
909+
910+
if err == nil {
911+
t.Fatalf("Expected error for %q, got none", tt.input)
912+
}
913+
if !strings.Contains(err.Error(), tt.err) {
914+
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
915+
}
916+
if !reflect.DeepEqual(arr, Float32Array{5, 5, 5}) {
917+
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
918+
}
919+
}
920+
}
921+
922+
func TestFloat32ArrayValue(t *testing.T) {
923+
result, err := Float32Array(nil).Value()
924+
925+
if err != nil {
926+
t.Fatalf("Expected no error for nil, got %v", err)
927+
}
928+
if result != nil {
929+
t.Errorf("Expected nil, got %q", result)
930+
}
931+
932+
result, err = Float32Array([]float32{}).Value()
933+
934+
if err != nil {
935+
t.Fatalf("Expected no error for empty, got %v", err)
936+
}
937+
if expected := `{}`; !reflect.DeepEqual(result, expected) {
938+
t.Errorf("Expected empty, got %q", result)
939+
}
940+
941+
result, err = Float32Array([]float32{1.2, 3.4, 5.6}).Value()
942+
943+
if err != nil {
944+
t.Fatalf("Expected no error, got %v", err)
945+
}
946+
if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) {
947+
t.Errorf("Expected %q, got %q", expected, result)
948+
}
949+
}
950+
951+
func BenchmarkFloat32ArrayValue(b *testing.B) {
952+
rand.Seed(1)
953+
x := make([]float32, 10)
954+
for i := 0; i < len(x); i++ {
955+
x[i] = rand.Float32()
956+
}
957+
a := Float32Array(x)
958+
959+
for i := 0; i < b.N; i++ {
960+
a.Value()
961+
}
962+
}
963+
964+
func TestInt32ArrayScanUnsupported(t *testing.T) {
965+
var arr Int32Array
966+
err := arr.Scan(true)
967+
968+
if err == nil {
969+
t.Fatal("Expected error when scanning from bool")
970+
}
971+
if !strings.Contains(err.Error(), "bool to Int32Array") {
972+
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
973+
}
974+
}
975+
976+
func TestInt32ArrayScanEmpty(t *testing.T) {
977+
var arr Int32Array
978+
err := arr.Scan(`{}`)
979+
980+
if err != nil {
981+
t.Fatalf("Expected no error, got %v", err)
982+
}
983+
if arr == nil || len(arr) != 0 {
984+
t.Errorf("Expected empty, got %#v", arr)
985+
}
986+
}
987+
988+
func TestInt32ArrayScanNil(t *testing.T) {
989+
arr := Int32Array{5, 5, 5}
990+
err := arr.Scan(nil)
991+
992+
if err != nil {
993+
t.Fatalf("Expected no error, got %v", err)
994+
}
995+
if arr != nil {
996+
t.Errorf("Expected nil, got %+v", arr)
997+
}
998+
}
999+
1000+
var Int32ArrayStringTests = []struct {
1001+
str string
1002+
arr Int32Array
1003+
}{
1004+
{`{}`, Int32Array{}},
1005+
{`{12}`, Int32Array{12}},
1006+
{`{345,678}`, Int32Array{345, 678}},
1007+
}
1008+
1009+
func TestInt32ArrayScanBytes(t *testing.T) {
1010+
for _, tt := range Int32ArrayStringTests {
1011+
bytes := []byte(tt.str)
1012+
arr := Int32Array{5, 5, 5}
1013+
err := arr.Scan(bytes)
1014+
1015+
if err != nil {
1016+
t.Fatalf("Expected no error for %q, got %v", bytes, err)
1017+
}
1018+
if !reflect.DeepEqual(arr, tt.arr) {
1019+
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
1020+
}
1021+
}
1022+
}
1023+
1024+
func BenchmarkInt32ArrayScanBytes(b *testing.B) {
1025+
var a Int32Array
1026+
var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`)
1027+
1028+
for i := 0; i < b.N; i++ {
1029+
a = Int32Array{}
1030+
a.Scan(x)
1031+
}
1032+
}
1033+
1034+
func TestInt32ArrayScanString(t *testing.T) {
1035+
for _, tt := range Int32ArrayStringTests {
1036+
arr := Int32Array{5, 5, 5}
1037+
err := arr.Scan(tt.str)
1038+
1039+
if err != nil {
1040+
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
1041+
}
1042+
if !reflect.DeepEqual(arr, tt.arr) {
1043+
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
1044+
}
1045+
}
1046+
}
1047+
1048+
func TestInt32ArrayScanError(t *testing.T) {
1049+
for _, tt := range []struct {
1050+
input, err string
1051+
}{
1052+
{``, "unable to parse array"},
1053+
{`{`, "unable to parse array"},
1054+
{`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int32Array"},
1055+
{`{NULL}`, "parsing array element index 0:"},
1056+
{`{a}`, "parsing array element index 0:"},
1057+
{`{5,a}`, "parsing array element index 1:"},
1058+
{`{5,6,a}`, "parsing array element index 2:"},
1059+
} {
1060+
arr := Int32Array{5, 5, 5}
1061+
err := arr.Scan(tt.input)
1062+
1063+
if err == nil {
1064+
t.Fatalf("Expected error for %q, got none", tt.input)
1065+
}
1066+
if !strings.Contains(err.Error(), tt.err) {
1067+
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
1068+
}
1069+
if !reflect.DeepEqual(arr, Int32Array{5, 5, 5}) {
1070+
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
1071+
}
1072+
}
1073+
}
1074+
1075+
func TestInt32ArrayValue(t *testing.T) {
1076+
result, err := Int32Array(nil).Value()
1077+
1078+
if err != nil {
1079+
t.Fatalf("Expected no error for nil, got %v", err)
1080+
}
1081+
if result != nil {
1082+
t.Errorf("Expected nil, got %q", result)
1083+
}
1084+
1085+
result, err = Int32Array([]int32{}).Value()
1086+
1087+
if err != nil {
1088+
t.Fatalf("Expected no error for empty, got %v", err)
1089+
}
1090+
if expected := `{}`; !reflect.DeepEqual(result, expected) {
1091+
t.Errorf("Expected empty, got %q", result)
1092+
}
1093+
1094+
result, err = Int32Array([]int32{1, 2, 3}).Value()
1095+
1096+
if err != nil {
1097+
t.Fatalf("Expected no error, got %v", err)
1098+
}
1099+
if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) {
1100+
t.Errorf("Expected %q, got %q", expected, result)
1101+
}
1102+
}
1103+
1104+
func BenchmarkInt32ArrayValue(b *testing.B) {
1105+
rand.Seed(1)
1106+
x := make([]int32, 10)
1107+
for i := 0; i < len(x); i++ {
1108+
x[i] = rand.Int31()
1109+
}
1110+
a := Int32Array(x)
1111+
1112+
for i := 0; i < b.N; i++ {
1113+
a.Value()
1114+
}
1115+
}
1116+
7761117
func TestStringArrayScanUnsupported(t *testing.T) {
7771118
var arr StringArray
7781119
err := arr.Scan(true)

0 commit comments

Comments
 (0)
Please sign in to comment.