diff --git a/go/vt/vtgate/planbuilder/testdata/ddl_cases.json b/go/vt/vtgate/planbuilder/testdata/ddl_cases.json index c6dad1ab946..a645ae71d37 100644 --- a/go/vt/vtgate/planbuilder/testdata/ddl_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/ddl_cases.json @@ -260,7 +260,7 @@ "Name": "main", "Sharded": false }, - "Query": "create view view_a as select col1, col2 from (select col1, col2 from unsharded where id = 1 union select col1, col2 from unsharded where id = 3) as a" + "Query": "create view view_a as select * from (select col1, col2 from unsharded where id = 1 union select col1, col2 from unsharded where id = 3) as a" }, "TablesUsed": [ "main.view_a" diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index 86ee2638515..aa98353a1c8 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -4062,7 +4062,7 @@ "Sharded": false }, "FieldQuery": "select col + 2 as a from unsharded where 1 != 1", - "Query": "select col + 2 as a from unsharded having col + 2 = 42", + "Query": "select col + 2 as a from unsharded having a = 42", "Table": "unsharded" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index 0d3c5e4745a..c9c0acb3cc7 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -1132,7 +1132,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where (u_tbl8.col8) in ::fkc_vals and u_tbl9.col9 is null limit 1 lock in share mode", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where u_tbl9.col9 is null and (u_tbl8.col8) in ::fkc_vals limit 1 lock in share mode", "Table": "u_tbl8, u_tbl9" }, { @@ -1208,7 +1208,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where (u_tbl4.col4) in ::fkc_vals and u_tbl3.col3 is null limit 1 lock in share mode", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals limit 1 lock in share mode", "Table": "u_tbl3, u_tbl4" }, { @@ -1220,7 +1220,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where (u_tbl4.col4) in ::fkc_vals and (u_tbl9.col9) not in (('foo')) and u_tbl4.col4 = u_tbl9.col9 limit 1 lock in share mode", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and ('foo' is null or (u_tbl9.col9) not in (('foo'))) and u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and ('foo' is null or (u_tbl9.col9) not in (('foo'))) limit 1 lock in share mode", "Table": "u_tbl4, u_tbl9" }, { @@ -1297,7 +1297,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where (u_tbl4.col4) in ::fkc_vals and u_tbl3.col3 is null limit 1 lock in share mode", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals limit 1 lock in share mode", "Table": "u_tbl3, u_tbl4" }, { @@ -1309,7 +1309,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where (u_tbl4.col4) in ::fkc_vals and (:v1 is null or (u_tbl9.col9) not in ((:v1))) and u_tbl4.col4 = u_tbl9.col9 limit 1 lock in share mode", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (:v1 is null or (u_tbl9.col9) not in ((:v1))) and u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (:v1 is null or (u_tbl9.col9) not in ((:v1))) limit 1 lock in share mode", "Table": "u_tbl4, u_tbl9" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index d3df710b909..567a3b5f254 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -711,8 +711,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select m1.col from unsharded as m1 join unsharded as m2 where 1 != 1", - "Query": "select m1.col from unsharded as m1 join unsharded as m2", + "FieldQuery": "select m1.col from unsharded as m1 straight_join unsharded as m2 where 1 != 1", + "Query": "select m1.col from unsharded as m1 straight_join unsharded as m2", "Table": "unsharded" }, "TablesUsed": [ @@ -3989,8 +3989,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1 where 1 != 1", - "Query": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1", + "FieldQuery": "select * from unsharded_authoritative as A left join unsharded_authoritative as B using (col1) where 1 != 1", + "Query": "select * from unsharded_authoritative as A left join unsharded_authoritative as B using (col1)", "Table": "unsharded_authoritative" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 8d7c902ded3..3798b0752cb 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -1428,8 +1428,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select col1, col2 from (select col1, col2 from unsharded where 1 != 1 union select col1, col2 from unsharded where 1 != 1) as a where 1 != 1", - "Query": "select col1, col2 from (select col1, col2 from unsharded where id = 1 union select col1, col2 from unsharded where id = 3) as a", + "FieldQuery": "select * from (select col1, col2 from unsharded where 1 != 1 union select col1, col2 from unsharded where 1 != 1) as a where 1 != 1", + "Query": "select * from (select col1, col2 from unsharded where id = 1 union select col1, col2 from unsharded where id = 3) as a", "Table": "unsharded" }, "TablesUsed": [ @@ -2544,7 +2544,7 @@ "Sharded": false }, "FieldQuery": "select 1 from (select col, count(*) as a from unsharded where 1 != 1 group by col) as f left join unsharded as u on f.col = u.id where 1 != 1", - "Query": "select 1 from (select col, count(*) as a from unsharded group by col having count(*) > 0 limit 0, 12) as f left join unsharded as u on f.col = u.id", + "Query": "select 1 from (select col, count(*) as a from unsharded group by col having a > 0 limit 0, 12) as f left join unsharded as u on f.col = u.id", "Table": "unsharded" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index fb4c9adda67..d86fe3f296a 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -882,7 +882,7 @@ "Sharded": false }, "FieldQuery": "(select 1 from unsharded where 1 != 1 union select 1 from unsharded where 1 != 1 union all select 1 from unsharded where 1 != 1) union select 1 from unsharded where 1 != 1 union all select 1 from unsharded where 1 != 1", - "Query": "(select 1 from unsharded union select 1 from unsharded union all select 1 from unsharded order by `1` asc) union select 1 from unsharded union all select 1 from unsharded order by `1` asc", + "Query": "(select 1 from unsharded union select 1 from unsharded union all select 1 from unsharded order by 1 asc) union select 1 from unsharded union all select 1 from unsharded order by 1 asc", "Table": "unsharded" }, "TablesUsed": [ diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index e4d566d7191..2f6f66b5d3d 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -25,47 +25,60 @@ import ( // analyzer controls the flow of the analysis. // It starts the tree walking and controls which part of the analysis sees which parts of the tree type analyzer struct { - scoper *scoper - tables *tableCollector - binder *binder - typer *typer - rewriter *earlyRewriter - sig QuerySignature + scoper *scoper + earlyTables *earlyTableCollector + tables *tableCollector + binder *binder + typer *typer + rewriter *earlyRewriter + sig QuerySignature + si SchemaInformation + currentDb string err error inProjection int - projErr error - unshardedErr error - warning string + projErr error + unshardedErr error + warning string + singleUnshardedKeyspace bool + fullAnalysis bool } // newAnalyzer create the semantic analyzer -func newAnalyzer(dbName string, si SchemaInformation) *analyzer { +func newAnalyzer(dbName string, si SchemaInformation, fullAnalysis bool) *analyzer { // TODO dependencies between these components are a little tangled. We should try to clean up s := newScoper() a := &analyzer{ - scoper: s, - tables: newTableCollector(s, si, dbName), - typer: newTyper(), + scoper: s, + earlyTables: newEarlyTableCollector(si, dbName), + typer: newTyper(), + si: si, + currentDb: dbName, + fullAnalysis: fullAnalysis, } s.org = a - a.tables.org = a + return a +} - b := newBinder(s, a, a.tables, a.typer) - a.binder = b +func (a *analyzer) lateInit() { + a.tables = a.earlyTables.newTableCollector(a.scoper, a) + a.binder = newBinder(a.scoper, a, a.tables, a.typer) + a.scoper.binder = a.binder a.rewriter = &earlyRewriter{ - scoper: s, - binder: b, + scoper: a.scoper, + binder: a.binder, expandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, } - s.binder = b - return a } // Analyze analyzes the parsed query. func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) { - analyzer := newAnalyzer(currentDb, newSchemaInfo(si)) + return analyseAndGetSemTable(statement, currentDb, si, false) +} + +func analyseAndGetSemTable(statement sqlparser.Statement, currentDb string, si SchemaInformation, fullAnalysis bool) (*SemTable, error) { + analyzer := newAnalyzer(currentDb, newSchemaInfo(si), fullAnalysis) // Analysis for initial scope err := analyzer.analyze(statement) @@ -79,7 +92,7 @@ func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformati // AnalyzeStrict analyzes the parsed query, and fails the analysis for any possible errors func AnalyzeStrict(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) { - st, err := Analyze(statement, currentDb, si) + st, err := analyseAndGetSemTable(statement, currentDb, si, true) if err != nil { return nil, err } @@ -103,6 +116,27 @@ func (a *analyzer) newSemTable( if isCommented { comments = commentedStmt.GetParsedComments() } + + if a.singleUnshardedKeyspace { + return &SemTable{ + Tables: a.earlyTables.Tables, + Comments: comments, + Warning: a.warning, + Collation: coll, + ExprTypes: map[sqlparser.Expr]Type{}, + NotSingleRouteErr: a.projErr, + NotUnshardedErr: a.unshardedErr, + Recursive: ExprDependencies{}, + Direct: ExprDependencies{}, + ColumnEqualities: map[columnName][]sqlparser.Expr{}, + ExpandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, + columns: map[*sqlparser.Union]sqlparser.SelectExprs{}, + comparator: nil, + StatementIDs: a.scoper.statementIDs, + QuerySignature: QuerySignature{}, + }, nil + } + columns := map[*sqlparser.Union]sqlparser.SelectExprs{} for union, info := range a.tables.unionInfo { columns[union] = info.exprs @@ -280,10 +314,43 @@ func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, } func (a *analyzer) analyze(statement sqlparser.Statement) error { + _ = sqlparser.Rewrite(statement, nil, a.earlyUp) + if a.err != nil { + return a.err + } + + if a.canShortCut(statement) { + return nil + } + + a.lateInit() + _ = sqlparser.Rewrite(statement, a.analyzeDown, a.analyzeUp) return a.err } +// canShortCut checks if we are dealing with a single unsharded keyspace and no tables that have managed foreign keys +// if so, we can stop the analyzer early +func (a *analyzer) canShortCut(statement sqlparser.Statement) bool { + if a.fullAnalysis { + return false + } + ks, _ := singleUnshardedKeyspace(a.earlyTables.Tables) + if ks == nil { + return false + } + + a.singleUnshardedKeyspace = !sqlparser.IsDMLStatement(statement) + return a.singleUnshardedKeyspace +} + +// earlyUp collects tables in the query, so we can check +// if this a single unsharded query we are dealing with +func (a *analyzer) earlyUp(cursor *sqlparser.Cursor) bool { + a.earlyTables.up(cursor) + return true +} + func (a *analyzer) shouldContinue() bool { return a.err == nil } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index fc372909f7c..6512dbf0aed 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -28,7 +28,7 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" ) -var T0 TableSet +var NoTables TableSet var ( // Just here to make outputs more readable @@ -586,7 +586,7 @@ func TestOrderByBindingTable(t *testing.T) { TS0, }, { "select 1 as c from tabl order by c", - T0, + NoTables, }, { "select name, name from t1, t2 order by name", TS1, @@ -664,7 +664,7 @@ func TestGroupByBinding(t *testing.T) { TS0, }, { "select 1 as c from tabl group by c", - T0, + NoTables, }, { "select t1.id from t1, t2 group by id", TS0, @@ -713,13 +713,13 @@ func TestHavingBinding(t *testing.T) { TS0, }, { "select col from tabl having 1 = 1", - T0, + NoTables, }, { "select col as c from tabl having c = 1", TS0, }, { "select 1 as c from tabl having c = 1", - T0, + NoTables, }, { "select t1.id from t1, t2 having id = 1", TS0, @@ -877,109 +877,6 @@ func TestUnionWithOrderBy(t *testing.T) { assert.Equal(t, TS1, d2) } -func TestScopingWDerivedTables(t *testing.T) { - queries := []struct { - query string - errorMessage string - recursiveExpectation TableSet - expectation TableSet - }{ - { - query: "select id from (select x as id from user) as t", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id from (select foo as id from user) as t", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id from (select foo as id from (select x as foo from user) as c) as t", - recursiveExpectation: TS0, - expectation: TS2, - }, { - query: "select t.id from (select foo as id from user) as t", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select t.id2 from (select foo as id from user) as t", - errorMessage: "column 't.id2' not found", - }, { - query: "select id from (select 42 as id) as t", - recursiveExpectation: T0, - expectation: TS1, - }, { - query: "select t.id from (select 42 as id) as t", - recursiveExpectation: T0, - expectation: TS1, - }, { - query: "select ks.t.id from (select 42 as id) as t", - errorMessage: "column 'ks.t.id' not found", - }, { - query: "select * from (select id, id from user) as t", - errorMessage: "Duplicate column name 'id'", - }, { - query: "select t.baz = 1 from (select id as baz from user) as t", - expectation: TS1, - recursiveExpectation: TS0, - }, { - query: "select t.id from (select * from user, music) as t", - expectation: TS2, - recursiveExpectation: MergeTableSets(TS0, TS1), - }, { - query: "select t.id from (select * from user, music) as t order by t.id", - expectation: TS2, - recursiveExpectation: MergeTableSets(TS0, TS1), - }, { - query: "select t.id from (select * from user) as t join user as u on t.id = u.id", - expectation: TS1, - recursiveExpectation: TS0, - }, { - query: "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", - expectation: TS3, - recursiveExpectation: TS1, - }, { - query: "select uu.test from (select id from t1) uu", - errorMessage: "column 'uu.test' not found", - }, { - query: "select uu.id from (select id as col from t1) uu", - errorMessage: "column 'uu.id' not found", - }, { - query: "select uu.id from (select id as col from t1) uu", - errorMessage: "column 'uu.id' not found", - }, { - query: "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", - expectation: TS1, - recursiveExpectation: TS0, - }, { - query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", - expectation: T0, - recursiveExpectation: T0, - }} - for _, query := range queries { - t.Run(query.query, func(t *testing.T) { - parse, err := sqlparser.Parse(query.query) - require.NoError(t, err) - st, err := Analyze(parse, "user", &FakeSI{ - Tables: map[string]*vindexes.Table{ - "t": {Name: sqlparser.NewIdentifierCS("t")}, - }, - }) - - switch { - case query.errorMessage != "" && err != nil: - require.EqualError(t, err, query.errorMessage) - case query.errorMessage != "": - require.EqualError(t, st.NotUnshardedErr, query.errorMessage) - default: - require.NoError(t, err) - sel := parse.(*sqlparser.Select) - assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(extract(sel, 0)), "RecursiveDeps") - assert.Equal(t, query.expectation, st.DirectDeps(extract(sel, 0)), "DirectDeps") - } - }) - } -} - func TestJoinPredicateDependencies(t *testing.T) { // create table t() // create table t1(id bigint) @@ -995,15 +892,15 @@ func TestJoinPredicateDependencies(t *testing.T) { directExpect: MergeTableSets(TS0, TS1), }, { query: "select 1 from (select * from t1) x join t2 on x.id = t2.uid", - recursiveExpect: MergeTableSets(TS0, TS2), + recursiveExpect: MergeTableSets(TS0, TS1), directExpect: MergeTableSets(TS1, TS2), }, { query: "select 1 from (select id from t1) x join t2 on x.id = t2.uid", - recursiveExpect: MergeTableSets(TS0, TS2), + recursiveExpect: MergeTableSets(TS0, TS1), directExpect: MergeTableSets(TS1, TS2), }, { query: "select 1 from (select id from t1 union select id from t) x join t2 on x.id = t2.uid", - recursiveExpect: MergeTableSets(TS0, TS1, TS3), + recursiveExpect: MergeTableSets(TS0, TS1, TS2), directExpect: MergeTableSets(TS2, TS3), }} for _, query := range queries { @@ -1022,107 +919,6 @@ func TestJoinPredicateDependencies(t *testing.T) { } } -func TestDerivedTablesOrderClause(t *testing.T) { - queries := []struct { - query string - recursiveExpectation TableSet - expectation TableSet - }{{ - query: "select 1 from (select id from user) as t order by id", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id from (select id from user) as t order by id", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id from (select id from user) as t order by t.id", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id as foo from (select id from user) as t order by foo", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar from (select id as bar from user) as t order by bar", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar as foo from (select id as bar from user) as t order by bar", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar as foo from (select id as bar from user) as t order by foo", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar as foo from (select id as bar, oo from user) as t order by oo", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar as foo from (select id, oo from user) as t(bar,oo) order by bar", - recursiveExpectation: TS0, - expectation: TS1, - }} - si := &FakeSI{Tables: map[string]*vindexes.Table{"t": {Name: sqlparser.NewIdentifierCS("t")}}} - for _, query := range queries { - t.Run(query.query, func(t *testing.T) { - parse, err := sqlparser.Parse(query.query) - require.NoError(t, err) - - st, err := Analyze(parse, "user", si) - require.NoError(t, err) - - sel := parse.(*sqlparser.Select) - assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(sel.OrderBy[0].Expr), "RecursiveDeps") - assert.Equal(t, query.expectation, st.DirectDeps(sel.OrderBy[0].Expr), "DirectDeps") - - }) - } -} - -func TestScopingWComplexDerivedTables(t *testing.T) { - queries := []struct { - query string - errorMessage string - rightExpectation TableSet - leftExpectation TableSet - }{ - { - query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", - rightExpectation: TS0, - leftExpectation: TS0, - }, - { - query: "select 1 from user.user uu where exists (select 1 from user.user as uu where exists (select 1 from (select 1 from user.t1) uu where uu.user_id = uu.id))", - rightExpectation: TS1, - leftExpectation: TS1, - }, - } - for _, query := range queries { - t.Run(query.query, func(t *testing.T) { - parse, err := sqlparser.Parse(query.query) - require.NoError(t, err) - st, err := Analyze(parse, "user", &FakeSI{ - Tables: map[string]*vindexes.Table{ - "t": {Name: sqlparser.NewIdentifierCS("t")}, - }, - }) - if query.errorMessage != "" { - require.EqualError(t, err, query.errorMessage) - } else { - require.NoError(t, err) - sel := parse.(*sqlparser.Select) - comparisonExpr := sel.Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ComparisonExpr) - left := comparisonExpr.Left - right := comparisonExpr.Right - assert.Equal(t, query.leftExpectation, st.RecursiveDeps(left), "Left RecursiveDeps") - assert.Equal(t, query.rightExpectation, st.RecursiveDeps(right), "Right RecursiveDeps") - } - }) - } -} - func TestScopingWVindexTables(t *testing.T) { queries := []struct { query string @@ -1242,36 +1038,6 @@ func BenchmarkAnalyzeSubQueries(b *testing.B) { } } -func BenchmarkAnalyzeDerivedTableQueries(b *testing.B) { - queries := []string{ - "select id from (select x as id from user) as t", - "select id from (select foo as id from user) as t", - "select id from (select foo as id from (select x as foo from user) as c) as t", - "select t.id from (select foo as id from user) as t", - "select t.id2 from (select foo as id from user) as t", - "select id from (select 42 as id) as t", - "select t.id from (select 42 as id) as t", - "select ks.t.id from (select 42 as id) as t", - "select * from (select id, id from user) as t", - "select t.baz = 1 from (select id as baz from user) as t", - "select t.id from (select * from user, music) as t", - "select t.id from (select * from user, music) as t order by t.id", - "select t.id from (select * from user) as t join user as u on t.id = u.id", - "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", - "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", - "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", - } - - for i := 0; i < b.N; i++ { - for _, query := range queries { - parse, err := sqlparser.Parse(query) - require.NoError(b, err) - - _, _ = Analyze(parse, "d", fakeSchemaInfo()) - } - } -} - func BenchmarkAnalyzeHavingQueries(b *testing.B) { queries := []string{ "select col from tabl having col = 1", @@ -1364,43 +1130,30 @@ func TestSingleUnshardedKeyspace(t *testing.T) { tests := []struct { query string unsharded *vindexes.Keyspace - tables []*vindexes.Table }{ { query: "select 1 from t, t1", unsharded: nil, // both tables are unsharded, but from different keyspaces - tables: nil, }, { query: "select 1 from t2", unsharded: nil, - tables: nil, }, { query: "select 1 from t, t2", unsharded: nil, - tables: nil, }, { query: "select 1 from t as A, t as B", - unsharded: ks1, - tables: []*vindexes.Table{ - {Keyspace: ks1, Name: sqlparser.NewIdentifierCS("t")}, - {Keyspace: ks1, Name: sqlparser.NewIdentifierCS("t")}, - }, + unsharded: unsharded, }, { query: "insert into t select * from t", - unsharded: ks1, - tables: []*vindexes.Table{ - {Keyspace: ks1, Name: sqlparser.NewIdentifierCS("t")}, - {Keyspace: ks1, Name: sqlparser.NewIdentifierCS("t")}, - }, + unsharded: unsharded, }, } for _, test := range tests { t.Run(test.query, func(t *testing.T) { _, semTable := parseAndAnalyze(t, test.query, "d") - queryIsUnsharded, tables := semTable.SingleUnshardedKeyspace() + queryIsUnsharded, _ := semTable.SingleUnshardedKeyspace() assert.Equal(t, test.unsharded, queryIsUnsharded) - assert.Equal(t, test.tables, tables) }) } } @@ -1481,13 +1234,13 @@ func TestScopingSubQueryJoinClause(t *testing.T) { } -var ks1 = &vindexes.Keyspace{ - Name: "ks1", +var unsharded = &vindexes.Keyspace{ + Name: "unsharded", Sharded: false, } var ks2 = &vindexes.Keyspace{ Name: "ks2", - Sharded: false, + Sharded: true, } var ks3 = &vindexes.Keyspace{ Name: "ks3", @@ -1498,29 +1251,52 @@ var ks3 = &vindexes.Keyspace{ // create table t1(id bigint) // create table t2(uid bigint, name varchar(255)) func fakeSchemaInfo() *FakeSI { - cols1 := []vindexes.Column{{ - Name: sqlparser.NewIdentifierCI("id"), - Type: querypb.Type_INT64, - }} - cols2 := []vindexes.Column{{ - Name: sqlparser.NewIdentifierCI("uid"), - Type: querypb.Type_INT64, - }, { - Name: sqlparser.NewIdentifierCI("name"), - Type: querypb.Type_VARCHAR, - CollationName: "utf8_bin", - }, { - Name: sqlparser.NewIdentifierCI("textcol"), - Type: querypb.Type_VARCHAR, - CollationName: "big5_bin", - }} - si := &FakeSI{ Tables: map[string]*vindexes.Table{ - "t": {Name: sqlparser.NewIdentifierCS("t"), Keyspace: ks1}, - "t1": {Name: sqlparser.NewIdentifierCS("t1"), Columns: cols1, ColumnListAuthoritative: true, Keyspace: ks2}, - "t2": {Name: sqlparser.NewIdentifierCS("t2"), Columns: cols2, ColumnListAuthoritative: true, Keyspace: ks3}, + "t": tableT(), + "t1": tableT1(), + "t2": tableT2(), }, } return si } + +func tableT() *vindexes.Table { + return &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t"), + Keyspace: unsharded, + } +} +func tableT1() *vindexes.Table { + return &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("id"), + Type: querypb.Type_INT64, + }}, + ColumnListAuthoritative: true, + ColumnVindexes: []*vindexes.ColumnVindex{ + {Name: "id_vindex"}, + }, + Keyspace: ks2, + } +} +func tableT2() *vindexes.Table { + return &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t2"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("uid"), + Type: querypb.Type_INT64, + }, { + Name: sqlparser.NewIdentifierCI("name"), + Type: querypb.Type_VARCHAR, + CollationName: "utf8_bin", + }, { + Name: sqlparser.NewIdentifierCI("textcol"), + Type: querypb.Type_VARCHAR, + CollationName: "big5_bin", + }}, + ColumnListAuthoritative: true, + Keyspace: ks3, + } +} diff --git a/go/vt/vtgate/semantics/derived_test.go b/go/vt/vtgate/semantics/derived_test.go new file mode 100644 index 00000000000..509c9925fb1 --- /dev/null +++ b/go/vt/vtgate/semantics/derived_test.go @@ -0,0 +1,265 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package semantics + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +func TestScopingWDerivedTables(t *testing.T) { + queries := []struct { + query string + errorMessage string + recursiveDeps TableSet + directDeps TableSet + }{ + { + query: "select id from (select x as id from user) as t", + recursiveDeps: TS0, + directDeps: TS1, + }, { + query: "select id from (select foo as id from user) as t", + recursiveDeps: TS0, + directDeps: TS1, + }, { + query: "select id from (select foo as id from (select x as foo from user) as c) as t", + recursiveDeps: TS0, + directDeps: TS2, + }, { + query: "select t.id from (select foo as id from user) as t", + recursiveDeps: TS0, + directDeps: TS1, + }, { + query: "select t.id2 from (select foo as id from user) as t", + errorMessage: "column 't.id2' not found", + }, { + query: "select id from (select 42 as id) as t", + recursiveDeps: NoTables, + directDeps: TS1, + }, { + query: "select t.id from (select 42 as id) as t", + recursiveDeps: NoTables, + directDeps: TS1, + }, { + query: "select ks.t.id from (select 42 as id) as t", + errorMessage: "column 'ks.t.id' not found", + }, { + query: "select * from (select id, id from user) as t", + errorMessage: "Duplicate column name 'id'", + }, { + query: "select t.baz = 1 from (select id as baz from user) as t", + directDeps: TS1, + recursiveDeps: TS0, + }, { + query: "select t.id from (select * from user, music) as t", + directDeps: TS2, + recursiveDeps: MergeTableSets(TS0, TS1), + }, { + query: "select t.id from (select * from user, music) as t order by t.id", + directDeps: TS2, + recursiveDeps: MergeTableSets(TS0, TS1), + }, { + query: "select t.id from (select * from user) as t join user as u on t.id = u.id", + directDeps: TS2, + recursiveDeps: TS0, + }, { + query: "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", + directDeps: TS3, + recursiveDeps: TS1, + }, { + query: "select uu.test from (select id from t1) uu", + errorMessage: "column 'uu.test' not found", + }, { + query: "select uu.id from (select id as col from t1) uu", + errorMessage: "column 'uu.id' not found", + }, { + query: "select uu.id from (select id as col from t1) uu", + errorMessage: "column 'uu.id' not found", + }, { + query: "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", + directDeps: TS2, + recursiveDeps: TS0, + }, { + query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", + directDeps: NoTables, + recursiveDeps: NoTables, + }, { + query: "select uu.count from (select count(*) as `count` from t1) uu", + directDeps: TS1, + recursiveDeps: TS0, + }} + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + parse, err := sqlparser.Parse(query.query) + require.NoError(t, err) + st, err := Analyze(parse, "user", &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t": {Name: sqlparser.NewIdentifierCS("t"), Keyspace: ks2}, + }, + }) + + switch { + case query.errorMessage != "" && err != nil: + require.EqualError(t, err, query.errorMessage) + case query.errorMessage != "": + require.EqualError(t, st.NotUnshardedErr, query.errorMessage) + default: + require.NoError(t, err) + sel := parse.(*sqlparser.Select) + assert.Equal(t, query.recursiveDeps, st.RecursiveDeps(extract(sel, 0)), "RecursiveDeps") + assert.Equal(t, query.directDeps, st.DirectDeps(extract(sel, 0)), "DirectDeps") + } + }) + } +} + +func TestDerivedTablesOrderClause(t *testing.T) { + queries := []struct { + query string + recursiveExpectation TableSet + expectation TableSet + }{{ + query: "select 1 from (select id from user) as t order by id", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select id from (select id from user) as t order by id", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select id from (select id from user) as t order by t.id", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select id as foo from (select id from user) as t order by foo", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar from (select id as bar from user) as t order by bar", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar as foo from (select id as bar from user) as t order by bar", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar as foo from (select id as bar from user) as t order by foo", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar as foo from (select id as bar, oo from user) as t order by oo", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar as foo from (select id, oo from user) as t(bar,oo) order by bar", + recursiveExpectation: TS0, + expectation: TS1, + }} + si := &FakeSI{Tables: map[string]*vindexes.Table{"t": {Name: sqlparser.NewIdentifierCS("t")}}} + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + parse, err := sqlparser.Parse(query.query) + require.NoError(t, err) + + st, err := Analyze(parse, "user", si) + require.NoError(t, err) + + sel := parse.(*sqlparser.Select) + assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(sel.OrderBy[0].Expr), "RecursiveDeps") + assert.Equal(t, query.expectation, st.DirectDeps(sel.OrderBy[0].Expr), "DirectDeps") + + }) + } +} + +func TestScopingWComplexDerivedTables(t *testing.T) { + queries := []struct { + query string + errorMessage string + rightExpectation TableSet + leftExpectation TableSet + }{ + { + query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", + rightExpectation: TS0, + leftExpectation: TS0, + }, + { + query: "select 1 from user.user uu where exists (select 1 from user.user as uu where exists (select 1 from (select 1 from user.t1) uu where uu.user_id = uu.id))", + rightExpectation: TS1, + leftExpectation: TS1, + }, + } + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + parse, err := sqlparser.Parse(query.query) + require.NoError(t, err) + st, err := Analyze(parse, "user", &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t": {Name: sqlparser.NewIdentifierCS("t")}, + }, + }) + if query.errorMessage != "" { + require.EqualError(t, err, query.errorMessage) + } else { + require.NoError(t, err) + sel := parse.(*sqlparser.Select) + comparisonExpr := sel.Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ComparisonExpr) + left := comparisonExpr.Left + right := comparisonExpr.Right + assert.Equal(t, query.leftExpectation, st.RecursiveDeps(left), "Left RecursiveDeps") + assert.Equal(t, query.rightExpectation, st.RecursiveDeps(right), "Right RecursiveDeps") + } + }) + } +} + +func BenchmarkAnalyzeDerivedTableQueries(b *testing.B) { + queries := []string{ + "select id from (select x as id from user) as t", + "select id from (select foo as id from user) as t", + "select id from (select foo as id from (select x as foo from user) as c) as t", + "select t.id from (select foo as id from user) as t", + "select t.id2 from (select foo as id from user) as t", + "select id from (select 42 as id) as t", + "select t.id from (select 42 as id) as t", + "select ks.t.id from (select 42 as id) as t", + "select * from (select id, id from user) as t", + "select t.baz = 1 from (select id as baz from user) as t", + "select t.id from (select * from user, music) as t", + "select t.id from (select * from user, music) as t order by t.id", + "select t.id from (select * from user) as t join user as u on t.id = u.id", + "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", + "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", + "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", + } + + for i := 0; i < b.N; i++ { + for _, query := range queries { + parse, err := sqlparser.Parse(query) + require.NoError(b, err) + + _, _ = Analyze(parse, "d", fakeSchemaInfo()) + } + } +} diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 771d5081c02..ffb8e441348 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -32,7 +32,7 @@ import ( func TestExpandStar(t *testing.T) { ks := &vindexes.Keyspace{ Name: "main", - Sharded: false, + Sharded: true, } schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{ @@ -479,7 +479,7 @@ func TestSemTableDependenciesAfterExpandStar(t *testing.T) { func TestRewriteNot(t *testing.T) { ks := &vindexes.Keyspace{ Name: "main", - Sharded: false, + Sharded: true, } schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{ @@ -531,7 +531,7 @@ func TestRewriteNot(t *testing.T) { func TestConstantFolding(t *testing.T) { ks := &vindexes.Keyspace{ Name: "main", - Sharded: false, + Sharded: true, } schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{ diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 13fd11d961f..1ede4731edd 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -439,11 +439,8 @@ func (st *SemTable) ColumnLookup(col *sqlparser.ColName) (int, error) { return 0, columnNotSupportedErr } -// SingleUnshardedKeyspace returns the single keyspace if all tables in the query are in the same, unsharded keyspace -func (st *SemTable) SingleUnshardedKeyspace() (*vindexes.Keyspace, []*vindexes.Table) { - var ks *vindexes.Keyspace - var tables []*vindexes.Table - for _, table := range st.Tables { +func singleUnshardedKeyspace(in []TableInfo) (ks *vindexes.Keyspace, tables []*vindexes.Table) { + for _, table := range in { vindexTable := table.GetVindexTable() if vindexTable == nil { @@ -484,7 +481,12 @@ func (st *SemTable) SingleUnshardedKeyspace() (*vindexes.Keyspace, []*vindexes.T } tables = append(tables, vindexTable) } - return ks, tables + return +} + +// SingleUnshardedKeyspace returns the single keyspace if all tables in the query are in the same, unsharded keyspace +func (st *SemTable) SingleUnshardedKeyspace() (*vindexes.Keyspace, []*vindexes.Table) { + return singleUnshardedKeyspace(st.Tables) } // EqualsExpr compares two expressions using the semantic analysis information. diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index d6fd4c6efd6..c6af502cb22 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -33,17 +33,71 @@ type tableCollector struct { currentDb string org originable unionInfo map[*sqlparser.Union]unionInfo + done map[*sqlparser.AliasedTableExpr]TableInfo } -func newTableCollector(scoper *scoper, si SchemaInformation, currentDb string) *tableCollector { +type earlyTableCollector struct { + si SchemaInformation + currentDb string + Tables []TableInfo + done map[*sqlparser.AliasedTableExpr]TableInfo + withTables map[sqlparser.IdentifierCS]any +} + +func newEarlyTableCollector(si SchemaInformation, currentDb string) *earlyTableCollector { + return &earlyTableCollector{ + si: si, + currentDb: currentDb, + done: map[*sqlparser.AliasedTableExpr]TableInfo{}, + withTables: map[sqlparser.IdentifierCS]any{}, + } +} + +func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) { + switch node := cursor.Node().(type) { + case *sqlparser.AliasedTableExpr: + etc.visitAliasedTableExpr(node) + } +} + +func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTableExpr) { + tbl, ok := aet.Expr.(sqlparser.TableName) + if !ok { + return + } + etc.handleTableName(tbl, aet) +} + +func (etc *earlyTableCollector) newTableCollector(scoper *scoper, org originable) *tableCollector { return &tableCollector{ + Tables: etc.Tables, scoper: scoper, - si: si, - currentDb: currentDb, + si: etc.si, + currentDb: etc.currentDb, unionInfo: map[*sqlparser.Union]unionInfo{}, + done: etc.done, + org: org, } } +func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sqlparser.AliasedTableExpr) { + if tbl.Qualifier.IsEmpty() { + _, isCTE := etc.withTables[tbl.Name] + if isCTE { + // no need to handle these tables here, we wait for the late phase instead + return + } + } + tableInfo, err := getTableInfo(aet, tbl, etc.si, etc.currentDb) + if err != nil { + // this could just be a CTE that we haven't processed, so we'll give it the benefit of the doubt for now + return + } + + etc.done[aet] = tableInfo + etc.Tables = append(etc.Tables, tableInfo) +} + func (tc *tableCollector) up(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case *sqlparser.AliasedTableExpr: @@ -100,26 +154,43 @@ func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr } case sqlparser.TableName: - var tbl *vindexes.Table - var vindex vindexes.Vindex - isInfSchema := sqlparser.SystemSchema(t.Qualifier.String()) - var err error - tbl, vindex, _, _, _, err = tc.si.FindTableOrVindex(t) - if err != nil && !isInfSchema { - // if we are dealing with a system table, it might not be available in the vschema, but that is OK + return tc.handleTableName(node, t) + } + return nil +} + +func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sqlparser.TableName) (err error) { + var tableInfo TableInfo + var found bool + + tableInfo, found = tc.done[node] + if !found { + tableInfo, err = getTableInfo(node, t, tc.si, tc.currentDb) + if err != nil { return err } - if tbl == nil && vindex != nil { - tbl = newVindexTable(t.Name) - } + tc.Tables = append(tc.Tables, tableInfo) + } - scope := tc.scoper.currentScope() - tableInfo := tc.createTable(t, node, tbl, isInfSchema, vindex) + scope := tc.scoper.currentScope() + return scope.addTable(tableInfo) +} - tc.Tables = append(tc.Tables, tableInfo) - return scope.addTable(tableInfo) +func getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, si SchemaInformation, currentDb string) (TableInfo, error) { + var tbl *vindexes.Table + var vindex vindexes.Vindex + isInfSchema := sqlparser.SystemSchema(t.Qualifier.String()) + var err error + tbl, vindex, _, _, _, err = si.FindTableOrVindex(t) + if err != nil && !isInfSchema { + // if we are dealing with a system table, it might not be available in the vschema, but that is OK + return nil, err } - return nil + if tbl == nil && vindex != nil { + tbl = newVindexTable(t.Name) + } + + return createTable(t, node, tbl, isInfSchema, vindex, currentDb), nil } func (tc *tableCollector) addSelectDerivedTable(sel *sqlparser.Select, node *sqlparser.AliasedTableExpr) error { @@ -207,12 +278,13 @@ func (tc *tableCollector) tableInfoFor(id TableSet) (TableInfo, error) { return tc.Tables[offset], nil } -func (tc *tableCollector) createTable( +func createTable( t sqlparser.TableName, alias *sqlparser.AliasedTableExpr, tbl *vindexes.Table, isInfSchema bool, vindex vindexes.Vindex, + currentDb string, ) TableInfo { table := &RealTable{ tableName: alias.As.String(), @@ -224,7 +296,7 @@ func (tc *tableCollector) createTable( if alias.As.IsEmpty() { dbName := t.Qualifier.String() if dbName == "" { - dbName = tc.currentDb + dbName = currentDb } table.dbName = dbName