Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Planner cleaning: cleanup and refactor #16569

Merged
merged 11 commits into from
Aug 15, 2024
28 changes: 13 additions & 15 deletions go/vt/vtgate/endtoend/aggr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,35 @@ import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/utils"

"vitess.io/vitess/go/mysql"
)

func TestAggregateTypes(t *testing.T) {
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer conn.Close()

exec(t, conn, "insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
exec(t, conn, "insert into aggr_test(id, val1, val2) values(6,'d',null), (7,'e',null), (8,'E',1)")

qr := exec(t, conn, "select val1, count(distinct val2), count(*) from aggr_test group by val1")
if got, want := fmt.Sprintf("%v", qr.Rows), `[[VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("d") INT64(0) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)]]`; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
want := `[[VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("d") INT64(0) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)]]`
utils.MustMatch(t, want, fmt.Sprintf("%v", qr.Rows))

qr = exec(t, conn, "select val1, sum(distinct val2), sum(val2) from aggr_test group by val1")
if got, want := fmt.Sprintf("%v", qr.Rows), `[[VARCHAR("a") DECIMAL(1) DECIMAL(2)] [VARCHAR("b") DECIMAL(1) DECIMAL(1)] [VARCHAR("c") DECIMAL(7) DECIMAL(7)] [VARCHAR("d") NULL NULL] [VARCHAR("e") DECIMAL(1) DECIMAL(1)]]`; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
want = `[[VARCHAR("a") DECIMAL(1) DECIMAL(2)] [VARCHAR("b") DECIMAL(1) DECIMAL(1)] [VARCHAR("c") DECIMAL(7) DECIMAL(7)] [VARCHAR("d") NULL NULL] [VARCHAR("e") DECIMAL(1) DECIMAL(1)]]`
utils.MustMatch(t, want, fmt.Sprintf("%v", qr.Rows))

qr = exec(t, conn, "select val1, count(distinct val2) k, count(*) from aggr_test group by val1 order by k desc, val1")
if got, want := fmt.Sprintf("%v", qr.Rows), `[[VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)] [VARCHAR("d") INT64(0) INT64(1)]]`; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
want = `[[VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)] [VARCHAR("d") INT64(0) INT64(1)]]`
utils.MustMatch(t, want, fmt.Sprintf("%v", qr.Rows))

qr = exec(t, conn, "select val1, count(distinct val2) k, count(*) from aggr_test group by val1 order by k desc, val1 limit 4")
if got, want := fmt.Sprintf("%v", qr.Rows), `[[VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)]]`; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
want = `[[VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)]]`
utils.MustMatch(t, want, fmt.Sprintf("%v", qr.Rows))
}
16 changes: 8 additions & 8 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3913,15 +3913,15 @@ func TestSelectAggregationNoData(t *testing.T) {
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64")),
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit 2",
expField: `[name:"count(*)" type:INT64]`,
expSandboxQ: "select 1 from (select col1, col2 from `user`) as x limit 2",
expField: `[name:"count(*)" type:INT64 charset:63 flags:32769]`,
expRow: `[[INT64(0)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 2) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary")),
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit 2",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expSandboxQ: "select x.col1, x.col2, weight_string(x.col2) from (select col1, col2 from `user`) as x limit 2",
expField: `[name:"col2" type:INT64 charset:63 flags:32768 name:"count(*)" type:INT64 charset:63 flags:32769]`,
expRow: `[]`,
},
}
Expand Down Expand Up @@ -4005,15 +4005,15 @@ func TestSelectAggregationData(t *testing.T) {
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64"), "100|200|1", "200|300|1"),
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit 2",
expField: `[name:"count(*)" type:INT64]`,
expSandboxQ: "select 1 from (select col1, col2 from `user`) as x limit 2",
expField: `[name:"count(*)" type:INT64 charset:63 flags:32769]`,
expRow: `[[INT64(2)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 9) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary"), "100|3|1|NULL", "200|2|1|NULL"),
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit 9",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expSandboxQ: "select x.col1, x.col2, weight_string(x.col2) from (select col1, col2 from `user`) as x limit 9",
expField: `[name:"col2" type:INT64 charset:63 flags:32768 name:"count(*)" type:INT64 charset:63 flags:32769]`,
expRow: `[[INT64(2) INT64(4)] [INT64(3) INT64(5)]]`,
},
{
Expand Down
40 changes: 35 additions & 5 deletions go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ type (
// This is used to truncate the columns in the final result
ResultColumns int

// Truncate is set to true if the columns produced by this operator should be truncated if we added any additional columns
Truncate bool

QP *QueryProjection

DT *DerivedTable
Expand Down Expand Up @@ -151,6 +154,8 @@ func (a *Aggregator) checkOffset(offset int) {
}

func (a *Aggregator) AddColumn(ctx *plancontext.PlanningContext, reuse bool, groupBy bool, ae *sqlparser.AliasedExpr) (offset int) {
a.planOffsets(ctx)

defer func() {
a.checkOffset(offset)
}()
Expand Down Expand Up @@ -199,6 +204,10 @@ func (a *Aggregator) AddColumn(ctx *plancontext.PlanningContext, reuse bool, gro
}

func (a *Aggregator) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int {
if !underRoute {
a.planOffsets(ctx)
}

if len(a.Columns) <= offset {
panic(vterrors.VT13001("offset out of range"))
}
Expand All @@ -221,7 +230,7 @@ func (a *Aggregator) AddWSColumn(ctx *plancontext.PlanningContext, offset int, u
}

if expr == nil {
for _, aggr := range a.Aggregations {
for i, aggr := range a.Aggregations {
if aggr.ColOffset != offset {
continue
}
Expand All @@ -230,9 +239,13 @@ func (a *Aggregator) AddWSColumn(ctx *plancontext.PlanningContext, offset int, u
return aggr.WSOffset
}

panic(vterrors.VT13001("expected to find a weight string for aggregation"))
a.Aggregations[i].WSOffset = len(a.Columns)
expr = a.Columns[offset].Expr
break
}
}

if expr == nil {
panic(vterrors.VT13001("could not find expression at offset"))
}

Expand Down Expand Up @@ -515,7 +528,7 @@ func (a *Aggregator) pushRemainingGroupingColumnsAndWeightStrings(ctx *planconte
continue
}

offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(gb.Inner)), false)
offset := a.internalAddWSColumn(ctx, a.Grouping[idx].ColOffset, aeWrap(weightStringFor(gb.Inner)))
a.Grouping[idx].WSOffset = offset
}
for idx, aggr := range a.Aggregations {
Expand All @@ -524,11 +537,28 @@ func (a *Aggregator) pushRemainingGroupingColumnsAndWeightStrings(ctx *planconte
}

arg := aggr.getPushColumn()
offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(arg)), false)
offset := a.internalAddWSColumn(ctx, aggr.ColOffset, aeWrap(weightStringFor(arg)))

a.Aggregations[idx].WSOffset = offset
}
}

func (a *Aggregator) internalAddWSColumn(ctx *plancontext.PlanningContext, inOffset int, aliasedExpr *sqlparser.AliasedExpr) int {
if a.ResultColumns == 0 && a.Truncate {
// if we need to use `internalAddColumn`, it means we are adding columns that are not part of the original list,
// so we need to set the ResultColumns to the current length of the columns list
a.ResultColumns = len(a.Columns)
}

offset := a.Source.AddWSColumn(ctx, inOffset, false)

if offset == len(a.Columns) {
// if we get an offset at the end of our current column list, it means we added a new column
a.Columns = append(a.Columns, aliasedExpr)
}
return offset
}

func (a *Aggregator) setTruncateColumnCount(offset int) {
a.ResultColumns = offset
}
Expand All @@ -538,7 +568,7 @@ func (a *Aggregator) getTruncateColumnCount() int {
}

func (a *Aggregator) internalAddColumn(ctx *plancontext.PlanningContext, aliasedExpr *sqlparser.AliasedExpr, addToGroupBy bool) int {
if a.ResultColumns == 0 {
if a.ResultColumns == 0 && a.Truncate {
// if we need to use `internalAddColumn`, it means we are adding columns that are not part of the original list,
// so we need to set the ResultColumns to the current length of the columns list
a.ResultColumns = len(a.Columns)
Expand Down
29 changes: 12 additions & 17 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ type (
// so they can be used for the result of this expression that is using data from both sides.
// All fields will be used for these
applyJoinColumn struct {
Original sqlparser.Expr // this is the original expression being passed through
LHSExprs []BindVarExpr // These are the expressions we are pushing to the left hand side which we'll receive as bind variables
RHSExpr sqlparser.Expr // This the expression that we'll evaluate on the right hand side. This is nil, if the right hand side has nothing.
DTColName *sqlparser.ColName // This is the output column name that the parent of JOIN will be seeing. If this is unset, then the colname is the String(Original). We set this when we push Projections with derived tables underneath a Join.
GroupBy bool // if this is true, we need to push this down to our inputs with addToGroupBy set to true
Original sqlparser.Expr // this is the original expression being passed through
LHSExprs []BindVarExpr // These are the expressions we are pushing to the left hand side which we'll receive as bind variables
RHSExpr sqlparser.Expr // This the expression that we'll evaluate on the right hand side. This is nil, if the right hand side has nothing.
GroupBy bool // if this is true, we need to push this down to our inputs with addToGroupBy set to true
}

// BindVarExpr is an expression needed from one side of a join/subquery, and the argument name for it.
Expand Down Expand Up @@ -225,8 +224,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq

func applyJoinCompare(ctx *plancontext.PlanningContext, expr sqlparser.Expr) func(e applyJoinColumn) bool {
return func(e applyJoinColumn) bool {
// e.DTColName is how the outside world will be using this expression. So we should check for an equality with that too.
return ctx.SemTable.EqualsExprWithDeps(e.Original, expr) || ctx.SemTable.EqualsExprWithDeps(e.DTColName, expr)
return ctx.SemTable.EqualsExprWithDeps(e.Original, expr)
}
}

Expand Down Expand Up @@ -302,12 +300,18 @@ func (aj *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) Operator {

for _, col := range aj.JoinPredicates.columns {
for _, lhsExpr := range col.LHSExprs {
if _, found := aj.Vars[lhsExpr.Name]; found {
continue
}
offset := aj.LHS.AddColumn(ctx, true, false, aeWrap(lhsExpr.Expr))
aj.Vars[lhsExpr.Name] = offset
}
}

for _, lhsExpr := range aj.ExtraLHSVars {
if _, found := aj.Vars[lhsExpr.Name]; found {
continue
}
offset := aj.LHS.AddColumn(ctx, true, false, aeWrap(lhsExpr.Expr))
aj.Vars[lhsExpr.Name] = offset
}
Expand Down Expand Up @@ -441,11 +445,8 @@ func (jc applyJoinColumn) String() string {
lhs := slice.Map(jc.LHSExprs, func(e BindVarExpr) string {
return sqlparser.String(e.Expr)
})
if jc.DTColName == nil {
return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original))
}

return fmt.Sprintf("[%s | %s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original), sqlparser.String(jc.DTColName))
return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original))
}

func (jc applyJoinColumn) IsPureLeft() bool {
Expand All @@ -461,16 +462,10 @@ func (jc applyJoinColumn) IsMixedLeftAndRight() bool {
}

func (jc applyJoinColumn) GetPureLeftExpr() sqlparser.Expr {
if jc.DTColName != nil {
return jc.DTColName
}
return jc.LHSExprs[0].Expr
}

func (jc applyJoinColumn) GetRHSExpr() sqlparser.Expr {
if jc.DTColName != nil {
return jc.DTColName
}
return jc.RHSExpr
}

Expand Down
12 changes: 10 additions & 2 deletions go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ func createOperatorFromUnion(ctx *plancontext.PlanningContext, node *sqlparser.U
if isRHSUnion {
panic(vterrors.VT12001("nesting of UNIONs on the right-hand side"))
}
opLHS := translateQueryToOp(ctx, node.Left)
opRHS := translateQueryToOp(ctx, node.Right)
opLHS := translateQueryToOpForUnion(ctx, node.Left)
opRHS := translateQueryToOpForUnion(ctx, node.Right)
lexprs := ctx.SemTable.SelectExprs(node.Left)
rexprs := ctx.SemTable.SelectExprs(node.Right)

Expand All @@ -158,6 +158,14 @@ func createOperatorFromUnion(ctx *plancontext.PlanningContext, node *sqlparser.U
return newHorizon(union, node)
}

func translateQueryToOpForUnion(ctx *plancontext.PlanningContext, node sqlparser.SelectStatement) Operator {
op := translateQueryToOp(ctx, node)
if hz, ok := op.(*Horizon); ok {
hz.Truncate = true
}
return op
}

// createOpFromStmt creates an operator from the given statement. It takes in two additional arguments—
// 1. verifyAllFKs: For this given statement, do we need to verify validity of all the foreign keys on the vtgate level.
// 2. fkToIgnore: The foreign key constraint to specifically ignore while planning the statement. This field is used in UPDATE CASCADE planning, wherein while planning the child update
Expand Down
12 changes: 12 additions & 0 deletions go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,15 @@ func breakExpressionInLHSandRHS(
col.Original = expr
return
}

// nothingNeedsFetching will return true if all the nodes in the expression are constant
func nothingNeedsFetching(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (constant bool) {
constant = true
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
if mustFetchFromInput(ctx, node) {
constant = false
}
return true, nil
}, expr)
return
}
2 changes: 2 additions & 0 deletions go/vt/vtgate/planbuilder/operators/horizon.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ type Horizon struct {
// Columns needed to feed other plans
Columns []*sqlparser.ColName
ColumnsOffset []int

Truncate bool
}

func newHorizon(src Operator, query sqlparser.SelectStatement) *Horizon {
Expand Down
14 changes: 9 additions & 5 deletions go/vt/vtgate/planbuilder/operators/horizon_expanding.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,17 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz
}

if qp.NeedsAggregation() {
return createProjectionWithAggr(ctx, qp, dt, horizon.src())
return createProjectionWithAggr(ctx, qp, dt, horizon)
}

projX := createProjectionWithoutAggr(ctx, qp, horizon.src())
projX.DT = dt
return projX
}

func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProjection, dt *DerivedTable, src Operator) Operator {
func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProjection, dt *DerivedTable, horizon *Horizon) Operator {
aggregations, complexAggr := qp.AggregationExpressions(ctx, true)
src := horizon.Source
aggrOp := &Aggregator{
Source: src,
Original: true,
Expand All @@ -239,7 +240,11 @@ func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProject
if complexAggr {
return createProjectionForComplexAggregation(aggrOp, qp)
}
return createProjectionForSimpleAggregation(ctx, aggrOp, qp)

addAllColumnsToAggregator(ctx, aggrOp, qp)
aggrOp.Truncate = horizon.Truncate

return aggrOp
}

func pullOutValueSubqueries(ctx *plancontext.PlanningContext, aggr Aggr, sqc *SubQueryBuilder, outerID semantics.TableSet) Aggr {
Expand All @@ -261,7 +266,7 @@ func pullOutValueSubqueries(ctx *plancontext.PlanningContext, aggr Aggr, sqc *Su
return aggr
}

func createProjectionForSimpleAggregation(ctx *plancontext.PlanningContext, a *Aggregator, qp *QueryProjection) Operator {
func addAllColumnsToAggregator(ctx *plancontext.PlanningContext, a *Aggregator, qp *QueryProjection) {
outer:
for colIdx, expr := range qp.SelectExprs {
ae, err := expr.GetAliasedExpr()
Expand Down Expand Up @@ -292,7 +297,6 @@ outer:
}
panic(vterrors.VT13001(fmt.Sprintf("Could not find the %s in aggregation in the original query", sqlparser.String(ae))))
}
return a
}

func createProjectionForComplexAggregation(a *Aggregator, qp *QueryProjection) Operator {
Expand Down
Loading
Loading