Skip to content

Commit 0cd6899

Browse files
authored
Merge pull request #213 from erizocosmico/fix/byte-array-scan
scan bytea as []byte and vice-versa
2 parents 96a9f67 + 92523fd commit 0cd6899

File tree

5 files changed

+81
-102
lines changed

5 files changed

+81
-102
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,8 @@ kallax migrate up --dir ./my-migrations --dsn 'user:pass@localhost:5432/dbname?s
724724
| `url.URL` | `text` |
725725
| `time.Time` | `timestamptz` |
726726
| `time.Duration` | `bigint` |
727-
| `[]T` | `T'[]` * where `T'` is the SQL type of type `T` |
727+
| `[]byte` | `bytea` |
728+
| `[]T` | `T'[]` * where `T'` is the SQL type of type `T`, except for `T` = `byte` |
728729
| `map[K]V` | `jsonb` |
729730
| `struct` | `jsonb` |
730731
| `*struct` | `jsonb` |

generator/migration.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ func (s *ColumnSchema) String() string {
197197
type ColumnType string
198198

199199
const (
200+
ByteaColumn ColumnType = "bytea"
200201
SmallIntColumn ColumnType = "smallint"
201202
IntegerColumn ColumnType = "integer"
202203
BigIntColumn ColumnType = "bigint"
@@ -225,6 +226,7 @@ func ArrayColumn(typ ColumnType) ColumnType {
225226
if strings.HasSuffix(string(typ), "[]") {
226227
return typ
227228
}
229+
228230
return typ + "[]"
229231
}
230232

@@ -833,7 +835,12 @@ func (t *packageTransformer) transformType(f *Field, pk bool) (ColumnType, error
833835
}
834836

835837
if f.Kind == Array || f.Kind == Slice {
836-
return ArrayColumn(typeMappings[removeTypePrefix(f.Type)]), nil
838+
typ := removeTypePrefix(f.Type)
839+
if typ == "byte" {
840+
return ByteaColumn, nil
841+
}
842+
843+
return ArrayColumn(typeMappings[typ]), nil
837844
}
838845

839846
if pk {

generator/migration_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ type Profile struct {
527527
// should be added anyway
528528
// should be added as bigint, as it is not a pk
529529
Metadata ProfileMetadata
530+
SomeData []byte
530531
}
531532
532533
type ProfileMetadata struct {
@@ -569,6 +570,7 @@ func (s *PackageTransformerSuite) TestTransform() {
569570
mkCol("background", TextColumn, false, true, nil),
570571
mkCol("user_id", UUIDColumn, false, false, mkRef("users", "id", true)),
571572
mkCol("spouse", UUIDColumn, false, false, nil),
573+
mkCol("some_data", ByteaColumn, false, true, nil),
572574
),
573575
mkTable(
574576
"metadata",

types/slices.go

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ func Slice(v interface{}) SQLType {
7171
return (*Int8Array)(&v)
7272
case *[]int8:
7373
return (*Int8Array)(v)
74-
case []uint8:
75-
return (*Uint8Array)(&v)
76-
case *[]uint8:
77-
return (*Uint8Array)(v)
74+
case []byte:
75+
return (*ByteArray)(&v)
76+
case *[]byte:
77+
return (*ByteArray)(v)
7878
case *[]float32:
7979
return (*Float32Array)(v)
8080
case []float32:
@@ -646,67 +646,29 @@ func (a Int8Array) Value() (driver.Value, error) {
646646
return "{}", nil
647647
}
648648

649-
// Uint8Array represents a one-dimensional array of the PostgreSQL unsigned integer type.
650-
type Uint8Array []uint8
649+
// ByteArray represents a byte array `bytea`.
650+
type ByteArray []uint8
651651

652652
// Scan implements the sql.Scanner interface.
653-
func (a *Uint8Array) Scan(src interface{}) error {
653+
func (a *ByteArray) Scan(src interface{}) error {
654654
switch src := src.(type) {
655655
case []byte:
656-
return a.scanBytes(src)
656+
*(*[]byte)(a) = src
657+
return nil
657658
case string:
658-
return a.scanBytes([]byte(src))
659+
*(*[]byte)(a) = []byte(src)
660+
return nil
659661
case nil:
660662
*a = nil
661663
return nil
662664
}
663665

664-
return fmt.Errorf("kallax: cannot convert %T to Uint8Array", src)
665-
}
666-
667-
func (a *Uint8Array) scanBytes(src []byte) error {
668-
elems, err := scanLinearArray(src, []byte{','}, "Uint8Array")
669-
if err != nil {
670-
return err
671-
}
672-
if *a != nil && len(elems) == 0 {
673-
*a = (*a)[:0]
674-
} else {
675-
b := make(Uint8Array, len(elems))
676-
for i, v := range elems {
677-
val, err := strconv.ParseUint(string(v), 10, 8)
678-
if err != nil {
679-
return fmt.Errorf("kallax: parsing array element index %d: %v", i, err)
680-
}
681-
b[i] = uint8(val)
682-
}
683-
*a = b
684-
}
685-
return nil
666+
return fmt.Errorf("kallax: cannot convert %T to ByteArray", src)
686667
}
687668

688669
// Value implements the driver.Valuer interface.
689-
func (a Uint8Array) Value() (driver.Value, error) {
690-
if a == nil {
691-
return nil, nil
692-
}
693-
694-
if n := len(a); n > 0 {
695-
// There will be at least two curly brackets, N bytes of values,
696-
// and N-1 bytes of delimiters.
697-
b := make([]byte, 1, 1+2*n)
698-
b[0] = '{'
699-
700-
b = strconv.AppendUint(b, uint64(a[0]), 10)
701-
for i := 1; i < n; i++ {
702-
b = append(b, ',')
703-
b = strconv.AppendUint(b, uint64(a[i]), 10)
704-
}
705-
706-
return string(append(b, '}')), nil
707-
}
708-
709-
return "{}", nil
670+
func (a ByteArray) Value() (driver.Value, error) {
671+
return ([]byte)(a), nil
710672
}
711673

712674
// Float32Array represents a one-dimensional array of the PostgreSQL real type.

types/slices_test.go

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ import (
1414
)
1515

1616
func TestSlice(t *testing.T) {
17-
require := require.New(t)
18-
1917
cases := []struct {
2018
v interface{}
2119
input interface{}
@@ -76,16 +74,6 @@ func TestSlice(t *testing.T) {
7674
[]int8{1, 3, 4},
7775
&([]int8{}),
7876
},
79-
{
80-
&([]uint8{1, 3, 4}),
81-
[]uint8{1, 3, 4},
82-
&([]uint8{}),
83-
},
84-
{
85-
&([]byte{1, 3, 4}),
86-
[]byte{1, 3, 4},
87-
&([]byte{}),
88-
},
8977
{
9078
&([]float32{1., 3., .4}),
9179
[]float32{1., 3., .4},
@@ -94,22 +82,35 @@ func TestSlice(t *testing.T) {
9482
}
9583

9684
for _, c := range cases {
97-
arr := Slice(c.v)
98-
val, err := arr.Value()
99-
require.Nil(err)
85+
t.Run(reflect.TypeOf(c.input).String(), func(t *testing.T) {
86+
require := require.New(t)
87+
arr := Slice(c.v)
88+
val, err := arr.Value()
89+
require.NoError(err)
90+
91+
pqArr := pq.Array(c.input)
92+
pqVal, err := pqArr.Value()
93+
require.NoError(err)
94+
95+
require.Equal(pqVal, val)
96+
require.NoError(Slice(c.dest).Scan(val))
97+
require.Equal(c.v, c.dest)
98+
})
99+
}
100100

101-
pqArr := pq.Array(c.input)
102-
pqVal, err := pqArr.Value()
103-
require.Nil(err)
101+
t.Run("[]byte", func(t *testing.T) {
102+
require := require.New(t)
103+
arr := Slice([]byte{1, 2, 3})
104+
val, err := arr.Value()
105+
require.NoError(err)
104106

105-
require.Equal(pqVal, val)
106-
require.Nil(Slice(c.dest).Scan(val))
107-
require.Equal(c.v, c.dest)
108-
}
107+
var b []byte
108+
require.NoError(Slice(&b).Scan(val))
109+
require.Equal([]byte{1, 2, 3}, b)
110+
})
109111
}
110112

111113
func TestSlice_Integration(t *testing.T) {
112-
s := require.New(t)
113114
cases := []struct {
114115
name string
115116
typ string
@@ -118,85 +119,91 @@ func TestSlice_Integration(t *testing.T) {
118119
}{
119120
{
120121
"int8",
121-
"smallint",
122+
"smallint[]",
122123
[]int8{math.MaxInt8, math.MinInt8},
123124
&([]int8{}),
124125
},
125126
{
126-
"unsigned int8",
127-
"smallint",
128-
[]uint8{math.MaxUint8, 0},
129-
&([]uint8{}),
127+
"byte",
128+
"bytea",
129+
[]byte{math.MaxUint8, 0},
130+
&([]byte{}),
130131
},
131132
{
132133
"int16",
133-
"smallint",
134+
"smallint[]",
134135
[]int16{math.MaxInt16, math.MinInt16},
135136
&([]int16{}),
136137
},
137138
{
138139
"unsigned int16",
139-
"integer",
140+
"integer[]",
140141
[]uint16{math.MaxUint16, 0},
141142
&([]uint16{}),
142143
},
143144
{
144145
"int32",
145-
"integer",
146+
"integer[]",
146147
[]int32{math.MaxInt32, math.MinInt32},
147148
&([]int32{}),
148149
},
149150
{
150151
"unsigned int32",
151-
"bigint",
152+
"bigint[]",
152153
[]uint32{math.MaxUint32, 0},
153154
&([]uint32{}),
154155
},
155156
{
156157
"int/int64",
157-
"bigint",
158+
"bigint[]",
158159
[]int{math.MaxInt64, math.MinInt64},
159160
&([]int{}),
160161
},
161162
{
162163
"unsigned int/int64",
163-
"numeric(20)",
164+
"numeric(20)[]",
164165
[]uint{math.MaxUint64, 0},
165166
&([]uint{}),
166167
},
167168
{
168169
"float32",
169-
"decimal(10,3)",
170+
"decimal(10,3)[]",
170171
[]float32{.3, .6},
171172
&([]float32{.3, .6}),
172173
},
173174
}
174175

175176
db, err := openTestDB()
176-
s.Nil(err)
177+
require.NoError(t, err)
177178

178179
defer func() {
179180
_, err = db.Exec("DROP TABLE IF EXISTS foo")
180-
s.Nil(err)
181+
require.NoError(t, err)
181182

182-
s.Nil(db.Close())
183+
require.NoError(t, db.Close())
183184
}()
184185

185186
for _, c := range cases {
186-
_, err := db.Exec(fmt.Sprintf(`CREATE TABLE foo (
187-
testcol %s[]
187+
t.Run(c.name, func(t *testing.T) {
188+
require := require.New(t)
189+
190+
_, err := db.Exec(fmt.Sprintf(`CREATE TABLE foo (
191+
testcol %s
188192
)`, c.typ))
189-
s.Nil(err, c.name)
193+
require.NoError(err, c.name)
190194

191-
_, err = db.Exec("INSERT INTO foo (testcol) VALUES ($1)", Slice(c.input))
192-
s.Nil(err, c.name)
195+
defer func() {
196+
_, err := db.Exec("DROP TABLE foo")
197+
require.NoError(err)
198+
}()
193199

194-
s.Nil(db.QueryRow("SELECT testcol FROM foo LIMIT 1").Scan(Slice(c.dst)), c.name)
195-
slice := reflect.ValueOf(c.dst).Elem().Interface()
196-
s.Equal(c.input, slice, c.name)
200+
_, err = db.Exec("INSERT INTO foo (testcol) VALUES ($1)", Slice(c.input))
201+
require.NoError(err, c.name)
197202

198-
_, err = db.Exec("DROP TABLE foo")
199-
s.Nil(err, c.name)
203+
require.NoError(db.QueryRow("SELECT testcol FROM foo LIMIT 1").Scan(Slice(c.dst)), c.name)
204+
slice := reflect.ValueOf(c.dst).Elem().Interface()
205+
require.Equal(c.input, slice, c.name)
206+
})
200207
}
201208
}
202209

0 commit comments

Comments
 (0)