Skip to content

Commit

Permalink
Fix nullability checks in evalengine (vitessio#14556)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
GuptaManan100 authored and ejortegau committed Dec 13, 2023
1 parent ae1dad9 commit 1b6d257
Show file tree
Hide file tree
Showing 21 changed files with 166 additions and 85 deletions.
2 changes: 1 addition & 1 deletion go/mysql/collations/integration/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func mysqlconn(t *testing.T) *mysql.Conn {
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(conn.ServerVersion, "8.0.") {
if !strings.HasPrefix(conn.ServerVersion, "8.") {
conn.Close()
t.Skipf("collation integration tests are only supported in MySQL 8.0+")
}
Expand Down
6 changes: 6 additions & 0 deletions go/vt/vtgate/evalengine/api_type_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type typeAggregation struct {
geometry uint16
blob uint16
total uint16

nullable bool
}

func AggregateTypes(types []sqltypes.Type) sqltypes.Type {
Expand All @@ -63,6 +65,7 @@ func (ta *typeAggregation) addEval(e eval) {
switch e := e.(type) {
case nil:
t = sqltypes.Null
ta.nullable = true
case *evalBytes:
t = sqltypes.Type(e.tt)
f = e.flag
Expand All @@ -73,6 +76,9 @@ func (ta *typeAggregation) addEval(e eval) {
}

func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) {
if f&flagNullable != 0 {
ta.nullable = true
}
switch tt {
case sqltypes.Float32, sqltypes.Float64:
ta.double++
Expand Down
24 changes: 17 additions & 7 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll
asm.emit(func(env *ExpressionEnv) int {
end := env.vm.sp - elseOffset
for sp := env.vm.sp - stackDepth; sp < end; sp += 2 {
if env.vm.stack[sp].(*evalInt64).i != 0 {
if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now)
goto done
}
Expand Down Expand Up @@ -782,16 +782,18 @@ func (asm *assembler) Convert_bB(offset int) {
var f float64
if arg != nil {
f, _ = fastparse.ParseFloat64(arg.(*evalBytes).string())
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0)
}
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0)
return 1
}, "CONV VARBINARY(SP-%d), BOOL", offset)
}

func (asm *assembler) Convert_TB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalTemporal).isZero())
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalTemporal).isZero())
}
return 1
}, "CONV SQLTYPES(SP-%d), BOOL", offset)
}
Expand Down Expand Up @@ -839,7 +841,9 @@ func (asm *assembler) Convert_Tj(offset int) {
func (asm *assembler) Convert_dB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero())
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalDecimal).dec.IsZero())
}
return 1
}, "CONV DECIMAL(SP-%d), BOOL", offset)
}
Expand All @@ -859,7 +863,9 @@ func (asm *assembler) Convert_dbit(offset int) {
func (asm *assembler) Convert_fB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0)
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalFloat).f != 0.0)
}
return 1
}, "CONV FLOAT64(SP-%d), BOOL", offset)
}
Expand Down Expand Up @@ -917,7 +923,9 @@ func (asm *assembler) Convert_Tf(offset int) {
func (asm *assembler) Convert_iB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalInt64).i != 0)
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalInt64).i != 0)
}
return 1
}, "CONV INT64(SP-%d), BOOL", offset)
}
Expand Down Expand Up @@ -997,7 +1005,9 @@ func (asm *assembler) Convert_Nj(offset int) {
func (asm *assembler) Convert_uB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalUint64).u != 0)
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalUint64).u != 0)
}
return 1
}, "CONV UINT64(SP-%d), BOOL", offset)
}
Expand Down
32 changes: 32 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,38 @@ func TestCompilerSingle(t *testing.T) {
expression: `UNIX_TIMESTAMP('20000101103458.111111') + 1`,
result: `DECIMAL(946719299.111111)`,
},
{
expression: `cast(null * 1 as CHAR)`,
result: `NULL`,
},
{
expression: `cast(null + 1 as CHAR)`,
result: `NULL`,
},
{
expression: `cast(null - 1 as CHAR)`,
result: `NULL`,
},
{
expression: `cast(null / 1 as CHAR)`,
result: `NULL`,
},
{
expression: `cast(null % 1 as CHAR)`,
result: `NULL`,
},
{
expression: `1 AND NULL * 1`,
result: `NULL`,
},
{
expression: `case 0 when NULL then 1 else 0 end`,
result: `INT64(0)`,
},
{
expression: `case when null is null then 23 else null end`,
result: `INT64(23)`,
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down
12 changes: 8 additions & 4 deletions go/vt/vtgate/evalengine/expr_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) {
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sumtype, Col: collationNumeric}, nil
return ctype{Type: sumtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (op *opArithSub) eval(left, right eval) (eval, error) {
Expand Down Expand Up @@ -210,7 +210,7 @@ func (op *opArithSub) compile(c *compiler, left, right IR) (ctype, error) {
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: subtype, Col: collationNumeric}, nil
return ctype{Type: subtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (op *opArithMul) eval(left, right eval) (eval, error) {
Expand Down Expand Up @@ -270,7 +270,7 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) {
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: multype, Col: collationNumeric}, nil
return ctype{Type: multype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (op *opArithDiv) eval(left, right eval) (eval, error) {
Expand Down Expand Up @@ -525,9 +525,13 @@ func (expr *NegateExpr) compile(c *compiler) (ctype, error) {
c.asm.jumpDestination(skip)
return ctype{
Type: neg,
Flag: arg.Flag & (flagNull | flagNullable),
Flag: nullableFlags(arg.Flag),
Size: arg.Size,
Scale: arg.Scale,
Col: collationNumeric,
}, nil
}

func nullableFlags(flag typeFlag) typeFlag {
return flag & (flagNull | flagNullable)
}
8 changes: 4 additions & 4 deletions go/vt/vtgate/evalengine/expr_bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func (expr *BitwiseExpr) compileBinary(c *compiler, asm_ins_bb, asm_ins_uu func(

asm_ins_uu()
c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil
return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) {
Expand Down Expand Up @@ -299,8 +299,8 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) {
return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil
}

_ = c.compileToBitwiseUint64(lt, 2)
_ = c.compileToUint64(rt, 1)
lt = c.compileToBitwiseUint64(lt, 2)
rt = c.compileToUint64(rt, 1)

if i < 0 {
c.asm.BitShiftLeft_uu()
Expand All @@ -309,7 +309,7 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) {
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil
return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (expr *BitwiseExpr) compile(c *compiler) (ctype, error) {
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/evalengine/expr_bvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ func (bv *BindVariable) typeof(env *ExpressionEnv) (ctype, error) {
case sqltypes.Null:
return ctype{Type: sqltypes.Null, Flag: flagNull | flagNullable, Col: collationNull}, nil
case sqltypes.HexNum, sqltypes.HexVal:
return ctype{Type: sqltypes.VarBinary, Flag: flagHex, Col: collationNumeric}, nil
return ctype{Type: sqltypes.VarBinary, Flag: flagHex | flagNullable, Col: collationNumeric}, nil
case sqltypes.BitNum:
return ctype{Type: sqltypes.VarBinary, Flag: flagBit, Col: collationNumeric}, nil
return ctype{Type: sqltypes.VarBinary, Flag: flagBit | flagNullable, Col: collationNumeric}, nil
default:
return ctype{Type: tt, Flag: 0, Col: typedCoercionCollation(tt, collations.CollationForType(tt, bv.Collation))}, nil
return ctype{Type: tt, Flag: flagNullable, Col: typedCoercionCollation(tt, collations.CollationForType(tt, bv.Collation))}, nil
}
}

Expand Down
11 changes: 9 additions & 2 deletions go/vt/vtgate/evalengine/expr_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,13 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) {

swapped := false
var skip2 *jump
nullable := true

switch expr.Op.(type) {
case compareNullSafeEQ:
skip2 = c.asm.jumpFrom()
c.asm.Cmp_nullsafe(skip2)
nullable = false
default:
skip2 = c.compileNullCheck1r(rt)
}
Expand Down Expand Up @@ -407,6 +409,9 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) {
}

cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}
if nullable {
cmptype.Flag |= nullableFlags(lt.Flag | rt.Flag)
}

switch expr.Op.(type) {
case compareEQ:
Expand Down Expand Up @@ -540,16 +545,18 @@ func (expr *InExpr) compile(c *compiler) (ctype, error) {

switch rhs := expr.Right.(type) {
case TupleExpr:
var rt ctype
if table := expr.compileTable(lhs, rhs); table != nil {
c.asm.In_table(expr.Negate, table)
} else {
_, err := rhs.compile(c)
rt, err = rhs.compile(c)
if err != nil {
return ctype{}, err
}
c.asm.In_slow(expr.Negate)
}
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil

return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil
case *BindVariable:
return ctype{}, c.unsupported(expr)
default:
Expand Down
10 changes: 7 additions & 3 deletions go/vt/vtgate/evalengine/expr_logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func (expr *NotExpr) compile(c *compiler) (ctype, error) {
c.asm.Not_i()
}
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil
return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil
}

func (l *LogicalExpr) eval(env *ExpressionEnv) (eval, error) {
Expand Down Expand Up @@ -450,7 +450,7 @@ func (expr *LogicalExpr) compile(c *compiler) (ctype, error) {

expr.op.compileRight(c)
c.asm.jumpDestination(jump)
return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil
return ctype{Type: sqltypes.Int64, Flag: ((lt.Flag | rt.Flag) & flagNullable) | flagIsBoolean, Col: collationNumeric}, nil
}

func intervalCompare(n, val eval) (int, bool, error) {
Expand Down Expand Up @@ -711,7 +711,11 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
}
}

ct := ctype{Type: ta.result(), Col: ca.result()}
var f typeFlag
if ta.nullable {
f |= flagNullable
}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()}
c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col)
return ct, nil
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/expr_tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ func (tuple TupleExpr) FormatFast(buf *sqlparser.TrackedBuffer) {
}

func (tuple TupleExpr) typeof(*ExpressionEnv) (ctype, error) {
return ctype{Type: sqltypes.Tuple}, nil
return ctype{Type: sqltypes.Tuple, Col: collationBinary}, nil
}
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/fn_base64.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (call *builtinToBase64) compile(c *compiler) (ctype, error) {
c.asm.Fn_TO_BASE64(t, col)
c.asm.jumpDestination(skip)

return ctype{Type: t, Col: col}, nil
return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: col}, nil
}

func (call *builtinFromBase64) eval(env *ExpressionEnv) (eval, error) {
Expand Down Expand Up @@ -155,5 +155,5 @@ func (call *builtinFromBase64) compile(c *compiler) (ctype, error) {
c.asm.Fn_FROM_BASE64(t)
c.asm.jumpDestination(skip)

return ctype{Type: t, Col: collationBinary}, nil
return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: collationBinary}, nil
}
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/fn_bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ func (expr *builtinBitCount) compile(c *compiler) (ctype, error) {
if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() {
c.asm.BitCount_b()
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil
return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil
}

_ = c.compileToBitwiseUint64(ct, 1)
c.asm.BitCount_u()
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil
return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil
}
Loading

0 comments on commit 1b6d257

Please sign in to comment.