diff --git a/go/mysql/fastparse/fastparse.go b/go/mysql/fastparse/fastparse.go index f9aca692abd..a17401b7fdd 100644 --- a/go/mysql/fastparse/fastparse.go +++ b/go/mysql/fastparse/fastparse.go @@ -123,6 +123,9 @@ func ParseInt64(s string, base int) (int64, error) { i++ } + if i >= uint(len(s)) { + return 0, fmt.Errorf("cannot parse int64 from %q", s) + } minus := s[i] == '-' if minus { i++ @@ -160,21 +163,15 @@ next: default: cutoff = math.MaxInt64/uint64(base) + 1 } - if d >= cutoff { - if minus { - return math.MinInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow) - } + if !minus && d >= cutoff { return math.MaxInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow) } - v := d*uint64(base) + uint64(b) - if v < d { - if minus { - return math.MinInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow) - } - return math.MaxInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow) + if minus && d > cutoff { + return math.MinInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow) } - d = v + + d = d*uint64(base) + uint64(b) i++ } diff --git a/go/mysql/fastparse/fastparse_test.go b/go/mysql/fastparse/fastparse_test.go index bec312b0bb5..d1bf351284c 100644 --- a/go/mysql/fastparse/fastparse_test.go +++ b/go/mysql/fastparse/fastparse_test.go @@ -17,6 +17,8 @@ package fastparse import ( "math" + "math/big" + "strconv" "testing" "github.com/stretchr/testify/require" @@ -190,6 +192,48 @@ func TestParseInt64(t *testing.T) { expected: 42, err: `unparsed tail left after parsing int64 from "\t 42 \n": "\n"`, }, + { + input: "", + base: 10, + expected: 0, + err: `cannot parse int64 from empty string`, + }, + { + input: "256", + base: 1, + expected: 0, + err: `invalid base 1; must be in [2, 36]`, + }, + { + input: "256", + base: 37, + expected: 0, + err: `invalid base 37; must be in [2, 36]`, + }, + { + input: " -", + base: 10, + expected: 0, + err: `cannot parse int64 from " -"`, + }, + { + input: "-18446744073709551615", + base: 10, + expected: -9223372036854775808, + err: `cannot parse int64 from "-18446744073709551615": overflow`, + }, + { + input: " ", + base: 10, + expected: 0, + err: `cannot parse int64 from " "`, + }, + { + input: " :", + base: 10, + expected: 0, + err: `cannot parse int64 from " :"`, + }, } for _, tc := range testcases { t.Run(tc.input, func(t *testing.T) { @@ -205,6 +249,69 @@ func TestParseInt64(t *testing.T) { } } +func TestParseEdgeInt64(t *testing.T) { + for i := int64(math.MinInt64); i < math.MinInt64+1000; i++ { + for base := 2; base <= 36; base++ { + val, err := ParseInt64(strconv.FormatInt(i, base), base) + require.NoError(t, err, "base %d", base) + require.Equal(t, int64(i), val) + } + } + for i := int64(math.MaxInt64 - 1000); i < math.MaxInt64; i++ { + for base := 2; base <= 36; base++ { + val, err := ParseInt64(strconv.FormatInt(i, base), base) + require.NoError(t, err) + require.NoError(t, err, "base %d", base) + require.Equal(t, int64(i), val) + } + } +} + +func TestParseOverflowInt64(t *testing.T) { + for i := int64(1); i <= 1000; i++ { + b := big.NewInt(math.MinInt64) + b.Sub(b, big.NewInt(i)) + for base := 2; base <= 36; base++ { + val, err := ParseInt64(b.Text(base), base) + require.Error(t, err) + require.Equal(t, int64(math.MinInt64), val) + } + } + + for i := int64(1); i <= 1000; i++ { + b := big.NewInt(math.MaxInt64) + b.Add(b, big.NewInt(i)) + for base := 2; base <= 36; base++ { + val, err := ParseInt64(b.Text(base), base) + require.Error(t, err) + require.Equal(t, int64(math.MaxInt64), val) + } + } +} + +func TestParseEdgeUint64(t *testing.T) { + for i := uint64(math.MaxUint64 - 1000); i < math.MaxUint64; i++ { + for base := 2; base <= 36; base++ { + val, err := ParseUint64(strconv.FormatUint(i, base), base) + require.NoError(t, err, "base %d", base) + require.Equal(t, uint64(i), val) + } + } +} + +func TestParseOverflowUint64(t *testing.T) { + var b big.Int + for i := int64(1); i <= 1000; i++ { + b.SetUint64(math.MaxUint64) + b.Add(&b, big.NewInt(i)) + for base := 2; base <= 36; base++ { + val, err := ParseUint64(b.Text(base), base) + require.Error(t, err) + require.Equal(t, uint64(math.MaxUint64), val) + } + } +} + func TestParseUint64(t *testing.T) { testcases := []struct { input string @@ -326,6 +433,36 @@ func TestParseUint64(t *testing.T) { expected: 42, err: `unparsed tail left after parsing uint64 from "\t 42 \n": "\n"`, }, + { + input: "", + base: 10, + expected: 0, + err: `cannot parse uint64 from empty string`, + }, + { + input: "256", + base: 1, + expected: 0, + err: `invalid base 1; must be in [2, 36]`, + }, + { + input: "256", + base: 37, + expected: 0, + err: `invalid base 37; must be in [2, 36]`, + }, + { + input: " ", + base: 10, + expected: 0, + err: `cannot parse uint64 from " "`, + }, + { + input: " :", + base: 10, + expected: 0, + err: `cannot parse uint64 from " :"`, + }, } for _, tc := range testcases { t.Run(tc.input, func(t *testing.T) {