diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index bfd5f413f80..a5a38f09641 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -78,6 +78,7 @@ func (a *analyzer) lateInit() { aliasMapCache: map[*sqlparser.Select]map[string]exprContainer{}, reAnalyze: a.reAnalyze, tables: a.tables, + aggrUDFs: a.si.GetAggregateUDFs(), } a.fk = &fkManager{ binder: a.binder, diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 2e67509c06f..6c796289057 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -41,6 +41,7 @@ type earlyRewriter struct { // have happened, and we are introducing or changing the AST. We invoke it so all parts of the query have been // typed, scoped and bound correctly reAnalyze func(n sqlparser.SQLNode) error + aggrUDFs []string } func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { @@ -500,6 +501,15 @@ func (r *earlyRewriter) rewriteAliasesInGroupBy(node sqlparser.Expr, sel *sqlpar return } +func (at *aggrTracker) isAggregateUDF(name sqlparser.IdentifierCI) bool { + for _, aggrUDF := range at.aggrUDFs { + if name.EqualString(aggrUDF) { + return true + } + } + return false +} + func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { currentScope := r.scoper.currentScope() if currentScope.isUnion { @@ -508,7 +518,10 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars } aliases := r.getAliasMap(sel) - aggrTrack := &aggrTracker{} + aggrTrack := &aggrTracker{ + insideAggr: false, + aggrUDFs: r.aggrUDFs, + } output := sqlparser.CopyOnRewrite(node, aggrTrack.down, func(cursor *sqlparser.CopyOnWriteCursor) { var col *sqlparser.ColName @@ -516,6 +529,11 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars case sqlparser.AggrFunc: aggrTrack.popAggr() return + case *sqlparser.FuncExpr: + if aggrTrack.isAggregateUDF(node.Name) { + aggrTrack.popAggr() + } + return case *sqlparser.ColName: col = node default: @@ -565,14 +583,19 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars type aggrTracker struct { insideAggr bool + aggrUDFs []string } func (at *aggrTracker) down(node, _ sqlparser.SQLNode) bool { - switch node.(type) { + switch node := node.(type) { case *sqlparser.Subquery: return false case sqlparser.AggrFunc: at.insideAggr = true + case *sqlparser.FuncExpr: + if at.isAggregateUDF(node.Name) { + at.insideAggr = true + } } return true diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index c44d6f6307d..719d33b83d7 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -536,6 +536,11 @@ func TestHavingColumnName(t *testing.T) { expSQL: "select id, sum(t1.foo) as foo from t1 having sum(foo) > 1", expDeps: TS0, warning: "Column 'foo' in having clause is ambiguous", + }, { + sql: "select id, sum(t1.foo) as foo from t1 having custom_udf(foo) > 1", + expSQL: "select id, sum(t1.foo) as foo from t1 having custom_udf(foo) > 1", + expDeps: TS0, + warning: "Column 'foo' in having clause is ambiguous", }, { sql: "select id, sum(t1.foo) as XYZ from t1 having sum(XYZ) > 1", expErr: "Invalid use of group function",