Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AVG on sharded queries #14419

Merged
merged 6 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ const (
AggregateAnyValue
AggregateCountStar
AggregateGroupConcat
AggregateAvg
_NumOfOpCodes // This line must be last of the opcodes!
)

Expand All @@ -85,6 +86,7 @@ var (
AggregateCountStar: sqltypes.Int64,
AggregateSumDistinct: sqltypes.Decimal,
AggregateSum: sqltypes.Decimal,
AggregateAvg: sqltypes.Decimal,
AggregateGtid: sqltypes.VarChar,
}
)
Expand All @@ -96,6 +98,7 @@ var SupportedAggregates = map[string]AggregateOpcode{
"sum": AggregateSum,
"min": AggregateMin,
"max": AggregateMax,
"avg": AggregateAvg,
// These functions don't exist in mysql, but are used
// to display the plan.
"count_distinct": AggregateCountDistinct,
Expand All @@ -117,6 +120,7 @@ var AggregateName = map[AggregateOpcode]string{
AggregateCountStar: "count_star",
AggregateGroupConcat: "group_concat",
AggregateAnyValue: "any_value",
AggregateAvg: "avg",
}

func (code AggregateOpcode) String() string {
Expand Down Expand Up @@ -148,7 +152,7 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
return sqltypes.Text
case AggregateMax, AggregateMin, AggregateAnyValue:
return typ
case AggregateSumDistinct, AggregateSum:
case AggregateSumDistinct, AggregateSum, AggregateAvg:
if typ == sqltypes.Unknown {
return sqltypes.Unknown
}
Expand Down
110 changes: 92 additions & 18 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,39 @@ func tryPushAggregator(ctx *plancontext.PlanningContext, aggregator *Aggregator)
if aggregator.Pushed {
return aggregator, rewrite.SameTree, nil
}

// this rewrite is always valid, and we should do it whenever possible
if route, ok := aggregator.Source.(*Route); ok && (route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping)) {
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
}

// other rewrites require us to have reached this phase before we can consider them
if !reachedPhase(ctx, delegateAggregation) {
return aggregator, rewrite.SameTree, nil
}

// if we have not yet been able to push this aggregation down,
// we need to turn AVG into SUM/COUNT to support this over a sharded keyspace
if needAvgBreaking(aggregator.Aggregations) {
output, err = splitAvgAggregations(ctx, aggregator)
if err != nil {
return nil, nil, err
}

applyResult = rewrite.NewTree("split avg aggregation", output)
return
}

switch src := aggregator.Source.(type) {
case *Route:
// if we have a single sharded route, we can push it down
output, applyResult, err = pushAggregationThroughRoute(ctx, aggregator, src)
case *ApplyJoin:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
case *Filter:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
case *SubQueryContainer:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
default:
return aggregator, rewrite.SameTree, nil
}
Expand Down Expand Up @@ -135,15 +152,6 @@ func pushAggregationThroughRoute(
aggregator *Aggregator,
route *Route,
) (ops.Operator, *rewrite.ApplyResult, error) {
// If the route is single-shard, or we are grouping by sharding keys, we can just push down the aggregation
if route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping) {
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
}

if !reachedPhase(ctx, delegateAggregation) {
return nil, nil, nil
}

// Create a new aggregator to be placed below the route.
aggrBelowRoute := aggregator.SplitAggregatorBelowRoute(route.Inputs())
aggrBelowRoute.Aggregations = nil
Expand Down Expand Up @@ -806,3 +814,69 @@ func initColReUse(size int) []int {
}

func extractExpr(expr *sqlparser.AliasedExpr) sqlparser.Expr { return expr.Expr }

func needAvgBreaking(aggrs []Aggr) bool {
for _, aggr := range aggrs {
if aggr.OpCode == opcode.AggregateAvg {
return true
}
}
return false
}

// splitAvgAggregations takes an aggregator that has AVG aggregations in it and splits
// these into sum/count expressions that can be spread out to shards
func splitAvgAggregations(ctx *plancontext.PlanningContext, aggr *Aggregator) (ops.Operator, error) {
proj := newAliasedProjection(aggr)

var columns []*sqlparser.AliasedExpr
var aggregations []Aggr

for offset, col := range aggr.Columns {
avg, ok := col.Expr.(*sqlparser.Avg)
if !ok {
proj.addColumnWithoutPushing(ctx, col, false /* addToGroupBy */)
continue
}

if avg.Distinct {
panic(vterrors.VT12001("AVG(distinct <>)"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding, supporting this would require loading all the distinct values from the shard, and then perform the avg calculation on the vtgate layer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. Have not spent enough time thinking about it to come up with a solution, which is why I just fail here.

}

// We have an AVG that we need to split
sumExpr := &sqlparser.Sum{Arg: avg.Arg}
countExpr := &sqlparser.Count{Args: []sqlparser.Expr{avg.Arg}}
calcExpr := &sqlparser.BinaryExpr{
Operator: sqlparser.DivOp,
Left: sumExpr,
Right: countExpr,
}
Comment on lines +843 to +847
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️


outputColumn := aeWrap(col.Expr)
outputColumn.As = sqlparser.NewIdentifierCI(col.ColumnName())
_, err := proj.addUnexploredExpr(sqlparser.CloneRefOfAliasedExpr(col), calcExpr)
if err != nil {
return nil, err
}
col.Expr = sumExpr

for aggrOffset, aggregation := range aggr.Aggregations {
if offset == aggregation.ColOffset {
// We have found the AVG column. We'll change it to SUM, and then we add a COUNT as well
aggr.Aggregations[aggrOffset].OpCode = opcode.AggregateSum

countExprAlias := aeWrap(countExpr)
countAggr := NewAggr(opcode.AggregateCount, countExpr, countExprAlias, sqlparser.String(countExpr))
countAggr.ColOffset = len(aggr.Columns) + len(columns)
aggregations = append(aggregations, countAggr)
columns = append(columns, countExprAlias)
break // no need to search the remaining aggregations
systay marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

aggr.Columns = append(aggr.Columns, columns...)
aggr.Aggregations = append(aggr.Aggregations, aggregations...)

return proj, nil
}
59 changes: 34 additions & 25 deletions go/vt/vtgate/planbuilder/operators/phases.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package operators
import (
"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand Down Expand Up @@ -56,9 +57,9 @@ func (p Phase) String() string {
return "optimize Distinct operations"
case subquerySettling:
return "settle subqueries"
default:
panic(vterrors.VT13001("unhandled default case"))
}

return "unknown"
}

func (p Phase) shouldRun(s semantics.QuerySignature) bool {
Expand All @@ -73,8 +74,9 @@ func (p Phase) shouldRun(s semantics.QuerySignature) bool {
return s.Distinct
case subquerySettling:
return s.SubQueries
default:
return true
}
return true
}

func (p Phase) act(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Operator, error) {
Expand All @@ -84,14 +86,14 @@ func (p Phase) act(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Opera
case delegateAggregation:
return enableDelegateAggregation(ctx, op)
case addAggrOrdering:
return addOrderBysForAggregations(ctx, op)
return addOrderingForAllAggregations(ctx, op)
case cleanOutPerfDistinct:
return removePerformanceDistinctAboveRoute(ctx, op)
case subquerySettling:
return settleSubqueries(ctx, op), nil
default:
return op, nil
}

return op, nil
}

// getPhases returns the ordered phases that the planner will undergo.
Expand Down Expand Up @@ -120,7 +122,8 @@ func enableDelegateAggregation(ctx *plancontext.PlanningContext, op ops.Operator
return addColumnsToInput(ctx, op)
}

func addOrderBysForAggregations(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) {
// addOrderingForAllAggregations is run we have pushed down Aggregators as far down as possible.
func addOrderingForAllAggregations(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) {
visitor := func(in ops.Operator, _ semantics.TableSet, isRoot bool) (ops.Operator, *rewrite.ApplyResult, error) {
aggrOp, ok := in.(*Aggregator)
if !ok {
Expand All @@ -131,30 +134,36 @@ func addOrderBysForAggregations(ctx *plancontext.PlanningContext, root ops.Opera
if err != nil {
return nil, nil, err
}
if !requireOrdering {
return in, rewrite.SameTree, nil
}
orderBys := slice.Map(aggrOp.Grouping, func(from GroupBy) ops.OrderBy {
return from.AsOrderBy()
})
if aggrOp.DistinctExpr != nil {
orderBys = append(orderBys, ops.OrderBy{
Inner: &sqlparser.Order{
Expr: aggrOp.DistinctExpr,
},
SimplifiedExpr: aggrOp.DistinctExpr,
})
}
aggrOp.Source = &Ordering{
Source: aggrOp.Source,
Order: orderBys,

var res *rewrite.ApplyResult
if requireOrdering {
addOrderingFor(aggrOp)
res = rewrite.NewTree("added ordering before aggregation", in)
}
return in, rewrite.NewTree("added ordering before aggregation", in), nil
return in, res, nil
}

return rewrite.BottomUp(root, TableID, visitor, stopAtRoute)
}

func addOrderingFor(aggrOp *Aggregator) {
orderBys := slice.Map(aggrOp.Grouping, func(from GroupBy) ops.OrderBy {
return from.AsOrderBy()
})
if aggrOp.DistinctExpr != nil {
orderBys = append(orderBys, ops.OrderBy{
Inner: &sqlparser.Order{
Expr: aggrOp.DistinctExpr,
},
SimplifiedExpr: aggrOp.DistinctExpr,
})
}
aggrOp.Source = &Ordering{
Source: aggrOp.Source,
Order: orderBys,
}
}

func needsOrdering(ctx *plancontext.PlanningContext, in *Aggregator) (bool, error) {
requiredOrder := slice.Map(in.Grouping, func(from GroupBy) sqlparser.Expr {
return from.SimplifiedExpr
Expand Down
Loading
Loading