Skip to content

Commit f6d888c

Browse files
committed
add FloatPrecision
- reject NaNs - lint; support go 1.20 Signed-off-by: Cole Anthony Capilongo [email protected]
1 parent 7161b93 commit f6d888c

File tree

2 files changed

+146
-1
lines changed

2 files changed

+146
-1
lines changed

decode.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,26 @@ func (tum TextUnmarshalerMode) valid() bool {
769769
return tum >= 0 && tum < maxTextUnmarshalerMode
770770
}
771771

772+
// FloatPrecision sets whether float64 CBOR values are decoded into float32 if the number cannot
773+
// be stored exactly.
774+
type FloatPrecisionMode int
775+
776+
const (
777+
// FloatPrecisionIgnored will decode float64 values into a float32 Go type even if
778+
// precision is lost.
779+
FloatPrecisionIgnored FloatPrecisionMode = iota
780+
781+
// FloatPrecisionKept will return an error when trying to decode a float64 into a float32,
782+
// if precision will be lost.
783+
FloatPrecisionKept
784+
785+
maxFloatPrecisionMode
786+
)
787+
788+
func (fpm FloatPrecisionMode) valid() bool {
789+
return fpm >= 0 && fpm < maxFloatPrecisionMode
790+
}
791+
772792
// DecOptions specifies decoding options.
773793
type DecOptions struct {
774794
// DupMapKey specifies whether to enforce duplicate map key.
@@ -912,6 +932,10 @@ type DecOptions struct {
912932
// implement json.Unmarshaler but do not also implement cbor.Unmarshaler. If nil, decoding
913933
// behavior is not influenced by whether or not a type implements json.Unmarshaler.
914934
JSONUnmarshalerTranscoder Transcoder
935+
936+
// FloatPrecision sets whether float64 CBOR values are decoded into float32 if the number cannot
937+
// be stored exactly.
938+
FloatPrecision FloatPrecisionMode
915939
}
916940

917941
// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
@@ -1128,6 +1152,10 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
11281152
return nil, errors.New("cbor: invalid TextUnmarshaler " + strconv.Itoa(int(opts.TextUnmarshaler)))
11291153
}
11301154

1155+
if !opts.FloatPrecision.valid() {
1156+
return nil, errors.New("cbor: invalid FloatPrecision " + strconv.Itoa(int(opts.FloatPrecision)))
1157+
}
1158+
11311159
dm := decMode{
11321160
dupMapKey: opts.DupMapKey,
11331161
timeTag: opts.TimeTag,
@@ -1157,6 +1185,7 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
11571185
binaryUnmarshaler: opts.BinaryUnmarshaler,
11581186
textUnmarshaler: opts.TextUnmarshaler,
11591187
jsonUnmarshalerTranscoder: opts.JSONUnmarshalerTranscoder,
1188+
floatPrecision: opts.FloatPrecision,
11601189
}
11611190

11621191
return &dm, nil
@@ -1238,6 +1267,7 @@ type decMode struct {
12381267
binaryUnmarshaler BinaryUnmarshalerMode
12391268
textUnmarshaler TextUnmarshalerMode
12401269
jsonUnmarshalerTranscoder Transcoder
1270+
floatPrecision FloatPrecisionMode
12411271
}
12421272

12431273
var defaultDecMode, _ = DecOptions{}.decMode()
@@ -1280,6 +1310,7 @@ func (dm *decMode) DecOptions() DecOptions {
12801310
BinaryUnmarshaler: dm.binaryUnmarshaler,
12811311
TextUnmarshaler: dm.textUnmarshaler,
12821312
JSONUnmarshalerTranscoder: dm.jsonUnmarshalerTranscoder,
1313+
FloatPrecision: dm.floatPrecision,
12831314
}
12841315
}
12851316

@@ -1592,7 +1623,16 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
15921623

15931624
case additionalInformationAsFloat64:
15941625
f := math.Float64frombits(val)
1595-
return fillFloat(t, f, v)
1626+
err := fillFloat(t, f, v)
1627+
if d.dm.floatPrecision == FloatPrecisionIgnored || err != nil {
1628+
return err
1629+
}
1630+
// No error and we need to maintain float precision
1631+
if v.Kind() == reflect.Float64 || (f == float64(float32(f)) && !math.IsNaN(f)) {
1632+
return nil
1633+
}
1634+
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String(),
1635+
errorMsg: "float64 value would lose precision in float32 type"}
15961636

15971637
default: // ai <= 24
15981638
if d.dm.simpleValues.rejected[SimpleValue(val)] {

decode_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5511,6 +5511,7 @@ func TestDecOptions(t *testing.T) {
55115511
BinaryUnmarshaler: BinaryUnmarshalerNone,
55125512
TextUnmarshaler: TextUnmarshalerTextString,
55135513
JSONUnmarshalerTranscoder: stubTranscoder{},
5514+
FloatPrecision: FloatPrecisionKept,
55145515
}
55155516
ov := reflect.ValueOf(opts1)
55165517
for i := 0; i < ov.NumField(); i++ {
@@ -10910,3 +10911,107 @@ func TestJSONUnmarshalerTranscoder(t *testing.T) {
1091010911
})
1091110912
}
1091210913
}
10914+
10915+
func TestFloatPrecisionMode(t *testing.T) {
10916+
for _, tc := range []struct {
10917+
name string
10918+
opts DecOptions
10919+
in []byte
10920+
intoType reflect.Type
10921+
want any
10922+
shouldErr bool
10923+
}{
10924+
{
10925+
name: "FloatPrecision is not called by default",
10926+
opts: DecOptions{},
10927+
in: mustHexDecode("fbc010666666666666"),
10928+
intoType: reflect.TypeOf(float32(0)),
10929+
want: float32(-4.1),
10930+
},
10931+
{
10932+
name: "FloatPrecisionKept float64 precise",
10933+
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
10934+
in: mustHexDecode("fbc010666666666666"),
10935+
intoType: reflect.TypeOf((*any)(nil)).Elem(),
10936+
want: float64(-4.1),
10937+
}, {
10938+
name: "FloatPrecisionKept float64 precise 2",
10939+
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
10940+
in: mustHexDecode("fb3ff199999999999a"),
10941+
intoType: reflect.TypeOf((*any)(nil)).Elem(),
10942+
want: float64(1.1),
10943+
},
10944+
{
10945+
name: "FloatPrecisionKept float32 precise",
10946+
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
10947+
in: mustHexDecode("fb3ff8000000000000"),
10948+
intoType: reflect.TypeOf(float32(0)),
10949+
want: float32(1.5),
10950+
},
10951+
{
10952+
name: "FloatPrecisionKept float32 precise 2",
10953+
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
10954+
in: mustHexDecode("fb3ff8000000000000"),
10955+
intoType: reflect.TypeOf(float64(0)),
10956+
want: float64(1.5),
10957+
},
10958+
{
10959+
name: "FloatPrecisionIgnored float64 precise",
10960+
opts: DecOptions{FloatPrecision: FloatPrecisionIgnored},
10961+
in: mustHexDecode("fbc010666666666666"),
10962+
intoType: reflect.TypeOf((*any)(nil)).Elem(),
10963+
want: float64(-4.1),
10964+
},
10965+
{
10966+
name: "FloatPrecisionKept float32 err",
10967+
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
10968+
in: mustHexDecode("fbc010666666666666"),
10969+
intoType: reflect.TypeOf(float32(0)),
10970+
shouldErr: true,
10971+
},
10972+
{
10973+
name: "FloatPrecisionKept float32 inf",
10974+
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
10975+
in: mustHexDecode("fb7ff0000000000000"),
10976+
intoType: reflect.TypeOf(float32(0)),
10977+
want: float32(math.Inf(1)),
10978+
}, {
10979+
name: "FloatPrecisionKept float32 NaN",
10980+
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
10981+
in: mustHexDecode("fb7ff8000000000000"),
10982+
intoType: reflect.TypeOf(float32(0)),
10983+
// want: float32(math.NaN()),
10984+
shouldErr: true,
10985+
}, {
10986+
name: "FloatPrecisionKept float32 signal NaN",
10987+
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
10988+
in: mustHexDecode("fb7ff8000000000001"),
10989+
intoType: reflect.TypeOf(float32(0)),
10990+
shouldErr: true,
10991+
},
10992+
} {
10993+
t.Run(tc.name, func(t *testing.T) {
10994+
dm, err := tc.opts.DecMode()
10995+
if err != nil {
10996+
t.Fatal(err)
10997+
}
10998+
10999+
gotrv := reflect.New(tc.intoType)
11000+
err = dm.Unmarshal(tc.in, gotrv.Interface())
11001+
if tc.shouldErr {
11002+
if err == nil {
11003+
t.Fatal("expected error")
11004+
}
11005+
// It should err and it did, done here
11006+
return
11007+
} else if err != nil {
11008+
t.Fatal(err)
11009+
}
11010+
11011+
got := gotrv.Elem().Interface()
11012+
if !reflect.DeepEqual(tc.want, got) {
11013+
t.Errorf("want: %v (%T), got: %v (%T)", tc.want, tc.want, got, got)
11014+
}
11015+
})
11016+
}
11017+
}

0 commit comments

Comments
 (0)