From 4cb97bd7cd799d7e9e7f277a5ab6aa752241238a Mon Sep 17 00:00:00 2001 From: Florent Poinsard <35779988+frouioui@users.noreply.github.com> Date: Fri, 8 Sep 2023 09:12:26 -0400 Subject: [PATCH] Rewrite `USING` to `ON` condition for joins (#13931) Co-authored-by: Andres Taylor --- .../endtoend/vtgate/queries/misc/misc_test.go | 8 + go/vt/schemadiff/schema_test.go | 2 +- .../planbuilder/testdata/from_cases.json | 22 +++ .../testdata/unsupported_cases.json | 4 +- go/vt/vtgate/semantics/analyzer.go | 5 + go/vt/vtgate/semantics/binder.go | 7 - go/vt/vtgate/semantics/early_rewriter.go | 151 ++++++++++++------ go/vt/vtgate/semantics/early_rewriter_test.go | 23 ++- 8 files changed, 152 insertions(+), 70 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index fe9699b8267..77cb1784c43 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -296,3 +296,11 @@ func TestBuggyOuterJoin(t *testing.T) { mcmp.Exec("insert into t1(id1, id2) values (1,2), (42,5), (5, 42)") mcmp.Exec("select t1.id1, t2.id1 from t1 left join t1 as t2 on t2.id1 = t2.id2") } + +func TestLeftJoinUsingUnsharded(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + utils.Exec(t, mcmp.VtConn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + utils.Exec(t, mcmp.VtConn, "select * from uks.unsharded as A left join uks.unsharded as B using(id1)") +} diff --git a/go/vt/schemadiff/schema_test.go b/go/vt/schemadiff/schema_test.go index 06139db3b8b..79bf44117e2 100644 --- a/go/vt/schemadiff/schema_test.go +++ b/go/vt/schemadiff/schema_test.go @@ -631,7 +631,7 @@ func TestViewReferences(t *testing.T) { "create table t2(id int primary key, n int, info int)", "create view v1 as select id, c as ch from t1 where id > 0", "create view v2 as select n as num, info from t2", - "create view v3 as select num, v1.id, ch from v1 join v2 using (id) where info > 5", + "create view v3 as select num, v1.id, ch from v1 join v2 on v1.id = v2.num where info > 5", }, }, { diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 46db9519fcc..efa04bfa7ca 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -4017,5 +4017,27 @@ "zlookup_unique.t1" ] } + }, + { + "comment": "left join with using has to be transformed into inner join with on condition", + "query": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)", + "plan": { + "QueryType": "SELECT", + "Original": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)", + "Instructions": { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1 where 1 != 1", + "Query": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1", + "Table": "unsharded_authoritative" + }, + "TablesUsed": [ + "main.unsharded_authoritative" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index e1d07bc58e3..9919b600b23 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -37,12 +37,12 @@ { "comment": "join with USING construct", "query": "select * from user join user_extra using(id)", - "plan": "can't handle JOIN USING without authoritative tables" + "plan": "VT09015: schema tracking required" }, { "comment": "join with USING construct with 3 tables", "query": "select user.id from user join user_extra using(id) join music using(id2)", - "plan": "can't handle JOIN USING without authoritative tables" + "plan": "VT09015: schema tracking required" }, { "comment": "natural left join", diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 5b560ec7075..6955f4bafcd 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -192,6 +192,11 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return false } + if err := a.rewriter.up(cursor); err != nil { + a.setError(err) + return true + } + a.leaveProjection(cursor) return a.shouldContinue() } diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index c43180d1efa..d467a97c130 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -84,13 +84,6 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { } currScope.joinUsing[ident.Lowered()] = deps.direct } - if len(node.Using) > 0 { - err := rewriteJoinUsing(currScope, node.Using, b.org) - if err != nil { - return err - } - node.Using = nil - } case *sqlparser.ColName: currentScope := b.scoper.currentScope() deps, err := b.resolveColumn(node, currentScope, false) diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 7dfdbf78247..ca1ebc6d2f4 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -17,8 +17,8 @@ limitations under the License. package semantics import ( + "fmt" "strconv" - "strings" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -60,6 +60,33 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { return nil } +func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { + // this rewriting is done in the `up` phase, because we need the scope to have been + // filled in with the available tables + node, ok := cursor.Node().(*sqlparser.JoinTableExpr) + if !ok || len(node.Condition.Using) == 0 { + return nil + } + + err := rewriteJoinUsing(r.binder, node) + if err != nil { + return err + } + + // since the binder has already been over the join, we need to invoke it again so it + // can bind columns to the right tables + sqlparser.Rewrite(node.Condition.On, nil, func(cursor *sqlparser.Cursor) bool { + innerErr := r.binder.up(cursor) + if innerErr == nil { + return true + } + + err = innerErr + return false + }) + return err +} + // handleWhereClause processes WHERE clauses, specifically the HAVING clause. func handleWhereClause(node *sqlparser.Where, parent sqlparser.SQLNode) { if node.Type != sqlparser.HavingClause { @@ -344,44 +371,25 @@ func rewriteOrFalse(orExpr sqlparser.OrExpr) sqlparser.Expr { // // This function returns an error if it encounters a non-authoritative table or // if it cannot find a SELECT statement to add the WHERE predicate to. -func rewriteJoinUsing( - current *scope, - using sqlparser.Columns, - org originable, -) error { - predicates, err := buildJoinPredicates(current, using, org) +func rewriteJoinUsing(b *binder, join *sqlparser.JoinTableExpr) error { + predicates, err := buildJoinPredicates(b, join) if err != nil { return err } - // now, we go up the scope until we find a SELECT - // with a where clause we can add this predicate to - for current != nil { - sel, found := current.stmt.(*sqlparser.Select) - if !found { - current = current.parent - continue - } - if sel.Where != nil { - predicates = append(predicates, sel.Where.Expr) - sel.Where = nil - } - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: sqlparser.AndExpressions(predicates...), - } - return nil + if len(predicates) > 0 { + join.Condition.On = sqlparser.AndExpressions(predicates...) + join.Condition.Using = nil } - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "did not find WHERE clause") + return nil } // buildJoinPredicates constructs the join predicates for a given set of USING columns. // It returns a slice of sqlparser.Expr, each representing a join predicate for the given columns. -func buildJoinPredicates(current *scope, using sqlparser.Columns, org originable) ([]sqlparser.Expr, error) { - joinUsing := current.prepareUsingMap() +func buildJoinPredicates(b *binder, join *sqlparser.JoinTableExpr) ([]sqlparser.Expr, error) { var predicates []sqlparser.Expr - for _, column := range using { - foundTables, err := findTablesWithColumn(current, joinUsing, org, column) + for _, column := range join.Condition.Using { + foundTables, err := findTablesWithColumn(b, join, column) if err != nil { return nil, err } @@ -392,42 +400,79 @@ func buildJoinPredicates(current *scope, using sqlparser.Columns, org originable return predicates, nil } -// findTablesWithColumn finds the tables with the specified column in the current scope. -func findTablesWithColumn(current *scope, joinUsing map[TableSet]map[string]TableSet, org originable, column sqlparser.IdentifierCI) ([]sqlparser.TableName, error) { - var foundTables []sqlparser.TableName - - for _, tbl := range current.tables { - if !tbl.authoritative() { - return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "can't handle JOIN USING without authoritative tables") +func findOnlyOneTableInfoThatHasColumn(b *binder, tbl sqlparser.TableExpr, column sqlparser.IdentifierCI) ([]TableInfo, error) { + switch tbl := tbl.(type) { + case *sqlparser.AliasedTableExpr: + ts := b.tc.tableSetFor(tbl) + tblInfo := b.tc.Tables[ts.TableOffset()] + for _, info := range tblInfo.getColumns() { + if column.EqualString(info.Name) { + return []TableInfo{tblInfo}, nil + } } - - currTable := tbl.getTableSet(org) - usingCols := joinUsing[currTable] - if usingCols == nil { - usingCols = map[string]TableSet{} + return nil, nil + case *sqlparser.JoinTableExpr: + tblInfoR, err := findOnlyOneTableInfoThatHasColumn(b, tbl.RightExpr, column) + if err != nil { + return nil, err + } + tblInfoL, err := findOnlyOneTableInfoThatHasColumn(b, tbl.LeftExpr, column) + if err != nil { + return nil, err } - if hasColumnInTable(tbl, usingCols) { - tblName, err := tbl.Name() + return append(tblInfoL, tblInfoR...), nil + case *sqlparser.ParenTableExpr: + var tblInfo []TableInfo + for _, parenTable := range tbl.Exprs { + newTblInfo, err := findOnlyOneTableInfoThatHasColumn(b, parenTable, column) if err != nil { return nil, err } - foundTables = append(foundTables, tblName) + if tblInfo != nil && newTblInfo != nil { + return nil, vterrors.VT03021(column.String()) + } + if newTblInfo != nil { + tblInfo = newTblInfo + } } + return tblInfo, nil + default: + panic(fmt.Sprintf("unsupported TableExpr type in JOIN: %T", tbl)) } - - return foundTables, nil } -// hasColumnInTable checks if the specified table has the given column. -func hasColumnInTable(tbl TableInfo, usingCols map[string]TableSet) bool { - for _, col := range tbl.getColumns() { - _, found := usingCols[strings.ToLower(col.Name)] - if found { - return true +// findTablesWithColumn finds the tables with the specified column in the current scope. +func findTablesWithColumn(b *binder, join *sqlparser.JoinTableExpr, column sqlparser.IdentifierCI) ([]sqlparser.TableName, error) { + leftTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.LeftExpr, column) + if err != nil { + return nil, err + } + + rightTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.RightExpr, column) + if err != nil { + return nil, err + } + + if leftTableInfo == nil || rightTableInfo == nil { + return nil, ShardedError{Inner: vterrors.VT09015()} + } + var tableNames []sqlparser.TableName + for _, info := range leftTableInfo { + nm, err := info.Name() + if err != nil { + return nil, err + } + tableNames = append(tableNames, nm) + } + for _, info := range rightTableInfo { + nm, err := info.Name() + if err != nil { + return nil, err } + tableNames = append(tableNames, nm) } - return false + return tableNames, nil } // createComparisonPredicates creates a list of comparison predicates between the given column and foundTables. diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index f1b16853cfc..2846bfd9366 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -144,26 +144,32 @@ func TestExpandStar(t *testing.T) { }, { sql: "select * from t1 join t2 on t1.a = t2.c1", expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 join t2 on t1.a = t2.c1", + }, { + sql: "select * from t1 left join t2 on t1.a = t2.c1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 left join t2 on t1.a = t2.c1", + }, { + sql: "select * from t1 right join t2 on t1.a = t2.c1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 right join t2 on t1.a = t2.c1", }, { sql: "select * from t2 join t4 using (c1)", - expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 where t2.c1 = t4.c1", + expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 on t2.c1 = t4.c1", expanded: "main.t2.c1, main.t2.c2, main.t4.c4", }, { sql: "select * from t2 join t4 using (c1) join t2 as X using (c1)", - expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 join t2 as X where t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1", + expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 on t2.c1 = t4.c1 join t2 as X on t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1", }, { sql: "select * from t2 join t4 using (c1), t2 as t2b join t4 as t4b using (c1)", - expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4, t2 as t2b join t4 as t4b where t2b.c1 = t4b.c1 and t2.c1 = t4.c1", + expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4 on t2.c1 = t4.c1, t2 as t2b join t4 as t4b on t2b.c1 = t4b.c1", }, { sql: "select * from t1 join t5 using (b)", - expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b", + expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b", expanded: "main.t1.a, main.t1.b, main.t1.c, main.t5.a", }, { sql: "select * from t1 join t5 using (b) having b = 12", - expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b having b = 12", + expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b having b = 12", }, { sql: "select 1 from t1 join t5 using (b) having b = 12", - expSQL: "select 1 from t1 join t5 where t1.b = t5.b having t1.b = 12", + expSQL: "select 1 from t1 join t5 on t1.b = t5.b having t1.b = 12", }, { sql: "select * from (select 12) as t", expSQL: "select t.`12` from (select 12 from dual) as t", @@ -265,13 +271,16 @@ func TestRewriteJoinUsingColumns(t *testing.T) { expErr string }{{ sql: "select 1 from t1 join t2 using (a) where a = 42", - expSQL: "select 1 from t1 join t2 where t1.a = t2.a and t1.a = 42", + expSQL: "select 1 from t1 join t2 on t1.a = t2.a where t1.a = 42", }, { sql: "select 1 from t1 join t2 using (a), t3 where a = 42", expErr: "Column 'a' in field list is ambiguous", }, { sql: "select 1 from t1 join t2 using (a), t1 as b join t3 on (a) where a = 42", expErr: "Column 'a' in field list is ambiguous", + }, { + sql: "select 1 from t1 left join t2 using (a) where a = 42", + expSQL: "select 1 from t1 left join t2 on t1.a = t2.a where t1.a = 42", }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) {