From da964cfee085ba8060a9a4d6241fc1ea510ee1ec Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Fri, 9 Feb 2024 15:20:12 +0100 Subject: [PATCH 1/3] evalengine: Implement LOCATE and friends This implements `LOCATE`, `POSITION` and `INSTR` to find substrings inside another string. It diverges in behavior for multibyte characters because of the bugs in MySQL identified in https://bugs.mysql.com/bug.php?id=113933. Signed-off-by: Dirkjan Bussink --- go/mysql/collations/charset/eightbit/8bit.go | 14 ++ .../collations/charset/eightbit/binary.go | 14 ++ .../collations/charset/eightbit/latin1.go | 14 ++ go/mysql/collations/colldata/collation.go | 44 ++++++ go/vt/vtgate/evalengine/cached_size.go | 12 ++ go/vt/vtgate/evalengine/compiler_asm.go | 88 +++++++++++ go/vt/vtgate/evalengine/compiler_test.go | 12 ++ go/vt/vtgate/evalengine/fn_string.go | 139 ++++++++++++++++++ go/vt/vtgate/evalengine/testcases/cases.go | 29 ++++ go/vt/vtgate/evalengine/testcases/inputs.go | 36 +++++ go/vt/vtgate/evalengine/translate_builtin.go | 33 ++++- 11 files changed, 434 insertions(+), 1 deletion(-) diff --git a/go/mysql/collations/charset/eightbit/8bit.go b/go/mysql/collations/charset/eightbit/8bit.go index 5bd930c61cb..12630749d5d 100644 --- a/go/mysql/collations/charset/eightbit/8bit.go +++ b/go/mysql/collations/charset/eightbit/8bit.go @@ -81,3 +81,17 @@ func (Charset_8bit) Length(src []byte) int { func (Charset_8bit) MaxWidth() int { return 1 } + +func (Charset_8bit) Slice(src []byte, from, to int) []byte { + if from >= len(src) { + return nil + } + if to > len(src) { + to = len(src) + } + return src[from:to] +} + +func (Charset_8bit) Validate(src []byte) bool { + return true +} diff --git a/go/mysql/collations/charset/eightbit/binary.go b/go/mysql/collations/charset/eightbit/binary.go index 44824bbc342..fa36fcf66a5 100644 --- a/go/mysql/collations/charset/eightbit/binary.go +++ b/go/mysql/collations/charset/eightbit/binary.go @@ -62,3 +62,17 @@ func (Charset_binary) Length(src []byte) int { func (Charset_binary) MaxWidth() int { return 1 } + +func (Charset_binary) Slice(src []byte, from, to int) []byte { + if from >= len(src) { + return nil + } + if to > len(src) { + to = len(src) + } + return src[from:to] +} + +func (Charset_binary) Validate(src []byte) bool { + return true +} diff --git a/go/mysql/collations/charset/eightbit/latin1.go b/go/mysql/collations/charset/eightbit/latin1.go index 67fa07c62c2..f32b4523a18 100644 --- a/go/mysql/collations/charset/eightbit/latin1.go +++ b/go/mysql/collations/charset/eightbit/latin1.go @@ -230,3 +230,17 @@ func (Charset_latin1) Length(src []byte) int { func (Charset_latin1) MaxWidth() int { return 1 } + +func (Charset_latin1) Slice(src []byte, from, to int) []byte { + if from >= len(src) { + return nil + } + if to > len(src) { + to = len(src) + } + return src[from:to] +} + +func (Charset_latin1) Validate(src []byte) bool { + return true +} diff --git a/go/mysql/collations/colldata/collation.go b/go/mysql/collations/colldata/collation.go index 7697c08cbed..a041006ddc7 100644 --- a/go/mysql/collations/colldata/collation.go +++ b/go/mysql/collations/colldata/collation.go @@ -17,6 +17,7 @@ limitations under the License. package colldata import ( + "bytes" "fmt" "math" @@ -380,3 +381,46 @@ coerceToRight: return charset.Convert(dst, rightCS, in, leftCS) }, nil, nil } + +func Index(col Collation, str, sub []byte, offset int) int { + cs := col.Charset() + if offset > 0 { + l := charset.Length(cs, str) + if offset > l { + return -1 + } + str = charset.Slice(cs, str, offset, len(str)) + } + + pos := instr(col, str, sub) + if pos < 0 { + return -1 + } + return offset + pos +} + +func instr(col Collation, str, sub []byte) int { + if len(sub) == 0 { + return 0 + } + + if len(str) == 0 { + return -1 + } + + if col.IsBinary() && col.Charset().MaxWidth() == 1 { + return bytes.Index(str, sub) + } + + var pos int + cs := col.Charset() + for len(str) > 0 { + if col.Collate(str, sub, true) == 0 { + return pos + } + _, size := cs.DecodeRune(str) + str = str[size:] + pos++ + } + return -1 +} diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 7525bfdaec4..c80fabb5dca 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1147,6 +1147,18 @@ func (cached *builtinLn) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinLocate) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinLog) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index e017a949a07..f7ad584683e 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2986,6 +2986,94 @@ func (asm *assembler) Like_collate(expr *LikeExpr, collation colldata.Collation) }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) } +func (asm *assembler) Locate_coerce3(coercion *compiledCoercion) { + asm.adjustStack(-2) + + asm.emit(func(env *ExpressionEnv) int { + substr := env.vm.stack[env.vm.sp-3].(*evalBytes) + str := env.vm.stack[env.vm.sp-2].(*evalBytes) + pos := env.vm.stack[env.vm.sp-1].(*evalInt64) + env.vm.sp -= 2 + + if pos.i < 1 { + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(0) + return 1 + } + + var bsub, bstr []byte + bsub, env.vm.err = coercion.left(nil, substr.bytes) + if env.vm.err != nil { + return 0 + } + bstr, env.vm.err = coercion.right(nil, str.bytes) + if env.vm.err != nil { + return 0 + } + + found := colldata.Index(coercion.col, bstr, bsub, int(pos.i)-1) + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(found) + 1) + return 1 + }, "LOCATE VARCHAR(SP-3), VARCHAR(SP-2) INT64(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) +} + +func (asm *assembler) Locate_coerce2(coercion *compiledCoercion) { + asm.adjustStack(-1) + + asm.emit(func(env *ExpressionEnv) int { + substr := env.vm.stack[env.vm.sp-2].(*evalBytes) + str := env.vm.stack[env.vm.sp-1].(*evalBytes) + env.vm.sp-- + + var bsub, bstr []byte + bsub, env.vm.err = coercion.left(nil, substr.bytes) + if env.vm.err != nil { + return 0 + } + bstr, env.vm.err = coercion.right(nil, str.bytes) + if env.vm.err != nil { + return 0 + } + + found := colldata.Index(coercion.col, bstr, bsub, 0) + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(found) + 1) + return 1 + }, "LOCATE VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) +} + +func (asm *assembler) Locate_collate3(collation colldata.Collation) { + asm.adjustStack(-2) + + asm.emit(func(env *ExpressionEnv) int { + substr := env.vm.stack[env.vm.sp-3].(*evalBytes) + str := env.vm.stack[env.vm.sp-2].(*evalBytes) + pos := env.vm.stack[env.vm.sp-1].(*evalInt64) + env.vm.sp -= 2 + + if pos.i < 1 || pos.i > math.MaxInt { + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(0) + return 1 + } + + found := colldata.Index(collation, str.bytes, substr.bytes, int(pos.i)-1) + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(found) + 1) + return 1 + }, "LOCATE VARCHAR(SP-3), VARCHAR(SP-2) INT64(SP-1) COLLATE '%s'", collation.Name()) +} + +func (asm *assembler) Locate_collate2(collation colldata.Collation) { + asm.adjustStack(-1) + + asm.emit(func(env *ExpressionEnv) int { + substr := env.vm.stack[env.vm.sp-2].(*evalBytes) + str := env.vm.stack[env.vm.sp-1].(*evalBytes) + env.vm.sp-- + + found := colldata.Index(collation, str.bytes, substr.bytes, 0) + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(found) + 1) + return 1 + }, "LOCATE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) +} + func (asm *assembler) Strcmp(collation collations.TypedCollation) { asm.adjustStack(-1) diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 7b2c92783ee..09e08ad0d48 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -615,6 +615,18 @@ func TestCompilerSingle(t *testing.T) { expression: `time('1111:66:56')`, result: `NULL`, }, + { + expression: `locate('Å', 'a')`, + result: `INT64(1)`, + }, + { + expression: `locate('a', 'Å')`, + result: `INT64(1)`, + }, + { + expression: `locate("", "😊😂🤢", 3)`, + result: `INT64(3)`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 23ff1cbdca3..eee8258cefc 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -110,6 +110,11 @@ type ( CallExpr collate collations.ID } + + builtinLocate struct { + CallExpr + collate collations.ID + } ) var _ IR = (*builtinInsert)(nil) @@ -1265,6 +1270,140 @@ func (call *builtinSubstring) compile(c *compiler) (ctype, error) { return ctype{Type: tt, Col: col, Flag: flagNullable}, nil } +func (call *builtinLocate) eval(env *ExpressionEnv) (eval, error) { + substr, err := call.Arguments[0].eval(env) + if err != nil || substr == nil { + return nil, err + } + + str, err := call.Arguments[1].eval(env) + if err != nil || str == nil { + return nil, err + } + + pos := int64(1) + if len(call.Arguments) > 2 { + p, err := call.Arguments[2].eval(env) + if err != nil || p == nil { + return nil, err + } + pos = evalToInt64(p).i + if pos < 1 || pos > math.MaxInt { + return newEvalInt64(0), nil + } + } + + var col collations.TypedCollation + substr, str, col, err = mergeAndCoerceCollations(substr, str, env.collationEnv) + if err != nil { + return nil, err + } + + var coll colldata.Collation + if typeIsTextual(substr.SQLType()) && typeIsTextual(str.SQLType()) { + coll = colldata.Lookup(col.Collation) + } else { + coll = colldata.Lookup(collations.CollationBinaryID) + } + found := colldata.Index(coll, str.ToRawBytes(), substr.ToRawBytes(), int(pos)-1) + return newEvalInt64(int64(found) + 1), nil +} + +func (call *builtinLocate) compile(c *compiler) (ctype, error) { + substr, err := call.Arguments[0].compile(c) + if err != nil { + return ctype{}, err + } + + str, err := call.Arguments[1].compile(c) + if err != nil { + return ctype{}, err + } + + skip1 := c.compileNullCheck2(substr, str) + var skip2 *jump + if len(call.Arguments) > 2 { + l, err := call.Arguments[2].compile(c) + if err != nil { + return ctype{}, err + } + skip2 = c.compileNullCheck2(str, l) + _ = c.compileToInt64(l, 1) + } + + if !substr.isTextual() { + c.asm.Convert_xc(len(call.Arguments), sqltypes.VarChar, c.collation, 0, false) + substr.Col = collations.TypedCollation{ + Collation: c.collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + if !str.isTextual() { + c.asm.Convert_xc(len(call.Arguments)-1, sqltypes.VarChar, c.collation, 0, false) + str.Col = collations.TypedCollation{ + Collation: c.collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + var merged collations.TypedCollation + var coerceLeft colldata.Coercion + var coerceRight colldata.Coercion + + if substr.Col.Collation != str.Col.Collation { + merged, coerceLeft, coerceRight, err = colldata.Merge(c.env.CollationEnv(), substr.Col, str.Col, colldata.CoercionOptions{ + ConvertToSuperset: true, + ConvertWithCoercion: true, + }) + } else { + merged = substr.Col + } + if err != nil { + return ctype{}, err + } + + var coll colldata.Collation + if typeIsTextual(substr.Type) && typeIsTextual(str.Type) { + coll = colldata.Lookup(merged.Collation) + } else { + coll = colldata.Lookup(collations.CollationBinaryID) + } + + if coerceLeft == nil && coerceRight == nil { + if len(call.Arguments) > 2 { + c.asm.Locate_collate3(coll) + } else { + c.asm.Locate_collate2(coll) + } + } else { + if coerceLeft == nil { + coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } + } + if coerceRight == nil { + coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } + } + if len(call.Arguments) > 2 { + c.asm.Locate_coerce3(&compiledCoercion{ + col: colldata.Lookup(merged.Collation), + left: coerceLeft, + right: coerceRight, + }) + } else { + c.asm.Locate_coerce2(&compiledCoercion{ + col: colldata.Lookup(merged.Collation), + left: coerceLeft, + right: coerceRight, + }) + } + } + + c.asm.jumpDestination(skip1, skip2) + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}, nil +} + type builtinConcat struct { CallExpr collate collations.ID diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 9d9cdfa248e..9dbb7276e12 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -82,6 +82,7 @@ var Cases = []TestCase{ {Run: FnRTrim}, {Run: FnTrim}, {Run: FnSubstr}, + {Run: FnLocate}, {Run: FnConcat}, {Run: FnConcatWs}, {Run: FnHex}, @@ -1527,6 +1528,34 @@ func FnSubstr(yield Query) { } } +func FnLocate(yield Query) { + mysqlDocSamples := []string{ + `LOCATE('bar', 'foobarbar')`, + `LOCATE('xbar', 'foobar')`, + `LOCATE('bar', 'foobarbar', 5)`, + `INSTR('foobarbar', 'bar')`, + `INSTR('xbar', 'foobar')`, + `POSITION('bar' IN 'foobarbar')`, + `POSITION('xbar' IN 'foobar')`, + } + + for _, q := range mysqlDocSamples { + yield(q, nil) + } + + for _, substr := range locateStrings { + for _, str := range locateStrings { + yield(fmt.Sprintf("LOCATE(%s, %s)", substr, str), nil) + yield(fmt.Sprintf("INSTR(%s, %s)", str, substr), nil) + yield(fmt.Sprintf("POSITION(%s IN %s)", str, substr), nil) + + for _, i := range radianInputs { + yield(fmt.Sprintf("LOCATE(%s, %s, %s)", substr, str, i), nil) + } + } + } +} + func FnConcat(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("CONCAT(%s)", str), nil) diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index c453f904c96..c4ab2fdb92d 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -234,6 +234,42 @@ var insertStrings = []string{ // "_ucs2 'AabcÅå'", } +var locateStrings = []string{ + "NULL", + "\"\"", + "\"a\"", + "\"abc\"", + "1", + "-1", + "0123", + "0xAACC", + "3.1415926", + // MySQL has broken behavior for these inputs, + // see https://bugs.mysql.com/bug.php?id=113933 + // "\"Å å\"", + // "\"中文测试\"", + // "\"日本語テスト\"", + // "\"한국어 시험\"", + // "\"😊😂🤢\"", + // "_utf8mb4 'abcABCÅå'", + "DATE '2022-10-11'", + "TIME '11:02:23'", + "'123'", + "9223372036854775807", + "-9223372036854775808", + "999999999999999999999999", + "-999999999999999999999999", + "_binary 'Müller' ", + "_utf8mb4 'abcABCÅå'", + "_latin1 0xFF", + // TODO: support other multibyte encodings + // "_dec8 'ÒòÅå'", + // "_utf8mb3 'abcABCÅå'", + // "_utf16 'AabcÅå'", + // "_utf32 'AabcÅå'", + // "_ucs2 'AabcÅå'", +} + var inputConversionTypes = []string{ "BINARY", "BINARY(1)", "BINARY(0)", "BINARY(16)", "BINARY(-1)", "CHAR", "CHAR(1)", "CHAR(0)", "CHAR(16)", "CHAR(-1)", diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 11618bb1d1a..73beb7fd59e 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -604,6 +604,12 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { return nil, argError(method) } return &builtinStrcmp{CallExpr: call, collate: ast.cfg.Collation}, nil + case "instr": + if len(args) != 2 { + return nil, argError(method) + } + call = CallExpr{Arguments: []IR{call.Arguments[1], call.Arguments[0]}, Method: method} + return &builtinLocate{CallExpr: call, collate: ast.cfg.Collation}, nil default: return nil, translateExprNotSupported(fn) } @@ -729,7 +735,7 @@ func (ast *astCompiler) translateCallable(call sqlparser.Callable) (IR, error) { case *sqlparser.CurTimeFuncExpr: if call.Fsp > 6 { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision 12 specified for '%s'. Maximum is 6.", call.Name.String()) + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for '%s'. Maximum is 6.", call.Fsp, call.Name.String()) } var cexpr = CallExpr{Arguments: nil, Method: call.Name.String()} @@ -802,6 +808,31 @@ func (ast *astCompiler) translateCallable(call sqlparser.Callable) (IR, error) { CallExpr: cexpr, collate: ast.cfg.Collation, }, nil + case *sqlparser.LocateExpr: + var args []IR + substr, err := ast.translateExpr(call.SubStr) + if err != nil { + return nil, err + } + args = append(args, substr) + str, err := ast.translateExpr(call.Str) + if err != nil { + return nil, err + } + args = append(args, str) + + if call.Pos != nil { + to, err := ast.translateExpr(call.Pos) + if err != nil { + return nil, err + } + args = append(args, to) + } + var cexpr = CallExpr{Arguments: args, Method: "LOCATE"} + return &builtinLocate{ + CallExpr: cexpr, + collate: ast.cfg.Collation, + }, nil case *sqlparser.IntervalDateExpr: var err error args := make([]IR, 2) From 3ae4423feef0fb74b4cd0a1f1e6498442bb8e850 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Sat, 10 Feb 2024 22:34:04 +0100 Subject: [PATCH 2/3] evalengine: Clean up zero date usage We have a bunch of cases where we convert to zero date with the runtime flag, but then check both nil and zero date. It's equivalent to convert without allowing zero dates to simplify the code here. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler_asm.go | 13 +------ go/vt/vtgate/evalengine/fn_time.go | 50 ++++++++++++------------- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index f7ad584683e..29ad659249d 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -3921,11 +3921,6 @@ func (asm *assembler) Fn_LAST_DAY() { return 1 } arg := env.vm.stack[env.vm.sp-1].(*evalTemporal) - if arg.dt.IsZero() { - env.vm.stack[env.vm.sp-1] = nil - return 1 - } - d := lastDay(env.currentTimezone(), arg.dt) env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDate(d) return 1 @@ -3938,12 +3933,8 @@ func (asm *assembler) Fn_TO_DAYS() { return 1 } arg := env.vm.stack[env.vm.sp-1].(*evalTemporal) - if arg.dt.Date.IsZero() { - env.vm.stack[env.vm.sp-1] = nil - } else { - numDays := datetime.MysqlDayNumber(arg.dt.Date.Year(), arg.dt.Date.Month(), arg.dt.Date.Day()) - env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(numDays)) - } + numDays := datetime.MysqlDayNumber(arg.dt.Date.Year(), arg.dt.Date.Month(), arg.dt.Date.Day()) + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(numDays)) return 1 }, "FN TO_DAYS DATE(SP-1)") } diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index 319f8d1b328..430e4e123ac 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -282,8 +282,8 @@ func (b *builtinDateFormat) eval(env *ExpressionEnv) (eval, error) { case *evalTemporal: t = e.toDateTime(datetime.DefaultPrecision, env.now) default: - t = evalToDateTime(date, datetime.DefaultPrecision, env.now, env.sqlmode.AllowZeroDate()) - if t == nil || t.isZero() { + t = evalToDateTime(date, datetime.DefaultPrecision, env.now, false) + if t == nil { return nil, nil } } @@ -379,8 +379,8 @@ func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - dt := evalToDateTime(n, -1, env.now, env.sqlmode.AllowZeroDate()) - if dt == nil || dt.isZero() { + dt := evalToDateTime(n, -1, env.now, false) + if dt == nil { return nil, nil } @@ -388,7 +388,7 @@ func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) { if !ok { return nil, nil } - return newEvalDateTime(out, int(dt.prec), env.sqlmode.AllowZeroDate()), nil + return newEvalDateTime(out, int(dt.prec), false), nil } func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { @@ -504,8 +504,8 @@ func (b *builtinDayOfWeek) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } return newEvalInt64(int64(d.dt.Date.Weekday() + 1)), nil @@ -537,8 +537,8 @@ func (b *builtinDayOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } return newEvalInt64(int64(d.dt.Date.ToStdTime(env.currentTimezone()).YearDay())), nil @@ -815,7 +815,7 @@ func (b *builtinMakedate) eval(env *ExpressionEnv) (eval, error) { if t.IsZero() { return nil, nil } - return newEvalDate(datetime.NewDateTimeFromStd(t).Date, env.sqlmode.AllowZeroDate()), nil + return newEvalDate(datetime.NewDateTimeFromStd(t).Date, false), nil } func (call *builtinMakedate) compile(c *compiler) (ctype, error) { @@ -1189,7 +1189,7 @@ func (b *builtinMonthName) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) + d := evalToDate(date, env.now, false) if d == nil { return nil, nil } @@ -1212,7 +1212,7 @@ func (call *builtinMonthName) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1, c.sqlmode.AllowZeroDate()) + c.asm.Convert_xD(1, false) } col := typedCoercionCollation(sqltypes.VarChar, call.collate) c.asm.Fn_MONTHNAME(col) @@ -1272,8 +1272,8 @@ func (b *builtinToDays) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - dt := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if dt == nil || dt.isZero() { + dt := evalToDate(date, env.now, false) + if dt == nil { return nil, nil } @@ -1292,7 +1292,7 @@ func (call *builtinToDays) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1, true) + c.asm.Convert_xD(1, false) } c.asm.Fn_TO_DAYS() c.asm.jumpDestination(skip) @@ -1477,8 +1477,8 @@ func dateTimeUnixTimestamp(env *ExpressionEnv, date eval) evalNumeric { case *evalTemporal: dt = e.toDateTime(int(e.prec), env.now) default: - dt = evalToDateTime(date, -1, env.now, env.sqlmode.AllowZeroDate()) - if dt == nil || dt.isZero() { + dt = evalToDateTime(date, -1, env.now, false) + if dt == nil { var prec int32 switch d := date.(type) { case *evalInt64, *evalUint64: @@ -1584,8 +1584,8 @@ func (b *builtinWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } @@ -1644,8 +1644,8 @@ func (b *builtinWeekDay) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } return newEvalInt64(int64(d.dt.Date.Weekday()+6) % 7), nil @@ -1678,8 +1678,8 @@ func (b *builtinWeekOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } @@ -1750,8 +1750,8 @@ func (b *builtinYearWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } From 789bbad2a3f6aecc6e23324d9f42238914e259a1 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 12 Feb 2024 11:12:22 +0100 Subject: [PATCH 3/3] evalengine: Simplify logic for collation handling Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler_asm.go | 58 +---------------- go/vt/vtgate/evalengine/fn_string.go | 83 ++++++++----------------- 2 files changed, 29 insertions(+), 112 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 29ad659249d..5097d54dbd6 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2986,61 +2986,7 @@ func (asm *assembler) Like_collate(expr *LikeExpr, collation colldata.Collation) }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) } -func (asm *assembler) Locate_coerce3(coercion *compiledCoercion) { - asm.adjustStack(-2) - - asm.emit(func(env *ExpressionEnv) int { - substr := env.vm.stack[env.vm.sp-3].(*evalBytes) - str := env.vm.stack[env.vm.sp-2].(*evalBytes) - pos := env.vm.stack[env.vm.sp-1].(*evalInt64) - env.vm.sp -= 2 - - if pos.i < 1 { - env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(0) - return 1 - } - - var bsub, bstr []byte - bsub, env.vm.err = coercion.left(nil, substr.bytes) - if env.vm.err != nil { - return 0 - } - bstr, env.vm.err = coercion.right(nil, str.bytes) - if env.vm.err != nil { - return 0 - } - - found := colldata.Index(coercion.col, bstr, bsub, int(pos.i)-1) - env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(found) + 1) - return 1 - }, "LOCATE VARCHAR(SP-3), VARCHAR(SP-2) INT64(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) -} - -func (asm *assembler) Locate_coerce2(coercion *compiledCoercion) { - asm.adjustStack(-1) - - asm.emit(func(env *ExpressionEnv) int { - substr := env.vm.stack[env.vm.sp-2].(*evalBytes) - str := env.vm.stack[env.vm.sp-1].(*evalBytes) - env.vm.sp-- - - var bsub, bstr []byte - bsub, env.vm.err = coercion.left(nil, substr.bytes) - if env.vm.err != nil { - return 0 - } - bstr, env.vm.err = coercion.right(nil, str.bytes) - if env.vm.err != nil { - return 0 - } - - found := colldata.Index(coercion.col, bstr, bsub, 0) - env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(found) + 1) - return 1 - }, "LOCATE VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) -} - -func (asm *assembler) Locate_collate3(collation colldata.Collation) { +func (asm *assembler) Locate3(collation colldata.Collation) { asm.adjustStack(-2) asm.emit(func(env *ExpressionEnv) int { @@ -3060,7 +3006,7 @@ func (asm *assembler) Locate_collate3(collation colldata.Collation) { }, "LOCATE VARCHAR(SP-3), VARCHAR(SP-2) INT64(SP-1) COLLATE '%s'", collation.Name()) } -func (asm *assembler) Locate_collate2(collation colldata.Collation) { +func (asm *assembler) Locate2(collation colldata.Collation) { asm.adjustStack(-1) asm.emit(func(env *ExpressionEnv) int { diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index eee8258cefc..e65800c9824 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -1281,6 +1281,19 @@ func (call *builtinLocate) eval(env *ExpressionEnv) (eval, error) { return nil, err } + if _, ok := str.(*evalBytes); !ok { + str, err = evalToVarchar(str, call.collate, true) + if err != nil { + return nil, err + } + } + + col := str.(*evalBytes).col.Collation + substr, err = evalToVarchar(substr, col, true) + if err != nil { + return nil, err + } + pos := int64(1) if len(call.Arguments) > 2 { p, err := call.Arguments[2].eval(env) @@ -1293,15 +1306,9 @@ func (call *builtinLocate) eval(env *ExpressionEnv) (eval, error) { } } - var col collations.TypedCollation - substr, str, col, err = mergeAndCoerceCollations(substr, str, env.collationEnv) - if err != nil { - return nil, err - } - var coll colldata.Collation if typeIsTextual(substr.SQLType()) && typeIsTextual(str.SQLType()) { - coll = colldata.Lookup(col.Collation) + coll = colldata.Lookup(col) } else { coll = colldata.Lookup(collations.CollationBinaryID) } @@ -1331,73 +1338,37 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) { _ = c.compileToInt64(l, 1) } - if !substr.isTextual() { - c.asm.Convert_xc(len(call.Arguments), sqltypes.VarChar, c.collation, 0, false) - substr.Col = collations.TypedCollation{ + if !str.isTextual() { + c.asm.Convert_xce(len(call.Arguments)-1, sqltypes.VarChar, c.collation) + str.Col = collations.TypedCollation{ Collation: c.collation, Coercibility: collations.CoerceCoercible, Repertoire: collations.RepertoireASCII, } } - if !str.isTextual() { - c.asm.Convert_xc(len(call.Arguments)-1, sqltypes.VarChar, c.collation, 0, false) - str.Col = collations.TypedCollation{ - Collation: c.collation, + fromCharset := colldata.Lookup(substr.Col.Collation).Charset() + toCharset := colldata.Lookup(str.Col.Collation).Charset() + if !substr.isTextual() || (fromCharset != toCharset && !toCharset.IsSuperset(fromCharset)) { + c.asm.Convert_xce(len(call.Arguments), sqltypes.VarChar, str.Col.Collation) + substr.Col = collations.TypedCollation{ + Collation: str.Col.Collation, Coercibility: collations.CoerceCoercible, Repertoire: collations.RepertoireASCII, } } - var merged collations.TypedCollation - var coerceLeft colldata.Coercion - var coerceRight colldata.Coercion - - if substr.Col.Collation != str.Col.Collation { - merged, coerceLeft, coerceRight, err = colldata.Merge(c.env.CollationEnv(), substr.Col, str.Col, colldata.CoercionOptions{ - ConvertToSuperset: true, - ConvertWithCoercion: true, - }) - } else { - merged = substr.Col - } - if err != nil { - return ctype{}, err - } - var coll colldata.Collation if typeIsTextual(substr.Type) && typeIsTextual(str.Type) { - coll = colldata.Lookup(merged.Collation) + coll = colldata.Lookup(str.Col.Collation) } else { coll = colldata.Lookup(collations.CollationBinaryID) } - if coerceLeft == nil && coerceRight == nil { - if len(call.Arguments) > 2 { - c.asm.Locate_collate3(coll) - } else { - c.asm.Locate_collate2(coll) - } + if len(call.Arguments) > 2 { + c.asm.Locate3(coll) } else { - if coerceLeft == nil { - coerceLeft = func(dst, in []byte) ([]byte, error) { return in, nil } - } - if coerceRight == nil { - coerceRight = func(dst, in []byte) ([]byte, error) { return in, nil } - } - if len(call.Arguments) > 2 { - c.asm.Locate_coerce3(&compiledCoercion{ - col: colldata.Lookup(merged.Collation), - left: coerceLeft, - right: coerceRight, - }) - } else { - c.asm.Locate_coerce2(&compiledCoercion{ - col: colldata.Lookup(merged.Collation), - left: coerceLeft, - right: coerceRight, - }) - } + c.asm.Locate2(coll) } c.asm.jumpDestination(skip1, skip2)