Skip to content

Commit

Permalink
Ensure hexval and int don't share BindVar after Normalization (#14451)
Browse files Browse the repository at this point in the history
Signed-off-by: William Martin <[email protected]>
  • Loading branch information
williammartin authored Nov 6, 2023
1 parent aec657b commit a7f0ead
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 21 deletions.
29 changes: 8 additions & 21 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*quer
type normalizer struct {
bindVars map[string]*querypb.BindVariable
reserved *ReservedVars
vals map[string]string
vals map[Literal]string
err error
inDerived bool
}
Expand All @@ -55,7 +55,7 @@ func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVari
return &normalizer{
bindVars: bindVars,
reserved: reserved,
vals: make(map[string]string),
vals: make(map[Literal]string),
}
}

Expand Down Expand Up @@ -198,30 +198,18 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) {
}

// Check if there's a bindvar for that value already.
key := keyFor(bval, node)
bvname, ok := nz.vals[key]
bvname, ok := nz.vals[*node]
if !ok {
// If there's no such bindvar, make a new one.
bvname = nz.reserved.nextUnusedVar()
nz.vals[key] = bvname
nz.vals[*node] = bvname
nz.bindVars[bvname] = bval
}

// Modify the AST node to a bindvar.
cursor.Replace(NewTypedArgument(bvname, node.SQLType()))
}

func keyFor(bval *querypb.BindVariable, lit *Literal) string {
if bval.Type != sqltypes.VarBinary && bval.Type != sqltypes.VarChar {
return lit.Val
}

// Prefixing strings with "'" ensures that a string
// and number that have the same representation don't
// collide.
return "'" + lit.Val
}

// convertLiteral converts an Literal without the dedup.
func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) {
err := validateLiteral(node)
Expand Down Expand Up @@ -279,15 +267,14 @@ func (nz *normalizer) parameterize(left, right Expr) Expr {
if bval == nil {
return nil
}
key := keyFor(bval, lit)
bvname := nz.decideBindVarName(key, lit, col, bval)
bvname := nz.decideBindVarName(lit, col, bval)
return NewTypedArgument(bvname, lit.SQLType())
}

func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, bval *querypb.BindVariable) string {
func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *querypb.BindVariable) string {
if len(lit.Val) <= 256 {
// first we check if we already have a bindvar for this value. if we do, we re-use that bindvar name
bvname, ok := nz.vals[key]
bvname, ok := nz.vals[*lit]
if ok {
return bvname
}
Expand All @@ -297,7 +284,7 @@ func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName,
// Big values are most likely not for vindexes.
// We save a lot of CPU because we avoid building
bvname := nz.reserved.ReserveColName(col)
nz.vals[key] = bvname
nz.vals[*lit] = bvname
nz.bindVars[bvname] = bval

return bvname
Expand Down
8 changes: 8 additions & 0 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,14 @@ func TestNormalize(t *testing.T) {
in: `select * from (select 12) as t`,
outstmt: `select * from (select 12 from dual) as t`,
outbv: map[string]*querypb.BindVariable{},
}, {
// HexVal and Int should not share a bindvar just because they have the same value
in: `select * from t where v1 = x'31' and v2 = 31`,
outstmt: `select * from t where v1 = :v1 /* HEXVAL */ and v2 = :v2 /* INT64 */`,
outbv: map[string]*querypb.BindVariable{
"v1": sqltypes.HexValBindVariable([]byte("x'31'")),
"v2": sqltypes.Int64BindVariable(31),
},
}}
for _, tc := range testcases {
t.Run(tc.in, func(t *testing.T) {
Expand Down

0 comments on commit a7f0ead

Please sign in to comment.