Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cte naming conflict #2812

Merged
merged 3 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion enginetest/join_planning_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ with recursive rec(x) as (
)
select * from uv
where u in (select * from rec);`,
types: []plan.JoinType{plan.JoinTypeSemi, plan.JoinTypeHash},
types: []plan.JoinType{plan.JoinTypeSemi, plan.JoinTypeInner},
exp: []sql.Row{{1, 1}},
},
{
Expand Down
5 changes: 2 additions & 3 deletions enginetest/plangen/testdata/spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ plans:
path: enginetest/queries/index_query_plans.go
- name: IntegrationPlanTests
path: enginetest/queries/integration_plans.go
# this suite is throwing a missing table error, despite the table being in the set up
# - name: TpchPlanTests
# path: enginetest/queries/tpch_plans.go
- name: TpchPlanTests
path: enginetest/queries/tpch_plans.go
- name: TpccPlanTests
path: enginetest/queries/tpcc_plans.go
# - name: TpcdsPlanTests
Expand Down
46,544 changes: 22,664 additions & 23,880 deletions enginetest/queries/imdb_plans.go

Large diffs are not rendered by default.

100 changes: 50 additions & 50 deletions enginetest/queries/integration_plans.go

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,12 @@ var SpatialQueryTests = []QueryTest{
}

var QueryTests = []QueryTest{
{
Query: "WITH cte AS (SELECT * FROM xy) SELECT *, (SELECT SUM(x) FROM cte) AS xy FROM cte",
Expected: []sql.Row{
{0, 2, float64(6)}, {1, 0, float64(6)}, {2, 1, float64(6)}, {3, 3, float64(6)},
},
},
{
Query: "select 0 as col1, 1 as col2, 2 as col2 group by col2 having col2 = 1",
Expected: []sql.Row{
Expand Down
646 changes: 350 additions & 296 deletions enginetest/queries/query_plans.go

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions enginetest/queries/tpch_plans.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 0 additions & 8 deletions enginetest/scriptgen/setup/scripts/tpch
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
exec
CREATE DATABASE tpch character set utf8mb4;
----

exec
USE tpch;
----

exec
CREATE TABLE nation ( N_NATIONKEY INTEGER primary key,
N_NAME CHAR(25) NOT NULL,
Expand Down
4 changes: 2 additions & 2 deletions enginetest/scriptgen/setup/setup_data.sg.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions sql/planbuilder/cte.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func (b *Builder) buildCte(inScope *scope, e ast.TableExpr, name string, columns

func (b *Builder) buildRecursiveCte(inScope *scope, union *ast.SetOp, name string, columns []string) *scope {
l, r := splitRecursiveCteUnion(name, union)
scopeMapping := make(map[sql.ColumnId]sql.Expression)
if r == nil {
// not recursive
sqScope := inScope.pushSubquery()
Expand All @@ -104,9 +105,10 @@ func (b *Builder) buildRecursiveCte(inScope *scope, union *ast.SetOp, name strin
c.tableId = tabId
cteScope.cols[i] = c
colset.Add(sql.ColumnId(c.id))
scopeMapping[sql.ColumnId(c.id)] = c.scalarGf()
}

cteScope.node = sq.WithId(tabId).WithColumns(colset)
cteScope.node = sq.WithScopeMapping(scopeMapping).WithId(tabId).WithColumns(colset)
}
return cteScope
}
Expand All @@ -128,7 +130,6 @@ func (b *Builder) buildRecursiveCte(inScope *scope, union *ast.SetOp, name strin
cteScope := leftScope.replace()
tableId := cteScope.addTable(name)
var cols sql.ColSet
scopeMapping := make(map[sql.ColumnId]sql.Expression)
{
rInit = leftScope.node
recSch = make(sql.Schema, len(rInit.Schema()))
Expand All @@ -149,7 +150,6 @@ func (b *Builder) buildRecursiveCte(inScope *scope, union *ast.SetOp, name strin
c.scalar = nil
c.table = name
toId := cteScope.newColumn(c)
scopeMapping[sql.ColumnId(toId)] = c.scalarGf()
cols.Add(sql.ColumnId(toId))
}
b.renameSource(cteScope, name, columns)
Expand Down
123 changes: 95 additions & 28 deletions sql/planbuilder/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,73 @@ func TestPlanBuilder(t *testing.T) {
//rewrite = true

var tests = []planTest{
{
Query: "WITH cte AS (SELECT * FROM xy) SELECT *, (SELECT SUM(x) FROM cte) AS xy FROM cte",
ExpectedPlan: `
Project
├─ columns: [cte.x:7!null, cte.y:8!null, cte.z:9!null, Subquery
│ ├─ cacheable: true
│ ├─ alias-string: select SUM(x) from cte
│ └─ Project
│ ├─ columns: [sum(cte.x):13!null as SUM(x)]
│ └─ GroupBy
│ ├─ select: SUM(cte.x:10!null)
│ ├─ group:
│ └─ SubqueryAlias
│ ├─ name: cte
│ ├─ outerVisibility: false
│ ├─ isLateral: false
│ ├─ cacheable: true
│ ├─ colSet: (10-12)
│ ├─ tableId: 4
│ └─ Project
│ ├─ columns: [xy.x:1!null, xy.y:2!null, xy.z:3!null]
│ └─ Table
│ ├─ name: xy
│ ├─ columns: [x y z]
│ ├─ colSet: (1-3)
│ └─ tableId: 1
│ as xy]
└─ Project
├─ columns: [cte.x:7!null, cte.y:8!null, cte.z:9!null, Subquery
│ ├─ cacheable: true
│ ├─ alias-string: select SUM(x) from cte
│ └─ Project
│ ├─ columns: [sum(cte.x):13!null as SUM(x)]
│ └─ GroupBy
│ ├─ select: SUM(cte.x:10!null)
│ ├─ group:
│ └─ SubqueryAlias
│ ├─ name: cte
│ ├─ outerVisibility: false
│ ├─ isLateral: false
│ ├─ cacheable: true
│ ├─ colSet: (10-12)
│ ├─ tableId: 4
│ └─ Project
│ ├─ columns: [xy.x:1!null, xy.y:2!null, xy.z:3!null]
│ └─ Table
│ ├─ name: xy
│ ├─ columns: [x y z]
│ ├─ colSet: (1-3)
│ └─ tableId: 1
│ as xy]
└─ SubqueryAlias
├─ name: cte
├─ outerVisibility: false
├─ isLateral: false
├─ cacheable: true
├─ colSet: (7-9)
├─ tableId: 3
└─ Project
├─ columns: [xy.x:1!null, xy.y:2!null, xy.z:3!null]
└─ Table
├─ name: xy
├─ columns: [x y z]
├─ colSet: (1-3)
└─ tableId: 1
`,
},
{
Query: "select 0 as col1, 1 as col2, 2 as col2 group by col2 having col2 = 1",
ExpectedPlan: `
Expand Down Expand Up @@ -80,20 +147,20 @@ Project
├─ columns: [1 (tinyint) as x]
└─ Having
├─ GreaterThan
│ ├─ avg(cte.x):4
│ ├─ avg(cte.x):5
│ └─ 0 (tinyint)
└─ Project
├─ columns: [avg(cte.x):4, cte.x:2!null, 1 (tinyint) as x]
├─ columns: [avg(cte.x):5, cte.x:3!null, 1 (tinyint) as x]
└─ GroupBy
├─ select: AVG(cte.x:2!null), cte.x:2!null
├─ select: AVG(cte.x:3!null), cte.x:3!null
├─ group:
└─ SubqueryAlias
├─ name: cte
├─ outerVisibility: false
├─ isLateral: false
├─ cacheable: true
├─ colSet: (2)
├─ tableId: 1
├─ colSet: (3)
├─ tableId: 2
└─ Project
├─ columns: [1 (tinyint) as x]
└─ Table
Expand Down Expand Up @@ -249,8 +316,8 @@ SubqueryAlias
├─ outerVisibility: false
├─ isLateral: false
├─ cacheable: true
├─ colSet: (4,5)
├─ tableId: 2
├─ colSet: (6,7)
├─ tableId: 3
└─ Project
├─ columns: [xy.x:1!null, xy.y:2!null]
└─ Table
Expand Down Expand Up @@ -462,8 +529,8 @@ SubqueryAlias
├─ outerVisibility: false
├─ isLateral: false
├─ cacheable: true
├─ colSet: (2)
├─ tableId: 1
├─ colSet: (3)
├─ tableId: 2
└─ Project
├─ columns: [1 (tinyint)]
└─ Table
Expand All @@ -481,8 +548,8 @@ SubqueryAlias
├─ outerVisibility: false
├─ isLateral: false
├─ cacheable: true
├─ colSet: (4)
├─ tableId: 2
├─ colSet: (9)
├─ tableId: 4
└─ RecursiveCTE
└─ Union distinct
├─ Project
Expand All @@ -493,16 +560,16 @@ SubqueryAlias
│ ├─ colSet: (1-3)
│ └─ tableId: 1
└─ Project
├─ columns: [cte.s:4!null]
├─ columns: [cte.s:5!null]
└─ InnerJoin
├─ Eq
│ ├─ xy.y:6!null
│ └─ cte.s:4!null
│ ├─ xy.y:7!null
│ └─ cte.s:5!null
├─ RecursiveTable(cte)
└─ Table
├─ name: xy
├─ columns: [x y z]
├─ colSet: (5-7)
├─ colSet: (6-8)
└─ tableId: 4
`,
},
Expand Down Expand Up @@ -1587,8 +1654,8 @@ SubqueryAlias
├─ outerVisibility: false
├─ isLateral: false
├─ cacheable: true
├─ colSet: (6,7)
├─ tableId: 4
├─ colSet: (14,15)
├─ tableId: 6
└─ RecursiveCTE
└─ Union all
├─ Project
Expand All @@ -1598,8 +1665,8 @@ SubqueryAlias
│ ├─ outerVisibility: false
│ ├─ isLateral: false
│ ├─ cacheable: true
│ ├─ colSet: (2)
│ ├─ tableId: 1
│ ├─ colSet: (5)
│ ├─ tableId: 3
│ └─ RecursiveCTE
│ └─ Union all
│ ├─ Project
Expand All @@ -1610,27 +1677,27 @@ SubqueryAlias
│ │ ├─ colSet: ()
│ │ └─ tableId: 0
│ └─ Project
│ ├─ columns: [(rt.foo:2!null + 1 (tinyint)) as foo]
│ ├─ columns: [(rt.foo:3!null + 1 (tinyint)) as foo]
│ └─ Filter
│ ├─ LessThan
│ │ ├─ rt.foo:2!null
│ │ ├─ rt.foo:3!null
│ │ └─ 5 (bigint)
│ └─ RecursiveTable(rt)
└─ Project
├─ columns: [(ladder.depth:6!null + 1 (tinyint)) as depth, rt.foo:2!null]
├─ columns: [(ladder.depth:10!null + 1 (tinyint)) as depth, rt.foo:12!null]
└─ Filter
├─ Eq
│ ├─ ladder.foo:7
│ └─ rt.foo:2!null
│ ├─ ladder.foo:11
│ └─ rt.foo:12!null
└─ CrossJoin
├─ RecursiveTable(ladder)
└─ SubqueryAlias
├─ name: rt
├─ outerVisibility: false
├─ isLateral: false
├─ cacheable: true
├─ colSet: (2)
├─ tableId: 1
├─ colSet: (12)
├─ tableId: 4
└─ RecursiveCTE
└─ Union all
├─ Project
Expand All @@ -1641,10 +1708,10 @@ SubqueryAlias
│ ├─ colSet: ()
│ └─ tableId: 0
└─ Project
├─ columns: [(rt.foo:2!null + 1 (tinyint)) as foo]
├─ columns: [(rt.foo:3!null + 1 (tinyint)) as foo]
└─ Filter
├─ LessThan
│ ├─ rt.foo:2!null
│ ├─ rt.foo:3!null
│ └─ 5 (bigint)
└─ RecursiveTable(rt)
`,
Expand Down
15 changes: 11 additions & 4 deletions sql/planbuilder/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,19 @@ func (s *scope) aliasCte(alias string) *scope {
return nil
}
outScope := s.copy()
if _, ok := s.tables[alias]; ok || alias == "" {
return outScope
}

sq, _ := outScope.node.(*plan.SubqueryAlias)

tabId := outScope.addTable(alias)
name := strings.ToLower(outScope.node.(sql.NameableNode).Name())

var tabId sql.TableId
if alias != "" {
tabId = outScope.addTable(alias)
} else {
alias = name
tabId = s.tables[strings.ToLower(name)]
}

outScope.cols = nil
var colSet sql.ColSet
scopeMapping := make(map[sql.ColumnId]sql.Expression)
Expand Down Expand Up @@ -462,6 +468,7 @@ func (s *scope) getCte(name string) *scope {
if checkScope.ctes != nil {
cte, ok := checkScope.ctes[strings.ToLower(name)]
if ok {
cte.tables[name] += 1
return cte
}
}
Expand Down
Loading