diff --git a/go/vt/vtgate/planbuilder/operators/keys.go b/go/vt/vtgate/planbuilder/operators/keys.go index f5b592b2291..8d75b260aeb 100644 --- a/go/vt/vtgate/planbuilder/operators/keys.go +++ b/go/vt/vtgate/planbuilder/operators/keys.go @@ -36,16 +36,32 @@ type ( Column Column Uses sqlparser.ComparisonExprOperator } + JoinPredicate struct { + LHS, RHS Column + Uses sqlparser.ComparisonExprOperator + } VExplainKeys struct { - StatementType string `json:"statementType"` - TableName []string `json:"tableName,omitempty"` - GroupingColumns []Column `json:"groupingColumns,omitempty"` - JoinColumns []ColumnUse `json:"joinColumns,omitempty"` - FilterColumns []ColumnUse `json:"filterColumns,omitempty"` - SelectColumns []Column `json:"selectColumns,omitempty"` + StatementType string `json:"statementType"` + TableName []string `json:"tableName,omitempty"` + GroupingColumns []Column `json:"groupingColumns,omitempty"` + FilterColumns []ColumnUse `json:"filterColumns,omitempty"` + SelectColumns []Column `json:"selectColumns,omitempty"` + JoinPredicates []JoinPredicate `json:"joinPredicates,omitempty"` } ) +func newJoinPredicate(lhs, rhs Column, op sqlparser.ComparisonExprOperator) JoinPredicate { + // we want to try to keep the columns in the same order, no matter how the query was written + if lhs.String() > rhs.String() { + var success bool + op, success = op.SwitchSides() + if success { + lhs, rhs = rhs, lhs + } + } + return JoinPredicate{LHS: lhs, RHS: rhs, Uses: op} +} + func (c Column) MarshalJSON() ([]byte, error) { if c.Table != "" { return json.Marshal(fmt.Sprintf("%s.%s", c.Table, c.Name)) @@ -111,6 +127,35 @@ func (cu *ColumnUse) UnmarshalJSON(data []byte) error { return nil } +func (jp *JoinPredicate) MarshalJSON() ([]byte, error) { + return json.Marshal(jp.String()) +} + +func (jp *JoinPredicate) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + subStrings := strings.Split(s, " ") + if len(subStrings) != 3 { + return fmt.Errorf("invalid JoinPredicate format: %s", s) + } + + op, err := sqlparser.ComparisonExprOperatorFromJson(subStrings[1]) + if err != nil { + return fmt.Errorf("invalid comparison operator: %w", err) + } + jp.Uses = op + + if err = jp.LHS.UnmarshalJSON([]byte(`"` + subStrings[0] + `"`)); err != nil { + return err + } + if err = jp.RHS.UnmarshalJSON([]byte(`"` + subStrings[2] + `"`)); err != nil { + return err + } + return nil +} + func (c Column) String() string { return fmt.Sprintf("%s.%s", c.Table, c.Name) } @@ -119,14 +164,25 @@ func (cu ColumnUse) String() string { return fmt.Sprintf("%s %s", cu.Column, cu.Uses.JSONString()) } +func (jp JoinPredicate) String() string { + return fmt.Sprintf("%s %s %s", jp.LHS.String(), jp.Uses.JSONString(), jp.RHS.String()) +} + type columnUse struct { col *sqlparser.ColName use sqlparser.ComparisonExprOperator } +type joinPredicate struct { + lhs *sqlparser.ColName + rhs *sqlparser.ColName + uses sqlparser.ComparisonExprOperator +} + func GetVExplainKeys(ctx *plancontext.PlanningContext, stmt sqlparser.Statement) (result VExplainKeys) { var groupingColumns, selectColumns []*sqlparser.ColName var filterColumns, joinColumns []columnUse + var jps []joinPredicate addPredicate := func(predicate sqlparser.Expr) { predicates := sqlparser.SplitAndExpression(nil, predicate) @@ -140,6 +196,7 @@ func GetVExplainKeys(ctx *plancontext.PlanningContext, stmt sqlparser.Statement) if lhsOK && rhsOK && ctx.SemTable.RecursiveDeps(lhs) != ctx.SemTable.RecursiveDeps(rhs) { // If the columns are from different tables, they are considered join columns output = &joinColumns + jps = append(jps, joinPredicate{lhs: lhs, rhs: rhs, uses: cmp.Operator}) } if lhsOK { @@ -189,12 +246,34 @@ func GetVExplainKeys(ctx *plancontext.PlanningContext, stmt sqlparser.Statement) return VExplainKeys{ SelectColumns: getUniqueColNames(ctx, selectColumns), GroupingColumns: getUniqueColNames(ctx, groupingColumns), - JoinColumns: getUniqueColUsages(ctx, joinColumns), FilterColumns: getUniqueColUsages(ctx, filterColumns), StatementType: sqlparser.ASTToStatementType(stmt).String(), + JoinPredicates: getUniqueJoinPredicates(ctx, jps), } } +func getUniqueJoinPredicates(ctx *plancontext.PlanningContext, joinPredicates []joinPredicate) []JoinPredicate { + var result []JoinPredicate + for _, predicate := range joinPredicates { + lhs := createColumn(ctx, predicate.lhs) + rhs := createColumn(ctx, predicate.rhs) + if lhs == nil || rhs == nil { + continue + } + + result = append(result, newJoinPredicate(*lhs, *rhs, predicate.uses)) + } + + sort.Slice(result, func(i, j int) bool { + if result[i].LHS.Name == result[j].LHS.Name { + return result[i].RHS.Name < result[j].RHS.Name + } + return result[i].LHS.Name < result[j].LHS.Name + }) + + return slices.Compact(result) +} + func getUniqueColNames(ctx *plancontext.PlanningContext, inCols []*sqlparser.ColName) (columns []Column) { for _, col := range inCols { column := createColumn(ctx, col) diff --git a/go/vt/vtgate/planbuilder/operators/keys_test.go b/go/vt/vtgate/planbuilder/operators/keys_test.go index 5c60e62c70c..3fcd69f0b5b 100644 --- a/go/vt/vtgate/planbuilder/operators/keys_test.go +++ b/go/vt/vtgate/planbuilder/operators/keys_test.go @@ -35,10 +35,6 @@ func TestMarshalUnmarshal(t *testing.T) { {Table: "orders", Name: "category"}, {Table: "users", Name: "department"}, }, - JoinColumns: []ColumnUse{ - {Column: Column{Table: "users", Name: "id"}, Uses: sqlparser.EqualOp}, - {Column: Column{Table: "orders", Name: "user_id"}, Uses: sqlparser.EqualOp}, - }, FilterColumns: []ColumnUse{ {Column: Column{Table: "users", Name: "age"}, Uses: sqlparser.GreaterThanOp}, {Column: Column{Table: "orders", Name: "total"}, Uses: sqlparser.LessThanOp}, @@ -49,6 +45,9 @@ func TestMarshalUnmarshal(t *testing.T) { {Table: "users", Name: "email"}, {Table: "orders", Name: "amount"}, }, + JoinPredicates: []JoinPredicate{ + {LHS: Column{Table: "users", Name: "id"}, RHS: Column{Table: "orders", Name: "user_id"}, Uses: sqlparser.EqualOp}, + }, } jsonData, err := json.Marshal(original) diff --git a/go/vt/vtgate/testdata/executor_vexplain.json b/go/vt/vtgate/testdata/executor_vexplain.json index 5b70354f158..2b893588aa5 100644 --- a/go/vt/vtgate/testdata/executor_vexplain.json +++ b/go/vt/vtgate/testdata/executor_vexplain.json @@ -1,132 +1,129 @@ [ - { - "query": "select count(*), col2 from music group by col2", - "expected": { - "statementType": "SELECT", - "groupingColumns": [ - "music.col2" - ], - "selectColumns": [ - "music.col2" - ] - } - }, - { - "query": "select * from user u join user_extra ue on u.id = ue.user_id where u.col1 \u003e 100 and ue.noLimit = 'foo'", - "expected": { - "statementType": "SELECT", - "joinColumns": [ - "`user`.id =", - "user_extra.user_id =" - ], - "filterColumns": [ - "`user`.col1 gt", - "user_extra.noLimit =" - ] - } - }, - { - "query": "select * from user_extra ue, user u where ue.noLimit = 'foo' and u.col1 \u003e 100 and ue.user_id = u.id", - "expected": { - "statementType": "SELECT", - "joinColumns": [ - "`user`.id =", - "user_extra.user_id =" - ], - "filterColumns": [ - "`user`.col1 gt", - "user_extra.noLimit =" - ] - } - }, - { - "query": "select u.foo, ue.bar, count(*) from user u join user_extra ue on u.id = ue.user_id where u.name = 'John Doe' group by 1, 2", - "expected": { - "statementType": "SELECT", - "groupingColumns": [ - "`user`.foo", - "user_extra.bar" - ], - "joinColumns": [ - "`user`.id =", - "user_extra.user_id =" - ], - "filterColumns": [ - "`user`.`name` =" - ], - "selectColumns": [ - "`user`.foo", - "user_extra.bar" - ] - } - }, - { - "query": "select * from (select * from user) as derived where derived.amount \u003e 1000", - "expected": { - "statementType": "SELECT" - } - }, - { - "query": "select name, sum(amount) from user group by name", - "expected": { - "statementType": "SELECT", - "groupingColumns": [ - "`user`.`name`" - ], - "selectColumns": [ - "`user`.`name`", - "`user`.amount" - ] - } - }, - { - "query": "select name from user where age \u003e 30", - "expected": { - "statementType": "SELECT", - "filterColumns": [ - "`user`.age gt" - ], - "selectColumns": [ - "`user`.`name`" - ] - } - }, - { - "query": "select * from user where name = 'apa' union select * from user_extra where name = 'monkey'", - "expected": { - "statementType": "SELECT", - "filterColumns": [ - "`user`.`name` =", - "user_extra.`name` =" - ] - } - }, - { - "query": "update user set name = 'Jane Doe' where id = 1", - "expected": { - "statementType": "UPDATE", - "filterColumns": [ - "`user`.id =" - ] - } - }, - { - "query": "delete from user where order_date \u003c '2023-01-01'", - "expected": { - "statementType": "DELETE", - "filterColumns": [ - "`user`.order_date lt" - ] - } - }, - { - "query": "select * from user where name between 'A' and 'C'", - "expected": { - "statementType": "SELECT", - "filterColumns": [ - "`user`.`name` ge", - "`user`.`name` le" - ] - } - } + { + "query": "select count(*), col2 from music group by col2", + "expected": { + "statementType": "SELECT", + "groupingColumns": [ + "music.col2" + ], + "selectColumns": [ + "music.col2" + ] + } + }, + { + "query": "select * from user u join user_extra ue on u.id = ue.user_id where u.col1 \u003e 100 and ue.noLimit = 'foo'", + "expected": { + "statementType": "SELECT", + "filterColumns": [ + "`user`.col1 gt", + "user_extra.noLimit =" + ], + "joinPredicates": [ + "`user`.id = user_extra.user_id" + ] + } + }, + { + "query": "select * from user_extra ue, user u where ue.noLimit = 'foo' and u.col1 \u003e 100 and ue.user_id = u.id", + "expected": { + "statementType": "SELECT", + "filterColumns": [ + "`user`.col1 gt", + "user_extra.noLimit =" + ], + "joinPredicates": [ + "`user`.id = user_extra.user_id" + ] + } + }, + { + "query": "select u.foo, ue.bar, count(*) from user u join user_extra ue on u.id = ue.user_id where u.name = 'John Doe' group by 1, 2", + "expected": { + "statementType": "SELECT", + "groupingColumns": [ + "`user`.foo", + "user_extra.bar" + ], + "filterColumns": [ + "`user`.`name` =" + ], + "selectColumns": [ + "`user`.foo", + "user_extra.bar" + ], + "joinPredicates": [ + "`user`.id = user_extra.user_id" + ] + } + }, + { + "query": "select * from (select * from user) as derived where derived.amount \u003e 1000", + "expected": { + "statementType": "SELECT" + } + }, + { + "query": "select name, sum(amount) from user group by name", + "expected": { + "statementType": "SELECT", + "groupingColumns": [ + "`user`.`name`" + ], + "selectColumns": [ + "`user`.`name`", + "`user`.amount" + ] + } + }, + { + "query": "select name from user where age \u003e 30", + "expected": { + "statementType": "SELECT", + "filterColumns": [ + "`user`.age gt" + ], + "selectColumns": [ + "`user`.`name`" + ] + } + }, + { + "query": "select * from user where name = 'apa' union select * from user_extra where name = 'monkey'", + "expected": { + "statementType": "SELECT", + "filterColumns": [ + "`user`.`name` =", + "user_extra.`name` =" + ] + } + }, + { + "query": "update user set name = 'Jane Doe' where id = 1", + "expected": { + "statementType": "UPDATE", + "filterColumns": [ + "`user`.id =" + ] + } + }, + { + "query": "delete from user where order_date \u003c '2023-01-01'", + "expected": { + "statementType": "DELETE", + "filterColumns": [ + "`user`.order_date lt" + ] + } + }, + { + "query": "select * from user where name between 'A' and 'C'", + "expected": { + "statementType": "SELECT", + "filterColumns": [ + "`user`.`name` ge", + "`user`.`name` le" + ] + } + } ] \ No newline at end of file