diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index ce90aadea8a..b386d3dc915 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1463,6 +1463,18 @@ func (cached *builtinRepeat) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinReplace) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinReverse) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index cddf0790ea7..5eeb9a6300d 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -3022,6 +3022,19 @@ func (asm *assembler) Locate2(collation colldata.Collation) { }, "LOCATE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) } +func (asm *assembler) Replace() { + asm.adjustStack(-2) + + asm.emit(func(env *ExpressionEnv) int { + str := env.vm.stack[env.vm.sp-3].(*evalBytes) + from := env.vm.stack[env.vm.sp-2].(*evalBytes) + to := env.vm.stack[env.vm.sp-1].(*evalBytes) + env.vm.sp -= 2 + str.bytes = replace(str.bytes, from.bytes, to.bytes) + return 1 + }, "REPLACE VARCHAR(SP-3), VARCHAR(SP-2) VARCHAR(SP-1)") +} + func (asm *assembler) Strcmp(collation collations.TypedCollation) { asm.adjustStack(-1) diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 09e08ad0d48..f101bf61c64 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -627,6 +627,10 @@ func TestCompilerSingle(t *testing.T) { expression: `locate("", "😊😂🤢", 3)`, result: `INT64(3)`, }, + { + expression: `REPLACE('www.mysql.com', '', 'Ww')`, + result: `VARCHAR("www.mysql.com")`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 11a18c95300..e0887037c0a 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -114,6 +114,31 @@ type ( CallExpr collate collations.ID } + + builtinChar struct { + CallExpr + collate collations.ID + } + + builtinRepeat struct { + CallExpr + collate collations.ID + } + + builtinConcat struct { + CallExpr + collate collations.ID + } + + builtinConcatWs struct { + CallExpr + collate collations.ID + } + + builtinReplace struct { + CallExpr + collate collations.ID + } ) var _ IR = (*builtinInsert)(nil) @@ -129,7 +154,15 @@ var _ IR = (*builtinCollation)(nil) var _ IR = (*builtinWeightString)(nil) var _ IR = (*builtinLeftRight)(nil) var _ IR = (*builtinPad)(nil) +var _ IR = (*builtinStrcmp)(nil) var _ IR = (*builtinTrim)(nil) +var _ IR = (*builtinSubstring)(nil) +var _ IR = (*builtinLocate)(nil) +var _ IR = (*builtinChar)(nil) +var _ IR = (*builtinRepeat)(nil) +var _ IR = (*builtinConcat)(nil) +var _ IR = (*builtinConcatWs)(nil) +var _ IR = (*builtinReplace)(nil) func insert(str, newstr *evalBytes, pos, l int) []byte { pos-- @@ -555,11 +588,6 @@ func (call *builtinOrd) compile(c *compiler) (ctype, error) { // - `> max_allowed_packet`, no error and returns `NULL`. const maxRepeatLength = 1073741824 -type builtinRepeat struct { - CallExpr - collate collations.ID -} - func (call *builtinRepeat) eval(env *ExpressionEnv) (eval, error) { arg1, arg2, err := call.arg2(env) if err != nil { @@ -1374,11 +1402,6 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}, nil } -type builtinConcat struct { - CallExpr - collate collations.ID -} - func concatSQLType(arg sqltypes.Type, tt sqltypes.Type) sqltypes.Type { if arg == sqltypes.TypeJSON { return sqltypes.Blob @@ -1507,11 +1530,6 @@ func (call *builtinConcat) compile(c *compiler) (ctype, error) { return ctype{Type: tt, Flag: f, Col: tc}, nil } -type builtinConcatWs struct { - CallExpr - collate collations.ID -} - func (call *builtinConcatWs) eval(env *ExpressionEnv) (eval, error) { var ca collationAggregation tt := sqltypes.VarChar @@ -1643,13 +1661,6 @@ func (call *builtinConcatWs) compile(c *compiler) (ctype, error) { return ctype{Type: tt, Flag: args[0].Flag, Col: tc}, nil } -type builtinChar struct { - CallExpr - collate collations.ID -} - -var _ IR = (*builtinChar)(nil) - func (call *builtinChar) eval(env *ExpressionEnv) (eval, error) { vals := make([]eval, 0, len(call.Arguments)) for _, arg := range call.Arguments { @@ -1726,3 +1737,119 @@ func encodeChar(buf []byte, i uint32) []byte { } return buf } + +func (call *builtinReplace) eval(env *ExpressionEnv) (eval, error) { + str, err := call.Arguments[0].eval(env) + if err != nil || str == nil { + return nil, err + } + + fromStr, err := call.Arguments[1].eval(env) + if err != nil || fromStr == nil { + return nil, err + } + + toStr, err := call.Arguments[2].eval(env) + if err != nil || toStr == nil { + return nil, err + } + + if _, ok := str.(*evalBytes); !ok { + str, err = evalToVarchar(str, call.collate, true) + if err != nil { + return nil, err + } + } + + col := str.(*evalBytes).col + fromStr, err = evalToVarchar(fromStr, col.Collation, true) + if err != nil { + return nil, err + } + + toStr, err = evalToVarchar(toStr, col.Collation, true) + if err != nil { + return nil, err + } + + strBytes := str.(*evalBytes).bytes + fromBytes := fromStr.(*evalBytes).bytes + toBytes := toStr.(*evalBytes).bytes + + out := replace(strBytes, fromBytes, toBytes) + return newEvalRaw(str.SQLType(), out, col), nil +} + +func (call *builtinReplace) compile(c *compiler) (ctype, error) { + str, err := call.Arguments[0].compile(c) + if err != nil { + return ctype{}, err + } + + fromStr, err := call.Arguments[1].compile(c) + if err != nil { + return ctype{}, err + } + + toStr, err := call.Arguments[2].compile(c) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck3(str, fromStr, toStr) + if !str.isTextual() { + c.asm.Convert_xce(3, sqltypes.VarChar, c.collation) + str.Col = collations.TypedCollation{ + Collation: c.collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + fromCharset := colldata.Lookup(fromStr.Col.Collation).Charset() + toCharset := colldata.Lookup(toStr.Col.Collation).Charset() + strCharset := colldata.Lookup(str.Col.Collation).Charset() + if !fromStr.isTextual() || (fromCharset != strCharset && !strCharset.IsSuperset(fromCharset)) { + c.asm.Convert_xce(2, sqltypes.VarChar, str.Col.Collation) + fromStr.Col = collations.TypedCollation{ + Collation: str.Col.Collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + if !toStr.isTextual() || (toCharset != strCharset && !strCharset.IsSuperset(toCharset)) { + c.asm.Convert_xce(1, sqltypes.VarChar, str.Col.Collation) + toStr.Col = collations.TypedCollation{ + Collation: str.Col.Collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + c.asm.Replace() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.VarChar, Col: str.Col, Flag: flagNullable}, nil +} + +func replace(str, from, to []byte) []byte { + if len(from) == 0 { + return str + } + n := bytes.Count(str, from) + if n == 0 { + return str + } + + out := make([]byte, len(str)+n*(len(to)-len(from))) + end := 0 + start := 0 + for i := 0; i < n; i++ { + pos := start + bytes.Index(str[start:], from) + end += copy(out[end:], str[start:pos]) + end += copy(out[end:], to) + start = pos + len(from) + } + end += copy(out[end:], str[start:]) + return out[0:end] +} diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index c53faa7f217..b55f6c2c18d 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -83,6 +83,7 @@ var Cases = []TestCase{ {Run: FnTrim}, {Run: FnSubstr}, {Run: FnLocate}, + {Run: FnReplace}, {Run: FnConcat}, {Run: FnConcatWs}, {Run: FnChar}, @@ -1577,6 +1578,33 @@ func FnLocate(yield Query) { } } +func FnReplace(yield Query) { + cases := []string{ + `REPLACE('www.mysql.com', 'w', 'Ww')`, + // MySQL doesn't do collation matching for replace, only + // byte equivalence, but make sure to check. + `REPLACE('straße', 'ss', 'b')`, + `REPLACE('straße', 'ß', 'b')`, + // From / to strings are converted into the collation of + // the input string. + `REPLACE('fooÿbar', _latin1 0xFF, _latin1 0xFE)`, + // First occurence is replaced + `replace('fff', 'ff', 'gg')`, + } + + for _, q := range cases { + yield(q, nil) + } + + for _, substr := range inputStrings { + for _, str := range inputStrings { + for _, i := range inputStrings { + yield(fmt.Sprintf("REPLACE(%s, %s, %s)", substr, str, i), nil) + } + } + } +} + func FnConcat(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("CONCAT(%s)", str), nil) diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 71eff66bc2b..cabe406bfb6 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -627,6 +627,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { } call = CallExpr{Arguments: []IR{call.Arguments[1], call.Arguments[0]}, Method: method} return &builtinLocate{CallExpr: call, collate: ast.cfg.Collation}, nil + case "replace": + if len(args) != 3 { + return nil, argError(method) + } + return &builtinReplace{CallExpr: call, collate: ast.cfg.Collation}, nil default: return nil, translateExprNotSupported(fn) }