diff --git a/go/vt/vtgate/evalengine/eval_result.go b/go/vt/vtgate/evalengine/eval_result.go index d9916af03be..5c1973d8eb1 100644 --- a/go/vt/vtgate/evalengine/eval_result.go +++ b/go/vt/vtgate/evalengine/eval_result.go @@ -62,6 +62,7 @@ func (er EvalResult) String() string { // TupleValues allows for retrieval of the value we expose for public consumption func (er EvalResult) TupleValues() []sqltypes.Value { + // TODO: Make this collation-aware switch v := er.v.(type) { case *evalTuple: result := make([]sqltypes.Value, 0, len(v.t)) diff --git a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go index 9bbc98ca2bd..96babb35666 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go +++ b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go @@ -89,6 +89,8 @@ const ( NotEqual // IsNotNull is used to filter a column if it is NULL IsNotNull + // In is used to filter a comparable column if equals any of the values from a specific tuple + In ) // Filter contains opcodes for filtering. @@ -97,6 +99,9 @@ type Filter struct { ColNum int Value sqltypes.Value + // Values will be used to store tuple/list values. + Values []sqltypes.Value + // Parameters for VindexMatch. // Vindex, VindexColumns and KeyRange, if set, will be used // to filter the row. @@ -166,6 +171,8 @@ func getOpcode(comparison *sqlparser.ComparisonExpr) (Opcode, error) { opcode = GreaterThanEqual case sqlparser.NotEqualOp: opcode = NotEqual + case sqlparser.InOp: + opcode = In default: return -1, fmt.Errorf("comparison operator %s not supported", comparison.Operator.ToString()) } @@ -238,6 +245,24 @@ func (plan *Plan) filter(values, result []sqltypes.Value, charsets []collations. if values[filter.ColNum].IsNull() { return false, nil } + case In: + if filter.Values == nil { + return false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected empty filter values when performing IN operator") + } + found := false + for _, filterValue := range filter.Values { + match, err := compare(Equal, values[filter.ColNum], filterValue, plan.env.CollationEnv(), charsets[filter.ColNum]) + if err != nil { + return false, err + } + if match { + found = true + break + } + } + if !found { + return false, nil + } default: match, err := compare(filter.Opcode, values[filter.ColNum], filter.Value, plan.env.CollationEnv(), charsets[filter.ColNum]) if err != nil { @@ -514,6 +539,27 @@ func (plan *Plan) getColumnFuncExpr(columnName string) *sqlparser.FuncExpr { return nil } +func (plan *Plan) appendTupleFilter(values sqlparser.ValTuple, opcode Opcode, colnum int) error { + pv, err := evalengine.Translate(values, &evalengine.Config{ + Collation: plan.env.CollationEnv().DefaultConnectionCharset(), + Environment: plan.env, + }) + if err != nil { + return err + } + env := evalengine.EmptyExpressionEnv(plan.env) + resolved, err := env.Evaluate(pv) + if err != nil { + return err + } + plan.Filters = append(plan.Filters, Filter{ + Opcode: opcode, + ColNum: colnum, + Values: resolved.TupleValues(), + }) + return nil +} + func (plan *Plan) analyzeWhere(vschema *localVSchema, where *sqlparser.Where) error { if where == nil { return nil @@ -537,6 +583,13 @@ func (plan *Plan) analyzeWhere(vschema *localVSchema, where *sqlparser.Where) er if err != nil { return err } + if opcode == In { + values, ok := expr.Right.(sqlparser.ValTuple) + if !ok { + return fmt.Errorf("unexpected: %v", sqlparser.String(expr)) + } + return plan.appendTupleFilter(values, opcode, colnum) + } val, ok := expr.Right.(*sqlparser.Literal) if !ok { return fmt.Errorf("unexpected: %v", sqlparser.String(expr)) diff --git a/go/vt/vttablet/tabletserver/vstreamer/planbuilder_test.go b/go/vt/vttablet/tabletserver/vstreamer/planbuilder_test.go index ba345b2a00b..aba74368802 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/planbuilder_test.go +++ b/go/vt/vttablet/tabletserver/vstreamer/planbuilder_test.go @@ -710,9 +710,15 @@ func TestPlanBuilderFilterComparison(t *testing.T) { outFilters: []Filter{{Opcode: LessThan, ColNum: 0, Value: sqltypes.NewInt64(2)}, {Opcode: LessThanEqual, ColNum: 1, Value: sqltypes.NewVarChar("xyz")}, }, + }, { + name: "in-operator", + inFilter: "select * from t1 where id in (1, 2)", + outFilters: []Filter{ + {Opcode: In, ColNum: 0, Values: []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}}, + }, }, { name: "vindex-and-operators", - inFilter: "select * from t1 where in_keyrange(id, 'hash', '-80') and id = 2 and val <> 'xyz'", + inFilter: "select * from t1 where in_keyrange(id, 'hash', '-80') and id = 2 and val <> 'xyz' and id in (100, 30)", outFilters: []Filter{ { Opcode: VindexMatch, @@ -727,6 +733,7 @@ func TestPlanBuilderFilterComparison(t *testing.T) { }, {Opcode: Equal, ColNum: 0, Value: sqltypes.NewInt64(2)}, {Opcode: NotEqual, ColNum: 1, Value: sqltypes.NewVarChar("xyz")}, + {Opcode: In, ColNum: 0, Values: []sqltypes.Value{sqltypes.NewInt64(100), sqltypes.NewInt64(30)}}, }, }}