diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 9e0b091b5c0..7525bfdaec4 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -931,6 +931,18 @@ func (cached *builtinInetNtoa) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinInsert) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinIsIPV4) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 5cb1426b55c..0b830568e7b 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -327,6 +327,15 @@ func (c *compiler) compileNullCheck3(arg1, arg2, arg3 ctype) *jump { return nil } +func (c *compiler) compileNullCheck4(arg1, arg2, arg3, arg4 ctype) *jump { + if arg1.nullable() || arg2.nullable() || arg3.nullable() || arg4.nullable() { + j := c.asm.jumpFrom() + c.asm.NullCheck4(j) + return j + } + return nil +} + func (c *compiler) compileNullCheckArg(ct ctype, offset int) *jump { if ct.nullable() { j := c.asm.jumpFrom() diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 39e02d07d55..e017a949a07 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2345,6 +2345,28 @@ func (asm *assembler) Fn_BIT_LENGTH() { }, "FN BIT_LENGTH VARCHAR(SP-1)") } +func (asm *assembler) Fn_INSERT(col collations.TypedCollation) { + asm.adjustStack(-3) + + asm.emit(func(env *ExpressionEnv) int { + str := env.vm.stack[env.vm.sp-4].(*evalBytes) + pos := env.vm.stack[env.vm.sp-3].(*evalInt64).i + l := env.vm.stack[env.vm.sp-2].(*evalInt64).i + newstr := env.vm.stack[env.vm.sp-1].(*evalBytes) + + res := insert(str, newstr, int(pos), int(l)) + if !validMaxLength(int64(len(res)), 1) { + env.vm.stack[env.vm.sp-4] = nil + env.vm.sp -= 3 + return 1 + } + + env.vm.stack[env.vm.sp-4] = env.vm.arena.newEvalText(res, col) + env.vm.sp -= 3 + return 1 + }, "FN INSERT VARCHAR(SP-4) INT64(SP-3) INT64(SP-2) VARCHAR(SP-1)") +} + func (asm *assembler) Fn_LUCASE(upcase bool) { if upcase { asm.emit(func(env *ExpressionEnv) int { @@ -3147,6 +3169,17 @@ func (asm *assembler) NullCheck3(j *jump) { }, "NULLCHECK SP-1, SP-2, SP-3") } +func (asm *assembler) NullCheck4(j *jump) { + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-4] == nil || env.vm.stack[env.vm.sp-3] == nil || env.vm.stack[env.vm.sp-2] == nil || env.vm.stack[env.vm.sp-1] == nil { + env.vm.stack[env.vm.sp-4] = nil + env.vm.sp -= 3 + return j.offset() + } + return 1 + }, "NULLCHECK SP-1, SP-2, SP-3, SP-4") +} + func (asm *assembler) NullCheckArg(j *jump, offset int) { asm.emit(func(env *ExpressionEnv) int { if env.vm.stack[env.vm.sp-1] == nil { diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index a6ab0a1c1cd..23ff1cbdca3 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -18,6 +18,7 @@ package evalengine import ( "bytes" + "math" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" @@ -29,6 +30,11 @@ import ( ) type ( + builtinInsert struct { + CallExpr + collate collations.ID + } + builtinChangeCase struct { CallExpr upcase bool @@ -106,6 +112,7 @@ type ( } ) +var _ IR = (*builtinInsert)(nil) var _ IR = (*builtinChangeCase)(nil) var _ IR = (*builtinCharLength)(nil) var _ IR = (*builtinLength)(nil) @@ -120,6 +127,122 @@ var _ IR = (*builtinLeftRight)(nil) var _ IR = (*builtinPad)(nil) var _ IR = (*builtinTrim)(nil) +func insert(str, newstr *evalBytes, pos, l int) []byte { + pos-- + + cs := colldata.Lookup(str.col.Collation).Charset() + strLen := charset.Length(cs, str.bytes) + + if pos < 0 || strLen <= pos { + return str.bytes + } + if l < 0 { + l = strLen + } + + front := charset.Slice(cs, str.bytes, 0, pos) + var back []byte + if pos <= math.MaxInt-l && pos+l < strLen { + back = charset.Slice(cs, str.bytes, pos+l, strLen) + } + + res := make([]byte, len(front)+len(newstr.bytes)+len(back)) + + copy(res[:len(front)], front) + copy(res[len(front):], newstr.bytes) + copy(res[len(front)+len(newstr.bytes):], back) + + return res +} + +func (call *builtinInsert) eval(env *ExpressionEnv) (eval, error) { + args, err := call.args(env) + if err != nil { + return nil, err + } + if args[0] == nil || args[1] == nil || args[2] == nil || args[3] == nil { + return nil, nil + } + + str, ok := args[0].(*evalBytes) + if !ok { + str, err = evalToVarchar(args[0], call.collate, true) + if err != nil { + return nil, err + } + } + + pos := evalToInt64(args[1]).i + l := evalToInt64(args[2]).i + + newstr, err := evalToVarchar(args[3], str.col.Collation, true) + if err != nil { + return nil, err + } + + res := insert(str, newstr, int(pos), int(l)) + if !validMaxLength(int64(len(res)), 1) { + return nil, nil + } + return newEvalText(res, str.col), nil +} + +func (call *builtinInsert) compile(c *compiler) (ctype, error) { + str, err := call.Arguments[0].compile(c) + if err != nil { + return ctype{}, err + } + + pos, err := call.Arguments[1].compile(c) + if err != nil { + return ctype{}, err + } + + l, err := call.Arguments[2].compile(c) + if err != nil { + return ctype{}, err + } + + newstr, err := call.Arguments[3].compile(c) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck4(str, pos, l, newstr) + + _ = c.compileToInt64(pos, 3) + _ = c.compileToInt64(l, 2) + + if err != nil { + return ctype{}, nil + } + + col := str.Col + + switch { + case str.isTextual(): + default: + c.asm.Convert_xce(4, sqltypes.VarChar, c.collation) + col = typedCoercionCollation(sqltypes.VarChar, c.collation) + } + + switch { + case newstr.isTextual(): + fromCharset := colldata.Lookup(newstr.Col.Collation).Charset() + toCharset := colldata.Lookup(col.Collation).Charset() + if fromCharset != toCharset && !toCharset.IsSuperset(fromCharset) { + c.asm.Convert_xce(1, sqltypes.VarChar, col.Collation) + } + default: + c.asm.Convert_xce(1, sqltypes.VarChar, col.Collation) + } + + c.asm.Fn_INSERT(col) + c.asm.jumpDestination(skip) + + return ctype{Type: sqltypes.VarChar, Col: col, Flag: flagNullable}, nil +} + func (call *builtinChangeCase) eval(env *ExpressionEnv) (eval, error) { arg, err := call.arg1(env) if err != nil { diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 45cfdc4dd10..9d9cdfa248e 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -63,6 +63,7 @@ var Cases = []TestCase{ {Run: TupleComparisons}, {Run: Comparisons}, {Run: InStatement}, + {Run: FnInsert}, {Run: FnLower}, {Run: FnUpper}, {Run: FnCharLength}, @@ -1314,6 +1315,28 @@ var JSONExtract_Schema = []*querypb.Field{ }, } +func FnInsert(yield Query) { + for _, s := range insertStrings { + for _, ns := range insertStrings { + for _, l := range inputBitwise { + for _, p := range inputBitwise { + yield(fmt.Sprintf("INSERT(%s, %s, %s, %s)", s, p, l, ns), nil) + } + } + } + } + + mysqlDocSamples := []string{ + "INSERT('Quadratic', 3, 4, 'What')", + "INSERT('Quadratic', -1, 4, 'What')", + "INSERT('Quadratic', 3, 100, 'What')", + } + + for _, q := range mysqlDocSamples { + yield(q, nil) + } +} + func FnLower(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("LOWER(%s)", str), nil) diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index b4a558d1145..c453f904c96 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -199,6 +199,41 @@ var inputStrings = []string{ // "_ucs2 'AabcÅå'", } +var insertStrings = []string{ + "NULL", + "\"\"", + "\"a\"", + "\"abc\"", + "1", + "-1", + "0123", + "0xAACC", + "3.1415926", + // MySQL has broken behavior for these inputs, + // see https://github.com/mysql/mysql-server/pull/517 + // "\"Å å\"", + // "\"中文测试\"", + // "\"日本語テスト\"", + // "\"한국어 시험\"", + // "\"😊😂🤢\"", + // "_utf8mb4 'abcABCÅå'", + "DATE '2022-10-11'", + "TIME '11:02:23'", + "'123'", + "9223372036854775807", + "-9223372036854775808", + "999999999999999999999999", + "-999999999999999999999999", + "_binary 'Müller' ", + "_latin1 0xFF", + // TODO: support other multibyte encodings + // "_dec8 'ÒòÅå'", + // "_utf8mb3 'abcABCÅå'", + // "_utf16 'AabcÅå'", + // "_utf32 'AabcÅå'", + // "_ucs2 'AabcÅå'", +} + var inputConversionTypes = []string{ "BINARY", "BINARY(1)", "BINARY(0)", "BINARY(16)", "BINARY(-1)", "CHAR", "CHAR(1)", "CHAR(0)", "CHAR(16)", "CHAR(-1)", diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 8b0b8326baa..11618bb1d1a 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -984,6 +984,35 @@ func (ast *astCompiler) translateCallable(call sqlparser.Callable) (IR, error) { return &builtinRegexpReplace{ CallExpr: CallExpr{Arguments: args, Method: "REGEXP_REPLACE"}, }, nil + + case *sqlparser.InsertExpr: + str, err := ast.translateExpr(call.Str) + if err != nil { + return nil, err + } + + pos, err := ast.translateExpr(call.Pos) + if err != nil { + return nil, err + } + + len, err := ast.translateExpr(call.Len) + if err != nil { + return nil, err + } + + newstr, err := ast.translateExpr(call.NewStr) + if err != nil { + return nil, err + } + + args := []IR{str, pos, len, newstr} + + var cexpr = CallExpr{Arguments: args, Method: "INSERT"} + return &builtinInsert{ + CallExpr: cexpr, + collate: ast.cfg.Collation, + }, nil default: return nil, translateExprNotSupported(call) } diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 59a6e4686a6..cee664e6cd9 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -3048,15 +3048,15 @@ "QueryType": "SELECT", "Original": "select insert('Quadratic', 3, 4, 'What')", "Instructions": { - "OperatorType": "Route", - "Variant": "Reference", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select insert('Quadratic', 3, 4, 'What') from dual where 1 != 1", - "Query": "select insert('Quadratic', 3, 4, 'What') from dual", - "Table": "dual" + "OperatorType": "Projection", + "Expressions": [ + "'QuWhattic' as insert('Quadratic', 3, 4, 'What')" + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] }, "TablesUsed": [ "main.dual"