diff --git a/go/vt/vtgate/executor_vexplain_test.go b/go/vt/vtgate/executor_vexplain_test.go index a89221b91c2..9d3dd3d7fdf 100644 --- a/go/vt/vtgate/executor_vexplain_test.go +++ b/go/vt/vtgate/executor_vexplain_test.go @@ -18,8 +18,11 @@ package vtgate import ( "context" + "fmt" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -112,17 +115,129 @@ func TestSimpleVexplainTrace(t *testing.T) { } func TestVExplainKeys(t *testing.T) { - executor, _, _, _, _ := createExecutorEnv(t) - - query := "vexplain keys select count(*), col2 from music group by col2" - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) - gotResult, err := executor.Execute(context.Background(), nil, "Execute", session, query, nil) - require.NoError(t, err) - - expectedRowString := `{ + tests := []struct { + query string + expectedRowString string + }{ + { + query: "select count(*), col2 from music group by col2", + expectedRowString: `{ + "Grouping Columns": [ + "music.col2" + ], "StatementType": "SELECT" -}` +}`, + }, { + query: "select * from user u join user_extra ue on u.id = ue.user_id where u.col1 > 100 and ue.noLimit = 'foo'", + expectedRowString: `{ + "JoinColumns": [ + "user.id", + "user_extra.user_id" + ], + "FilterColumns": [ + "user.col1", + "user_extra.noLimit" + ], + "StatementType": "SELECT" +}`, + }, { + // same as above, but written differently + query: "select * from user_extra ue, user u where ue.noLimit = 'foo' and u.col1 > 100 and ue.user_id = u.id", + expectedRowString: `{ + "JoinColumns": [ + "user.id", + "user_extra.user_id" + ], + "FilterColumns": [ + "user.col1", + "user_extra.noLimit" + ], + "StatementType": "SELECT" +}`, + }, + { + query: "select u.foo, ue.bar, count(*) from user u join user_extra ue on u.id = ue.user_id where u.name = 'John Doe' group by 1, 2", + expectedRowString: `{ + "Grouping Columns": [ + "user.foo", + "user_extra.bar" + ], + "JoinColumns": [ + "user.id", + "user_extra.user_id" + ], + "FilterColumns": [ + "user.name" + ], + "StatementType": "SELECT" +}`, + }, + { + query: "select * from (select * from user) as derived where derived.amount > 1000", + expectedRowString: `{ + "StatementType": "SELECT" +}`, + }, + { + query: "select name, sum(amount) from user group by name", + expectedRowString: `{ + "Grouping Columns": [ + "user.name" + ], + "StatementType": "SELECT" +}`, + }, + { + query: "select name from user where age > 30", + expectedRowString: `{ + "FilterColumns": [ + "user.age" + ], + "StatementType": "SELECT" +}`, + }, + { + query: "select * from user where name = 'apa' union select * from user_extra where name = 'monkey'", + expectedRowString: `{ + "FilterColumns": [ + "user.name", + "user_extra.name" + ], + "StatementType": "SELECT" +}`, + }, + { + query: "update user set name = 'Jane Doe' where id = 1", + expectedRowString: `{ + "FilterColumns": [ + "user.id" + ], + "StatementType": "UPDATE" +}`, + }, + { + query: "delete from user where order_date < '2023-01-01'", + expectedRowString: `{ + "FilterColumns": [ + "user.order_date" + ], + "StatementType": "DELETE" +}`, + }, + } - gotRowString := gotResult.Rows[0][0].ToString() - require.Equal(t, expectedRowString, gotRowString) + for _, tt := range tests { + t.Run(tt.query, func(t *testing.T) { + executor, _, _, _, _ := createExecutorEnv(t) + session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + gotResult, err := executor.Execute(context.Background(), nil, "Execute", session, "vexplain keys "+tt.query, nil) + require.NoError(t, err) + + gotRowString := gotResult.Rows[0][0].ToString() + assert.Equal(t, tt.expectedRowString, gotRowString) + if t.Failed() { + fmt.Println(gotRowString) + } + }) + } } diff --git a/go/vt/vtgate/planbuilder/operators/keys.go b/go/vt/vtgate/planbuilder/operators/keys.go index e9cba011d0e..1045f208e3a 100644 --- a/go/vt/vtgate/planbuilder/operators/keys.go +++ b/go/vt/vtgate/planbuilder/operators/keys.go @@ -17,6 +17,9 @@ limitations under the License. package operators import ( + "fmt" + "slices" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) @@ -30,6 +33,73 @@ type VExplainKeys struct { } func GetVExplainKeys(ctx *plancontext.PlanningContext, stmt sqlparser.Statement) (result VExplainKeys) { - result.StatementType = sqlparser.ASTToStatementType(stmt).String() - return + var filterColumns, joinColumns, groupingColumns []*sqlparser.ColName + + addPredicate := func(predicate sqlparser.Expr) { + predicates := sqlparser.SplitAndExpression(nil, predicate) + for _, expr := range predicates { + cmp, ok := expr.(*sqlparser.ComparisonExpr) + if !ok { + continue + } + lhs, lhsOK := cmp.Left.(*sqlparser.ColName) + rhs, rhsOK := cmp.Right.(*sqlparser.ColName) + if lhsOK && rhsOK && ctx.SemTable.RecursiveDeps(lhs) != ctx.SemTable.RecursiveDeps(rhs) { + joinColumns = append(joinColumns, lhs, rhs) + continue + } + if lhsOK { + filterColumns = append(filterColumns, lhs) + } + if rhsOK { + filterColumns = append(filterColumns, rhs) + } + } + } + + _ = sqlparser.VisitSQLNode(stmt, func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *sqlparser.Where: + addPredicate(node.Expr) + case *sqlparser.JoinCondition: + addPredicate(node.On) + case *sqlparser.GroupBy: + for _, expr := range node.Exprs { + predicates := sqlparser.SplitAndExpression(nil, expr) + for _, expr := range predicates { + col, ok := expr.(*sqlparser.ColName) + if ok { + groupingColumns = append(groupingColumns, col) + } + } + } + } + + return true, nil + }) + + return VExplainKeys{ + GroupingColumns: getUniqueColNames(ctx, groupingColumns), + JoinColumns: getUniqueColNames(ctx, joinColumns), + FilterColumns: getUniqueColNames(ctx, filterColumns), + StatementType: sqlparser.ASTToStatementType(stmt).String(), + } +} + +func getUniqueColNames(ctx *plancontext.PlanningContext, columns []*sqlparser.ColName) []string { + var colNames []string + for _, col := range columns { + tableInfo, err := ctx.SemTable.TableInfoForExpr(col) + if err != nil { + panic(err.Error()) // WIP this should not be left before merging + } + table := tableInfo.GetVindexTable() + if table == nil { + continue + } + colNames = append(colNames, fmt.Sprintf("%s.%s", table.Name.String(), col.Name.String())) + } + + slices.Sort(colNames) + return slices.Compact(colNames) }