diff --git a/decode.go b/decode.go index f0bdc3b3..b22a40bf 100644 --- a/decode.go +++ b/decode.go @@ -769,6 +769,26 @@ func (tum TextUnmarshalerMode) valid() bool { return tum >= 0 && tum < maxTextUnmarshalerMode } +// FloatPrecision sets whether float64 CBOR values are decoded into float32 if the number cannot +// be stored exactly. +type FloatPrecisionMode int + +const ( + // FloatPrecisionIgnored will decode float64 values into a float32 Go type even if + // precision is lost. + FloatPrecisionIgnored FloatPrecisionMode = iota + + // FloatPrecisionKept will return an error when trying to decode a float64 into a float32, + // if precision will be lost. + FloatPrecisionKept + + maxFloatPrecisionMode +) + +func (fpm FloatPrecisionMode) valid() bool { + return fpm >= 0 && fpm < maxFloatPrecisionMode +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -912,6 +932,10 @@ type DecOptions struct { // implement json.Unmarshaler but do not also implement cbor.Unmarshaler. If nil, decoding // behavior is not influenced by whether or not a type implements json.Unmarshaler. JSONUnmarshalerTranscoder Transcoder + + // FloatPrecision sets whether float64 CBOR values are decoded into float32 if the number cannot + // be stored exactly. + FloatPrecision FloatPrecisionMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -1128,6 +1152,10 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore return nil, errors.New("cbor: invalid TextUnmarshaler " + strconv.Itoa(int(opts.TextUnmarshaler))) } + if !opts.FloatPrecision.valid() { + return nil, errors.New("cbor: invalid FloatPrecision " + strconv.Itoa(int(opts.FloatPrecision))) + } + dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -1157,6 +1185,7 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore binaryUnmarshaler: opts.BinaryUnmarshaler, textUnmarshaler: opts.TextUnmarshaler, jsonUnmarshalerTranscoder: opts.JSONUnmarshalerTranscoder, + floatPrecision: opts.FloatPrecision, } return &dm, nil @@ -1238,6 +1267,7 @@ type decMode struct { binaryUnmarshaler BinaryUnmarshalerMode textUnmarshaler TextUnmarshalerMode jsonUnmarshalerTranscoder Transcoder + floatPrecision FloatPrecisionMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -1280,6 +1310,7 @@ func (dm *decMode) DecOptions() DecOptions { BinaryUnmarshaler: dm.binaryUnmarshaler, TextUnmarshaler: dm.textUnmarshaler, JSONUnmarshalerTranscoder: dm.jsonUnmarshalerTranscoder, + FloatPrecision: dm.floatPrecision, } } @@ -1592,7 +1623,20 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin case additionalInformationAsFloat64: f := math.Float64frombits(val) - return fillFloat(t, f, v) + err := fillFloat(t, f, v) + if d.dm.floatPrecision == FloatPrecisionIgnored || err != nil { + return err + } + // No error and we need to maintain float precision + if v.Kind() == reflect.Float64 { + return nil + } else if math.IsNaN(f) { + return fillFloat(t, f, v) + } else if f == float64(float32(f)) { + return nil + } + return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String(), + errorMsg: "float64 value would lose precision in float32 type"} default: // ai <= 24 if d.dm.simpleValues.rejected[SimpleValue(val)] { diff --git a/decode_test.go b/decode_test.go index 7ef6f510..5538ef0c 100644 --- a/decode_test.go +++ b/decode_test.go @@ -5511,6 +5511,7 @@ func TestDecOptions(t *testing.T) { BinaryUnmarshaler: BinaryUnmarshalerNone, TextUnmarshaler: TextUnmarshalerTextString, JSONUnmarshalerTranscoder: stubTranscoder{}, + FloatPrecision: FloatPrecisionKept, } ov := reflect.ValueOf(opts1) for i := 0; i < ov.NumField(); i++ { @@ -10910,3 +10911,128 @@ func TestJSONUnmarshalerTranscoder(t *testing.T) { }) } } + +func TestFloatPrecisionMode(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + in []byte + intoType reflect.Type + want any + shouldErr bool + }{ + { + name: "FloatPrecision is not called by default", + opts: DecOptions{}, + in: mustHexDecode("fbc010666666666666"), + intoType: reflect.TypeOf(float32(0)), + want: float32(-4.1), + }, + { + name: "FloatPrecisionKept float64 precise", + opts: DecOptions{FloatPrecision: FloatPrecisionKept}, + in: mustHexDecode("fbc010666666666666"), + intoType: reflect.TypeOf((*any)(nil)).Elem(), + want: float64(-4.1), + }, { + name: "FloatPrecisionKept float64 precise 2", + opts: DecOptions{FloatPrecision: FloatPrecisionKept}, + in: mustHexDecode("fb3ff199999999999a"), + intoType: reflect.TypeOf((*any)(nil)).Elem(), + want: float64(1.1), + }, + { + name: "FloatPrecisionKept float32 precise", + opts: DecOptions{FloatPrecision: FloatPrecisionKept}, + in: mustHexDecode("fb3ff8000000000000"), + intoType: reflect.TypeOf(float32(0)), + want: float32(1.5), + }, + { + name: "FloatPrecisionKept float32 precise 2", + opts: DecOptions{FloatPrecision: FloatPrecisionKept}, + in: mustHexDecode("fb3ff8000000000000"), + intoType: reflect.TypeOf(float64(0)), + want: float64(1.5), + }, + { + name: "FloatPrecisionIgnored float64 precise", + opts: DecOptions{FloatPrecision: FloatPrecisionIgnored}, + in: mustHexDecode("fbc010666666666666"), + intoType: reflect.TypeOf((*any)(nil)).Elem(), + want: float64(-4.1), + }, + { + name: "FloatPrecisionKept float32 err", + opts: DecOptions{FloatPrecision: FloatPrecisionKept}, + in: mustHexDecode("fbc010666666666666"), + intoType: reflect.TypeOf(float32(0)), + shouldErr: true, + }, + { + name: "FloatPrecisionKept float32 inf", + opts: DecOptions{FloatPrecision: FloatPrecisionKept}, + in: mustHexDecode("fb7ff0000000000000"), + intoType: reflect.TypeOf(float32(0)), + want: float32(math.Inf(1)), + }, { + name: "FloatPrecisionKept float32 NaN", + opts: DecOptions{FloatPrecision: FloatPrecisionKept}, + in: mustHexDecode("fb7ff8000000000000"), + intoType: reflect.TypeOf(float32(0)), + want: float32(math.NaN()), + }, { + name: "FloatPrecisionKept float32 signal NaN", + opts: DecOptions{FloatPrecision: FloatPrecisionKept}, + in: mustHexDecode("fb7ff8000000000001"), + intoType: reflect.TypeOf(float32(0)), + want: float32(math.NaN()), + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecMode() + if err != nil { + t.Fatal(err) + } + + gotrv := reflect.New(tc.intoType) + err = dm.Unmarshal(tc.in, gotrv.Interface()) + if tc.shouldErr { + if err == nil { + t.Fatal("expected error") + } + // It should err and it did, done here + return + } else if err != nil { + t.Fatal(err) + } + + got := gotrv.Elem().Interface() + + // Special handling for NaN values since reflect.DeepEqual considers NaN != NaN + wantIsNaN := false + gotIsNaN := false + + switch wantVal := tc.want.(type) { + case float32: + wantIsNaN = math.IsNaN(float64(wantVal)) + case float64: + wantIsNaN = math.IsNaN(wantVal) + } + + switch gotVal := got.(type) { + case float32: + gotIsNaN = math.IsNaN(float64(gotVal)) + case float64: + gotIsNaN = math.IsNaN(gotVal) + } + + if wantIsNaN && gotIsNaN { + // Both are NaN, consider them equal + return + } else if !reflect.DeepEqual(tc.want, got) { + t.Errorf("want: %v (%T), got: %v (%T)", tc.want, tc.want, got, got) + } + }) + } +}