Skip to content

Commit

Permalink
feat: use the UDF info when planning
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
systay committed Apr 15, 2024
1 parent 6b25965 commit 28e792a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
1 change: 1 addition & 0 deletions go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func (a *analyzer) lateInit() {
aliasMapCache: map[*sqlparser.Select]map[string]exprContainer{},
reAnalyze: a.reAnalyze,
tables: a.tables,
aggrUDFs: a.si.GetAggregateUDFs(),
}
a.fk = &fkManager{
binder: a.binder,
Expand Down
27 changes: 25 additions & 2 deletions go/vt/vtgate/semantics/early_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type earlyRewriter struct {
// have happened, and we are introducing or changing the AST. We invoke it so all parts of the query have been
// typed, scoped and bound correctly
reAnalyze func(n sqlparser.SQLNode) error
aggrUDFs []string
}

func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error {
Expand Down Expand Up @@ -500,6 +501,15 @@ func (r *earlyRewriter) rewriteAliasesInGroupBy(node sqlparser.Expr, sel *sqlpar
return
}

func (at *aggrTracker) isAggregateUDF(name sqlparser.IdentifierCI) bool {
for _, aggrUDF := range at.aggrUDFs {
if name.EqualString(aggrUDF) {
return true
}
}
return false
}

func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) {
currentScope := r.scoper.currentScope()
if currentScope.isUnion {
Expand All @@ -508,14 +518,22 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars
}

aliases := r.getAliasMap(sel)
aggrTrack := &aggrTracker{}
aggrTrack := &aggrTracker{
insideAggr: false,
aggrUDFs: r.aggrUDFs,
}
output := sqlparser.CopyOnRewrite(node, aggrTrack.down, func(cursor *sqlparser.CopyOnWriteCursor) {
var col *sqlparser.ColName

switch node := cursor.Node().(type) {
case sqlparser.AggrFunc:
aggrTrack.popAggr()
return
case *sqlparser.FuncExpr:
if aggrTrack.isAggregateUDF(node.Name) {
aggrTrack.popAggr()
}
return
case *sqlparser.ColName:
col = node
default:
Expand Down Expand Up @@ -565,14 +583,19 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars

type aggrTracker struct {
insideAggr bool
aggrUDFs []string
}

func (at *aggrTracker) down(node, _ sqlparser.SQLNode) bool {
switch node.(type) {
switch node := node.(type) {
case *sqlparser.Subquery:
return false
case sqlparser.AggrFunc:
at.insideAggr = true
case *sqlparser.FuncExpr:
if at.isAggregateUDF(node.Name) {
at.insideAggr = true
}
}

return true
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/semantics/early_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,11 @@ func TestHavingColumnName(t *testing.T) {
expSQL: "select id, sum(t1.foo) as foo from t1 having sum(foo) > 1",
expDeps: TS0,
warning: "Column 'foo' in having clause is ambiguous",
}, {
sql: "select id, sum(t1.foo) as foo from t1 having custom_udf(foo) > 1",
expSQL: "select id, sum(t1.foo) as foo from t1 having custom_udf(foo) > 1",
expDeps: TS0,
warning: "Column 'foo' in having clause is ambiguous",
}, {
sql: "select id, sum(t1.foo) as XYZ from t1 having sum(XYZ) > 1",
expErr: "Invalid use of group function",
Expand Down

0 comments on commit 28e792a

Please sign in to comment.