diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 9009b069f5a..c1ed1f9475c 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1409,6 +1409,18 @@ func (cached *builtinPeriodAdd) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinPeriodDiff) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinPi) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 93781ed077b..dfb1a30bffc 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -4319,6 +4319,26 @@ func (asm *assembler) Fn_PERIOD_ADD() { }, "FN PERIOD_ADD INT64(SP-2) INT64(SP-1)") } +func (asm *assembler) Fn_PERIOD_DIFF() { + asm.adjustStack(-1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-2] == nil { + env.vm.sp-- + return 1 + } + period1 := env.vm.stack[env.vm.sp-2].(*evalInt64).i + period2 := env.vm.stack[env.vm.sp-1].(*evalInt64).i + res, err := periodDiff(period1, period2) + if err != nil { + env.vm.err = err + return 0 + } + env.vm.stack[env.vm.sp-2] = res + env.vm.sp-- + return 1 + }, "FN PERIOD_DIFF INT64(SP-2) INT64(SP-1)") +} + func (asm *assembler) Interval(l int) { asm.adjustStack(-l) asm.emit(func(env *ExpressionEnv) int { diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index 90fcda2c32a..2d5e12f518d 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -181,6 +181,10 @@ type ( CallExpr } + builtinPeriodDiff struct { + CallExpr + } + builtinDateMath struct { CallExpr sub bool @@ -222,6 +226,7 @@ var _ IR = (*builtinWeekOfYear)(nil) var _ IR = (*builtinYear)(nil) var _ IR = (*builtinYearWeek)(nil) var _ IR = (*builtinPeriodAdd)(nil) +var _ IR = (*builtinPeriodDiff)(nil) func (call *builtinNow) eval(env *ExpressionEnv) (eval, error) { now := env.time(call.utc) @@ -2021,6 +2026,56 @@ func (call *builtinPeriodAdd) compile(c *compiler) (ctype, error) { return ctype{Type: sqltypes.Int64, Flag: period.Flag | months.Flag | flagNullable}, nil } +func periodDiff(period1, period2 int64) (*evalInt64, error) { + if !datetime.ValidatePeriod(period1) || !datetime.ValidatePeriod(period2) { + return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongArguments, "Incorrect arguments to period_diff") + } + res := datetime.PeriodToMonths(period1) - datetime.PeriodToMonths(period2) + return newEvalInt64(res), nil +} + +func (b *builtinPeriodDiff) eval(env *ExpressionEnv) (eval, error) { + p1, p2, err := b.arg2(env) + if err != nil { + return nil, err + } + if p1 == nil || p2 == nil { + return nil, nil + } + period1 := evalToInt64(p1) + period2 := evalToInt64(p2) + return periodDiff(period1.i, period2.i) +} + +func (call *builtinPeriodDiff) compile(c *compiler) (ctype, error) { + period1, err := call.Arguments[0].compile(c) + if err != nil { + return ctype{}, err + } + period2, err := call.Arguments[1].compile(c) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck2(period1, period2) + + switch period1.Type { + case sqltypes.Int64: + default: + c.asm.Convert_xi(2) + } + + switch period2.Type { + case sqltypes.Int64: + default: + c.asm.Convert_xi(1) + } + + c.asm.Fn_PERIOD_DIFF() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Int64, Flag: period1.Flag | period2.Flag | flagNullable}, nil +} + func evalToInterval(itv eval, unit datetime.IntervalType, negate bool) *datetime.Interval { switch itv := itv.(type) { case *evalBytes: diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 7d5305b21f7..ed1c5ed1f76 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -156,6 +156,7 @@ var Cases = []TestCase{ {Run: FnYear}, {Run: FnYearWeek}, {Run: FnPeriodAdd}, + {Run: FnPeriodDiff}, {Run: FnInetAton}, {Run: FnInetNtoa}, {Run: FnInet6Aton}, @@ -2245,6 +2246,27 @@ func FnPeriodAdd(yield Query) { } } +func FnPeriodDiff(yield Query) { + for _, p1 := range inputBitwise { + for _, p2 := range inputBitwise { + yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil) + } + } + for _, p1 := range inputPeriods { + for _, p2 := range inputPeriods { + yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil) + } + } + + mysqlDocSamples := []string{ + `PERIOD_DIFF(200802,200703)`, + } + + for _, q := range mysqlDocSamples { + yield(q, nil) + } +} + func FnInetAton(yield Query) { for _, d := range ipInputs { yield(fmt.Sprintf("INET_ATON(%s)", d), nil) diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 2c4d887ff19..476ee32483b 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -535,6 +535,13 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { default: return nil, argError(method) } + case "period_diff": + switch len(args) { + case 2: + return &builtinPeriodDiff{CallExpr: call}, nil + default: + return nil, argError(method) + } case "inet_aton": if len(args) != 1 { return nil, argError(method)