Skip to content

Commit

Permalink
Rewrite USING to ON condition for joins (vitessio#13931)
Browse files Browse the repository at this point in the history
Co-authored-by: Andres Taylor <[email protected]>
  • Loading branch information
frouioui and systay authored Sep 8, 2023
1 parent 04fe231 commit 4cb97bd
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 70 deletions.
8 changes: 8 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,11 @@ func TestBuggyOuterJoin(t *testing.T) {
mcmp.Exec("insert into t1(id1, id2) values (1,2), (42,5), (5, 42)")
mcmp.Exec("select t1.id1, t2.id1 from t1 left join t1 as t2 on t2.id1 = t2.id2")
}

func TestLeftJoinUsingUnsharded(t *testing.T) {
mcmp, closer := start(t)
defer closer()

utils.Exec(t, mcmp.VtConn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)")
utils.Exec(t, mcmp.VtConn, "select * from uks.unsharded as A left join uks.unsharded as B using(id1)")
}
2 changes: 1 addition & 1 deletion go/vt/schemadiff/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ func TestViewReferences(t *testing.T) {
"create table t2(id int primary key, n int, info int)",
"create view v1 as select id, c as ch from t1 where id > 0",
"create view v2 as select n as num, info from t2",
"create view v3 as select num, v1.id, ch from v1 join v2 using (id) where info > 5",
"create view v3 as select num, v1.id, ch from v1 join v2 on v1.id = v2.num where info > 5",
},
},
{
Expand Down
22 changes: 22 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/from_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -4017,5 +4017,27 @@
"zlookup_unique.t1"
]
}
},
{
"comment": "left join with using has to be transformed into inner join with on condition",
"query": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)",
"plan": {
"QueryType": "SELECT",
"Original": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)",
"Instructions": {
"OperatorType": "Route",
"Variant": "Unsharded",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"FieldQuery": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1 where 1 != 1",
"Query": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1",
"Table": "unsharded_authoritative"
},
"TablesUsed": [
"main.unsharded_authoritative"
]
}
}
]
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/testdata/unsupported_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
{
"comment": "join with USING construct",
"query": "select * from user join user_extra using(id)",
"plan": "can't handle JOIN USING without authoritative tables"
"plan": "VT09015: schema tracking required"
},
{
"comment": "join with USING construct with 3 tables",
"query": "select user.id from user join user_extra using(id) join music using(id2)",
"plan": "can't handle JOIN USING without authoritative tables"
"plan": "VT09015: schema tracking required"
},
{
"comment": "natural left join",
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool {
return false
}

if err := a.rewriter.up(cursor); err != nil {
a.setError(err)
return true
}

a.leaveProjection(cursor)
return a.shouldContinue()
}
Expand Down
7 changes: 0 additions & 7 deletions go/vt/vtgate/semantics/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,6 @@ func (b *binder) up(cursor *sqlparser.Cursor) error {
}
currScope.joinUsing[ident.Lowered()] = deps.direct
}
if len(node.Using) > 0 {
err := rewriteJoinUsing(currScope, node.Using, b.org)
if err != nil {
return err
}
node.Using = nil
}
case *sqlparser.ColName:
currentScope := b.scoper.currentScope()
deps, err := b.resolveColumn(node, currentScope, false)
Expand Down
151 changes: 98 additions & 53 deletions go/vt/vtgate/semantics/early_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ limitations under the License.
package semantics

import (
"fmt"
"strconv"
"strings"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -60,6 +60,33 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error {
return nil
}

func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error {
// this rewriting is done in the `up` phase, because we need the scope to have been
// filled in with the available tables
node, ok := cursor.Node().(*sqlparser.JoinTableExpr)
if !ok || len(node.Condition.Using) == 0 {
return nil
}

err := rewriteJoinUsing(r.binder, node)
if err != nil {
return err
}

// since the binder has already been over the join, we need to invoke it again so it
// can bind columns to the right tables
sqlparser.Rewrite(node.Condition.On, nil, func(cursor *sqlparser.Cursor) bool {
innerErr := r.binder.up(cursor)
if innerErr == nil {
return true
}

err = innerErr
return false
})
return err
}

// handleWhereClause processes WHERE clauses, specifically the HAVING clause.
func handleWhereClause(node *sqlparser.Where, parent sqlparser.SQLNode) {
if node.Type != sqlparser.HavingClause {
Expand Down Expand Up @@ -344,44 +371,25 @@ func rewriteOrFalse(orExpr sqlparser.OrExpr) sqlparser.Expr {
//
// This function returns an error if it encounters a non-authoritative table or
// if it cannot find a SELECT statement to add the WHERE predicate to.
func rewriteJoinUsing(
current *scope,
using sqlparser.Columns,
org originable,
) error {
predicates, err := buildJoinPredicates(current, using, org)
func rewriteJoinUsing(b *binder, join *sqlparser.JoinTableExpr) error {
predicates, err := buildJoinPredicates(b, join)
if err != nil {
return err
}
// now, we go up the scope until we find a SELECT
// with a where clause we can add this predicate to
for current != nil {
sel, found := current.stmt.(*sqlparser.Select)
if !found {
current = current.parent
continue
}
if sel.Where != nil {
predicates = append(predicates, sel.Where.Expr)
sel.Where = nil
}
sel.Where = &sqlparser.Where{
Type: sqlparser.WhereClause,
Expr: sqlparser.AndExpressions(predicates...),
}
return nil
if len(predicates) > 0 {
join.Condition.On = sqlparser.AndExpressions(predicates...)
join.Condition.Using = nil
}
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "did not find WHERE clause")
return nil
}

// buildJoinPredicates constructs the join predicates for a given set of USING columns.
// It returns a slice of sqlparser.Expr, each representing a join predicate for the given columns.
func buildJoinPredicates(current *scope, using sqlparser.Columns, org originable) ([]sqlparser.Expr, error) {
joinUsing := current.prepareUsingMap()
func buildJoinPredicates(b *binder, join *sqlparser.JoinTableExpr) ([]sqlparser.Expr, error) {
var predicates []sqlparser.Expr

for _, column := range using {
foundTables, err := findTablesWithColumn(current, joinUsing, org, column)
for _, column := range join.Condition.Using {
foundTables, err := findTablesWithColumn(b, join, column)
if err != nil {
return nil, err
}
Expand All @@ -392,42 +400,79 @@ func buildJoinPredicates(current *scope, using sqlparser.Columns, org originable
return predicates, nil
}

// findTablesWithColumn finds the tables with the specified column in the current scope.
func findTablesWithColumn(current *scope, joinUsing map[TableSet]map[string]TableSet, org originable, column sqlparser.IdentifierCI) ([]sqlparser.TableName, error) {
var foundTables []sqlparser.TableName

for _, tbl := range current.tables {
if !tbl.authoritative() {
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "can't handle JOIN USING without authoritative tables")
func findOnlyOneTableInfoThatHasColumn(b *binder, tbl sqlparser.TableExpr, column sqlparser.IdentifierCI) ([]TableInfo, error) {
switch tbl := tbl.(type) {
case *sqlparser.AliasedTableExpr:
ts := b.tc.tableSetFor(tbl)
tblInfo := b.tc.Tables[ts.TableOffset()]
for _, info := range tblInfo.getColumns() {
if column.EqualString(info.Name) {
return []TableInfo{tblInfo}, nil
}
}

currTable := tbl.getTableSet(org)
usingCols := joinUsing[currTable]
if usingCols == nil {
usingCols = map[string]TableSet{}
return nil, nil
case *sqlparser.JoinTableExpr:
tblInfoR, err := findOnlyOneTableInfoThatHasColumn(b, tbl.RightExpr, column)
if err != nil {
return nil, err
}
tblInfoL, err := findOnlyOneTableInfoThatHasColumn(b, tbl.LeftExpr, column)
if err != nil {
return nil, err
}

if hasColumnInTable(tbl, usingCols) {
tblName, err := tbl.Name()
return append(tblInfoL, tblInfoR...), nil
case *sqlparser.ParenTableExpr:
var tblInfo []TableInfo
for _, parenTable := range tbl.Exprs {
newTblInfo, err := findOnlyOneTableInfoThatHasColumn(b, parenTable, column)
if err != nil {
return nil, err
}
foundTables = append(foundTables, tblName)
if tblInfo != nil && newTblInfo != nil {
return nil, vterrors.VT03021(column.String())
}
if newTblInfo != nil {
tblInfo = newTblInfo
}
}
return tblInfo, nil
default:
panic(fmt.Sprintf("unsupported TableExpr type in JOIN: %T", tbl))
}

return foundTables, nil
}

// hasColumnInTable checks if the specified table has the given column.
func hasColumnInTable(tbl TableInfo, usingCols map[string]TableSet) bool {
for _, col := range tbl.getColumns() {
_, found := usingCols[strings.ToLower(col.Name)]
if found {
return true
// findTablesWithColumn finds the tables with the specified column in the current scope.
func findTablesWithColumn(b *binder, join *sqlparser.JoinTableExpr, column sqlparser.IdentifierCI) ([]sqlparser.TableName, error) {
leftTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.LeftExpr, column)
if err != nil {
return nil, err
}

rightTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.RightExpr, column)
if err != nil {
return nil, err
}

if leftTableInfo == nil || rightTableInfo == nil {
return nil, ShardedError{Inner: vterrors.VT09015()}
}
var tableNames []sqlparser.TableName
for _, info := range leftTableInfo {
nm, err := info.Name()
if err != nil {
return nil, err
}
tableNames = append(tableNames, nm)
}
for _, info := range rightTableInfo {
nm, err := info.Name()
if err != nil {
return nil, err
}
tableNames = append(tableNames, nm)
}
return false
return tableNames, nil
}

// createComparisonPredicates creates a list of comparison predicates between the given column and foundTables.
Expand Down
23 changes: 16 additions & 7 deletions go/vt/vtgate/semantics/early_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,32 @@ func TestExpandStar(t *testing.T) {
}, {
sql: "select * from t1 join t2 on t1.a = t2.c1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 join t2 on t1.a = t2.c1",
}, {
sql: "select * from t1 left join t2 on t1.a = t2.c1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 left join t2 on t1.a = t2.c1",
}, {
sql: "select * from t1 right join t2 on t1.a = t2.c1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 right join t2 on t1.a = t2.c1",
}, {
sql: "select * from t2 join t4 using (c1)",
expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 where t2.c1 = t4.c1",
expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 on t2.c1 = t4.c1",
expanded: "main.t2.c1, main.t2.c2, main.t4.c4",
}, {
sql: "select * from t2 join t4 using (c1) join t2 as X using (c1)",
expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 join t2 as X where t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1",
expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 on t2.c1 = t4.c1 join t2 as X on t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1",
}, {
sql: "select * from t2 join t4 using (c1), t2 as t2b join t4 as t4b using (c1)",
expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4, t2 as t2b join t4 as t4b where t2b.c1 = t4b.c1 and t2.c1 = t4.c1",
expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4 on t2.c1 = t4.c1, t2 as t2b join t4 as t4b on t2b.c1 = t4b.c1",
}, {
sql: "select * from t1 join t5 using (b)",
expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b",
expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b",
expanded: "main.t1.a, main.t1.b, main.t1.c, main.t5.a",
}, {
sql: "select * from t1 join t5 using (b) having b = 12",
expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b having b = 12",
expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b having b = 12",
}, {
sql: "select 1 from t1 join t5 using (b) having b = 12",
expSQL: "select 1 from t1 join t5 where t1.b = t5.b having t1.b = 12",
expSQL: "select 1 from t1 join t5 on t1.b = t5.b having t1.b = 12",
}, {
sql: "select * from (select 12) as t",
expSQL: "select t.`12` from (select 12 from dual) as t",
Expand Down Expand Up @@ -265,13 +271,16 @@ func TestRewriteJoinUsingColumns(t *testing.T) {
expErr string
}{{
sql: "select 1 from t1 join t2 using (a) where a = 42",
expSQL: "select 1 from t1 join t2 where t1.a = t2.a and t1.a = 42",
expSQL: "select 1 from t1 join t2 on t1.a = t2.a where t1.a = 42",
}, {
sql: "select 1 from t1 join t2 using (a), t3 where a = 42",
expErr: "Column 'a' in field list is ambiguous",
}, {
sql: "select 1 from t1 join t2 using (a), t1 as b join t3 on (a) where a = 42",
expErr: "Column 'a' in field list is ambiguous",
}, {
sql: "select 1 from t1 left join t2 using (a) where a = 42",
expSQL: "select 1 from t1 left join t2 on t1.a = t2.a where t1.a = 42",
}}
for _, tcase := range tcases {
t.Run(tcase.sql, func(t *testing.T) {
Expand Down

0 comments on commit 4cb97bd

Please sign in to comment.