Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,26 @@ func (tum TextUnmarshalerMode) valid() bool {
return tum >= 0 && tum < maxTextUnmarshalerMode
}

// FloatPrecision sets whether float64 CBOR values are decoded into float32 if the number cannot
// be stored exactly.
type FloatPrecisionMode int

const (
// FloatPrecisionIgnored will decode float64 values into a float32 Go type even if
// precision is lost.
FloatPrecisionIgnored FloatPrecisionMode = iota

// FloatPrecisionKept will return an error when trying to decode a float64 into a float32,
// if precision will be lost.
FloatPrecisionKept

maxFloatPrecisionMode
)

func (fpm FloatPrecisionMode) valid() bool {
return fpm >= 0 && fpm < maxFloatPrecisionMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -912,6 +932,10 @@ type DecOptions struct {
// implement json.Unmarshaler but do not also implement cbor.Unmarshaler. If nil, decoding
// behavior is not influenced by whether or not a type implements json.Unmarshaler.
JSONUnmarshalerTranscoder Transcoder

// FloatPrecision sets whether float64 CBOR values are decoded into float32 if the number cannot
// be stored exactly.
FloatPrecision FloatPrecisionMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -1128,6 +1152,10 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
return nil, errors.New("cbor: invalid TextUnmarshaler " + strconv.Itoa(int(opts.TextUnmarshaler)))
}

if !opts.FloatPrecision.valid() {
return nil, errors.New("cbor: invalid FloatPrecision " + strconv.Itoa(int(opts.FloatPrecision)))
}

dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand Down Expand Up @@ -1157,6 +1185,7 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
binaryUnmarshaler: opts.BinaryUnmarshaler,
textUnmarshaler: opts.TextUnmarshaler,
jsonUnmarshalerTranscoder: opts.JSONUnmarshalerTranscoder,
floatPrecision: opts.FloatPrecision,
}

return &dm, nil
Expand Down Expand Up @@ -1238,6 +1267,7 @@ type decMode struct {
binaryUnmarshaler BinaryUnmarshalerMode
textUnmarshaler TextUnmarshalerMode
jsonUnmarshalerTranscoder Transcoder
floatPrecision FloatPrecisionMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand Down Expand Up @@ -1280,6 +1310,7 @@ func (dm *decMode) DecOptions() DecOptions {
BinaryUnmarshaler: dm.binaryUnmarshaler,
TextUnmarshaler: dm.textUnmarshaler,
JSONUnmarshalerTranscoder: dm.jsonUnmarshalerTranscoder,
FloatPrecision: dm.floatPrecision,
}
}

Expand Down Expand Up @@ -1592,7 +1623,20 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin

case additionalInformationAsFloat64:
f := math.Float64frombits(val)
return fillFloat(t, f, v)
err := fillFloat(t, f, v)
if d.dm.floatPrecision == FloatPrecisionIgnored || err != nil {
return err
}
// No error and we need to maintain float precision
if v.Kind() == reflect.Float64 {
return nil
} else if math.IsNaN(f) {
return fillFloat(t, f, v)
} else if f == float64(float32(f)) {
return nil
}
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String(),
errorMsg: "float64 value would lose precision in float32 type"}

default: // ai <= 24
if d.dm.simpleValues.rejected[SimpleValue(val)] {
Expand Down
126 changes: 126 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5511,6 +5511,7 @@ func TestDecOptions(t *testing.T) {
BinaryUnmarshaler: BinaryUnmarshalerNone,
TextUnmarshaler: TextUnmarshalerTextString,
JSONUnmarshalerTranscoder: stubTranscoder{},
FloatPrecision: FloatPrecisionKept,
}
ov := reflect.ValueOf(opts1)
for i := 0; i < ov.NumField(); i++ {
Expand Down Expand Up @@ -10910,3 +10911,128 @@ func TestJSONUnmarshalerTranscoder(t *testing.T) {
})
}
}

func TestFloatPrecisionMode(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
in []byte
intoType reflect.Type
want any
shouldErr bool
}{
{
name: "FloatPrecision is not called by default",
opts: DecOptions{},
in: mustHexDecode("fbc010666666666666"),
intoType: reflect.TypeOf(float32(0)),
want: float32(-4.1),
},
{
name: "FloatPrecisionKept float64 precise",
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
in: mustHexDecode("fbc010666666666666"),
intoType: reflect.TypeOf((*any)(nil)).Elem(),
want: float64(-4.1),
}, {
name: "FloatPrecisionKept float64 precise 2",
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
in: mustHexDecode("fb3ff199999999999a"),
intoType: reflect.TypeOf((*any)(nil)).Elem(),
want: float64(1.1),
},
{
name: "FloatPrecisionKept float32 precise",
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
in: mustHexDecode("fb3ff8000000000000"),
intoType: reflect.TypeOf(float32(0)),
want: float32(1.5),
},
{
name: "FloatPrecisionKept float32 precise 2",
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
in: mustHexDecode("fb3ff8000000000000"),
intoType: reflect.TypeOf(float64(0)),
want: float64(1.5),
},
{
name: "FloatPrecisionIgnored float64 precise",
opts: DecOptions{FloatPrecision: FloatPrecisionIgnored},
in: mustHexDecode("fbc010666666666666"),
intoType: reflect.TypeOf((*any)(nil)).Elem(),
want: float64(-4.1),
},
{
name: "FloatPrecisionKept float32 err",
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
in: mustHexDecode("fbc010666666666666"),
intoType: reflect.TypeOf(float32(0)),
shouldErr: true,
},
{
name: "FloatPrecisionKept float32 inf",
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
in: mustHexDecode("fb7ff0000000000000"),
intoType: reflect.TypeOf(float32(0)),
want: float32(math.Inf(1)),
}, {
name: "FloatPrecisionKept float32 NaN",
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
in: mustHexDecode("fb7ff8000000000000"),
intoType: reflect.TypeOf(float32(0)),
want: float32(math.NaN()),
}, {
name: "FloatPrecisionKept float32 signal NaN",
opts: DecOptions{FloatPrecision: FloatPrecisionKept},
in: mustHexDecode("fb7ff8000000000001"),
intoType: reflect.TypeOf(float32(0)),
want: float32(math.NaN()),
},
} {
t.Run(tc.name, func(t *testing.T) {
dm, err := tc.opts.DecMode()
if err != nil {
t.Fatal(err)
}

gotrv := reflect.New(tc.intoType)
err = dm.Unmarshal(tc.in, gotrv.Interface())
if tc.shouldErr {
if err == nil {
t.Fatal("expected error")
}
// It should err and it did, done here
return
} else if err != nil {
t.Fatal(err)
}

got := gotrv.Elem().Interface()

// Special handling for NaN values since reflect.DeepEqual considers NaN != NaN
wantIsNaN := false
gotIsNaN := false

switch wantVal := tc.want.(type) {
case float32:
wantIsNaN = math.IsNaN(float64(wantVal))
case float64:
wantIsNaN = math.IsNaN(wantVal)
}

switch gotVal := got.(type) {
case float32:
gotIsNaN = math.IsNaN(float64(gotVal))
case float64:
gotIsNaN = math.IsNaN(gotVal)
}

if wantIsNaN && gotIsNaN {
// Both are NaN, consider them equal
return
} else if !reflect.DeepEqual(tc.want, got) {
t.Errorf("want: %v (%T), got: %v (%T)", tc.want, tc.want, got, got)
}
})
}
}
Loading