Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gen4 Planner: support aggregate UDFs #15710

Merged
merged 9 commits into from
Apr 17, 2024
1 change: 1 addition & 0 deletions go/mysql/sqlerror/sql_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{
vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState},
vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation},
vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState},
vterrors.AggregateMustPushDown: {num: ERNotSupportedYet, state: SSUnknownSQLState},
}

func getStateToMySQLState(state vterrors.State) mysqlCode {
Expand Down
4 changes: 4 additions & 0 deletions go/test/vschemawrapper/vschema_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ func (vw *VSchemaWrapper) KeyspaceError(keyspace string) error {
return nil
}

func (vw *VSchemaWrapper) GetAggregateUDFs() (udfs []string) {
return vw.V.GetAggregateUDFs()
}

func (vw *VSchemaWrapper) GetForeignKeyChecksState() *bool {
return vw.ForeignKeyChecksState
}
Expand Down
4 changes: 4 additions & 0 deletions go/vt/schemadiff/semantics.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func (si *declarativeSchemaInformation) KeyspaceError(keyspace string) error {
return nil
}

func (si *declarativeSchemaInformation) GetAggregateUDFs() []string {
return nil
}

func (si *declarativeSchemaInformation) GetForeignKeyChecksState() *bool {
return nil
}
Expand Down
10 changes: 10 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,16 @@ func (node IdentifierCI) EqualString(str string) bool {
return node.Lowered() == strings.ToLower(str)
}

// EqualsAnyString returns true if any of these strings match
func (node IdentifierCI) EqualsAnyString(str []string) bool {
for _, s := range str {
if node.EqualString(s) {
return true
}
}
return false
}

// MarshalJSON marshals into JSON.
func (node IdentifierCI) MarshalJSON() ([]byte, error) {
return json.Marshal(node.val)
Expand Down
3 changes: 2 additions & 1 deletion go/vt/vterrors/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ var (
VT03029 = errorWithState("VT03029", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "column count does not match value count with the row for vindex '%s'", "The number of columns you want to insert do not match the number of columns of your SELECT query.")
VT03030 = errorWithState("VT03030", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "lookup column count does not match value count with the row (columns, count): (%v, %d)", "The number of columns you want to insert do not match the number of columns of your SELECT query.")
VT03031 = errorWithoutState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, "EXPLAIN is only supported for single keyspace", "EXPLAIN has to be sent down as a single query to the underlying MySQL, and this is not possible if it uses tables from multiple keyspaces")
VT03032 = errorWithState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.")
VT03032 = errorWithState("VT03032", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.")
VT03033 = errorWithState("VT03033", vtrpcpb.Code_INVALID_ARGUMENT, AggregateMustPushDown, "aggregate user-defined function %s must be pushed down to mysql", "The aggregate user-defined function must be pushed down to mysql and can't be evaluated on the vtgate. The query contains aggregation that can't be fully pushed down to MySQL.")
systay marked this conversation as resolved.
Show resolved Hide resolved
systay marked this conversation as resolved.
Show resolved Hide resolved

VT05001 = errorWithState("VT05001", vtrpcpb.Code_NOT_FOUND, DbDropExists, "cannot drop database '%s'; database does not exists", "The given database does not exist; Vitess cannot drop it.")
VT05002 = errorWithState("VT05002", vtrpcpb.Code_NOT_FOUND, BadDb, "cannot alter database '%s'; unknown database", "The given database does not exist; Vitess cannot alter it.")
Expand Down
1 change: 1 addition & 0 deletions go/vt/vterrors/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const (
WrongArguments
BadNullError
InvalidGroupFuncUse
AggregateMustPushDown

// failed precondition
NoDB
Expand Down
16 changes: 3 additions & 13 deletions go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,10 @@ const (
AggregateCountStar
AggregateGroupConcat
AggregateAvg
AggregateUDF // This is an opcode used to represent UDFs
_NumOfOpCodes // This line must be last of the opcodes!
)

var (
// OpcodeType keeps track of the known output types for different aggregate functions
OpcodeType = map[AggregateOpcode]querypb.Type{
AggregateCountDistinct: sqltypes.Int64,
AggregateCount: sqltypes.Int64,
AggregateCountStar: sqltypes.Int64,
AggregateSumDistinct: sqltypes.Decimal,
AggregateSum: sqltypes.Decimal,
AggregateAvg: sqltypes.Decimal,
AggregateGtid: sqltypes.VarChar,
}
)

// SupportedAggregates maps the list of supported aggregate
// functions to their opcodes.
var SupportedAggregates = map[string]AggregateOpcode{
Expand Down Expand Up @@ -166,6 +154,8 @@ func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type {
return sqltypes.Int64
case AggregateGtid:
return sqltypes.VarChar
case AggregateUDF:
return sqltypes.Unknown
default:
panic(code.String()) // we have a unit test checking we never reach here
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func gen4DeleteStmtPlanner(
return nil, err
}

err = queryRewrite(ctx.SemTable, reservedVars, deleteStmt)
err = queryRewrite(ctx, deleteStmt)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func gen4InsertStmtPlanner(version querypb.ExecuteOptions_PlannerVersion, insStm
return nil, err
}

err = queryRewrite(ctx.SemTable, reservedVars, insStmt)
err = queryRewrite(ctx, insStmt)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) {

switch stmt := qb.stmt.(type) {
case *sqlparser.Select:
if containsAggr(expr) {
if ContainsAggr(qb.ctx, expr) {
addPred = stmt.AddHaving
} else {
addPred = stmt.AddWhere
Expand Down
8 changes: 7 additions & 1 deletion go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (a *Aggregator) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser
return newFilter(a, expr)
}

func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int {
func (a *Aggregator) addColumnWithoutPushing(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int {
offset := len(a.Columns)
a.Columns = append(a.Columns, expr)

Expand All @@ -96,6 +96,12 @@ func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, exp
switch e := expr.Expr.(type) {
case sqlparser.AggrFunc:
aggr = createAggrFromAggrFunc(e, expr)
case *sqlparser.FuncExpr:
if IsAggr(ctx, e) {
aggr = NewAggr(opcode.AggregateUDF, nil, expr, expr.As.String())
} else {
aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String())
}
default:
aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String())
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func breakExpressionInLHSandRHSForApplyJoin(
) (col applyJoinColumn) {
rewrittenExpr := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) {
nodeExpr, ok := cursor.Node().(sqlparser.Expr)
if !ok || !mustFetchFromInput(nodeExpr) {
if !ok || !mustFetchFromInput(ctx, nodeExpr) {
return
}
deps := ctx.SemTable.RecursiveDeps(nodeExpr)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Exp
}
inOffset := op.FindCol(ctx, expr, false)
if inOffset == -1 {
if !mustFetchFromInput(expr) {
if !mustFetchFromInput(ctx, expr) {
return -1
}

Expand Down Expand Up @@ -398,7 +398,7 @@ func (hj *HashJoin) addSingleSidedColumn(
}
inOffset := op.FindCol(ctx, expr, false)
if inOffset == -1 {
if !mustFetchFromInput(expr) {
if !mustFetchFromInput(ctx, expr) {
return -1
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/horizon.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.
}

newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo)
if sqlparser.ContainsAggregation(newExpr) {
if ContainsAggr(ctx, newExpr) {
return newFilter(h, expr)
}
h.Source = h.Source.AddPredicate(ctx, newExpr)
Expand Down
40 changes: 32 additions & 8 deletions go/vt/vtgate/planbuilder/operators/offset_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package operators
import (
"fmt"

"vitess.io/vitess/go/vt/vtgate/engine/opcode"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand Down Expand Up @@ -56,10 +58,12 @@ func planOffsets(ctx *plancontext.PlanningContext, root Operator) Operator {
}

// mustFetchFromInput returns true for expressions that have to be fetched from the input and cannot be evaluated
func mustFetchFromInput(e sqlparser.SQLNode) bool {
switch e.(type) {
func mustFetchFromInput(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) bool {
switch fun := e.(type) {
case *sqlparser.ColName, sqlparser.AggrFunc:
return true
case *sqlparser.FuncExpr:
return fun.Name.EqualsAnyString(ctx.VSchema.GetAggregateUDFs())
default:
return false
}
Expand Down Expand Up @@ -93,10 +97,10 @@ func useOffsets(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op Operat
return rewritten.(sqlparser.Expr)
}

// addColumnsToInput adds columns needed by an operator to its input.
// This happens only when the filter expression can be retrieved as an offset from the underlying mysql.
func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator {
visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
// addColumnsToInput adds columns needed by an operator to its input.
// This happens only when the filter expression can be retrieved as an offset from the underlying mysql.
addColumnsNeededByFilter := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) {
filter, ok := in.(*Filter)
if !ok {
return in, NoRewrite
Expand Down Expand Up @@ -126,11 +130,31 @@ func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator
return in, NoRewrite
}

// while we are out here walking the operator tree, if we find a UDF in an aggregation, we should fail
failUDFAggregation := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) {
aggrOp, ok := in.(*Aggregator)
if !ok {
return in, NoRewrite
}
for _, aggr := range aggrOp.Aggregations {
if aggr.OpCode == opcode.AggregateUDF {
// we don't support UDFs in aggregation if it's still above a route
panic(vterrors.VT03033(sqlparser.String(aggr.Original.Expr)))
}
}
return in, NoRewrite
}

visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
out, res := addColumnsNeededByFilter(in, semantics.EmptyTableSet(), isRoot)
failUDFAggregation(in, semantics.EmptyTableSet(), isRoot)
return out, res
}

return TopDown(root, TableID, visitor, stopAtRoute)
}

// addColumnsToInput adds columns needed by an operator to its input.
// This happens only when the filter expression can be retrieved as an offset from the underlying mysql.
// pullDistinctFromUNION will pull out the distinct from a union operator
func pullDistinctFromUNION(_ *plancontext.PlanningContext, root Operator) Operator {
systay marked this conversation as resolved.
Show resolved Hide resolved
visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
union, ok := in.(*Union)
Expand Down Expand Up @@ -170,7 +194,7 @@ func getOffsetRewritingVisitor(
return false
}

if mustFetchFromInput(e) {
if mustFetchFromInput(ctx, e) {
notFound(e)
return false
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (Operator,
case *Projection:
// we can move ordering under a projection if it's not introducing a column we're sorting by
for _, by := range in.Order {
if !mustFetchFromInput(by.SimplifiedExpr) {
if !mustFetchFromInput(ctx, by.SimplifiedExpr) {
return in, NoRewrite
}
}
Expand Down Expand Up @@ -459,7 +459,7 @@ func pushFilterUnderProjection(ctx *plancontext.PlanningContext, filter *Filter,
for _, p := range filter.Predicates {
cantPush := false
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
if !mustFetchFromInput(node) {
if !mustFetchFromInput(ctx, node) {
return true, nil
}

Expand Down
Loading
Loading