Skip to content

Commit

Permalink
evalengine: Implement PERIOD_ADD (#16492)
Browse files Browse the repository at this point in the history
  • Loading branch information
beingnoble03 authored Aug 7, 2024
1 parent d042d7c commit 27f2d73
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 0 deletions.
45 changes: 45 additions & 0 deletions go/mysql/datetime/mydate.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,48 @@ func DateFromDayNumber(daynr int) Date {
d.year, d.month, d.day = mysqlDateFromDayNumber(daynr)
return d
}

// ValidatePeriod validates the MySQL period.
// Returns false if period is non-positive or contains incorrect month value.
func ValidatePeriod(period int64) bool {
if period <= 0 {
return false
}
month := period % 100
if month == 0 || month > 12 {
return false
}
return true
}

// PeriodToMonths converts a MySQL period into number of months.
// This is an algorithm that has been reverse engineered from MySQL.
func PeriodToMonths(period int64) int64 {
p := uint64(period)
if p == 0 {
return 0
}
y := p / 100
if y < 70 {
y += 2000
} else if y < 100 {
y += 1900
}
return int64(y*12 + p%100 - 1)
}

// MonthsToPeriod converts number of months into MySQL period.
// This is an algorithm that has been reverse engineered from MySQL.
func MonthsToPeriod(months int64) int64 {
m := uint64(months)
if m == 0 {
return 0
}
y := m / 12
if y < 70 {
y += 2000
} else if y < 100 {
y += 1900
}
return int64(y*100 + m%12 + 1)
}
54 changes: 54 additions & 0 deletions go/mysql/datetime/mydate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,57 @@ func TestDayNumberFields(t *testing.T) {
assert.Equal(t, wantDate, got)
}
}

func TestValidatePeriod(t *testing.T) {
testCases := []struct {
period int64
want bool
}{
{110112, true},
{101122, false},
{-1112212, false},
{7110, true},
}

for _, tc := range testCases {
got := ValidatePeriod(tc.period)
assert.Equal(t, tc.want, got)
}
}

func TestPeriodToMonths(t *testing.T) {
testCases := []struct {
period int64
want int64
}{
{0, 0},
{110112, 13223},
{100112, 12023},
{7112, 23663},
{200112, 24023},
{112, 24023},
}

for _, tc := range testCases {
got := PeriodToMonths(tc.period)
assert.Equal(t, tc.want, got)
}
}

func TestMonthsToPeriod(t *testing.T) {
testCases := []struct {
months int64
want int64
}{
{0, 0},
{13223, 110112},
{12023, 100112},
{23663, 197112},
{24023, 200112},
}

for _, tc := range testCases {
got := MonthsToPeriod(tc.months)
assert.Equal(t, tc.want, got)
}
}
12 changes: 12 additions & 0 deletions go/vt/vtgate/evalengine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4299,6 +4299,26 @@ func (asm *assembler) Fn_YEARWEEK() {
}, "FN YEARWEEK DATE(SP-1)")
}

func (asm *assembler) Fn_PERIOD_ADD() {
asm.adjustStack(-1)
asm.emit(func(env *ExpressionEnv) int {
if env.vm.stack[env.vm.sp-2] == nil {
env.vm.sp--
return 1
}
period := env.vm.stack[env.vm.sp-2].(*evalInt64).i
months := env.vm.stack[env.vm.sp-1].(*evalInt64).i
res, err := periodAdd(period, months)
if err != nil {
env.vm.err = err
return 0
}
env.vm.stack[env.vm.sp-2] = res
env.vm.sp--
return 1
}, "FN PERIOD_ADD INT64(SP-2) INT64(SP-1)")
}

func (asm *assembler) Interval(l int) {
asm.adjustStack(-l)
asm.emit(func(env *ExpressionEnv) int {
Expand Down
57 changes: 57 additions & 0 deletions go/vt/vtgate/evalengine/fn_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import (
"vitess.io/vitess/go/mysql/datetime"
"vitess.io/vitess/go/mysql/decimal"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/vterrors"

vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
)

var SystemTime = time.Now
Expand Down Expand Up @@ -174,6 +177,10 @@ type (
CallExpr
}

builtinPeriodAdd struct {
CallExpr
}

builtinDateMath struct {
CallExpr
sub bool
Expand Down Expand Up @@ -214,6 +221,7 @@ var _ IR = (*builtinWeekDay)(nil)
var _ IR = (*builtinWeekOfYear)(nil)
var _ IR = (*builtinYear)(nil)
var _ IR = (*builtinYearWeek)(nil)
var _ IR = (*builtinPeriodAdd)(nil)

func (call *builtinNow) eval(env *ExpressionEnv) (eval, error) {
now := env.time(call.utc)
Expand Down Expand Up @@ -1964,6 +1972,55 @@ func (call *builtinYearWeek) compile(c *compiler) (ctype, error) {
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag | flagNullable}, nil
}

func periodAdd(period, months int64) (*evalInt64, error) {
if !datetime.ValidatePeriod(period) {
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongArguments, "Incorrect arguments to period_add")
}
return newEvalInt64(datetime.MonthsToPeriod(datetime.PeriodToMonths(period) + months)), nil
}

func (b *builtinPeriodAdd) eval(env *ExpressionEnv) (eval, error) {
p, m, err := b.arg2(env)
if err != nil {
return nil, err
}
if p == nil || m == nil {
return nil, nil
}
period := evalToInt64(p)
months := evalToInt64(m)
return periodAdd(period.i, months.i)
}

func (call *builtinPeriodAdd) compile(c *compiler) (ctype, error) {
period, err := call.Arguments[0].compile(c)
if err != nil {
return ctype{}, err
}
months, err := call.Arguments[1].compile(c)
if err != nil {
return ctype{}, err
}

skip := c.compileNullCheck2(period, months)

switch period.Type {
case sqltypes.Int64:
default:
c.asm.Convert_xi(2)
}

switch months.Type {
case sqltypes.Int64:
default:
c.asm.Convert_xi(1)
}

c.asm.Fn_PERIOD_ADD()
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.Int64, Flag: period.Flag | months.Flag | flagNullable}, nil
}

func evalToInterval(itv eval, unit datetime.IntervalType, negate bool) *datetime.Interval {
switch itv := itv.(type) {
case *evalBytes:
Expand Down
22 changes: 22 additions & 0 deletions go/vt/vtgate/evalengine/testcases/cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ var Cases = []TestCase{
{Run: FnWeekOfYear},
{Run: FnYear},
{Run: FnYearWeek},
{Run: FnPeriodAdd},
{Run: FnInetAton},
{Run: FnInetNtoa},
{Run: FnInet6Aton},
Expand Down Expand Up @@ -2223,6 +2224,27 @@ func FnYearWeek(yield Query) {
}
}

func FnPeriodAdd(yield Query) {
for _, p := range inputBitwise {
for _, m := range inputBitwise {
yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil)
}
}
for _, p := range inputPeriods {
for _, m := range inputBitwise {
yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil)
}
}

mysqlDocSamples := []string{
`PERIOD_ADD(200801,2)`,
}

for _, q := range mysqlDocSamples {
yield(q, nil)
}
}

func FnInetAton(yield Query) {
for _, d := range ipInputs {
yield(fmt.Sprintf("INET_ATON(%s)", d), nil)
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/evalengine/testcases/inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ var inputBitwise = []string{
"64", "'64'", "_binary '64'", "X'40'", "_binary X'40'",
}

var inputPeriods = []string{
"110192", "'119812'", "2703", "7111", "200103", "200309", "0309", "-110102", "0",
"'032'", "223", "'-119812'", "-2703", "99999999999999999999999911", "'-0309'",
}

var radianInputs = []string{
"0",
"1",
Expand Down
7 changes: 7 additions & 0 deletions go/vt/vtgate/evalengine/translate_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,13 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) {
default:
return nil, argError(method)
}
case "period_add":
switch len(args) {
case 2:
return &builtinPeriodAdd{CallExpr: call}, nil
default:
return nil, argError(method)
}
case "inet_aton":
if len(args) != 1 {
return nil, argError(method)
Expand Down

0 comments on commit 27f2d73

Please sign in to comment.