Skip to content

Commit

Permalink
Add check to avoid runtime error and add tests for `go/mysql/fastpars…
Browse files Browse the repository at this point in the history
…e` (vitessio#15000)

Signed-off-by: Noble Mittal <[email protected]>
Signed-off-by: Dirkjan Bussink <[email protected]>
Co-authored-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
beingnoble03 and dbussink authored Jan 22, 2024
1 parent ed4b872 commit 4a6b138
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 11 deletions.
19 changes: 8 additions & 11 deletions go/mysql/fastparse/fastparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ func ParseInt64(s string, base int) (int64, error) {
i++
}

if i >= uint(len(s)) {
return 0, fmt.Errorf("cannot parse int64 from %q", s)
}
minus := s[i] == '-'
if minus {
i++
Expand Down Expand Up @@ -160,21 +163,15 @@ next:
default:
cutoff = math.MaxInt64/uint64(base) + 1
}
if d >= cutoff {
if minus {
return math.MinInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow)
}
if !minus && d >= cutoff {
return math.MaxInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow)
}

v := d*uint64(base) + uint64(b)
if v < d {
if minus {
return math.MinInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow)
}
return math.MaxInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow)
if minus && d > cutoff {
return math.MinInt64, fmt.Errorf("cannot parse int64 from %q: %w", s, ErrOverflow)
}
d = v

d = d*uint64(base) + uint64(b)
i++
}

Expand Down
137 changes: 137 additions & 0 deletions go/mysql/fastparse/fastparse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package fastparse

import (
"math"
"math/big"
"strconv"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -190,6 +192,48 @@ func TestParseInt64(t *testing.T) {
expected: 42,
err: `unparsed tail left after parsing int64 from "\t 42 \n": "\n"`,
},
{
input: "",
base: 10,
expected: 0,
err: `cannot parse int64 from empty string`,
},
{
input: "256",
base: 1,
expected: 0,
err: `invalid base 1; must be in [2, 36]`,
},
{
input: "256",
base: 37,
expected: 0,
err: `invalid base 37; must be in [2, 36]`,
},
{
input: " -",
base: 10,
expected: 0,
err: `cannot parse int64 from " -"`,
},
{
input: "-18446744073709551615",
base: 10,
expected: -9223372036854775808,
err: `cannot parse int64 from "-18446744073709551615": overflow`,
},
{
input: " ",
base: 10,
expected: 0,
err: `cannot parse int64 from " "`,
},
{
input: " :",
base: 10,
expected: 0,
err: `cannot parse int64 from " :"`,
},
}
for _, tc := range testcases {
t.Run(tc.input, func(t *testing.T) {
Expand All @@ -205,6 +249,69 @@ func TestParseInt64(t *testing.T) {
}
}

func TestParseEdgeInt64(t *testing.T) {
for i := int64(math.MinInt64); i < math.MinInt64+1000; i++ {
for base := 2; base <= 36; base++ {
val, err := ParseInt64(strconv.FormatInt(i, base), base)
require.NoError(t, err, "base %d", base)
require.Equal(t, int64(i), val)
}
}
for i := int64(math.MaxInt64 - 1000); i < math.MaxInt64; i++ {
for base := 2; base <= 36; base++ {
val, err := ParseInt64(strconv.FormatInt(i, base), base)
require.NoError(t, err)
require.NoError(t, err, "base %d", base)
require.Equal(t, int64(i), val)
}
}
}

func TestParseOverflowInt64(t *testing.T) {
for i := int64(1); i <= 1000; i++ {
b := big.NewInt(math.MinInt64)
b.Sub(b, big.NewInt(i))
for base := 2; base <= 36; base++ {
val, err := ParseInt64(b.Text(base), base)
require.Error(t, err)
require.Equal(t, int64(math.MinInt64), val)
}
}

for i := int64(1); i <= 1000; i++ {
b := big.NewInt(math.MaxInt64)
b.Add(b, big.NewInt(i))
for base := 2; base <= 36; base++ {
val, err := ParseInt64(b.Text(base), base)
require.Error(t, err)
require.Equal(t, int64(math.MaxInt64), val)
}
}
}

func TestParseEdgeUint64(t *testing.T) {
for i := uint64(math.MaxUint64 - 1000); i < math.MaxUint64; i++ {
for base := 2; base <= 36; base++ {
val, err := ParseUint64(strconv.FormatUint(i, base), base)
require.NoError(t, err, "base %d", base)
require.Equal(t, uint64(i), val)
}
}
}

func TestParseOverflowUint64(t *testing.T) {
var b big.Int
for i := int64(1); i <= 1000; i++ {
b.SetUint64(math.MaxUint64)
b.Add(&b, big.NewInt(i))
for base := 2; base <= 36; base++ {
val, err := ParseUint64(b.Text(base), base)
require.Error(t, err)
require.Equal(t, uint64(math.MaxUint64), val)
}
}
}

func TestParseUint64(t *testing.T) {
testcases := []struct {
input string
Expand Down Expand Up @@ -326,6 +433,36 @@ func TestParseUint64(t *testing.T) {
expected: 42,
err: `unparsed tail left after parsing uint64 from "\t 42 \n": "\n"`,
},
{
input: "",
base: 10,
expected: 0,
err: `cannot parse uint64 from empty string`,
},
{
input: "256",
base: 1,
expected: 0,
err: `invalid base 1; must be in [2, 36]`,
},
{
input: "256",
base: 37,
expected: 0,
err: `invalid base 37; must be in [2, 36]`,
},
{
input: " ",
base: 10,
expected: 0,
err: `cannot parse uint64 from " "`,
},
{
input: " :",
base: 10,
expected: 0,
err: `cannot parse uint64 from " :"`,
},
}
for _, tc := range testcases {
t.Run(tc.input, func(t *testing.T) {
Expand Down

0 comments on commit 4a6b138

Please sign in to comment.