Skip to content

Commit

Permalink
evalengine: Implement SUBSTRING (#14899)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored Jan 9, 2024
1 parent f2e0ea3 commit fab65bf
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 9 deletions.
12 changes: 12 additions & 0 deletions go/vt/vtgate/evalengine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 60 additions & 0 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (

"github.com/google/uuid"

"vitess.io/vitess/go/mysql/collations/charset/types"
"vitess.io/vitess/go/mysql/collations/colldata"

"vitess.io/vitess/go/hack"
Expand Down Expand Up @@ -2725,6 +2726,65 @@ func (asm *assembler) Fn_TRIM2(col collations.TypedCollation) {
}, "FN TRIM VARCHAR(SP-2) VARCHAR(SP-1)")
}

func (asm *assembler) Fn_SUBSTRING2(tt sqltypes.Type, cs types.Charset, col collations.TypedCollation) {
asm.adjustStack(-1)
asm.emit(func(env *ExpressionEnv) int {
str := env.vm.stack[env.vm.sp-2].(*evalBytes)
pos := env.vm.stack[env.vm.sp-1].(*evalInt64)

end := int64(charset.Length(cs, str.bytes))
if pos.i < 0 {
pos.i += end + 1
}
str.tt = int16(tt)
if pos.i < 1 || pos.i > end {
str.bytes = nil
str.col = col
env.vm.sp--
return 1
}

res := charset.Slice(cs, str.bytes, int(pos.i-1), int(end))
str.bytes = res
str.col = col
env.vm.sp--
return 1
}, "FN SUBSTRING VARCHAR(SP-2) INT64(SP-1)")
}

func (asm *assembler) Fn_SUBSTRING3(tt sqltypes.Type, cs types.Charset, col collations.TypedCollation) {
asm.adjustStack(-2)
asm.emit(func(env *ExpressionEnv) int {
str := env.vm.stack[env.vm.sp-3].(*evalBytes)
pos := env.vm.stack[env.vm.sp-2].(*evalInt64)
ll := env.vm.stack[env.vm.sp-1].(*evalInt64)

end := int64(charset.Length(cs, str.bytes))
if pos.i < 0 {
pos.i += end + 1
}
str.tt = int16(tt)

if pos.i < 1 || pos.i > end || ll.i < 1 {
str.bytes = nil
str.col = col
env.vm.sp -= 2
return 1
}

if ll.i > end-pos.i+1 {
ll.i = end - pos.i + 1
}
end = pos.i + ll.i - 1
res := charset.Slice(cs, str.bytes, int(pos.i-1), int(end))
str.tt = int16(tt)
str.bytes = res
str.col = col
env.vm.sp -= 2
return 1
}, "FN SUBSTRING VARCHAR(SP-3) INT64(SP-2) INT64(SP-1)")
}

func (asm *assembler) Fn_TO_BASE64(t sqltypes.Type, col collations.TypedCollation) {
asm.emit(func(env *ExpressionEnv) int {
str := env.vm.stack[env.vm.sp-1].(*evalBytes)
Expand Down
3 changes: 2 additions & 1 deletion go/vt/vtgate/evalengine/expr_collate.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ func (expr *CollateExpr) compile(c *compiler) (ctype, error) {
case sqltypes.VarBinary:
c.asm.Collate(expr.TypedCollation.Collation)
default:
return ctype{}, c.unsupported(expr)
c.asm.Convert_xc(1, sqltypes.VarChar, expr.TypedCollation.Collation, 0, false)
}

c.asm.jumpDestination(skip)

ct.Type = sqltypes.VarChar
ct.Col = expr.TypedCollation
ct.Flag |= flagExplicitCollation | flagNullable
return ct, nil
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/expr_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ func (expr *InExpr) compile(c *compiler) (ctype, error) {

return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil
case *BindVariable:
return ctype{}, c.unsupported(expr)
return ctype{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "rhs of an In operation should be a tuple")
default:
panic("unreachable")
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/expr_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) {
convt = c.compileToFloat(arg, 1)

case "FLOAT":
return ctype{}, c.unsupported(conv)
return ctype{}, conv.returnUnsupportedError()

case "SIGNED", "SIGNED INTEGER":
convt = c.compileToInt64(arg, 1)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/fn_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) {
case sqltypes.Null:
nullable = true
default:
return ctype{}, c.unsupported(call)
panic("unexpected argument type")
}
}

Expand Down
8 changes: 5 additions & 3 deletions go/vt/vtgate/evalengine/fn_regexp.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@ func compileRegex(pat eval, c colldata.Charset, flags icuregex.RegexpFlag) (*icu
return nil, err
}

var errNonConstantRegexp = errors.New("non-constant regexp")

func compileConstantRegex(c *compiler, args TupleExpr, pat, mt int, cs collations.TypedCollation, flags icuregex.RegexpFlag, f string) (*icuregex.Pattern, error) {
pattern := args[pat]
if !pattern.constant() {
return nil, c.unsupported(pattern)
return nil, errNonConstantRegexp
}
var err error
staticEnv := EmptyExpressionEnv(c.collationEnv)
Expand All @@ -225,7 +227,7 @@ func compileConstantRegex(c *compiler, args TupleExpr, pat, mt int, cs collation
if len(args) > mt {
fl := args[mt]
if !fl.constant() {
return nil, c.unsupported(fl)
return nil, errNonConstantRegexp
}
fl, err = simplifyExpr(staticEnv, fl)
if err != nil {
Expand All @@ -238,7 +240,7 @@ func compileConstantRegex(c *compiler, args TupleExpr, pat, mt int, cs collation
}

if pattern.(*Literal).inner == nil {
return nil, c.unsupported(pattern)
return nil, errNonConstantRegexp
}

innerPat, err := evalToVarchar(pattern.(*Literal).inner, cs.Collation, true)
Expand Down
108 changes: 106 additions & 2 deletions go/vt/vtgate/evalengine/fn_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ type (
collate collations.ID
trim sqlparser.TrimType
}

builtinSubstring struct {
CallExpr
collate collations.ID
}
)

var _ IR = (*builtinChangeCase)(nil)
Expand Down Expand Up @@ -817,7 +822,7 @@ func (expr *builtinStrcmp) compile(c *compiler) (ctype, error) {
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: nullableFlags(lt.Flag | rt.Flag)}, nil
}

func (call builtinTrim) eval(env *ExpressionEnv) (eval, error) {
func (call *builtinTrim) eval(env *ExpressionEnv) (eval, error) {
str, err := call.arg1(env)
if err != nil {
return nil, err
Expand Down Expand Up @@ -872,7 +877,7 @@ func (call builtinTrim) eval(env *ExpressionEnv) (eval, error) {
}
}

func (call builtinTrim) compile(c *compiler) (ctype, error) {
func (call *builtinTrim) compile(c *compiler) (ctype, error) {
str, err := call.Arguments[0].compile(c)
if err != nil {
return ctype{}, err
Expand Down Expand Up @@ -932,6 +937,105 @@ func (call builtinTrim) compile(c *compiler) (ctype, error) {
return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil
}

func (call *builtinSubstring) eval(env *ExpressionEnv) (eval, error) {
str, err := call.Arguments[0].eval(env)
if err != nil || str == nil {
return nil, err
}

tt := str.SQLType()
text, ok := str.(*evalBytes)
if !ok {
text, err = evalToVarchar(str, call.collate, true)
if err != nil {
return nil, err
}
tt = sqltypes.VarChar
}

p, err := call.Arguments[1].eval(env)
if err != nil || p == nil {
return nil, err
}

var l eval
if len(call.Arguments) > 2 {
l, err = call.Arguments[2].eval(env)
if err != nil || l == nil {
return nil, err
}
}

pos := evalToInt64(p).i
if pos == 0 {
return newEvalRaw(tt, nil, text.col), nil
}
cs := colldata.Lookup(text.col.Collation).Charset()
end := int64(charset.Length(cs, text.bytes))

if pos < 0 {
pos += end + 1
}
if pos < 1 || pos > end {
return newEvalRaw(tt, nil, text.col), nil
}

if len(call.Arguments) > 2 {
ll := evalToInt64(l).i
if ll < 1 {
return newEvalRaw(tt, nil, text.col), nil
}
if ll > end-pos+1 {
ll = end - pos + 1
}
end = pos + ll - 1
}
res := charset.Slice(cs, text.bytes, int(pos-1), int(end))
return newEvalRaw(tt, res, text.col), nil
}

func (call *builtinSubstring) compile(c *compiler) (ctype, error) {
str, err := call.Arguments[0].compile(c)
if err != nil {
return ctype{}, err
}

p, err := call.Arguments[1].compile(c)
if err != nil {
return ctype{}, err
}

tt := str.Type
skip1 := c.compileNullCheck2(str, p)

col := typedCoercionCollation(sqltypes.VarChar, c.collation)
switch {
case str.isTextual():
col = str.Col
default:
tt = sqltypes.VarChar
c.asm.Convert_xc(2, tt, col.Collation, 0, false)
}
_ = c.compileToInt64(p, 1)

cs := colldata.Lookup(str.Col.Collation).Charset()
var skip2 *jump
if len(call.Arguments) > 2 {
l, err := call.Arguments[2].compile(c)
if err != nil {
return ctype{}, err
}
skip2 = c.compileNullCheck2(str, l)
_ = c.compileToInt64(l, 1)
c.asm.Fn_SUBSTRING3(tt, cs, col)
} else {
c.asm.Fn_SUBSTRING2(tt, cs, col)
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: tt, Col: col, Flag: flagNullable}, nil
}

type builtinConcat struct {
CallExpr
collate collations.ID
Expand Down
32 changes: 32 additions & 0 deletions go/vt/vtgate/evalengine/testcases/cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ var Cases = []TestCase{
{Run: FnLTrim},
{Run: FnRTrim},
{Run: FnTrim},
{Run: FnSubstr},
{Run: FnConcat},
{Run: FnConcatWs},
{Run: FnHex},
Expand Down Expand Up @@ -1436,6 +1437,37 @@ func FnTrim(yield Query) {
}
}

func FnSubstr(yield Query) {
mysqlDocSamples := []string{
`SUBSTRING('Quadratically',5)`,
`SUBSTRING('foobarbar' FROM 4)`,
`SUBSTRING('Quadratically',5,6)`,
`SUBSTRING('Sakila', -3)`,
`SUBSTRING('Sakila', -5, 3)`,
`SUBSTRING('Sakila' FROM -4 FOR 2)`,
`SUBSTR('Quadratically',5)`,
`SUBSTR('foobarbar' FROM 4)`,
`SUBSTR('Quadratically',5,6)`,
`SUBSTR('Sakila', -3)`,
`SUBSTR('Sakila', -5, 3)`,
`SUBSTR('Sakila' FROM -4 FOR 2)`,
}

for _, q := range mysqlDocSamples {
yield(q, nil)
}

for _, str := range inputStrings {
for _, i := range radianInputs {
yield(fmt.Sprintf("SUBSTRING(%s, %s)", str, i), nil)

for _, j := range radianInputs {
yield(fmt.Sprintf("SUBSTRING(%s, %s, %s)", str, i, j), nil)
}
}
}
}

func FnConcat(yield Query) {
for _, str := range inputStrings {
yield(fmt.Sprintf("CONCAT(%s)", str), nil)
Expand Down
25 changes: 25 additions & 0 deletions go/vt/vtgate/evalengine/translate_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,31 @@ func (ast *astCompiler) translateCallable(call sqlparser.Callable) (IR, error) {
trim: call.Type,
}, nil

case *sqlparser.SubstrExpr:
var args []IR
str, err := ast.translateExpr(call.Name)
if err != nil {
return nil, err
}
args = append(args, str)
pos, err := ast.translateExpr(call.From)
if err != nil {
return nil, err
}
args = append(args, pos)

if call.To != nil {
to, err := ast.translateExpr(call.To)
if err != nil {
return nil, err
}
args = append(args, to)
}
var cexpr = CallExpr{Arguments: args, Method: "SUBSTRING"}
return &builtinSubstring{
CallExpr: cexpr,
collate: ast.cfg.Collation,
}, nil
case *sqlparser.IntervalDateExpr:
var err error
args := make([]IR, 2)
Expand Down

0 comments on commit fab65bf

Please sign in to comment.