From 5a62cf40a4b3828c1bc9ba23b0669458529e1c55 Mon Sep 17 00:00:00 2001 From: PauliusLozys Date: Mon, 9 Jun 2025 21:15:27 +0300 Subject: [PATCH 1/2] `UnmarshalJSON` and `NewFromString` performance improvements This PR improves `UnmarshalJSON` performance by reducing unnecessary allocation caused by `unquoteIfQuoted` function. This also touches on `Scan` method to split `default` case into `string` and `[]byte` cases. This PR also slightly touches `NewFromString` function by making so scientific notation and dots are checked in a single loop. --- decimal.go | 79 ++++++++++++++++++++----------------------- decimal_bench_test.go | 12 +++++++ 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/decimal.go b/decimal.go index a37a230..0b223a8 100644 --- a/decimal.go +++ b/decimal.go @@ -182,8 +182,23 @@ func NewFromString(value string) (Decimal, error) { var intString string var exp int64 - // Check if number is using scientific notation - eIndex := strings.IndexAny(value, "Ee") + // Check if number is using scientific notation and find dots + eIndex := -1 + pIndex := -1 + for i, r := range value { + if eIndex == -1 && (r == 'E' || r == 'e') { + eIndex = i + continue + } + + if r == '.' { + if pIndex > -1 { + return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value) + } + pIndex = i + } + } + if eIndex != -1 { expInt, err := strconv.ParseInt(value[eIndex+1:], 10, 32) if err != nil { @@ -196,23 +211,12 @@ func NewFromString(value string) (Decimal, error) { exp = expInt } - pIndex := -1 - vLen := len(value) - for i := 0; i < vLen; i++ { - if value[i] == '.' { - if pIndex > -1 { - return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value) - } - pIndex = i - } - } - if pIndex == -1 { // There is no decimal point, we can just parse the original string as // an int intString = value } else { - if pIndex+1 < vLen { + if pIndex+1 < len(value) { intString = value[:pIndex] + value[pIndex+1:] } else { intString = value[:pIndex] @@ -1766,15 +1770,10 @@ func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error { return nil } - str, err := unquoteIfQuoted(decimalBytes) - if err != nil { - return fmt.Errorf("error decoding string '%s': %s", decimalBytes, err) - } - - decimal, err := NewFromString(str) + decimal, err := NewFromString(unquoteIfQuoted(string(decimalBytes))) *d = decimal if err != nil { - return fmt.Errorf("error decoding string '%s': %s", str, err) + return fmt.Errorf("error decoding string '%s': %s", string(decimalBytes), err) } return nil } @@ -1852,14 +1851,18 @@ func (d *Decimal) Scan(value interface{}) error { *d = NewFromUint64(v) return nil - default: - // default is trying to interpret value stored as string - str, err := unquoteIfQuoted(v) - if err != nil { - return err - } - *d, err = NewFromString(str) + case string: + var err error + *d, err = NewFromString(unquoteIfQuoted(v)) return err + + case []byte: + var err error + *d, err = NewFromString(unquoteIfQuoted(string(v))) + return err + + default: + return fmt.Errorf("could not convert value '%+v' to any known type", value) } } @@ -2021,23 +2024,13 @@ func RescalePair(d1 Decimal, d2 Decimal) (Decimal, Decimal) { return d1, d2 } -func unquoteIfQuoted(value interface{}) (string, error) { - var bytes []byte - - switch v := value.(type) { - case string: - bytes = []byte(v) - case []byte: - bytes = v - default: - return "", fmt.Errorf("could not convert value '%+v' to byte array of type '%T'", value, value) - } - +func unquoteIfQuoted(value string) string { // If the amount is quoted, strip the quotes - if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' { - bytes = bytes[1 : len(bytes)-1] + if len(value) > 2 && value[0] == '"' && value[len(value)-1] == '"' { + return value[1 : len(value)-1] } - return string(bytes), nil + + return value } // NullDecimal represents a nullable decimal with compatibility for diff --git a/decimal_bench_test.go b/decimal_bench_test.go index 34e038f..4e46251 100644 --- a/decimal_bench_test.go +++ b/decimal_bench_test.go @@ -312,3 +312,15 @@ func BenchmarkDecimal_ExpTaylor(b *testing.B) { _, _ = d.ExpTaylor(10) } } + +func BenchmarkDecimal_UnmarshalJSON(b *testing.B) { + b.ResetTimer() + + bstr := []byte("1234.56789") + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = (&Decimal{}).UnmarshalJSON(bstr) + } +} From 2ea3e348eac9c3b90b25259b0ceb19a6074adb38 Mon Sep 17 00:00:00 2001 From: PauliusLozys Date: Tue, 24 Jun 2025 16:22:36 +0300 Subject: [PATCH 2/2] Add malformed scientific notation check and tests --- decimal.go | 5 ++++- decimal_test.go | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/decimal.go b/decimal.go index 0b223a8..b48cfa9 100644 --- a/decimal.go +++ b/decimal.go @@ -186,7 +186,10 @@ func NewFromString(value string) (Decimal, error) { eIndex := -1 pIndex := -1 for i, r := range value { - if eIndex == -1 && (r == 'E' || r == 'e') { + if r == 'E' || r == 'e' { + if eIndex > -1 { + return Decimal{}, fmt.Errorf("can't convert %s to decimal: multiple 'E' characters found", value) + } eIndex = i continue } diff --git a/decimal_test.go b/decimal_test.go index d398f2d..d288821 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -80,6 +80,11 @@ var testTableScientificNotation = map[string]string{ "123.456e10": "1234560000000", } +var testMalformedDecimalStrings = map[string]error{ + "1ee10": fmt.Errorf("can't convert %s to decimal: multiple 'E' characters found", "1ee10"), + "123.45.66": fmt.Errorf("can't convert %s to decimal: too many .s", "123.45.66"), +} + func init() { for _, s := range testTable { s.exact = strconv.FormatFloat(s.float, 'f', 1500, 64) @@ -239,6 +244,15 @@ func TestNewFromString(t *testing.T) { d.value.String(), d.exp) } } + + for s, e := range testMalformedDecimalStrings { + _, err := NewFromString(s) + if err == nil { + t.Errorf("expected an error, got nil %s", s) + } else if err.Error() != e.Error() { + t.Errorf("expected %v error, got %v", e, err) + } + } } func TestNewFromFormattedString(t *testing.T) {