Skip to content

Commit

Permalink
Added support for group_concat and count distinct with multiple expre…
Browse files Browse the repository at this point in the history
…ssions (#14851)

Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal authored Dec 26, 2023
1 parent d072adb commit 2176095
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 51 deletions.
14 changes: 13 additions & 1 deletion go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,12 @@ func TestGroupConcatAggregation(t *testing.T) {
compareRow(t, mQr, vtQr, nil, []int{0})
mQr, vtQr = mcmp.ExecNoCompare(`SELECT group_concat(value), t1.name FROM t1, t2 group by t1.name`)
compareRow(t, mQr, vtQr, []int{1}, []int{0})
if versionMet := utils.BinaryIsAtLeastAtVersion(19, "vtgate"); !versionMet {
// skipping
return
}
mQr, vtQr = mcmp.ExecNoCompare(`SELECT group_concat(name, value) FROM t1`)
compareRow(t, mQr, vtQr, nil, []int{0})
}

func compareRow(t *testing.T, mRes *sqltypes.Result, vtRes *sqltypes.Result, grpCols []int, fCols []int) {
Expand Down Expand Up @@ -613,6 +619,7 @@ func TestDistinctAggregation(t *testing.T) {
tcases := []struct {
query string
expectedErr string
minVersion int
}{{
query: `SELECT COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`,
expectedErr: "VT12001: unsupported: only one DISTINCT aggregation is allowed in a SELECT: sum(distinct shardkey) (errno 1235) (sqlstate 42000)",
Expand All @@ -626,10 +633,15 @@ func TestDistinctAggregation(t *testing.T) {
}, {
query: `SELECT a.value, SUM(DISTINCT b.t1_id), min(DISTINCT a.t1_id) FROM t1 a, t1 b group by a.value`,
}, {
query: `SELECT distinct count(*) from t1, (select distinct count(*) from t1) as t2`,
minVersion: 19,
query: `SELECT count(distinct name, shardkey) from t1`,
}}

for _, tc := range tcases {
if versionMet := utils.BinaryIsAtLeastAtVersion(tc.minVersion, "vtgate"); !versionMet {
// skipping
continue
}
mcmp.Run(tc.query, func(mcmp *utils.MySQLCompare) {
_, err := mcmp.ExecAllowError(tc.query)
if tc.expectedErr == "" {
Expand Down
54 changes: 39 additions & 15 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ import (
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
)

func errDistinctAggrWithMultiExpr(f sqlparser.AggrFunc) {
if f == nil {
panic(vterrors.VT12001("distinct aggregation function with multiple expressions"))
}
panic(vterrors.VT12001(fmt.Sprintf("distinct aggregation function with multiple expressions '%s'", sqlparser.String(f))))
}

func tryPushAggregator(ctx *plancontext.PlanningContext, aggregator *Aggregator) (output Operator, applyResult *ApplyResult) {
if aggregator.Pushed {
return aggregator, NoRewrite
Expand Down Expand Up @@ -162,7 +169,7 @@ func pushAggregationThroughRoute(

// pushAggregations splits aggregations between the original aggregator and the one we are pushing down
func pushAggregations(ctx *plancontext.PlanningContext, aggregator *Aggregator, aggrBelowRoute *Aggregator) {
canPushDistinctAggr, distinctExpr := checkIfWeCanPush(ctx, aggregator)
canPushDistinctAggr, distinctExprs := checkIfWeCanPush(ctx, aggregator)

distinctAggrGroupByAdded := false

Expand All @@ -173,54 +180,68 @@ func pushAggregations(ctx *plancontext.PlanningContext, aggregator *Aggregator,
continue
}

if len(distinctExprs) != 1 {
errDistinctAggrWithMultiExpr(aggr.Func)
}

// We handle a distinct aggregation by turning it into a group by and
// doing the aggregating on the vtgate level instead
aeDistinctExpr := aeWrap(distinctExpr)
aeDistinctExpr := aeWrap(distinctExprs[0])
aggrBelowRoute.Columns[aggr.ColOffset] = aeDistinctExpr

// We handle a distinct aggregation by turning it into a group by and
// doing the aggregating on the vtgate level instead
// Adding to group by can be done only once even though there are multiple distinct aggregation with same expression.
if !distinctAggrGroupByAdded {
groupBy := NewGroupBy(distinctExpr, distinctExpr)
groupBy := NewGroupBy(distinctExprs[0], distinctExprs[0])
groupBy.ColOffset = aggr.ColOffset
aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy)
distinctAggrGroupByAdded = true
}
}

if !canPushDistinctAggr {
aggregator.DistinctExpr = distinctExpr
aggregator.DistinctExpr = distinctExprs[0]
}
}

func checkIfWeCanPush(ctx *plancontext.PlanningContext, aggregator *Aggregator) (bool, sqlparser.Expr) {
func checkIfWeCanPush(ctx *plancontext.PlanningContext, aggregator *Aggregator) (bool, sqlparser.Exprs) {
canPush := true
var distinctExpr sqlparser.Expr
var distinctExprs sqlparser.Exprs
var differentExpr *sqlparser.AliasedExpr

for _, aggr := range aggregator.Aggregations {
if !aggr.Distinct {
continue
}

innerExpr := aggr.Func.GetArg()
if !exprHasUniqueVindex(ctx, innerExpr) {
args := aggr.Func.GetArgs()
hasUniqVindex := false
for _, arg := range args {
if exprHasUniqueVindex(ctx, arg) {
hasUniqVindex = true
break
}
}
if !hasUniqVindex {
canPush = false
}
if distinctExpr == nil {
distinctExpr = innerExpr
if len(distinctExprs) == 0 {
distinctExprs = args
}
if !ctx.SemTable.EqualsExpr(distinctExpr, innerExpr) {
differentExpr = aggr.Original
for idx, expr := range distinctExprs {
if !ctx.SemTable.EqualsExpr(expr, args[idx]) {
differentExpr = aggr.Original
break
}
}
}

if !canPush && differentExpr != nil {
panic(vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(differentExpr))))
}

return canPush, distinctExpr
return canPush, distinctExprs
}

func pushAggregationThroughFilter(
Expand Down Expand Up @@ -530,12 +551,15 @@ func splitAggrColumnsToLeftAndRight(
outerJoin: leftJoin,
}

canPushDistinctAggr, distinctExpr := checkIfWeCanPush(ctx, aggregator)
canPushDistinctAggr, distinctExprs := checkIfWeCanPush(ctx, aggregator)

// Distinct aggregation cannot be pushed down in the join.
// We keep node of the distinct aggregation expression to be used later for ordering.
if !canPushDistinctAggr {
aggregator.DistinctExpr = distinctExpr
if len(distinctExprs) != 1 {
errDistinctAggrWithMultiExpr(nil)
}
aggregator.DistinctExpr = distinctExprs[0]
return nil, errAbortAggrPushing
}

Expand Down
14 changes: 9 additions & 5 deletions go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator {
if a.offsetPlanned {
return nil
}
a.checkForInvalidAggregations()
defer func() {
a.offsetPlanned = true
}()
Expand All @@ -281,7 +280,8 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator {
if !aggr.NeedsWeightString(ctx) {
continue
}
offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(aggr.Func.GetArg())), true)
arg := aggr.getPushColumn()
offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(arg)), true)
a.Aggregations[idx].WSOffset = offset
}
return nil
Expand All @@ -295,10 +295,13 @@ func (aggr Aggr) getPushColumn() sqlparser.Expr {
return sqlparser.NewIntLiteral("1")
case opcode.AggregateGroupConcat:
if len(aggr.Func.GetArgs()) > 1 {
panic("more than 1 column")
panic(vterrors.VT12001("group_concat with more than 1 column"))
}
fallthrough
return aggr.Func.GetArg()
default:
if len(aggr.Func.GetArgs()) > 1 {
panic(vterrors.VT03001(sqlparser.String(aggr.Func)))
}
return aggr.Func.GetArg()
}
}
Expand Down Expand Up @@ -380,7 +383,8 @@ func (a *Aggregator) pushRemainingGroupingColumnsAndWeightStrings(ctx *planconte
continue
}

offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(aggr.Func.GetArg())), false)
arg := aggr.getPushColumn()
offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(arg)), false)
a.Aggregations[idx].WSOffset = offset
}
}
Expand Down
22 changes: 12 additions & 10 deletions go/vt/vtgate/planbuilder/operators/horizon_expanding.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel
return op, Rewrote(fmt.Sprintf("expand SELECT horizon into (%s)", strings.Join(extracted, ", ")))
}

func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horizon) (out Operator) {
func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horizon) Operator {
qp := horizon.getQP(ctx)

var dt *DerivedTable
Expand All @@ -131,15 +131,15 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz
if !qp.NeedsAggregation() {
projX := createProjectionWithoutAggr(ctx, qp, horizon.src())
projX.DT = dt
out = projX

return out
return projX
}

aggregations, complexAggr := qp.AggregationExpressions(ctx, true)
return createProjectionWithAggr(ctx, qp, dt, horizon.src())
}

src := horizon.src()
a := &Aggregator{
func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProjection, dt *DerivedTable, src Operator) Operator {
aggregations, complexAggr := qp.AggregationExpressions(ctx, true)
aggrOp := &Aggregator{
Source: src,
Original: true,
QP: qp,
Expand All @@ -148,6 +148,7 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz
DT: dt,
}

// Go through all aggregations and check for any subquery.
sqc := &SubQueryBuilder{}
outerID := TableID(src)
for idx, aggr := range aggregations {
Expand All @@ -157,12 +158,13 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz
aggregations[idx].SubQueryExpression = subqs
}
}
a.Source = sqc.getRootOperator(src, nil)
aggrOp.Source = sqc.getRootOperator(src, nil)

// create the projection columns from aggregator.
if complexAggr {
return createProjectionForComplexAggregation(a, qp)
return createProjectionForComplexAggregation(aggrOp, qp)
}
return createProjectionForSimpleAggregation(ctx, a, qp)
return createProjectionForSimpleAggregation(ctx, aggrOp, qp)
}

func createProjectionForSimpleAggregation(ctx *plancontext.PlanningContext, a *Aggregator, qp *QueryProjection) Operator {
Expand Down
Loading

0 comments on commit 2176095

Please sign in to comment.