Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AVG on sharded queries #14419

Merged
merged 6 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func TestAggregateTypes(t *testing.T) {
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select sum(val1) from aggr_test", `[[FLOAT64(0)]]`)
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
}

func TestGroupBy(t *testing.T) {
Expand Down Expand Up @@ -172,6 +173,13 @@ func TestAggrOnJoin(t *testing.T) {

mcmp.AssertMatches("select a.val1 from aggr_test a join t3 t on a.val2 = t.id7 group by a.val1 having count(*) = 4",
`[[VARCHAR("a")]]`)

mcmp.AssertMatches(`select avg(a1.val2), avg(a2.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7`,
"[[DECIMAL(1.5000) DECIMAL(1.0000)]]")

mcmp.AssertMatches(`select a1.val1, avg(a1.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7 group by a1.val1`,
`[[VARCHAR("a") DECIMAL(1.0000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.0000)]]`)

}

func TestNotEqualFilterOnScatter(t *testing.T) {
Expand Down Expand Up @@ -314,22 +322,26 @@ func TestAggOnTopOfLimit(t *testing.T) {
for _, workload := range []string{"oltp", "olap"} {
t.Run(workload, func(t *testing.T) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))
mcmp.AssertMatches(" select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches(" select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches(" select count(*) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches(" select count(val1) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1)]]")
mcmp.AssertMatches(" select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]")
mcmp.AssertMatches(" select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`)
mcmp.AssertMatchesNoOrder(" select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`)
mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1)]]")
mcmp.AssertMatches("select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]")
mcmp.AssertMatches("select avg(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[NULL]]")
mcmp.AssertMatches("select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`)
mcmp.AssertMatchesNoOrder("select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`)
mcmp.AssertMatchesNoOrder("select val1, avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL DECIMAL(2.0000)] [VARCHAR("a") DECIMAL(3.5000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.5000)]]`)

// mysql returns FLOAT64(0), vitess returns DECIMAL(0)
mcmp.AssertMatchesNoCompare(" select count(*), sum(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0)]]", "[[INT64(2) FLOAT64(0)]]")
mcmp.AssertMatches(" select count(val1), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7)]]")
mcmp.AssertMatches(" select count(*), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2) DECIMAL(14)]]")
mcmp.AssertMatches(" select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]")
mcmp.AssertMatches(" select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]")
mcmp.AssertMatches(" select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`)
mcmp.AssertMatchesNoOrder(" select val1, count(val2), sum(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1) DECIMAL(2)] [VARCHAR("a") INT64(2) DECIMAL(7)] [VARCHAR("b") INT64(1) DECIMAL(1)] [VARCHAR("c") INT64(2) DECIMAL(7)]]`)
mcmp.AssertMatches("select count(*), sum(val1), avg(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0) FLOAT64(0)]]")
mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7)]]")
mcmp.AssertMatches("select count(val1), sum(id), avg(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7) DECIMAL(3.5000)]]")
mcmp.AssertMatches("select count(*), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2) DECIMAL(14)]]")
mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]")
mcmp.AssertMatches("select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]")
mcmp.AssertMatches("select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`)
mcmp.AssertMatchesNoOrder("select val1, count(val2), sum(val2), avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1",
`[[NULL INT64(1) DECIMAL(2) DECIMAL(2.0000)] [VARCHAR("a") INT64(2) DECIMAL(7) DECIMAL(3.5000)] [VARCHAR("b") INT64(1) DECIMAL(1) DECIMAL(1.0000)] [VARCHAR("c") INT64(2) DECIMAL(7) DECIMAL(3.5000)]]`)
})
}
}
Expand All @@ -343,6 +355,8 @@ func TestEmptyTableAggr(t *testing.T) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
})
Expand All @@ -355,8 +369,10 @@ func TestEmptyTableAggr(t *testing.T) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
})
}

Expand Down Expand Up @@ -398,6 +414,8 @@ func TestAggregateLeftJoin(t *testing.T) {
mcmp.AssertMatches("SELECT count(*) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[INT64(2)]]`)
mcmp.AssertMatches("SELECT sum(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
mcmp.AssertMatches("SELECT sum(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
mcmp.AssertMatches("SELECT avg(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(0.5000)]]`)
mcmp.AssertMatches("SELECT avg(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1.0000)]]`)
mcmp.AssertMatches("SELECT count(*) FROM t2 LEFT JOIN t1 ON t1.t1_id = t2.id WHERE IFNULL(t1.name, 'NOTSET') = 'r'", `[[INT64(1)]]`)
}

Expand Down Expand Up @@ -426,6 +444,7 @@ func TestScalarAggregate(t *testing.T) {

mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
mcmp.AssertMatches("select count(distinct val1) from aggr_test", `[[INT64(3)]]`)
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
}

func TestAggregationRandomOnAnAggregatedValue(t *testing.T) {
Expand Down Expand Up @@ -478,6 +497,7 @@ func TestComplexAggregation(t *testing.T) {
mcmp.Exec(`SELECT 1+COUNT(t1_id) FROM t1`)
mcmp.Exec(`SELECT COUNT(t1_id)+1 FROM t1`)
mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey) FROM t1`)
mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey)+AVG(t1_id) FROM t1`)
mcmp.Exec(`SELECT shardkey, MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT name+COUNT(t1_id)+1 FROM t1 GROUP BY name`)
Expand Down
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ const (
AggregateAnyValue
AggregateCountStar
AggregateGroupConcat
AggregateAvg
_NumOfOpCodes // This line must be last of the opcodes!
)

Expand All @@ -85,6 +86,7 @@ var (
AggregateCountStar: sqltypes.Int64,
AggregateSumDistinct: sqltypes.Decimal,
AggregateSum: sqltypes.Decimal,
AggregateAvg: sqltypes.Decimal,
AggregateGtid: sqltypes.VarChar,
}
)
Expand All @@ -96,6 +98,7 @@ var SupportedAggregates = map[string]AggregateOpcode{
"sum": AggregateSum,
"min": AggregateMin,
"max": AggregateMax,
"avg": AggregateAvg,
// These functions don't exist in mysql, but are used
// to display the plan.
"count_distinct": AggregateCountDistinct,
Expand All @@ -117,6 +120,7 @@ var AggregateName = map[AggregateOpcode]string{
AggregateCountStar: "count_star",
AggregateGroupConcat: "group_concat",
AggregateAnyValue: "any_value",
AggregateAvg: "avg",
}

func (code AggregateOpcode) String() string {
Expand Down Expand Up @@ -148,7 +152,7 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
return sqltypes.Text
case AggregateMax, AggregateMin, AggregateAnyValue:
return typ
case AggregateSumDistinct, AggregateSum:
case AggregateSumDistinct, AggregateSum, AggregateAvg:
if typ == sqltypes.Unknown {
return sqltypes.Unknown
}
Expand Down
109 changes: 91 additions & 18 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,33 @@ func tryPushAggregator(ctx *plancontext.PlanningContext, aggregator *Aggregator)
if aggregator.Pushed {
return aggregator, rewrite.SameTree, nil
}

// this rewrite is always valid, and we should do it whenever possible
if route, ok := aggregator.Source.(*Route); ok && (route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping)) {
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
}

// other rewrites require us to have reached this phase before we can consider them
if !reachedPhase(ctx, delegateAggregation) {
return aggregator, rewrite.SameTree, nil
}

// if we have not yet been able to push this aggregation down,
// we need to turn AVG into SUM/COUNT to support this over a sharded keyspace
if needAvgBreaking(aggregator.Aggregations) {
return splitAvgAggregations(ctx, aggregator)
}

switch src := aggregator.Source.(type) {
case *Route:
// if we have a single sharded route, we can push it down
output, applyResult, err = pushAggregationThroughRoute(ctx, aggregator, src)
case *ApplyJoin:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
case *Filter:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
case *SubQueryContainer:
if reachedPhase(ctx, delegateAggregation) {
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
}
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
default:
return aggregator, rewrite.SameTree, nil
}
Expand Down Expand Up @@ -135,15 +146,6 @@ func pushAggregationThroughRoute(
aggregator *Aggregator,
route *Route,
) (ops.Operator, *rewrite.ApplyResult, error) {
// If the route is single-shard, or we are grouping by sharding keys, we can just push down the aggregation
if route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping) {
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
}

if !reachedPhase(ctx, delegateAggregation) {
return nil, nil, nil
}

// Create a new aggregator to be placed below the route.
aggrBelowRoute := aggregator.SplitAggregatorBelowRoute(route.Inputs())
aggrBelowRoute.Aggregations = nil
Expand Down Expand Up @@ -806,3 +808,74 @@ func initColReUse(size int) []int {
}

func extractExpr(expr *sqlparser.AliasedExpr) sqlparser.Expr { return expr.Expr }

func needAvgBreaking(aggrs []Aggr) bool {
for _, aggr := range aggrs {
if aggr.OpCode == opcode.AggregateAvg {
return true
}
}
return false
}

// splitAvgAggregations takes an aggregator that has AVG aggregations in it and splits
// these into sum/count expressions that can be spread out to shards
func splitAvgAggregations(ctx *plancontext.PlanningContext, aggr *Aggregator) (ops.Operator, *rewrite.ApplyResult, error) {
proj := newAliasedProjection(aggr)

var columns []*sqlparser.AliasedExpr
var aggregations []Aggr

for offset, col := range aggr.Columns {
avg, ok := col.Expr.(*sqlparser.Avg)
if !ok {
proj.addColumnWithoutPushing(ctx, col, false /* addToGroupBy */)
continue
}

if avg.Distinct {
panic(vterrors.VT12001("AVG(distinct <>)"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding, supporting this would require loading all the distinct values from the shard, and then perform the avg calculation on the vtgate layer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. Have not spent enough time thinking about it to come up with a solution, which is why I just fail here.

}

// We have an AVG that we need to split
sumExpr := &sqlparser.Sum{Arg: avg.Arg}
countExpr := &sqlparser.Count{Args: []sqlparser.Expr{avg.Arg}}
calcExpr := &sqlparser.BinaryExpr{
Operator: sqlparser.DivOp,
Left: sumExpr,
Right: countExpr,
}
Comment on lines +843 to +847
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️


outputColumn := aeWrap(col.Expr)
outputColumn.As = sqlparser.NewIdentifierCI(col.ColumnName())
_, err := proj.addUnexploredExpr(sqlparser.CloneRefOfAliasedExpr(col), calcExpr)
if err != nil {
return nil, nil, err
}
col.Expr = sumExpr
found := false
for aggrOffset, aggregation := range aggr.Aggregations {
if offset == aggregation.ColOffset {
// We have found the AVG column. We'll change it to SUM, and then we add a COUNT as well
aggr.Aggregations[aggrOffset].OpCode = opcode.AggregateSum

countExprAlias := aeWrap(countExpr)
countAggr := NewAggr(opcode.AggregateCount, countExpr, countExprAlias, sqlparser.String(countExpr))
countAggr.ColOffset = len(aggr.Columns) + len(columns)
aggregations = append(aggregations, countAggr)
columns = append(columns, countExprAlias)
found = true
break // no need to search the remaining aggregations
systay marked this conversation as resolved.
Show resolved Hide resolved
}
}
if !found {
// if we get here, it's because we didn't find the aggregation. Something is wrong
panic(vterrors.VT13001("no aggregation pointing to this column was found"))
}
}

aggr.Columns = append(aggr.Columns, columns...)
aggr.Aggregations = append(aggr.Aggregations, aggregations...)

return proj, rewrite.NewTree("split avg aggregation", proj), nil
}
Loading
Loading