Skip to content

Commit 42cd12f

Browse files
feat(encoder): support bracket encoding form-data object members
1 parent 38be6e4 commit 42cd12f

File tree

2 files changed

+94
-37
lines changed

2 files changed

+94
-37
lines changed

internal/apiform/encoder.go

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ type encoderField struct {
6060
type encoderEntry struct {
6161
reflect.Type
6262
dateFormat string
63+
arrayFmt string
6364
root bool
6465
}
6566

@@ -77,6 +78,7 @@ func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
7778
entry := encoderEntry{
7879
Type: t,
7980
dateFormat: e.dateFormat,
81+
arrayFmt: e.arrayFmt,
8082
root: e.root,
8183
}
8284

@@ -178,34 +180,9 @@ func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
178180
}
179181
}
180182

181-
func arrayKeyEncoder(arrayFmt string) func(string, int) string {
182-
var keyFn func(string, int) string
183-
switch arrayFmt {
184-
case "comma", "repeat":
185-
keyFn = func(k string, _ int) string { return k }
186-
case "brackets":
187-
keyFn = func(key string, _ int) string { return key + "[]" }
188-
case "indices:dots":
189-
keyFn = func(k string, i int) string {
190-
if k == "" {
191-
return strconv.Itoa(i)
192-
}
193-
return k + "." + strconv.Itoa(i)
194-
}
195-
case "indices:brackets":
196-
keyFn = func(k string, i int) string {
197-
if k == "" {
198-
return strconv.Itoa(i)
199-
}
200-
return k + "[" + strconv.Itoa(i) + "]"
201-
}
202-
}
203-
return keyFn
204-
}
205-
206183
func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
207184
itemEncoder := e.typeEncoder(t.Elem())
208-
keyFn := arrayKeyEncoder(e.arrayFmt)
185+
keyFn := e.arrayKeyEncoder()
209186
return func(key string, v reflect.Value, writer *multipart.Writer) error {
210187
if keyFn == nil {
211188
return fmt.Errorf("apiform: unsupported array format")
@@ -303,13 +280,10 @@ func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
303280
})
304281

305282
return func(key string, value reflect.Value, writer *multipart.Writer) error {
306-
if key != "" {
307-
key = key + "."
308-
}
309-
283+
keyFn := e.objKeyEncoder(key)
310284
for _, ef := range encoderFields {
311285
field := value.FieldByIndex(ef.idx)
312-
err := ef.fn(key+ef.tag.name, field, writer)
286+
err := ef.fn(keyFn(ef.tag.name), field, writer)
313287
if err != nil {
314288
return err
315289
}
@@ -405,6 +379,43 @@ func (e *encoder) newReaderTypeEncoder() encoderFunc {
405379
}
406380
}
407381

382+
func (e encoder) arrayKeyEncoder() func(string, int) string {
383+
var keyFn func(string, int) string
384+
switch e.arrayFmt {
385+
case "comma", "repeat":
386+
keyFn = func(k string, _ int) string { return k }
387+
case "brackets":
388+
keyFn = func(key string, _ int) string { return key + "[]" }
389+
case "indices:dots":
390+
keyFn = func(k string, i int) string {
391+
if k == "" {
392+
return strconv.Itoa(i)
393+
}
394+
return k + "." + strconv.Itoa(i)
395+
}
396+
case "indices:brackets":
397+
keyFn = func(k string, i int) string {
398+
if k == "" {
399+
return strconv.Itoa(i)
400+
}
401+
return k + "[" + strconv.Itoa(i) + "]"
402+
}
403+
}
404+
return keyFn
405+
}
406+
407+
func (e encoder) objKeyEncoder(parent string) func(string) string {
408+
if parent == "" {
409+
return func(child string) string { return child }
410+
}
411+
switch e.arrayFmt {
412+
case "brackets":
413+
return func(child string) string { return parent + "[" + child + "]" }
414+
default:
415+
return func(child string) string { return parent + "." + child }
416+
}
417+
}
418+
408419
// Given a []byte of json (may either be an empty object or an object that already contains entries)
409420
// encode all of the entries in the map to the json byte array.
410421
func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error {
@@ -413,10 +424,6 @@ func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipar
413424
value reflect.Value
414425
}
415426

416-
if key != "" {
417-
key = key + "."
418-
}
419-
420427
pairs := []mapPair{}
421428

422429
iter := v.MapRange()
@@ -434,8 +441,9 @@ func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipar
434441
})
435442

436443
elementEncoder := e.typeEncoder(v.Type().Elem())
444+
keyFn := e.objKeyEncoder(key)
437445
for _, p := range pairs {
438-
err := elementEncoder(key+string(p.key), p.value, writer)
446+
err := elementEncoder(keyFn(p.key), p.value, writer)
439447
if err != nil {
440448
return err
441449
}

internal/apiform/form_test.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ type StructUnion struct {
123123
param.APIUnion
124124
}
125125

126+
type MultipartMarshalerParent struct {
127+
Middle MultipartMarshalerMiddleNext `form:"middle"`
128+
}
129+
130+
type MultipartMarshalerMiddleNext struct {
131+
MiddleNext MultipartMarshalerMiddle `form:"middleNext"`
132+
}
133+
134+
type MultipartMarshalerMiddle struct {
135+
Child int `form:"child"`
136+
}
137+
126138
var tests = map[string]struct {
127139
buf string
128140
val any
@@ -366,6 +378,19 @@ true
366378
},
367379
},
368380
},
381+
"recursive_struct,brackets": {
382+
`--xxx
383+
Content-Disposition: form-data; name="child[name]"
384+
385+
Alex
386+
--xxx
387+
Content-Disposition: form-data; name="name"
388+
389+
Robert
390+
--xxx--
391+
`,
392+
Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}},
393+
},
369394

370395
"recursive_struct": {
371396
`--xxx
@@ -529,6 +554,30 @@ Content-Disposition: form-data; name="union"
529554
Union: UnionTime(time.Date(2010, 05, 23, 0, 0, 0, 0, time.UTC)),
530555
},
531556
},
557+
"deeply-nested-struct,brackets": {
558+
`--xxx
559+
Content-Disposition: form-data; name="middle[middleNext][child]"
560+
561+
10
562+
--xxx--
563+
`,
564+
MultipartMarshalerParent{
565+
Middle: MultipartMarshalerMiddleNext{
566+
MiddleNext: MultipartMarshalerMiddle{
567+
Child: 10,
568+
},
569+
},
570+
},
571+
},
572+
"deeply-nested-map,brackets": {
573+
`--xxx
574+
Content-Disposition: form-data; name="middle[middleNext][child]"
575+
576+
10
577+
--xxx--
578+
`,
579+
map[string]any{"middle": map[string]any{"middleNext": map[string]any{"child": 10}}},
580+
},
532581
}
533582

534583
func TestEncode(t *testing.T) {
@@ -553,7 +602,7 @@ func TestEncode(t *testing.T) {
553602
}
554603
raw := buf.Bytes()
555604
if string(raw) != strings.ReplaceAll(test.buf, "\n", "\r\n") {
556-
t.Errorf("expected %+#v to serialize to '%s' but got '%s'", test.val, test.buf, string(raw))
605+
t.Errorf("expected %+#v to serialize to '%s' but got '%s' (with format %s)", test.val, test.buf, string(raw), arrayFmt)
557606
}
558607
})
559608
}

0 commit comments

Comments
 (0)