From f0c92b56a8c1e9ebf85cf4e54bb0fa886d4cec0a Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 27 Nov 2023 12:03:29 +0100 Subject: [PATCH] evalengine: Handle zero dates correctly This removes the different configuration for testing to ensure we have the same settings as production defaults and subsequently fixes the failing evalengine issues. Signed-off-by: Dirkjan Bussink --- config/mycnf/test-suite.cnf | 9 +-- go/vt/vtgate/engine/concatenate.go | 22 +++--- go/vt/vtgate/engine/distinct.go | 9 ++- go/vt/vtgate/engine/fake_vcursor_test.go | 4 + go/vt/vtgate/engine/hash_join.go | 3 +- go/vt/vtgate/engine/primitive.go | 1 + go/vt/vtgate/evalengine/api_coerce.go | 4 +- go/vt/vtgate/evalengine/api_hash.go | 20 ++--- go/vt/vtgate/evalengine/api_hash_test.go | 16 ++-- go/vt/vtgate/evalengine/compiler.go | 11 +-- go/vt/vtgate/evalengine/compiler_asm.go | 46 ++--------- go/vt/vtgate/evalengine/compiler_test.go | 8 ++ go/vt/vtgate/evalengine/eval.go | 12 +-- go/vt/vtgate/evalengine/eval_temporal.go | 76 ++++++++++--------- go/vt/vtgate/evalengine/expr_convert.go | 4 +- go/vt/vtgate/evalengine/expr_env.go | 11 ++- go/vt/vtgate/evalengine/expr_logical.go | 4 +- go/vt/vtgate/evalengine/fn_time.go | 68 ++++++++--------- .../evalengine/integration/comparison_test.go | 4 + go/vt/vtgate/evalengine/translate.go | 9 ++- go/vt/vtgate/evalengine/weights.go | 10 +-- go/vt/vtgate/evalengine/weights_test.go | 6 +- go/vt/vtgate/vcursor_impl.go | 5 ++ 23 files changed, 182 insertions(+), 180 deletions(-) diff --git a/config/mycnf/test-suite.cnf b/config/mycnf/test-suite.cnf index e6d0992f6e6..28f4ac16e0d 100644 --- a/config/mycnf/test-suite.cnf +++ b/config/mycnf/test-suite.cnf @@ -1,5 +1,5 @@ # This sets some unsafe settings specifically for -# the test-suite which is currently MySQL 5.7 based +# the test-suite which is currently MySQL 8.0 based # In future it should be renamed testsuite.cnf innodb_buffer_pool_size = 32M @@ -14,13 +14,6 @@ key_buffer_size = 2M sync_binlog=0 innodb_doublewrite=0 -# These two settings are required for the testsuite to pass, -# but enabling them does not spark joy. They should be removed -# in the future. See: -# https://github.com/vitessio/vitess/issues/5396 - -sql_mode = STRICT_TRANS_TABLES - # set a short heartbeat interval in order to detect failures quickly slave_net_timeout = 4 # Disabling `super-read-only`. `test-suite` is mainly used for `vttestserver`. Since `vttestserver` uses a single MySQL for primary and replicas, diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 1e8cb655547..0f52acf6570 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -105,7 +105,7 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars err = c.coerceAndVisitResults(res, fields, func(result *sqltypes.Result) error { rows = append(rows, result.Rows...) return nil - }) + }, vcursor.AllowZeroDate()) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars }, nil } -func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) error { +func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field, allowZeroDate bool) error { if len(row) != len(fields) { return errWrongNumberOfColumnsInSelect } @@ -126,7 +126,7 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) continue } if fields[i].Type != value.Type() { - newValue, err := evalengine.CoerceTo(value, fields[i].Type) + newValue, err := evalengine.CoerceTo(value, fields[i].Type, allowZeroDate) if err != nil { return err } @@ -228,16 +228,17 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV // TryStreamExecute performs a streaming exec. func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool, callback func(*sqltypes.Result) error) error { + allowZeroDate := vcursor.AllowZeroDate() if vcursor.Session().InTransaction() { // as we are in a transaction, we need to execute all queries inside a single connection, // which holds the single transaction we have - return c.sequentialStreamExec(ctx, vcursor, bindVars, callback) + return c.sequentialStreamExec(ctx, vcursor, bindVars, callback, allowZeroDate) } // not in transaction, so execute in parallel. - return c.parallelStreamExec(ctx, vcursor, bindVars, callback) + return c.parallelStreamExec(ctx, vcursor, bindVars, callback, allowZeroDate) } -func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error { +func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error, allowZeroDate bool) error { // Scoped context; any early exit triggers cancel() to clean up ongoing work. ctx, cancel := context.WithCancel(inCtx) defer cancel() @@ -271,7 +272,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // Apply type coercion if needed. if needsCoercion { for _, row := range res.Rows { - if err := c.coerceValuesTo(row, fields); err != nil { + if err := c.coerceValuesTo(row, fields, allowZeroDate); err != nil { return err } } @@ -340,7 +341,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, return wg.Wait() } -func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { +func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error, allowZeroDate bool) error { // all the below fields ensure that the fields are sent only once. results := make([][]*sqltypes.Result, len(c.Sources)) @@ -374,7 +375,7 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, return err } for _, res := range results { - if err = c.coerceAndVisitResults(res, fields, callback); err != nil { + if err = c.coerceAndVisitResults(res, fields, callback, allowZeroDate); err != nil { return err } } @@ -386,6 +387,7 @@ func (c *Concatenate) coerceAndVisitResults( res []*sqltypes.Result, fields []*querypb.Field, callback func(*sqltypes.Result) error, + allowZeroDate bool, ) error { for _, r := range res { if len(r.Rows) > 0 && @@ -402,7 +404,7 @@ func (c *Concatenate) coerceAndVisitResults( } if needsCoercion { for _, row := range r.Rows { - err := c.coerceValuesTo(row, fields) + err := c.coerceValuesTo(row, fields, allowZeroDate) if err != nil { return err } diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index 2d263464a2e..0c990008a0f 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -44,8 +44,9 @@ type ( Type evalengine.Type } probeTable struct { - seenRows map[evalengine.HashCode][]sqltypes.Row - checkCols []CheckCol + seenRows map[evalengine.HashCode][]sqltypes.Row + checkCols []CheckCol + allowZeroDate bool } ) @@ -119,14 +120,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (evalengine.HashCode return 0, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code") } col := inputRow[checkCol.Col] - hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Type.Collation(), col.Type()) + hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Type.Collation(), col.Type(), pt.allowZeroDate) if err != nil { if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil { return 0, err } checkCol = checkCol.SwitchToWeightString() pt.checkCols[i] = checkCol - hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Type.Collation(), col.Type()) + hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Type.Collation(), col.Type(), pt.allowZeroDate) if err != nil { return 0, err } diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 0ee95a72e60..0bb138d8d1a 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -135,6 +135,10 @@ func (t *noopVCursor) TimeZone() *time.Location { return nil } +func (t *noopVCursor) AllowZeroDate() bool { + return false +} + func (t *noopVCursor) ExecutePrimitive(ctx context.Context, primitive Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { return primitive.TryExecute(ctx, t, bindVars, wantfields) } diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go index 4f205f1bcdc..69b20777cd6 100644 --- a/go/vt/vtgate/engine/hash_join.go +++ b/go/vt/vtgate/engine/hash_join.go @@ -75,6 +75,7 @@ type ( lhsKey, rhsKey int cols []int hasher vthash.Hasher + allowZeroDate bool } probeTableEntry struct { @@ -283,7 +284,7 @@ func (pt *hashJoinProbeTable) addLeftRow(r sqltypes.Row) error { } func (pt *hashJoinProbeTable) hash(val sqltypes.Value) (vthash.Hash, error) { - err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ) + err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.allowZeroDate) if err != nil { return vthash.Hash{}, err } diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index dc1259e2267..e6334baa13d 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -88,6 +88,7 @@ type ( ConnCollation() collations.ID TimeZone() *time.Location + AllowZeroDate() bool ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) diff --git a/go/vt/vtgate/evalengine/api_coerce.go b/go/vt/vtgate/evalengine/api_coerce.go index 143d22ab78c..d6e003514d4 100644 --- a/go/vt/vtgate/evalengine/api_coerce.go +++ b/go/vt/vtgate/evalengine/api_coerce.go @@ -23,8 +23,8 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -func CoerceTo(value sqltypes.Value, typ sqltypes.Type) (sqltypes.Value, error) { - cast, err := valueToEvalCast(value, value.Type(), collations.Unknown) +func CoerceTo(value sqltypes.Value, typ sqltypes.Type, allowZeroDate bool) (sqltypes.Value, error) { + cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, allowZeroDate) if err != nil { return sqltypes.Value{}, err } diff --git a/go/vt/vtgate/evalengine/api_hash.go b/go/vt/vtgate/evalengine/api_hash.go index 209f766840d..fefab9161a6 100644 --- a/go/vt/vtgate/evalengine/api_hash.go +++ b/go/vt/vtgate/evalengine/api_hash.go @@ -34,8 +34,8 @@ type HashCode = uint64 // NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same // for two values that are considered equal by `NullsafeCompare`. -func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type) (HashCode, error) { - e, err := valueToEvalCast(v, coerceType, collation) +func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type, allowZeroDate bool) (HashCode, error) { + e, err := valueToEvalCast(v, coerceType, collation, allowZeroDate) if err != nil { return 0, err } @@ -75,7 +75,7 @@ var ErrHashCoercionIsNotExact = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, " // for two values that are considered equal by `NullsafeCompare`. // This can be used to avoid having to do comparison checks after a hash, // since we consider the 128 bits of entropy enough to guarantee uniqueness. -func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type) error { +func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, allowZeroDate bool) error { switch { case v.IsNull(), sqltypes.IsNull(coerceTo): hash.Write16(hashPrefixNil) @@ -97,7 +97,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat case v.IsText(), v.IsBinary(): f, _ = fastparse.ParseFloat64(v.RawStr()) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, allowZeroDate) } if err != nil { return err @@ -137,7 +137,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat } neg = i < 0 default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, allowZeroDate) } if err != nil { return err @@ -180,7 +180,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat u, err = uint64(fval), nil } default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, allowZeroDate) } if err != nil { return err @@ -223,20 +223,20 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, allowZeroDate) } hash.Write16(hashPrefixDecimal) dec.Hash(hash) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, allowZeroDate) } return nil } -func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type) error { +func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, allowZeroDate bool) error { // Slow path to handle all other types. This uses the generic // logic for value casting to ensure we match MySQL here. - e, err := valueToEvalCast(v, coerceTo, collation) + e, err := valueToEvalCast(v, coerceTo, collation, allowZeroDate) if err != nil { return err } diff --git a/go/vt/vtgate/evalengine/api_hash_test.go b/go/vt/vtgate/evalengine/api_hash_test.go index 0add16de89d..1bd44d4d7e4 100644 --- a/go/vt/vtgate/evalengine/api_hash_test.go +++ b/go/vt/vtgate/evalengine/api_hash_test.go @@ -56,10 +56,10 @@ func TestHashCodes(t *testing.T) { require.NoError(t, err) require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) - h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type()) + h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), false) require.NoError(t, err) - h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type()) + h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), false) require.ErrorIs(t, err, tc.err) assert.Equalf(t, tc.equal, h1 == h2, "HASH(%v) %s HASH(%v) (expected %s)", tc.static, equality(h1 == h2).Operator(), tc.dynamic, equality(tc.equal)) @@ -82,9 +82,9 @@ func TestHashCodesRandom(t *testing.T) { typ, err := coerceTo(v1.Type(), v2.Type()) require.NoError(t, err) - hash1, err := NullsafeHashcode(v1, collation, typ) + hash1, err := NullsafeHashcode(v1, collation, typ, false) require.NoError(t, err) - hash2, err := NullsafeHashcode(v2, collation, typ) + hash2, err := NullsafeHashcode(v2, collation, typ, false) require.NoError(t, err) if cmp == 0 { equal++ @@ -142,11 +142,11 @@ func TestHashCodes128(t *testing.T) { require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) hasher1 := vthash.New() - err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type()) + err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), false) require.NoError(t, err) hasher2 := vthash.New() - err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type()) + err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), false) require.ErrorIs(t, err, tc.err) h1 := hasher1.Sum128() @@ -172,10 +172,10 @@ func TestHashCodesRandom128(t *testing.T) { require.NoError(t, err) hasher1 := vthash.New() - err = NullsafeHashcode128(&hasher1, v1, collation, typ) + err = NullsafeHashcode128(&hasher1, v1, collation, typ, false) require.NoError(t, err) hasher2 := vthash.New() - err = NullsafeHashcode128(&hasher2, v2, collation, typ) + err = NullsafeHashcode128(&hasher2, v2, collation, typ, false) require.NoError(t, err) if cmp == 0 { equal++ diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index e7d0e962fab..24a0f4a20d0 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -30,9 +30,10 @@ import ( type frame func(env *ExpressionEnv) int type compiler struct { - collation collations.ID - dynamicTypes []ctype - asm assembler + collation collations.ID + dynamicTypes []ctype + asm assembler + allowZeroDate bool } type CompilerLog interface { @@ -257,7 +258,7 @@ func (c *compiler) compileToDate(doct ctype, offset int) ctype { case sqltypes.Date: return doct default: - c.asm.Convert_xD(offset) + c.asm.Convert_xD(offset, c.allowZeroDate) } return ctype{Type: sqltypes.Date, Col: collationBinary, Flag: flagNullable} } @@ -268,7 +269,7 @@ func (c *compiler) compileToDateTime(doct ctype, offset, prec int) ctype { c.asm.Convert_tp(offset, prec) return doct default: - c.asm.Convert_xDT(offset, prec) + c.asm.Convert_xDT(offset, prec, c.allowZeroDate) } return ctype{Type: sqltypes.Datetime, Size: int32(prec), Col: collationBinary, Flag: flagNullable} } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 19d4c01a399..cbf9df9c57e 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -516,7 +516,7 @@ func (asm *assembler) Cmp_ne_n() { }, "CMPFLAG NE [NULL]") } -func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation) { +func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation, allowZeroDate bool) { elseOffset := 0 if hasElse { elseOffset = 1 @@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll end := env.vm.sp - elseOffset for sp := env.vm.sp - stackDepth; sp < end; sp += 2 { if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now, allowZeroDate) goto done } } if elseOffset != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now, allowZeroDate) } else { env.vm.stack[env.vm.sp-stackDepth] = nil } @@ -1126,12 +1126,12 @@ func (asm *assembler) Convert_xu(offset int) { }, "CONV (SP-%d), UINT64", offset) } -func (asm *assembler) Convert_xD(offset int) { +func (asm *assembler) Convert_xD(offset int, allowZero bool) { asm.emit(func(env *ExpressionEnv) int { // Need to explicitly check here or we otherwise // store a nil wrapper in an interface vs. a direct // nil. - d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now) + d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now, allowZero) if d == nil { env.vm.stack[env.vm.sp-offset] = nil } else { @@ -1141,27 +1141,12 @@ func (asm *assembler) Convert_xD(offset int) { }, "CONV (SP-%d), DATE", offset) } -func (asm *assembler) Convert_xD_nz(offset int) { +func (asm *assembler) Convert_xDT(offset, prec int, allowZero bool) { asm.emit(func(env *ExpressionEnv) int { // Need to explicitly check here or we otherwise // store a nil wrapper in an interface vs. a direct // nil. - d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now) - if d == nil || d.isZero() { - env.vm.stack[env.vm.sp-offset] = nil - } else { - env.vm.stack[env.vm.sp-offset] = d - } - return 1 - }, "CONV (SP-%d), DATE(NOZERO)", offset) -} - -func (asm *assembler) Convert_xDT(offset, prec int) { - asm.emit(func(env *ExpressionEnv) int { - // Need to explicitly check here or we otherwise - // store a nil wrapper in an interface vs. a direct - // nil. - dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now) + dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now, allowZero) if dt == nil { env.vm.stack[env.vm.sp-offset] = nil } else { @@ -1171,21 +1156,6 @@ func (asm *assembler) Convert_xDT(offset, prec int) { }, "CONV (SP-%d), DATETIME", offset) } -func (asm *assembler) Convert_xDT_nz(offset, prec int) { - asm.emit(func(env *ExpressionEnv) int { - // Need to explicitly check here or we otherwise - // store a nil wrapper in an interface vs. a direct - // nil. - dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now) - if dt == nil || dt.isZero() { - env.vm.stack[env.vm.sp-offset] = nil - } else { - env.vm.stack[env.vm.sp-offset] = dt - } - return 1 - }, "CONV (SP-%d), DATETIME(NOZERO)", offset) -} - func (asm *assembler) Convert_xT(offset, prec int) { asm.emit(func(env *ExpressionEnv) int { t := evalToTime(env.vm.stack[env.vm.sp-offset], prec) @@ -4189,7 +4159,7 @@ func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col col goto baddate } - tmp = evalToTemporal(env.vm.stack[env.vm.sp-2]) + tmp = evalToTemporal(env.vm.stack[env.vm.sp-2], true) if tmp == nil { goto baddate } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 77174a01d71..e7b51b41748 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -564,6 +564,14 @@ func TestCompilerSingle(t *testing.T) { expression: `case when null is null then 23 else null end`, result: `INT64(23)`, }, + { + expression: `CAST(0 AS DATE)`, + result: `NULL`, + }, + { + expression: `DAYOFMONTH(0)`, + result: `INT64(0)`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 82f1ec688c7..5afce3090fd 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -176,7 +176,7 @@ func evalIsTruthy(e eval) boolean { } } -func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (eval, error) { +func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, allowZero bool) (eval, error) { if e == nil { return nil, nil } @@ -208,9 +208,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (ev case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64: return evalToInt64(e).toUint64(), nil case sqltypes.Date: - return evalToDate(e, now), nil + return evalToDate(e, now, allowZero), nil case sqltypes.Datetime, sqltypes.Timestamp: - return evalToDateTime(e, -1, now), nil + return evalToDateTime(e, -1, now, allowZero), nil case sqltypes.Time: return evalToTime(e, -1), nil default: @@ -218,7 +218,7 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (ev } } -func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID) (eval, error) { +func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID, allowZeroDate bool) (eval, error) { switch { case typ == sqltypes.Null: return nil, nil @@ -338,7 +338,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, err } // Separate return here to avoid nil wrapped in interface type - d := evalToDate(e, time.Now()) + d := evalToDate(e, time.Now(), allowZeroDate) if d == nil { return nil, nil } @@ -349,7 +349,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, err } // Separate return here to avoid nil wrapped in interface type - dt := evalToDateTime(e, -1, time.Now()) + dt := evalToDateTime(e, -1, time.Now(), allowZeroDate) if dt == nil { return nil, nil } diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index d44839a6853..7952fc5aa5d 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -164,11 +164,17 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations. return tmp } -func newEvalDateTime(dt datetime.DateTime, l int) *evalTemporal { +func newEvalDateTime(dt datetime.DateTime, l int, allowZero bool) *evalTemporal { + if !allowZero && dt.IsZero() { + return nil + } return &evalTemporal{t: sqltypes.Datetime, dt: dt.Round(l), prec: uint8(l)} } -func newEvalDate(d datetime.Date) *evalTemporal { +func newEvalDate(d datetime.Date, allowZero bool) *evalTemporal { + if !allowZero && d.IsZero() { + return nil + } return &evalTemporal{t: sqltypes.Date, dt: datetime.DateTime{Date: d}} } @@ -185,7 +191,7 @@ func parseDate(s []byte) (*evalTemporal, error) { if !ok { return nil, errIncorrectTemporal("DATE", s) } - return newEvalDate(t), nil + return newEvalDate(t, true), nil } func parseDateTime(s []byte) (*evalTemporal, error) { @@ -193,7 +199,7 @@ func parseDateTime(s []byte) (*evalTemporal, error) { if !ok { return nil, errIncorrectTemporal("DATETIME", s) } - return newEvalDateTime(t, l), nil + return newEvalDateTime(t, l, true), nil } func parseTime(s []byte) (*evalTemporal, error) { @@ -211,56 +217,56 @@ func precision(req, got int) int { return req } -func evalToTemporal(e eval) *evalTemporal { +func evalToTemporal(e eval, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: return e case *evalBytes: if t, l, ok := datetime.ParseDateTime(e.string(), -1); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDate(e.string()); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if t, l, ok := datetime.ParseTime(e.string(), -1); ok { return newEvalTime(t, l) } case *evalInt64: if t, ok := datetime.ParseDateTimeInt64(e.i); ok { - return newEvalDateTime(t, 0) + return newEvalDateTime(t, 0, allowZero) } if d, ok := datetime.ParseDateInt64(e.i); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if t, ok := datetime.ParseTimeInt64(e.i); ok { return newEvalTime(t, 0) } case *evalUint64: if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok { - return newEvalDateTime(t, 0) + return newEvalDateTime(t, 0, allowZero) } if d, ok := datetime.ParseDateInt64(int64(e.u)); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if t, ok := datetime.ParseTimeInt64(int64(e.u)); ok { return newEvalTime(t, 0) } case *evalFloat: if t, l, ok := datetime.ParseDateTimeFloat(e.f, -1); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDateFloat(e.f); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if t, l, ok := datetime.ParseTimeFloat(e.f, -1); ok { return newEvalTime(t, l) } case *evalDecimal: if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, -1); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDateDecimal(e.dec); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if d, l, ok := datetime.ParseTimeDecimal(e.dec, e.length, -1); ok { return newEvalTime(d, l) @@ -271,9 +277,9 @@ func evalToTemporal(e eval) *evalTemporal { return newEvalTime(dt.Time, datetime.DefaultPrecision) } if dt.Time.IsZero() { - return newEvalDate(dt.Date) + return newEvalDate(dt.Date, allowZero) } - return newEvalDateTime(dt, datetime.DefaultPrecision) + return newEvalDateTime(dt, datetime.DefaultPrecision, allowZero) } } return nil @@ -326,74 +332,74 @@ func evalToTime(e eval, l int) *evalTemporal { return nil } -func evalToDateTime(e eval, l int, now time.Time) *evalTemporal { +func evalToDateTime(e eval, l int, now time.Time, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: return e.toDateTime(precision(l, int(e.prec)), now) case *evalBytes: if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, _ := datetime.ParseDate(e.string()); !d.IsZero() { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalInt64: if t, ok := datetime.ParseDateTimeInt64(e.i); ok { - return newEvalDateTime(t, precision(l, 0)) + return newEvalDateTime(t, precision(l, 0), allowZero) } if d, ok := datetime.ParseDateInt64(e.i); ok { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalUint64: if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok { - return newEvalDateTime(t, precision(l, 0)) + return newEvalDateTime(t, precision(l, 0), allowZero) } if d, ok := datetime.ParseDateInt64(int64(e.u)); ok { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalFloat: if t, l, ok := datetime.ParseDateTimeFloat(e.f, l); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDateFloat(e.f); ok { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalDecimal: if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, l); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDateDecimal(e.dec); ok { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalJSON: if dt, ok := e.DateTime(); ok { - return newEvalDateTime(dt, precision(l, datetime.DefaultPrecision)) + return newEvalDateTime(dt, precision(l, datetime.DefaultPrecision), allowZero) } } return nil } -func evalToDate(e eval, now time.Time) *evalTemporal { +func evalToDate(e eval, now time.Time, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: return e.toDate(now) case *evalBytes: if t, _ := datetime.ParseDate(e.string()); !t.IsZero() { - return newEvalDate(t) + return newEvalDate(t, allowZero) } if dt, _, _ := datetime.ParseDateTime(e.string(), -1); !dt.IsZero() { - return newEvalDate(dt.Date) + return newEvalDate(dt.Date, allowZero) } case evalNumeric: if t, ok := datetime.ParseDateInt64(e.toInt64().i); ok { - return newEvalDate(t) + return newEvalDate(t, allowZero) } if dt, ok := datetime.ParseDateTimeInt64(e.toInt64().i); ok { - return newEvalDate(dt.Date) + return newEvalDate(dt.Date, allowZero) } case *evalJSON: if d, ok := e.Date(); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } } return nil diff --git a/go/vt/vtgate/evalengine/expr_convert.go b/go/vt/vtgate/evalengine/expr_convert.go index 900d4e37f8f..c186382ee80 100644 --- a/go/vt/vtgate/evalengine/expr_convert.go +++ b/go/vt/vtgate/evalengine/expr_convert.go @@ -124,12 +124,12 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) { case p > 6: return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p) } - if dt := evalToDateTime(e, c.Length, env.now); dt != nil { + if dt := evalToDateTime(e, c.Length, env.now, env.allowZeroDate); dt != nil { return dt, nil } return nil, nil case "DATE": - if d := evalToDate(e, env.now); d != nil { + if d := evalToDate(e, env.now, env.allowZeroDate); d != nil { return d, nil } return nil, nil diff --git a/go/vt/vtgate/evalengine/expr_env.go b/go/vt/vtgate/evalengine/expr_env.go index a6ca1411e74..523c3e905f7 100644 --- a/go/vt/vtgate/evalengine/expr_env.go +++ b/go/vt/vtgate/evalengine/expr_env.go @@ -30,6 +30,7 @@ import ( type VCursor interface { TimeZone() *time.Location GetKeyspace() string + AllowZeroDate() bool } type ( @@ -43,9 +44,10 @@ type ( Fields []*querypb.Field // internal state - now time.Time - vc VCursor - user *querypb.VTGateCallerID + now time.Time + vc VCursor + user *querypb.VTGateCallerID + allowZeroDate bool } ) @@ -121,5 +123,8 @@ func NewExpressionEnv(ctx context.Context, bindVars map[string]*querypb.BindVari env := &ExpressionEnv{BindVars: bindVars, vc: vc} env.user = callerid.ImmediateCallerIDFromContext(ctx) env.SetTime(time.Now()) + if vc != nil { + env.allowZeroDate = vc.AllowZeroDate() + } return env } diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index f22ca091acb..0cb4f7b1080 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -633,7 +633,7 @@ func (c *CaseExpr) eval(env *ExpressionEnv) (eval, error) { if !matched { return nil, nil } - return evalCoerce(result, ta.result(), ca.result().Collation, env.now) + return evalCoerce(result, ta.result(), ca.result().Collation, env.now, env.allowZeroDate) } func (c *CaseExpr) constant() bool { @@ -716,7 +716,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { f |= flagNullable } ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()} - c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col) + c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.allowZeroDate) return ct, nil } diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index b0e7366ce8f..c273d798aaf 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -266,7 +266,7 @@ func (b *builtinDateFormat) eval(env *ExpressionEnv) (eval, error) { case *evalTemporal: t = e.toDateTime(datetime.DefaultPrecision, env.now) default: - t = evalToDateTime(date, datetime.DefaultPrecision, env.now) + t = evalToDateTime(date, datetime.DefaultPrecision, env.now, env.allowZeroDate) if t == nil || t.isZero() { return nil, nil } @@ -291,7 +291,7 @@ func (call *builtinDateFormat) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Datetime, sqltypes.Date: default: - c.asm.Convert_xDT_nz(1, datetime.DefaultPrecision) + c.asm.Convert_xDT(1, datetime.DefaultPrecision, false) } format, err := call.Arguments[1].compile(c) @@ -359,7 +359,7 @@ func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - dt := evalToDateTime(n, -1, env.now) + dt := evalToDateTime(n, -1, env.now, env.allowZeroDate) if dt == nil || dt.isZero() { return nil, nil } @@ -368,7 +368,7 @@ func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) { if !ok { return nil, nil } - return newEvalDateTime(out, int(dt.prec)), nil + return newEvalDateTime(out, int(dt.prec), env.allowZeroDate), nil } func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { @@ -402,7 +402,7 @@ func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { switch n.Type { case sqltypes.Datetime, sqltypes.Date: default: - c.asm.Convert_xDT_nz(3, -1) + c.asm.Convert_xDT(3, -1, false) } c.asm.Fn_CONVERT_TZ() @@ -418,7 +418,7 @@ func (b *builtinDate) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.allowZeroDate) if d == nil { return nil, nil } @@ -436,7 +436,7 @@ func (call *builtinDate) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, c.allowZeroDate) } c.asm.jumpDestination(skip) @@ -451,7 +451,7 @@ func (b *builtinDayOfMonth) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, true) if d == nil { return nil, nil } @@ -469,7 +469,7 @@ func (call *builtinDayOfMonth) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, true) } c.asm.Fn_DAYOFMONTH() c.asm.jumpDestination(skip) @@ -484,7 +484,7 @@ func (b *builtinDayOfWeek) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.allowZeroDate) if d == nil || d.isZero() { return nil, nil } @@ -502,7 +502,7 @@ func (call *builtinDayOfWeek) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } c.asm.Fn_DAYOFWEEK() c.asm.jumpDestination(skip) @@ -517,7 +517,7 @@ func (b *builtinDayOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.allowZeroDate) if d == nil || d.isZero() { return nil, nil } @@ -535,7 +535,7 @@ func (call *builtinDayOfYear) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } c.asm.Fn_DAYOFYEAR() c.asm.jumpDestination(skip) @@ -610,7 +610,7 @@ func (b *builtinFromUnixtime) eval(env *ExpressionEnv) (eval, error) { t = t.In(tz) } - dt := newEvalDateTime(datetime.NewDateTimeFromStd(t), prec) + dt := newEvalDateTime(datetime.NewDateTimeFromStd(t), prec, env.allowZeroDate) if len(b.Arguments) == 1 { return dt, nil @@ -761,7 +761,7 @@ func (b *builtinMakedate) eval(env *ExpressionEnv) (eval, error) { if t.IsZero() { return nil, nil } - return newEvalDate(datetime.NewDateTimeFromStd(t).Date), nil + return newEvalDate(datetime.NewDateTimeFromStd(t).Date, env.allowZeroDate), nil } func (call *builtinMakedate) compile(c *compiler) (ctype, error) { @@ -1102,7 +1102,7 @@ func (b *builtinMonth) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, true) if d == nil { return nil, nil } @@ -1120,7 +1120,7 @@ func (call *builtinMonth) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, true) } c.asm.Fn_MONTH() c.asm.jumpDestination(skip) @@ -1135,7 +1135,7 @@ func (b *builtinMonthName) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.allowZeroDate) if d == nil { return nil, nil } @@ -1158,7 +1158,7 @@ func (call *builtinMonthName) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, c.allowZeroDate) } col := typedCoercionCollation(sqltypes.VarChar, call.collate) c.asm.Fn_MONTHNAME(col) @@ -1174,7 +1174,7 @@ func (b *builtinQuarter) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, true) if d == nil { return nil, nil } @@ -1192,7 +1192,7 @@ func (call *builtinQuarter) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, true) } c.asm.Fn_QUARTER() c.asm.jumpDestination(skip) @@ -1270,7 +1270,7 @@ func dateTimeUnixTimestamp(env *ExpressionEnv, date eval) evalNumeric { case *evalTemporal: dt = e.toDateTime(int(e.prec), env.now) default: - dt = evalToDateTime(date, -1, env.now) + dt = evalToDateTime(date, -1, env.now, env.allowZeroDate) if dt == nil || dt.isZero() { var prec int32 switch d := date.(type) { @@ -1351,7 +1351,7 @@ func (call *builtinUnixTimestamp) compile(c *compiler) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil } if lit, ok := call.Arguments[0].(*Literal); ok { - if dt := evalToDateTime(lit.inner, -1, time.Now()); dt != nil { + if dt := evalToDateTime(lit.inner, -1, time.Now(), c.allowZeroDate); dt != nil { if dt.prec == 0 { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil } @@ -1373,7 +1373,7 @@ func (b *builtinWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.allowZeroDate) if d == nil || d.isZero() { return nil, nil } @@ -1406,7 +1406,7 @@ func (call *builtinWeek) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } if len(call.Arguments) == 1 { @@ -1433,7 +1433,7 @@ func (b *builtinWeekDay) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.allowZeroDate) if d == nil || d.isZero() { return nil, nil } @@ -1451,7 +1451,7 @@ func (call *builtinWeekDay) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } c.asm.Fn_WEEKDAY() @@ -1467,7 +1467,7 @@ func (b *builtinWeekOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.allowZeroDate) if d == nil || d.isZero() { return nil, nil } @@ -1487,7 +1487,7 @@ func (call *builtinWeekOfYear) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } c.asm.Fn_WEEKOFYEAR() @@ -1503,7 +1503,7 @@ func (b *builtinYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, true) if d == nil { return nil, nil } @@ -1522,7 +1522,7 @@ func (call *builtinYear) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, true) } c.asm.Fn_YEAR() @@ -1539,7 +1539,7 @@ func (b *builtinYearWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.allowZeroDate) if d == nil || d.isZero() { return nil, nil } @@ -1572,7 +1572,7 @@ func (call *builtinYearWeek) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } if len(call.Arguments) == 1 { @@ -1624,7 +1624,7 @@ func (call *builtinDateMath) eval(env *ExpressionEnv) (eval, error) { return tmp.addInterval(interval, collations.Unknown, env.now), nil } - if tmp := evalToTemporal(date); tmp != nil { + if tmp := evalToTemporal(date, env.allowZeroDate); tmp != nil { return tmp.addInterval(interval, call.collate, env.now), nil } diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index 649dc7b5583..a59a8042bda 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -215,6 +215,10 @@ func (vc *vcursor) TimeZone() *time.Location { return time.Local } +func (vc *vcursor) AllowZeroDate() bool { + return false +} + func initTimezoneData(t *testing.T, conn *mysql.Conn) { // We load the timezone information into MySQL. The evalengine assumes // our backend MySQL is configured with the timezone information as well diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 7993854db36..58a143eaa45 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -569,6 +569,7 @@ type Config struct { Collation collations.ID NoConstantFolding bool NoCompilation bool + AllowZeroDate bool } func Translate(e sqlparser.Expr, cfg *Config) (Expr, error) { @@ -603,7 +604,7 @@ func Translate(e sqlparser.Expr, cfg *Config) (Expr, error) { } if len(ast.untyped) == 0 && !cfg.NoCompilation { - comp := compiler{collation: cfg.Collation} + comp := compiler{collation: cfg.Collation, allowZeroDate: cfg.AllowZeroDate} return comp.compile(expr) } @@ -626,9 +627,9 @@ type typedExpr struct { err error } -func (typed *typedExpr) compile(expr IR, collation collations.ID) (*CompiledExpr, error) { +func (typed *typedExpr) compile(expr IR, collation collations.ID, allowZeroDate bool) (*CompiledExpr, error) { typed.once.Do(func() { - comp := compiler{collation: collation, dynamicTypes: typed.types} + comp := compiler{collation: collation, dynamicTypes: typed.types, allowZeroDate: allowZeroDate} typed.compiled, typed.err = comp.compile(expr) }) return typed.compiled, typed.err @@ -695,7 +696,7 @@ func (u *UntypedExpr) Compile(env *ExpressionEnv) (*CompiledExpr, error) { if err != nil { return nil, err } - return typed.compile(u.ir, u.collation) + return typed.compile(u.ir, u.collation, env.allowZeroDate) } func (u *UntypedExpr) typeof(env *ExpressionEnv) (ctype, error) { diff --git a/go/vt/vtgate/evalengine/weights.go b/go/vt/vtgate/evalengine/weights.go index fa7fa7e11a6..97a7ae424d8 100644 --- a/go/vt/vtgate/evalengine/weights.go +++ b/go/vt/vtgate/evalengine/weights.go @@ -41,11 +41,11 @@ import ( // externally communicates with the `WEIGHT_STRING` function, so that we // can also use this to order / sort other types like Float and Decimal // as well. -func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int) ([]byte, bool, error) { +func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, allowZeroDate bool) ([]byte, bool, error) { // We optimize here for the case where we already have the desired type. // Otherwise, we fall back to the general evalengine conversion logic. if v.Type() != coerceTo { - return fallbackWeightString(dst, v, coerceTo, col, length, precision) + return fallbackWeightString(dst, v, coerceTo, col, length, precision, allowZeroDate) } switch { @@ -117,12 +117,12 @@ func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col coll } return j.WeightString(dst), false, nil default: - return fallbackWeightString(dst, v, coerceTo, col, length, precision) + return fallbackWeightString(dst, v, coerceTo, col, length, precision, allowZeroDate) } } -func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int) ([]byte, bool, error) { - e, err := valueToEvalCast(v, coerceTo, col) +func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, allowZeroDate bool) ([]byte, bool, error) { + e, err := valueToEvalCast(v, coerceTo, col, allowZeroDate) if err != nil { return dst, false, err } diff --git a/go/vt/vtgate/evalengine/weights_test.go b/go/vt/vtgate/evalengine/weights_test.go index 0dee4c72d03..d257d1bfc11 100644 --- a/go/vt/vtgate/evalengine/weights_test.go +++ b/go/vt/vtgate/evalengine/weights_test.go @@ -136,7 +136,7 @@ func TestWeightStrings(t *testing.T) { items := make([]item, 0, Length) for i := 0; i < Length; i++ { v := tc.gen() - w, _, err := WeightString(nil, v, typ, tc.col, tc.len, tc.prec) + w, _, err := WeightString(nil, v, typ, tc.col, tc.len, tc.prec, false) require.NoError(t, err) items = append(items, item{value: v, weight: string(w)}) @@ -156,9 +156,9 @@ func TestWeightStrings(t *testing.T) { a := items[i] b := items[i+1] - v1, err := valueToEvalCast(a.value, typ, tc.col) + v1, err := valueToEvalCast(a.value, typ, tc.col, false) require.NoError(t, err) - v2, err := valueToEvalCast(b.value, typ, tc.col) + v2, err := valueToEvalCast(b.value, typ, tc.col, false) require.NoError(t, err) cmp, err := evalCompareNullSafe(v1, v2) diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index 9ec4bb0dc03..d5643d13eef 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -211,6 +211,11 @@ func (vc *vcursorImpl) TimeZone() *time.Location { return vc.safeSession.TimeZone() } +func (vc *vcursorImpl) AllowZeroDate() bool { + // TODO: Implement detecting zero date + return false +} + // MaxMemoryRows returns the maxMemoryRows flag value. func (vc *vcursorImpl) MaxMemoryRows() int { return maxMemoryRows