diff --git a/decimal.go b/decimal.go index a37a230..b48cfa9 100644 --- a/decimal.go +++ b/decimal.go @@ -182,8 +182,26 @@ func NewFromString(value string) (Decimal, error) { var intString string var exp int64 - // Check if number is using scientific notation - eIndex := strings.IndexAny(value, "Ee") + // Check if number is using scientific notation and find dots + eIndex := -1 + pIndex := -1 + for i, r := range value { + if r == 'E' || r == 'e' { + if eIndex > -1 { + return Decimal{}, fmt.Errorf("can't convert %s to decimal: multiple 'E' characters found", value) + } + eIndex = i + continue + } + + if r == '.' { + if pIndex > -1 { + return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value) + } + pIndex = i + } + } + if eIndex != -1 { expInt, err := strconv.ParseInt(value[eIndex+1:], 10, 32) if err != nil { @@ -196,23 +214,12 @@ func NewFromString(value string) (Decimal, error) { exp = expInt } - pIndex := -1 - vLen := len(value) - for i := 0; i < vLen; i++ { - if value[i] == '.' { - if pIndex > -1 { - return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value) - } - pIndex = i - } - } - if pIndex == -1 { // There is no decimal point, we can just parse the original string as // an int intString = value } else { - if pIndex+1 < vLen { + if pIndex+1 < len(value) { intString = value[:pIndex] + value[pIndex+1:] } else { intString = value[:pIndex] @@ -1766,15 +1773,10 @@ func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error { return nil } - str, err := unquoteIfQuoted(decimalBytes) - if err != nil { - return fmt.Errorf("error decoding string '%s': %s", decimalBytes, err) - } - - decimal, err := NewFromString(str) + decimal, err := NewFromString(unquoteIfQuoted(string(decimalBytes))) *d = decimal if err != nil { - return fmt.Errorf("error decoding string '%s': %s", str, err) + return fmt.Errorf("error decoding string '%s': %s", string(decimalBytes), err) } return nil } @@ -1852,14 +1854,18 @@ func (d *Decimal) Scan(value interface{}) error { *d = NewFromUint64(v) return nil - default: - // default is trying to interpret value stored as string - str, err := unquoteIfQuoted(v) - if err != nil { - return err - } - *d, err = NewFromString(str) + case string: + var err error + *d, err = NewFromString(unquoteIfQuoted(v)) + return err + + case []byte: + var err error + *d, err = NewFromString(unquoteIfQuoted(string(v))) return err + + default: + return fmt.Errorf("could not convert value '%+v' to any known type", value) } } @@ -2021,23 +2027,13 @@ func RescalePair(d1 Decimal, d2 Decimal) (Decimal, Decimal) { return d1, d2 } -func unquoteIfQuoted(value interface{}) (string, error) { - var bytes []byte - - switch v := value.(type) { - case string: - bytes = []byte(v) - case []byte: - bytes = v - default: - return "", fmt.Errorf("could not convert value '%+v' to byte array of type '%T'", value, value) - } - +func unquoteIfQuoted(value string) string { // If the amount is quoted, strip the quotes - if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' { - bytes = bytes[1 : len(bytes)-1] + if len(value) > 2 && value[0] == '"' && value[len(value)-1] == '"' { + return value[1 : len(value)-1] } - return string(bytes), nil + + return value } // NullDecimal represents a nullable decimal with compatibility for diff --git a/decimal_bench_test.go b/decimal_bench_test.go index 34e038f..4e46251 100644 --- a/decimal_bench_test.go +++ b/decimal_bench_test.go @@ -312,3 +312,15 @@ func BenchmarkDecimal_ExpTaylor(b *testing.B) { _, _ = d.ExpTaylor(10) } } + +func BenchmarkDecimal_UnmarshalJSON(b *testing.B) { + b.ResetTimer() + + bstr := []byte("1234.56789") + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = (&Decimal{}).UnmarshalJSON(bstr) + } +} diff --git a/decimal_test.go b/decimal_test.go index d398f2d..d288821 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -80,6 +80,11 @@ var testTableScientificNotation = map[string]string{ "123.456e10": "1234560000000", } +var testMalformedDecimalStrings = map[string]error{ + "1ee10": fmt.Errorf("can't convert %s to decimal: multiple 'E' characters found", "1ee10"), + "123.45.66": fmt.Errorf("can't convert %s to decimal: too many .s", "123.45.66"), +} + func init() { for _, s := range testTable { s.exact = strconv.FormatFloat(s.float, 'f', 1500, 64) @@ -239,6 +244,15 @@ func TestNewFromString(t *testing.T) { d.value.String(), d.exp) } } + + for s, e := range testMalformedDecimalStrings { + _, err := NewFromString(s) + if err == nil { + t.Errorf("expected an error, got nil %s", s) + } else if err.Error() != e.Error() { + t.Errorf("expected %v error, got %v", e, err) + } + } } func TestNewFromFormattedString(t *testing.T) {