diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index 6d7ee526bf6..07cc1851c0e 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -44,7 +44,7 @@ func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*quer type normalizer struct { bindVars map[string]*querypb.BindVariable reserved *ReservedVars - vals map[string]string + vals map[Literal]string err error inDerived bool } @@ -53,7 +53,7 @@ func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVari return &normalizer{ bindVars: bindVars, reserved: reserved, - vals: make(map[string]string), + vals: make(map[Literal]string), } } @@ -190,12 +190,11 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { } // Check if there's a bindvar for that value already. - key := keyFor(bval, node) - bvname, ok := nz.vals[key] + bvname, ok := nz.vals[*node] if !ok { // If there's no such bindvar, make a new one. bvname = nz.reserved.nextUnusedVar() - nz.vals[key] = bvname + nz.vals[*node] = bvname nz.bindVars[bvname] = bval } @@ -203,18 +202,6 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { cursor.Replace(NewArgument(bvname)) } -func keyFor(bval *querypb.BindVariable, lit *Literal) string { - if bval.Type != sqltypes.VarBinary && bval.Type != sqltypes.VarChar { - return lit.Val - } - - // Prefixing strings with "'" ensures that a string - // and number that have the same representation don't - // collide. - return "'" + lit.Val - -} - // convertLiteral converts an Literal without the dedup. func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) { err := validateLiteral(node) @@ -273,15 +260,14 @@ func (nz *normalizer) parameterize(left, right Expr) Expr { if bval == nil { return nil } - key := keyFor(bval, lit) - bvname := nz.decideBindVarName(key, lit, col, bval) + bvname := nz.decideBindVarName(lit, col, bval) return Argument(bvname) } -func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, bval *querypb.BindVariable) string { +func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *querypb.BindVariable) string { if len(lit.Val) <= 256 { // first we check if we already have a bindvar for this value. if we do, we re-use that bindvar name - bvname, ok := nz.vals[key] + bvname, ok := nz.vals[*lit] if ok { return bvname } @@ -291,7 +277,7 @@ func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, // Big values are most likely not for vindexes. // We save a lot of CPU because we avoid building bvname := nz.reserved.ReserveColName(col) - nz.vals[key] = bvname + nz.vals[*lit] = bvname nz.bindVars[bvname] = bval return bvname diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index fe3dc85b8c8..27c57ae2e6d 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -336,6 +336,14 @@ func TestNormalize(t *testing.T) { in: `select * from (select 12) as t`, outstmt: `select * from (select 12 from dual) as t`, outbv: map[string]*querypb.BindVariable{}, + }, { + // HexVal and Int should not share a bindvar just because they have the same value + in: `select * from t where v1 = x'31' and v2 = 31`, + outstmt: `select * from t where v1 = :v1 and v2 = :v2`, + outbv: map[string]*querypb.BindVariable{ + "v1": sqltypes.HexValBindVariable([]byte("x'31'")), + "v2": sqltypes.Int64BindVariable(31), + }, }} for _, tc := range testcases { t.Run(tc.in, func(t *testing.T) {