diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 372cf25562c..e7424941680 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -517,6 +517,70 @@ func (node *ComparisonExpr) IsImpossible() bool { return false } +func (op ComparisonExprOperator) Inverse() ComparisonExprOperator { + switch op { + case EqualOp: + return NotEqualOp + case LessThanOp: + return GreaterEqualOp + case GreaterThanOp: + return LessEqualOp + case LessEqualOp: + return GreaterThanOp + case GreaterEqualOp: + return LessThanOp + case NotEqualOp: + return EqualOp + case NullSafeEqualOp: + return NotEqualOp + case InOp: + return NotInOp + case NotInOp: + return InOp + case LikeOp: + return NotLikeOp + case NotLikeOp: + return LikeOp + case RegexpOp: + return NotRegexpOp + case NotRegexpOp: + return RegexpOp + } + panic("unreachable") +} + +// SwitchSides returns the reversed comparison operator if applicable, along with a boolean indicating success. +// For symmetric operators like '=', '!=', and '<=>', it returns the same operator and true. +// For directional comparison operators ('<', '>', '<=', '>='), it returns the opposite operator and true. +// For operators that imply directionality or cannot be logically reversed (such as 'IN', 'LIKE', 'REGEXP'), +// it returns the original operator and false, indicating that switching sides is not valid. +func (op ComparisonExprOperator) SwitchSides() (ComparisonExprOperator, bool) { + switch op { + case EqualOp, NotEqualOp, NullSafeEqualOp: + // These operators are symmetric, so switching sides has no effect + return op, true + case LessThanOp: + return GreaterThanOp, true + case GreaterThanOp: + return LessThanOp, true + case LessEqualOp: + return GreaterEqualOp, true + case GreaterEqualOp: + return LessEqualOp, true + default: + return op, false + } +} + +func (op ComparisonExprOperator) IsCommutative() bool { + switch op { + case EqualOp, NotEqualOp, NullSafeEqualOp: + return true + default: + return false + } +} + // NewStrLiteral builds a new StrVal. func NewStrLiteral(in string) *Literal { return &Literal{Type: StrVal, Val: in} @@ -1498,6 +1562,65 @@ func (op ComparisonExprOperator) ToString() string { } } +func ComparisonExprOperatorFromJson(s string) ComparisonExprOperator { + switch s { + case EqualStr: + return EqualOp + case JsonLessThanStr: + return LessThanOp + case JsonGreaterThanStr: + return GreaterThanOp + case JsonLessThanOrEqualStr: + return LessEqualOp + case JsonGreaterThanOrEqualStr: + return GreaterEqualOp + case NotEqualStr: + return NotEqualOp + case NullSafeEqualStr: + return NullSafeEqualOp + case InStr: + return InOp + case NotInStr: + return NotInOp + case LikeStr: + return LikeOp + case NotLikeStr: + return NotLikeOp + case RegexpStr: + return RegexpOp + case NotRegexpStr: + return NotRegexpOp + default: + return 0 + } +} + +const ( + JsonGreaterThanStr = "gt" + JsonLessThanStr = "lt" + JsonGreaterThanOrEqualStr = "ge" + JsonLessThanOrEqualStr = "le" +) + +// JSONString returns a string representation for this operator that does not need escaping in JSON +func (op ComparisonExprOperator) JSONString() string { + switch op { + case EqualOp, NotEqualOp, NullSafeEqualOp, InOp, NotInOp, LikeOp, NotLikeOp, RegexpOp, NotRegexpOp: + // These operators are safe for JSON output, so we delegate to ToString + return op.ToString() + case LessThanOp: + return JsonLessThanStr + case GreaterThanOp: + return JsonGreaterThanStr + case LessEqualOp: + return JsonLessThanOrEqualStr + case GreaterEqualOp: + return JsonGreaterThanOrEqualStr + default: + panic("unreachable") + } +} + // ToString returns the operator as a string func (op IsExprOperator) ToString() string { switch op { diff --git a/go/vt/sqlparser/constants.go b/go/vt/sqlparser/constants.go index 34189b52380..08538fbd749 100644 --- a/go/vt/sqlparser/constants.go +++ b/go/vt/sqlparser/constants.go @@ -692,47 +692,6 @@ const ( All ) -func (op ComparisonExprOperator) Inverse() ComparisonExprOperator { - switch op { - case EqualOp: - return NotEqualOp - case LessThanOp: - return GreaterEqualOp - case GreaterThanOp: - return LessEqualOp - case LessEqualOp: - return GreaterThanOp - case GreaterEqualOp: - return LessThanOp - case NotEqualOp: - return EqualOp - case NullSafeEqualOp: - return NotEqualOp - case InOp: - return NotInOp - case NotInOp: - return InOp - case LikeOp: - return NotLikeOp - case NotLikeOp: - return LikeOp - case RegexpOp: - return NotRegexpOp - case NotRegexpOp: - return RegexpOp - } - panic("unreachable") -} - -func (op ComparisonExprOperator) IsCommutative() bool { - switch op { - case EqualOp, NotEqualOp, NullSafeEqualOp: - return true - default: - return false - } -} - // Constant for Enum Type - IsExprOperator const ( IsNullOp IsExprOperator = iota diff --git a/go/vt/vtgate/executor_vexplain_test.go b/go/vt/vtgate/executor_vexplain_test.go index a19c353ef5f..443370205a9 100644 --- a/go/vt/vtgate/executor_vexplain_test.go +++ b/go/vt/vtgate/executor_vexplain_test.go @@ -122,54 +122,61 @@ func TestVExplainKeys(t *testing.T) { { query: "select count(*), col2 from music group by col2", expectedRowString: `{ + "statementType": "SELECT", "groupingColumns": [ "music.col2" ], - "statementType": "SELECT" + "selectColumns": [ + "music.col2" + ] }`, }, { query: "select * from user u join user_extra ue on u.id = ue.user_id where u.col1 > 100 and ue.noLimit = 'foo'", expectedRowString: `{ + "statementType": "SELECT", "joinColumns": [ - "user.id", - "user_extra.user_id" + "user.id =", + "user_extra.user_id =" ], "filterColumns": [ - "user.col1", - "user_extra.noLimit" - ], - "statementType": "SELECT" + "user.col1 gt", + "user_extra.noLimit =" + ] }`, }, { // same as above, but written differently query: "select * from user_extra ue, user u where ue.noLimit = 'foo' and u.col1 > 100 and ue.user_id = u.id", expectedRowString: `{ + "statementType": "SELECT", "joinColumns": [ - "user.id", - "user_extra.user_id" + "user.id =", + "user_extra.user_id =" ], "filterColumns": [ - "user.col1", - "user_extra.noLimit" - ], - "statementType": "SELECT" + "user.col1 gt", + "user_extra.noLimit =" + ] }`, }, { query: "select u.foo, ue.bar, count(*) from user u join user_extra ue on u.id = ue.user_id where u.name = 'John Doe' group by 1, 2", expectedRowString: `{ + "statementType": "SELECT", "groupingColumns": [ "user.foo", "user_extra.bar" ], "joinColumns": [ - "user.id", - "user_extra.user_id" + "user.id =", + "user_extra.user_id =" ], "filterColumns": [ - "user.name" + "user.name =" ], - "statementType": "SELECT" + "selectColumns": [ + "user.foo", + "user_extra.bar" + ] }`, }, { @@ -181,47 +188,64 @@ func TestVExplainKeys(t *testing.T) { { query: "select name, sum(amount) from user group by name", expectedRowString: `{ + "statementType": "SELECT", "groupingColumns": [ "user.name" ], - "statementType": "SELECT" + "selectColumns": [ + "user.amount", + "user.name" + ] }`, }, { query: "select name from user where age > 30", expectedRowString: `{ + "statementType": "SELECT", "filterColumns": [ - "user.age" + "user.age gt" ], - "statementType": "SELECT" + "selectColumns": [ + "user.name" + ] }`, }, { query: "select * from user where name = 'apa' union select * from user_extra where name = 'monkey'", expectedRowString: `{ + "statementType": "SELECT", "filterColumns": [ - "user.name", - "user_extra.name" - ], - "statementType": "SELECT" + "user.name =", + "user_extra.name =" + ] }`, }, { query: "update user set name = 'Jane Doe' where id = 1", expectedRowString: `{ + "statementType": "UPDATE", "filterColumns": [ - "user.id" - ], - "statementType": "UPDATE" + "user.id =" + ] }`, }, { query: "delete from user where order_date < '2023-01-01'", expectedRowString: `{ + "statementType": "DELETE", "filterColumns": [ - "user.order_date" - ], - "statementType": "DELETE" + "user.order_date lt" + ] +}`, + }, + { + query: "select * from user where name between 'A' and 'C'", + expectedRowString: `{ + "statementType": "SELECT", + "filterColumns": [ + "user.name ge", + "user.name le" + ] }`, }, } diff --git a/go/vt/vtgate/planbuilder/operators/keys.go b/go/vt/vtgate/planbuilder/operators/keys.go index ccebcbd7c10..c16d9b23b63 100644 --- a/go/vt/vtgate/planbuilder/operators/keys.go +++ b/go/vt/vtgate/planbuilder/operators/keys.go @@ -17,43 +17,124 @@ limitations under the License. package operators import ( + "encoding/json" "fmt" "slices" + "sort" + "strings" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) -type VExplainKeys struct { - GroupingColumns []string `json:"groupingColumns,omitempty"` - TableName []string `json:"tableName,omitempty"` - JoinColumns []string `json:"joinColumns,omitempty"` - FilterColumns []string `json:"filterColumns,omitempty"` - StatementType string `json:"statementType"` +type ( + Column struct { + Table string + Name string + } + ColumnUse struct { + Column Column + Uses sqlparser.ComparisonExprOperator + } + VExplainKeys struct { + StatementType string `json:"statementType"` + TableName []string `json:"tableName,omitempty"` + GroupingColumns []Column `json:"groupingColumns,omitempty"` + JoinColumns []ColumnUse `json:"joinColumns,omitempty"` + FilterColumns []ColumnUse `json:"filterColumns,omitempty"` + SelectColumns []Column `json:"selectColumns,omitempty"` + } +) + +func (c Column) MarshalJSON() ([]byte, error) { + if c.Table != "" { + return json.Marshal(fmt.Sprintf("%s.%s", c.Table, c.Name)) + } + return json.Marshal(c.Name) +} + +func (c *Column) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + parts := strings.Split(s, ".") + if len(parts) > 1 { + c.Table = parts[0] + c.Name = parts[1] + } else { + c.Name = s + } + return nil +} + +func (cu ColumnUse) MarshalJSON() ([]byte, error) { + return json.Marshal(fmt.Sprintf("%s %s", cu.Column, cu.Uses.JSONString())) +} + +func (cu *ColumnUse) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + parts := strings.Fields(s) + if len(parts) != 2 { + return fmt.Errorf("invalid ColumnUse format: %s", s) + } + if err := cu.Column.UnmarshalJSON([]byte(`"` + parts[0] + `"`)); err != nil { + return err + } + cu.Uses = sqlparser.ComparisonExprOperatorFromJson(strings.ToLower(parts[1])) + return nil +} + +func (c Column) String() string { + return fmt.Sprintf("%s.%s", c.Table, c.Name) +} + +func (cu ColumnUse) String() string { + return fmt.Sprintf("%s %s", cu.Column, cu.Uses.JSONString()) +} + +type columnUse struct { + col *sqlparser.ColName + use sqlparser.ComparisonExprOperator } func GetVExplainKeys(ctx *plancontext.PlanningContext, stmt sqlparser.Statement) (result VExplainKeys) { - var filterColumns, joinColumns, groupingColumns []*sqlparser.ColName + var groupingColumns, selectColumns []*sqlparser.ColName + var filterColumns, joinColumns []columnUse addPredicate := func(predicate sqlparser.Expr) { predicates := sqlparser.SplitAndExpression(nil, predicate) for _, expr := range predicates { - cmp, ok := expr.(*sqlparser.ComparisonExpr) - if !ok { - continue - } - lhs, lhsOK := cmp.Left.(*sqlparser.ColName) - rhs, rhsOK := cmp.Right.(*sqlparser.ColName) - if lhsOK && rhsOK && ctx.SemTable.RecursiveDeps(lhs) != ctx.SemTable.RecursiveDeps(rhs) { - joinColumns = append(joinColumns, lhs, rhs) - continue - } - if lhsOK { - filterColumns = append(filterColumns, lhs) - } - if rhsOK { - filterColumns = append(filterColumns, rhs) + switch cmp := expr.(type) { + case *sqlparser.ComparisonExpr: + lhs, lhsOK := cmp.Left.(*sqlparser.ColName) + rhs, rhsOK := cmp.Right.(*sqlparser.ColName) + + var output = &filterColumns + if lhsOK && rhsOK && ctx.SemTable.RecursiveDeps(lhs) != ctx.SemTable.RecursiveDeps(rhs) { + // If the columns are from different tables, they are considered join columns + output = &joinColumns + } + + if lhsOK { + *output = append(*output, columnUse{lhs, cmp.Operator}) + } + + if switchedOp, ok := cmp.Operator.SwitchSides(); rhsOK && ok { + *output = append(*output, columnUse{rhs, switchedOp}) + } + case *sqlparser.BetweenExpr: + if col, ok := cmp.Left.(*sqlparser.ColName); ok { + // a BETWEEN 100 AND 200 is equivalent to a >= 100 AND a <= 200 + filterColumns = append(filterColumns, + columnUse{col, sqlparser.GreaterEqualOp}, + columnUse{col, sqlparser.LessEqualOp}) + } } + } } @@ -65,41 +146,68 @@ func GetVExplainKeys(ctx *plancontext.PlanningContext, stmt sqlparser.Statement) addPredicate(node.On) case *sqlparser.GroupBy: for _, expr := range node.Exprs { - predicates := sqlparser.SplitAndExpression(nil, expr) - for _, expr := range predicates { - col, ok := expr.(*sqlparser.ColName) - if ok { - groupingColumns = append(groupingColumns, col) - } + col, ok := expr.(*sqlparser.ColName) + if ok { + groupingColumns = append(groupingColumns, col) } } + case *sqlparser.AliasedExpr: + _ = sqlparser.VisitSQLNode(node, func(e sqlparser.SQLNode) (kontinue bool, err error) { + if col, ok := e.(*sqlparser.ColName); ok { + selectColumns = append(selectColumns, col) + } + return true, nil + }) } return true, nil }) return VExplainKeys{ + SelectColumns: getUniqueColNames(ctx, selectColumns), GroupingColumns: getUniqueColNames(ctx, groupingColumns), - JoinColumns: getUniqueColNames(ctx, joinColumns), - FilterColumns: getUniqueColNames(ctx, filterColumns), + JoinColumns: getUniqueColUsages(ctx, joinColumns), + FilterColumns: getUniqueColUsages(ctx, filterColumns), StatementType: sqlparser.ASTToStatementType(stmt).String(), } } -func getUniqueColNames(ctx *plancontext.PlanningContext, columns []*sqlparser.ColName) []string { - var colNames []string - for _, col := range columns { - tableInfo, err := ctx.SemTable.TableInfoForExpr(col) - if err != nil { - continue +func getUniqueColNames(ctx *plancontext.PlanningContext, inCols []*sqlparser.ColName) (columns []Column) { + for _, col := range inCols { + column := createColumn(ctx, col) + if column != nil { + columns = append(columns, *column) } - table := tableInfo.GetVindexTable() - if table == nil { - continue + } + sort.Slice(columns, func(i, j int) bool { + return columns[i].String() < columns[j].String() + }) + + return slices.Compact(columns) +} + +func getUniqueColUsages(ctx *plancontext.PlanningContext, inCols []columnUse) (columns []ColumnUse) { + for _, col := range inCols { + column := createColumn(ctx, col.col) + if column != nil { + columns = append(columns, ColumnUse{Column: *column, Uses: col.use}) } - colNames = append(colNames, fmt.Sprintf("%s.%s", table.Name.String(), col.Name.String())) } - slices.Sort(colNames) - return slices.Compact(colNames) + sort.Slice(columns, func(i, j int) bool { + return columns[i].Column.String() < columns[j].Column.String() + }) + return slices.Compact(columns) +} + +func createColumn(ctx *plancontext.PlanningContext, col *sqlparser.ColName) *Column { + tableInfo, err := ctx.SemTable.TableInfoForExpr(col) + if err != nil { + return nil + } + table := tableInfo.GetVindexTable() + if table == nil { + return nil + } + return &Column{Table: table.Name.String(), Name: col.Name.String()} } diff --git a/go/vt/vtgate/planbuilder/operators/keys_test.go b/go/vt/vtgate/planbuilder/operators/keys_test.go new file mode 100644 index 00000000000..6f53a33da5c --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/keys_test.go @@ -0,0 +1,66 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" +) + +func TestMarshalUnmarshal(t *testing.T) { + // Test that marshalling and unmarshalling a struct works as expected + original := VExplainKeys{ + StatementType: "SELECT", + TableName: []string{"users", "orders"}, + GroupingColumns: []Column{ + {Table: "", Name: "category"}, + {Table: "users", Name: "department"}, + }, + JoinColumns: []ColumnUse{ + {Column: Column{Table: "users", Name: "id"}, Uses: sqlparser.EqualOp}, + {Column: Column{Table: "orders", Name: "user_id"}, Uses: sqlparser.EqualOp}, + }, + FilterColumns: []ColumnUse{ + {Column: Column{Table: "", Name: "age"}, Uses: sqlparser.GreaterThanOp}, + {Column: Column{Table: "orders", Name: "total"}, Uses: sqlparser.LessThanOp}, + }, + SelectColumns: []Column{ + {Table: "users", Name: "name"}, + {Table: "", Name: "email"}, + {Table: "orders", Name: "amount"}, + }, + } + + jsonData, err := json.Marshal(original) + require.NoError(t, err) + + t.Logf("Marshalled JSON: %s", jsonData) + + var unmarshalled VExplainKeys + err = json.Unmarshal(jsonData, &unmarshalled) + require.NoError(t, err) + + if diff := cmp.Diff(original, unmarshalled); diff != "" { + t.Errorf("Unmarshalled struct does not match original (-want +got):\n%s", diff) + t.FailNow() + } +}