diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 458c8c5e1c3..a22719b4489 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -87,12 +87,12 @@ func transformToPrimitive(ctx *plancontext.PlanningContext, op operators.Operato } func transformPercentBasedMirror(ctx *plancontext.PlanningContext, op *operators.PercentBasedMirror) (engine.Primitive, error) { - primitive, err := transformToPrimitive(ctx, op.Operator) + primitive, err := transformToPrimitive(ctx, op.Operator()) if err != nil { return nil, err } - target, err := transformToPrimitive(ctx.UseMirror(), op.Target) + target, err := transformToPrimitive(ctx.UseMirror(), op.Target()) // Mirroring is best-effort. If we encounter an error while building the // mirror target primitive, proceed without mirroring. if err != nil { @@ -169,7 +169,7 @@ func transformSequential(ctx *plancontext.PlanningContext, op *operators.Sequent } func transformInsertionSelection(ctx *plancontext.PlanningContext, op *operators.InsertSelection) (engine.Primitive, error) { - rb, isRoute := op.Insert.(*operators.Route) + rb, isRoute := op.Insert().(*operators.Route) if !isRoute { return nil, vterrors.VT13001(fmt.Sprintf("Incorrect type encountered: %T (transformInsertionSelection)", op.Insert)) } @@ -198,7 +198,7 @@ func transformInsertionSelection(ctx *plancontext.PlanningContext, op *operators eins.Prefix, _, eins.Suffix = generateInsertShardedQuery(ins.AST) - selectionPlan, err := transformToPrimitive(ctx, op.Select) + selectionPlan, err := transformToPrimitive(ctx, op.Select()) if err != nil { return nil, err } @@ -1000,11 +1000,11 @@ func transformVindexPlan(ctx *plancontext.PlanningContext, op *operators.Vindex) } func transformRecurseCTE(ctx *plancontext.PlanningContext, op *operators.RecurseCTE) (engine.Primitive, error) { - seed, err := transformToPrimitive(ctx, op.Seed) + seed, err := transformToPrimitive(ctx, op.Seed()) if err != nil { return nil, err } - term, err := transformToPrimitive(ctx, op.Term) + term, err := transformToPrimitive(ctx, op.Term()) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 3972ac8290a..20af2a698c3 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -717,9 +717,9 @@ func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) { return jc.Original }) pred := sqlparser.AndExpressions(predicates...) - buildQuery(op.Seed, qb) + buildQuery(op.Seed(), qb) qbR := &queryBuilder{ctx: qb.ctx} - buildQuery(op.Term, qbR) + buildQuery(op.Term(), qbR) qbR.addPredicate(pred) infoFor, err := qb.ctx.SemTable.TableInfoFor(op.OuterID) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index 671f4b78954..73169369a41 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -460,8 +460,8 @@ func createJoinPusher(rootAggr *Aggregator, operator Operator) *joinPusher { return &joinPusher{ orig: rootAggr, pushed: &Aggregator{ - Source: operator, - QP: rootAggr.QP, + unaryOperator: newUnaryOp(operator), + QP: rootAggr.QP, }, columns: initColReUse(len(rootAggr.Columns)), tableID: TableID(operator), diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 63c21ba2bce..f353ee02d1e 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -34,7 +34,7 @@ type ( // Both all aggregations and no grouping, and the inverse // of all grouping and no aggregations are valid configurations of this operator Aggregator struct { - Source Operator + unaryOperator Columns []*sqlparser.AliasedExpr WithRollup bool @@ -75,14 +75,6 @@ func (a *Aggregator) Clone(inputs []Operator) Operator { return &kopy } -func (a *Aggregator) Inputs() []Operator { - return []Operator{a.Source} -} - -func (a *Aggregator) SetInputs(operators []Operator) { - a.Source = operators[0] -} - func (a *Aggregator) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser.Expr) Operator { return newFilter(a, expr) } diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 4c6baab3729..80bf74708a8 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -32,7 +32,7 @@ type ( // ApplyJoin is a nested loop join - for each row on the LHS, // we'll execute the plan on the RHS, feeding data from left to right ApplyJoin struct { - LHS, RHS Operator + binaryOperator // JoinType is permitted to store only 3 of the possible values // NormalJoinType, StraightJoinType and LeftJoinType. @@ -85,8 +85,7 @@ type ( func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, joinType sqlparser.JoinType) *ApplyJoin { aj := &ApplyJoin{ - LHS: lhs, - RHS: rhs, + binaryOperator: newBinaryOp(lhs, rhs), Vars: map[string]int{}, JoinType: joinType, JoinColumns: &applyJoinColumns{}, @@ -113,16 +112,6 @@ func (aj *ApplyJoin) AddPredicate(ctx *plancontext.PlanningContext, expr sqlpars return AddPredicate(ctx, aj, expr, false, newFilterSinglePredicate) } -// Inputs implements the Operator interface -func (aj *ApplyJoin) Inputs() []Operator { - return []Operator{aj.LHS, aj.RHS} -} - -// SetInputs implements the Operator interface -func (aj *ApplyJoin) SetInputs(inputs []Operator) { - aj.LHS, aj.RHS = inputs[0], inputs[1] -} - func (aj *ApplyJoin) GetLHS() Operator { return aj.LHS } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 86d1c6197d4..12c19bb72a6 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -65,16 +65,10 @@ func translateQueryToOpWithMirroring(ctx *plancontext.PlanningContext, stmt sqlp func createOperatorFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) Operator { op := crossJoin(ctx, sel.From) - if sel.Where != nil { - op = addWherePredicates(ctx, sel.Where.Expr, op) - } + op = addWherePredicates(ctx, sel.GetWherePredicate(), op) if sel.Comments != nil || sel.Lock != sqlparser.NoLock { - op = &LockAndComment{ - Source: op, - Comments: sel.Comments, - Lock: sel.Lock, - } + op = newLockAndComment(op, sel.Comments, sel.Lock) } op = newHorizon(op, sel) @@ -88,15 +82,26 @@ func addWherePredicates(ctx *plancontext.PlanningContext, expr sqlparser.Expr, o return sqc.getRootOperator(op, nil) } -func addWherePredsToSubQueryBuilder(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op Operator, sqc *SubQueryBuilder) Operator { +func addWherePredsToSubQueryBuilder(ctx *plancontext.PlanningContext, in sqlparser.Expr, op Operator, sqc *SubQueryBuilder) Operator { outerID := TableID(op) - exprs := sqlparser.SplitAndExpression(nil, expr) - for _, expr := range exprs { + for _, expr := range sqlparser.SplitAndExpression(nil, in) { sqlparser.RemoveKeyspaceInCol(expr) + expr = simplifyPredicates(ctx, expr) subq := sqc.handleSubquery(ctx, expr, outerID) if subq != nil { continue } + boolean := ctx.IsConstantBool(expr) + if boolean != nil { + if *boolean { + // If the predicate is true, we can ignore it. + continue + } + + // If the predicate is false, we push down a false predicate to influence routing + expr = sqlparser.NewIntLiteral("0") + } + op = op.AddPredicate(ctx, expr) addColumnEquality(ctx, expr) } diff --git a/go/vt/vtgate/planbuilder/operators/comments.go b/go/vt/vtgate/planbuilder/operators/comments.go index 7e7749a61b5..9f0202c250a 100644 --- a/go/vt/vtgate/planbuilder/operators/comments.go +++ b/go/vt/vtgate/planbuilder/operators/comments.go @@ -26,25 +26,25 @@ import ( // LockAndComment contains any comments or locking directives we want on all queries down from this operator type LockAndComment struct { - Source Operator + unaryOperator Comments *sqlparser.ParsedComments Lock sqlparser.Lock } +func newLockAndComment(op Operator, comments *sqlparser.ParsedComments, lock sqlparser.Lock) Operator { + return &LockAndComment{ + unaryOperator: newUnaryOp(op), + Comments: comments, + Lock: lock, + } +} + func (l *LockAndComment) Clone(inputs []Operator) Operator { klon := *l klon.Source = inputs[0] return &klon } -func (l *LockAndComment) Inputs() []Operator { - return []Operator{l.Source} -} - -func (l *LockAndComment) SetInputs(operators []Operator) { - l.Source = operators[0] -} - func (l *LockAndComment) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { l.Source = l.Source.AddPredicate(ctx, expr) return l diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go index a6830cbe12b..cb19e06b2a7 100644 --- a/go/vt/vtgate/planbuilder/operators/cte_merging.go +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -22,7 +22,7 @@ import ( ) func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator, *ApplyResult) { - op := tryMergeCTE(ctx, in.Seed, in.Term, in) + op := tryMergeCTE(ctx, in.Seed(), in.Term(), in) if op == nil { return in, NoRewrite } @@ -79,17 +79,17 @@ func mergeCTE(ctx *plancontext.PlanningContext, seed, term *Route, r Routing, in hz := in.Horizon hz.Source = term.Source newTerm, _ := expandHorizon(ctx, hz) + cte := &RecurseCTE{ + binaryOperator: newBinaryOp(seed.Source, newTerm), + Predicates: in.Predicates, + Def: in.Def, + LeftID: in.LeftID, + OuterID: in.OuterID, + Distinct: in.Distinct, + } return &Route{ - Routing: r, - Source: &RecurseCTE{ - Predicates: in.Predicates, - Def: in.Def, - Seed: seed.Source, - Term: newTerm, - LeftID: in.LeftID, - OuterID: in.OuterID, - Distinct: in.Distinct, - }, - MergedWith: []*Route{term}, + Routing: r, + unaryOperator: newUnaryOp(cte), + MergedWith: []*Route{term}, } } diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index ef8403e5603..e4f1fc0e7ae 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -75,10 +75,7 @@ func createOperatorFromDelete(ctx *plancontext.PlanningContext, deleteStmt *sqlp op, vTbl = createDeleteOperator(ctx, deleteStmt) if deleteStmt.Comments != nil { - op = &LockAndComment{ - Source: op, - Comments: deleteStmt.Comments, - } + op = newLockAndComment(op, deleteStmt.Comments, sqlparser.NoLock) } var err error @@ -151,10 +148,7 @@ func createDeleteWithInputOp(ctx *plancontext.PlanningContext, del *sqlparser.De } if del.Comments != nil { - op = &LockAndComment{ - Source: op, - Comments: del.Comments, - } + op = newLockAndComment(op, del.Comments, sqlparser.NoLock) } return op } @@ -261,10 +255,7 @@ func createDeleteOperator(ctx *plancontext.PlanningContext, del *sqlparser.Delet } if del.Limit != nil { - delOp.Source = &Limit{ - Source: addOrdering(ctx, op, del.OrderBy), - AST: del.Limit, - } + delOp.Source = newLimit(addOrdering(ctx, op, del.OrderBy), del.Limit, false) } else { delOp.Source = op } @@ -316,7 +307,7 @@ func addOrdering(ctx *plancontext.PlanningContext, op Operator, orderBy sqlparse if len(order) == 0 { return op } - return &Ordering{Source: op, Order: order} + return newOrdering(op, order) } func updateQueryGraphWithSource(ctx *plancontext.PlanningContext, input Operator, tblID semantics.TableSet, vTbl *vindexes.Table) *vindexes.Table { diff --git a/go/vt/vtgate/planbuilder/operators/distinct.go b/go/vt/vtgate/planbuilder/operators/distinct.go index 4fd53725e10..52221498eea 100644 --- a/go/vt/vtgate/planbuilder/operators/distinct.go +++ b/go/vt/vtgate/planbuilder/operators/distinct.go @@ -26,8 +26,8 @@ import ( type ( Distinct struct { - Source Operator - QP *QueryProjection + unaryOperator + QP *QueryProjection // When we go from AST to operator, we place DISTINCT ops in the required places in the op tree // These are marked as `Required`, because they are semantically important to the results of the query. @@ -45,6 +45,14 @@ type ( } ) +func newDistinct(src Operator, qp *QueryProjection, required bool) *Distinct { + return &Distinct{ + unaryOperator: newUnaryOp(src), + QP: qp, + Required: required, + } +} + func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) Operator { columns := d.GetColumns(ctx) for idx, col := range columns { @@ -66,22 +74,10 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) Operator { } func (d *Distinct) Clone(inputs []Operator) Operator { - return &Distinct{ - Required: d.Required, - Source: inputs[0], - Columns: slices.Clone(d.Columns), - QP: d.QP, - PushedPerformance: d.PushedPerformance, - ResultColumns: d.ResultColumns, - } -} - -func (d *Distinct) Inputs() []Operator { - return []Operator{d.Source} -} - -func (d *Distinct) SetInputs(operators []Operator) { - d.Source = operators[0] + kopy := *d + kopy.Columns = slices.Clone(d.Columns) + kopy.Source = inputs[0] + return &kopy } func (d *Distinct) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 4e920d4312c..38848693775 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -71,3 +71,53 @@ func nothingNeedsFetching(ctx *plancontext.PlanningContext, expr sqlparser.Expr) }, expr) return } + +func simplifyPredicates(ctx *plancontext.PlanningContext, in sqlparser.Expr) sqlparser.Expr { + var replace sqlparser.Expr + + // if expr is constant true, replace with trueReplacement, if constant false, replace with falseReplacement + handleExpr := func(expr, trueReplacement, falseReplacement sqlparser.Expr) bool { + b := ctx.IsConstantBool(expr) + if b != nil { + if *b { + replace = trueReplacement + } else { + replace = falseReplacement + } + return true + } + return false + } + + pre := func(node, _ sqlparser.SQLNode) bool { + switch node := node.(type) { + case *sqlparser.OrExpr: + if handleExpr(node.Left, sqlparser.NewIntLiteral("1"), node.Right) { + return false + } + if handleExpr(node.Right, sqlparser.NewIntLiteral("1"), node.Left) { + return false + } + case *sqlparser.AndExpr: + if handleExpr(node.Left, node.Right, sqlparser.NewIntLiteral("0")) { + return false + } + if handleExpr(node.Right, node.Left, sqlparser.NewIntLiteral("0")) { + return false + } + } + return true + } + post := func(cursor *sqlparser.CopyOnWriteCursor) { + if replace != nil { + cursor.Replace(replace) + replace = nil + } + } + output := sqlparser.CopyOnRewrite(in, pre, post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) + if in != output { + // we need to do this, since one simplification might lead to another + return simplifyPredicates(ctx, output) + } + return output +} diff --git a/go/vt/vtgate/planbuilder/operators/filter.go b/go/vt/vtgate/planbuilder/operators/filter.go index d68b2a43a24..d58d218908e 100644 --- a/go/vt/vtgate/planbuilder/operators/filter.go +++ b/go/vt/vtgate/planbuilder/operators/filter.go @@ -29,7 +29,7 @@ import ( ) type Filter struct { - Source Operator + unaryOperator Predicates []sqlparser.Expr // PredicateWithOffsets is the evalengine expression that will finally be used. @@ -45,28 +45,17 @@ func newFilterSinglePredicate(op Operator, expr sqlparser.Expr) Operator { func newFilter(op Operator, expr ...sqlparser.Expr) Operator { return &Filter{ - Source: op, Predicates: expr, + unaryOperator: newUnaryOp(op), + Predicates: expr, } } // Clone implements the Operator interface func (f *Filter) Clone(inputs []Operator) Operator { - return &Filter{ - Source: inputs[0], - Predicates: slices.Clone(f.Predicates), - PredicateWithOffsets: f.PredicateWithOffsets, - ResultColumns: f.ResultColumns, - } -} - -// Inputs implements the Operator interface -func (f *Filter) Inputs() []Operator { - return []Operator{f.Source} -} - -// SetInputs implements the Operator interface -func (f *Filter) SetInputs(ops []Operator) { - f.Source = ops[0] + klon := *f + klon.Source = inputs[0] + klon.Predicates = slices.Clone(f.Predicates) + return &klon } // UnsolvedPredicates implements the unresolved interface diff --git a/go/vt/vtgate/planbuilder/operators/hash_join.go b/go/vt/vtgate/planbuilder/operators/hash_join.go index 23d0d061e21..3761c4b87a6 100644 --- a/go/vt/vtgate/planbuilder/operators/hash_join.go +++ b/go/vt/vtgate/planbuilder/operators/hash_join.go @@ -31,7 +31,7 @@ import ( type ( HashJoin struct { - LHS, RHS Operator + binaryOperator // LeftJoin will be true in the case of an outer join LeftJoin bool @@ -79,10 +79,9 @@ var _ JoinOp = (*HashJoin)(nil) func NewHashJoin(lhs, rhs Operator, outerJoin bool) *HashJoin { hj := &HashJoin{ - LHS: lhs, - RHS: rhs, - LeftJoin: outerJoin, - columns: &hashJoinColumns{}, + binaryOperator: newBinaryOp(lhs, rhs), + LeftJoin: outerJoin, + columns: &hashJoinColumns{}, } return hj } @@ -97,14 +96,6 @@ func (hj *HashJoin) Clone(inputs []Operator) Operator { return &kopy } -func (hj *HashJoin) Inputs() []Operator { - return []Operator{hj.LHS, hj.RHS} -} - -func (hj *HashJoin) SetInputs(operators []Operator) { - hj.LHS, hj.RHS = operators[0], operators[1] -} - func (hj *HashJoin) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { return AddPredicate(ctx, hj, expr, false, newFilterSinglePredicate) } diff --git a/go/vt/vtgate/planbuilder/operators/hash_join_test.go b/go/vt/vtgate/planbuilder/operators/hash_join_test.go index 2bf1d08d2b6..7325cc015a1 100644 --- a/go/vt/vtgate/planbuilder/operators/hash_join_test.go +++ b/go/vt/vtgate/planbuilder/operators/hash_join_test.go @@ -41,10 +41,9 @@ func TestJoinPredicates(t *testing.T) { lhs := &fakeOp{id: lid} rhs := &fakeOp{id: rid} hj := &HashJoin{ - LHS: lhs, - RHS: rhs, - LeftJoin: false, - columns: &hashJoinColumns{}, + binaryOperator: newBinaryOp(lhs, rhs), + LeftJoin: false, + columns: &hashJoinColumns{}, } cmp := &sqlparser.ComparisonExpr{ @@ -99,10 +98,9 @@ func TestOffsetPlanning(t *testing.T) { for _, test := range tests { t.Run(sqlparser.String(test.expr), func(t *testing.T) { hj := &HashJoin{ - LHS: lhs, - RHS: rhs, - LeftJoin: false, - columns: &hashJoinColumns{}, + binaryOperator: newBinaryOp(lhs, rhs), + LeftJoin: false, + columns: &hashJoinColumns{}, } hj.AddColumn(ctx, true, false, aeWrap(test.expr)) hj.planOffsets(ctx) diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index 3fb72df91b4..292be1b37c5 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -35,7 +35,7 @@ import ( // Project/Aggregate/Sort/Limit operations, some which can be pushed down, // and some that have to be evaluated at the vtgate level. type Horizon struct { - Source Operator + unaryOperator // If this is a derived table, the two following fields will contain the tableID and name of it TableId *semantics.TableSet @@ -55,7 +55,10 @@ type Horizon struct { } func newHorizon(src Operator, query sqlparser.SelectStatement) *Horizon { - return &Horizon{Source: src, Query: query} + return &Horizon{ + unaryOperator: newUnaryOp(src), + Query: query, + } } // Clone implements the Operator interface @@ -78,16 +81,6 @@ func (h *Horizon) IsMergeable(ctx *plancontext.PlanningContext) bool { return isMergeable(ctx, h.Query, h) } -// Inputs implements the Operator interface -func (h *Horizon) Inputs() []Operator { - return []Operator{h.Source} -} - -// SetInputs implements the Operator interface -func (h *Horizon) SetInputs(ops []Operator) { - h.Source = ops[0] -} - func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { if _, isUNion := h.Source.(*Union); isUNion { // If we have a derived table on top of a UNION, we can let the UNION do the expression rewriting diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index 29c1b1033f1..7b058627b17 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -44,17 +44,11 @@ func expandUnionHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, unio qp := horizon.getQP(ctx) if len(qp.OrderExprs) > 0 { - op = &Ordering{ - Source: op, - Order: qp.OrderExprs, - } + op = newOrdering(op, qp.OrderExprs) } if union.Limit != nil { - op = &Limit{ - Source: op, - AST: union.Limit, - } + op = newLimit(op, union.Limit, false) } if horizon.TableId != nil { @@ -94,11 +88,7 @@ func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel } if qp.NeedsDistinct() { - op = &Distinct{ - Required: true, - Source: op, - QP: qp, - } + op = newDistinct(op, qp, true) extracted = append(extracted, "Distinct") } @@ -113,11 +103,7 @@ func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel } if sel.Limit != nil { - op = &Limit{ - Source: op, - AST: sel.Limit, - Top: true, - } + op = newLimit(op, sel.Limit, true) extracted = append(extracted, "Limit") } @@ -144,10 +130,7 @@ func expandOrderBy(ctx *plancontext.PlanningContext, op Operator, qp *QueryProje // If the operator is not a projection, we cannot handle subqueries with aggregation if we are unable to push everything into a single route. if !ok { ctx.SemTable.NotSingleRouteErr = vterrors.VT12001("subquery with aggregation in order by") - return &Ordering{ - Source: op, - Order: qp.OrderExprs, - } + return newOrdering(op, qp.OrderExprs) } else { // Add the new subquery expression to the projection proj.addSubqueryExpr(ctx, aeWrap(newExpr), newExpr, subqs...) @@ -169,10 +152,7 @@ func expandOrderBy(ctx *plancontext.PlanningContext, op Operator, qp *QueryProje } // Return the updated operator with the new order by expressions - return &Ordering{ - Source: op, - Order: newOrder, - } + return newOrdering(op, newOrder) } // exposeOrderingColumn will expose the ordering column to the outer query @@ -220,13 +200,13 @@ func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProject aggregations, complexAggr := qp.AggregationExpressions(ctx, true) src := horizon.Source aggrOp := &Aggregator{ - Source: src, - Original: true, - QP: qp, - Grouping: qp.GetGrouping(), - WithRollup: qp.WithRollup, - Aggregations: aggregations, - DT: dt, + unaryOperator: newUnaryOp(src), + Original: true, + QP: qp, + Grouping: qp.GetGrouping(), + WithRollup: qp.WithRollup, + Aggregations: aggregations, + DT: dt, } // Go through all aggregations and check for any subquery. @@ -372,7 +352,7 @@ func newStarProjection(src Operator, qp *QueryProjection) *Projection { } return &Projection{ - Source: src, - Columns: StarProjections(cols), + unaryOperator: newUnaryOp(src), + Columns: StarProjections(cols), } } diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 6832dc363d5..3176bac50a2 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -50,7 +50,7 @@ type Insert struct { // that will appear in the result set of the select query. VindexValueOffset [][]int - noInputs + nullaryOperator noColumns noPredicates } @@ -363,8 +363,8 @@ func createInsertOperator(ctx *plancontext.PlanningContext, insStmt *sqlparser.I AST: insStmt, } route := &Route{ - Source: insOp, - Routing: routing, + unaryOperator: newUnaryOp(insOp), + Routing: routing, } // Table column list is nil then add all the columns @@ -394,10 +394,7 @@ func createInsertOperator(ctx *plancontext.PlanningContext, insStmt *sqlparser.I op = insertSelectPlan(ctx, insOp, route, insStmt, rows) } if insStmt.Comments != nil { - op = &LockAndComment{ - Source: op, - Comments: insStmt.Comments, - } + op = newLockAndComment(op, insStmt.Comments, sqlparser.NoLock) } return op } @@ -420,11 +417,7 @@ func insertSelectPlan( // output of the select plan will be used to insert rows into the table. insertSelect := &InsertSelection{ - Select: &LockAndComment{ - Source: selOp, - Lock: sqlparser.ShareModeLock, - }, - Insert: routeOp, + binaryOperator: newBinaryOp(newLockAndComment(selOp, nil, sqlparser.ShareModeLock), routeOp), } // When the table you are streaming data from and table you are inserting from are same. diff --git a/go/vt/vtgate/planbuilder/operators/insert_selection.go b/go/vt/vtgate/planbuilder/operators/insert_selection.go index 70bda0a990a..5f806bbda0b 100644 --- a/go/vt/vtgate/planbuilder/operators/insert_selection.go +++ b/go/vt/vtgate/planbuilder/operators/insert_selection.go @@ -23,8 +23,7 @@ import ( // InsertSelection operator represents an INSERT into SELECT FROM query. // It holds the operators for running the selection and insertion. type InsertSelection struct { - Select Operator - Insert Operator + binaryOperator // ForceNonStreaming when true, select first then insert, this is to avoid locking rows by select for insert. ForceNonStreaming bool @@ -33,21 +32,13 @@ type InsertSelection struct { noPredicates } -func (is *InsertSelection) Clone(inputs []Operator) Operator { - return &InsertSelection{ - Select: inputs[0], - Insert: inputs[1], - ForceNonStreaming: is.ForceNonStreaming, - } -} - -func (is *InsertSelection) Inputs() []Operator { - return []Operator{is.Select, is.Insert} -} +var _ Operator = (*InsertSelection)(nil) -func (is *InsertSelection) SetInputs(inputs []Operator) { - is.Select = inputs[0] - is.Insert = inputs[1] +func (is *InsertSelection) Clone(inputs []Operator) Operator { + klone := *is + klone.LHS = inputs[0] + klone.RHS = inputs[1] + return &klone } func (is *InsertSelection) ShortDescription() string { @@ -61,4 +52,10 @@ func (is *InsertSelection) GetOrdering(*plancontext.PlanningContext) []OrderBy { return nil } -var _ Operator = (*InsertSelection)(nil) +func (is *InsertSelection) Select() Operator { + return is.LHS +} + +func (is *InsertSelection) Insert() Operator { + return is.RHS +} diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 35760bceafb..ff4915527a7 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -26,7 +26,7 @@ import ( // Join represents a join. If we have a predicate, this is an inner join. If no predicate exists, it is a cross join type Join struct { - LHS, RHS Operator + binaryOperator Predicate sqlparser.Expr // JoinType is permitted to store only 3 of the possible values // NormalJoinType, StraightJoinType and LeftJoinType. @@ -42,28 +42,13 @@ func (j *Join) Clone(inputs []Operator) Operator { clone := *j clone.LHS = inputs[0] clone.RHS = inputs[1] - return &Join{ - LHS: inputs[0], - RHS: inputs[1], - Predicate: j.Predicate, - JoinType: j.JoinType, - } + return &clone } func (j *Join) GetOrdering(*plancontext.PlanningContext) []OrderBy { return nil } -// Inputs implements the Operator interface -func (j *Join) Inputs() []Operator { - return []Operator{j.LHS, j.RHS} -} - -// SetInputs implements the Operator interface -func (j *Join) SetInputs(ops []Operator) { - j.LHS, j.RHS = ops[0], ops[1] -} - func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult) { if !j.JoinType.IsCommutative() { // if we can't move tables around, we can't merge these inputs @@ -89,7 +74,10 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult func createStraightJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator { // for inner joins we can treat the predicates as filters on top of the join - joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join} + joinOp := &Join{ + binaryOperator: newBinaryOp(lhs, rhs), + JoinType: join.Join, + } return addJoinPredicates(ctx, join.Condition.On, joinOp) } @@ -105,7 +93,10 @@ func createLeftOuterJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinT join.Join = sqlparser.NaturalLeftJoinType } - joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join} + joinOp := &Join{ + binaryOperator: newBinaryOp(lhs, rhs), + JoinType: join.Join, + } // mark the RHS as outer tables so we know which columns are nullable ctx.OuterTables = ctx.OuterTables.Merge(TableID(rhs)) @@ -197,7 +188,9 @@ func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator { } return op } - return &Join{LHS: LHS, RHS: RHS} + return &Join{ + binaryOperator: newBinaryOp(LHS, RHS), + } } func (j *Join) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/join_merging.go b/go/vt/vtgate/planbuilder/operators/join_merging.go index 672da551fa6..c035b7d11ed 100644 --- a/go/vt/vtgate/planbuilder/operators/join_merging.go +++ b/go/vt/vtgate/planbuilder/operators/join_merging.go @@ -236,8 +236,8 @@ func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *R func (jm *joinMerger) merge(ctx *plancontext.PlanningContext, op1, op2 *Route, r Routing) *Route { return &Route{ - Source: jm.getApplyJoin(ctx, op1, op2), - MergedWith: []*Route{op2}, - Routing: r, + unaryOperator: newUnaryOp(jm.getApplyJoin(ctx, op1, op2)), + MergedWith: []*Route{op2}, + Routing: r, } } diff --git a/go/vt/vtgate/planbuilder/operators/limit.go b/go/vt/vtgate/planbuilder/operators/limit.go index 1801e57c1c9..4549b85fcda 100644 --- a/go/vt/vtgate/planbuilder/operators/limit.go +++ b/go/vt/vtgate/planbuilder/operators/limit.go @@ -22,8 +22,8 @@ import ( ) type Limit struct { - Source Operator - AST *sqlparser.Limit + unaryOperator + AST *sqlparser.Limit // Top is true if the limit is a top level limit. To optimise, we push LIMIT to the RHS of joins, // but we need to still LIMIT the total result set to the top level limit. @@ -33,21 +33,19 @@ type Limit struct { Pushed bool } -func (l *Limit) Clone(inputs []Operator) Operator { +func newLimit(op Operator, ast *sqlparser.Limit, top bool) *Limit { return &Limit{ - Source: inputs[0], - AST: sqlparser.Clone(l.AST), - Top: l.Top, - Pushed: l.Pushed, + unaryOperator: newUnaryOp(op), + AST: ast, + Top: top, } } -func (l *Limit) Inputs() []Operator { - return []Operator{l.Source} -} - -func (l *Limit) SetInputs(operators []Operator) { - l.Source = operators[0] +func (l *Limit) Clone(inputs []Operator) Operator { + k := *l + k.Source = inputs[0] + k.AST = sqlparser.Clone(l.AST) + return &k } func (l *Limit) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/mirror.go b/go/vt/vtgate/planbuilder/operators/mirror.go index 3ab1d66e70d..82a431af0a2 100644 --- a/go/vt/vtgate/planbuilder/operators/mirror.go +++ b/go/vt/vtgate/planbuilder/operators/mirror.go @@ -26,19 +26,25 @@ import ( type ( PercentBasedMirror struct { - Percent float32 - Operator Operator - Target Operator + binaryOperator + Percent float32 } ) var _ Operator = (*PercentBasedMirror)(nil) +func (m *PercentBasedMirror) Operator() Operator { + return m.LHS +} + +func (m *PercentBasedMirror) Target() Operator { + return m.RHS +} + func NewPercentBasedMirror(percent float32, operator, target Operator) *PercentBasedMirror { return &PercentBasedMirror{ - percent, - operator, - target, + binaryOperator: newBinaryOp(operator, target), + Percent: percent, } } @@ -49,45 +55,28 @@ func (m *PercentBasedMirror) Clone(inputs []Operator) Operator { return &cloneMirror } -// Inputs returns the inputs for this operator -func (m *PercentBasedMirror) Inputs() []Operator { - return []Operator{ - m.Operator, - m.Target, - } -} - -// SetInputs changes the inputs for this op -func (m *PercentBasedMirror) SetInputs(inputs []Operator) { - if len(inputs) < 2 { - panic(vterrors.VT13001("unexpected number of inputs for PercentBasedMirror operator")) - } - m.Operator = inputs[0] - m.Target = inputs[1] -} - // AddPredicate is used to push predicates. It pushed it as far down as is possible in the tree. // If we encounter a join and the predicate depends on both sides of the join, the predicate will be split into two parts, // where data is fetched from the LHS of the join to be used in the evaluation on the RHS // TODO: we should remove this and replace it with rewriters -func (m *PercentBasedMirror) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { +func (m *PercentBasedMirror) AddPredicate(*plancontext.PlanningContext, sqlparser.Expr) Operator { panic(vterrors.VT13001("not supported")) } -func (m *PercentBasedMirror) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { +func (m *PercentBasedMirror) AddColumn(*plancontext.PlanningContext, bool, bool, *sqlparser.AliasedExpr) int { panic(vterrors.VT13001("not supported")) } func (m *PercentBasedMirror) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { - return m.Operator.FindCol(ctx, expr, underRoute) + return m.Operator().FindCol(ctx, expr, underRoute) } func (m *PercentBasedMirror) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { - return m.Operator.GetColumns(ctx) + return m.Operator().GetColumns(ctx) } func (m *PercentBasedMirror) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { - return m.Operator.GetSelectExprs(ctx) + return m.Operator().GetSelectExprs(ctx) } func (m *PercentBasedMirror) ShortDescription() string { @@ -95,10 +84,10 @@ func (m *PercentBasedMirror) ShortDescription() string { } func (m *PercentBasedMirror) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { - return m.Operator.GetOrdering(ctx) + return m.Operator().GetOrdering(ctx) } // AddWSColumn implements Operator. -func (m *PercentBasedMirror) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int { +func (m *PercentBasedMirror) AddWSColumn(*plancontext.PlanningContext, int, bool) int { panic(vterrors.VT13001("not supported")) } diff --git a/go/vt/vtgate/planbuilder/operators/offset_planning.go b/go/vt/vtgate/planbuilder/operators/offset_planning.go index adf71a47f24..62abff29e57 100644 --- a/go/vt/vtgate/planbuilder/operators/offset_planning.go +++ b/go/vt/vtgate/planbuilder/operators/offset_planning.go @@ -130,10 +130,7 @@ func isolateDistinctFromUnion(_ *plancontext.PlanningContext, root Operator) Ope union.distinct = false - distinct := &Distinct{ - Required: true, - Source: union, - } + distinct := newDistinct(union, nil, true) return distinct, Rewrote("pulled out DISTINCT from union") } diff --git a/go/vt/vtgate/planbuilder/operators/operator.go b/go/vt/vtgate/planbuilder/operators/operator.go index 76797aee906..42658e4c52e 100644 --- a/go/vt/vtgate/planbuilder/operators/operator.go +++ b/go/vt/vtgate/planbuilder/operators/operator.go @@ -86,8 +86,46 @@ type ( // See GroupBy#SimplifiedExpr for more details about this SimplifiedExpr sqlparser.Expr } + + unaryOperator struct { + Operator + Source Operator + } + + binaryOperator struct { + Operator + LHS, RHS Operator + } ) +func newUnaryOp(source Operator) unaryOperator { + return unaryOperator{Source: source} +} + +func newBinaryOp(l, r Operator) binaryOperator { + return binaryOperator{ + LHS: l, + RHS: r, + } +} + +func (s *unaryOperator) Inputs() []Operator { + return []Operator{s.Source} +} + +func (s *unaryOperator) SetInputs(operators []Operator) { + s.Source = operators[0] +} + +func (b *binaryOperator) Inputs() []Operator { + return []Operator{b.LHS, b.RHS} +} + +func (b *binaryOperator) SetInputs(operators []Operator) { + b.LHS = operators[0] + b.RHS = operators[1] +} + // Map takes in a mapping function and applies it to both the expression in OrderBy. func (ob OrderBy) Map(mappingFunc func(sqlparser.Expr) sqlparser.Expr) OrderBy { return OrderBy{ diff --git a/go/vt/vtgate/planbuilder/operators/ordering.go b/go/vt/vtgate/planbuilder/operators/ordering.go index c8f4ccdf853..7e40b420f9e 100644 --- a/go/vt/vtgate/planbuilder/operators/ordering.go +++ b/go/vt/vtgate/planbuilder/operators/ordering.go @@ -26,7 +26,7 @@ import ( ) type Ordering struct { - Source Operator + unaryOperator Offset []int WOffset []int @@ -35,21 +35,20 @@ type Ordering struct { } func (o *Ordering) Clone(inputs []Operator) Operator { - return &Ordering{ - Source: inputs[0], - Offset: slices.Clone(o.Offset), - WOffset: slices.Clone(o.WOffset), - Order: slices.Clone(o.Order), - ResultColumns: o.ResultColumns, - } -} + klone := *o + klone.Source = inputs[0] + klone.Offset = slices.Clone(o.Offset) + klone.WOffset = slices.Clone(o.WOffset) + klone.Order = slices.Clone(o.Order) -func (o *Ordering) Inputs() []Operator { - return []Operator{o.Source} + return &klone } -func (o *Ordering) SetInputs(operators []Operator) { - o.Source = operators[0] +func newOrdering(src Operator, order []OrderBy) Operator { + return &Ordering{ + unaryOperator: newUnaryOp(src), + Order: order, + } } func (o *Ordering) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index d5354e9548f..eb6c42b8724 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -249,10 +249,7 @@ func addOrderingFor(aggrOp *Aggregator) { SimplifiedExpr: aggrOp.DistinctExpr, }) } - aggrOp.Source = &Ordering{ - Source: aggrOp.Source, - Order: orderBys, - } + aggrOp.Source = newOrdering(aggrOp.Source, orderBys) } func needsOrdering(ctx *plancontext.PlanningContext, in *Aggregator) bool { @@ -357,7 +354,7 @@ func planRecursiveCTEHorizons(ctx *plancontext.PlanningContext, root Operator) O return in, NoRewrite } hz := rcte.Horizon - hz.Source = rcte.Term + hz.Source = rcte.Term() newTerm, _ := expandHorizon(ctx, hz) pr := findProjection(newTerm) ap, err := pr.GetAliasedProjections() @@ -372,7 +369,7 @@ func planRecursiveCTEHorizons(ctx *plancontext.PlanningContext, root Operator) O return recurseExpression }) rcte.Projections = projections - rcte.Term = newTerm + rcte.RHS = newTerm return rcte, Rewrote("expanded horizon on term side of recursive CTE") }, stopAtRoute) } diff --git a/go/vt/vtgate/planbuilder/operators/plan_query.go b/go/vt/vtgate/planbuilder/operators/plan_query.go index 40c27d03126..baa9a7883e2 100644 --- a/go/vt/vtgate/planbuilder/operators/plan_query.go +++ b/go/vt/vtgate/planbuilder/operators/plan_query.go @@ -47,7 +47,9 @@ import ( type ( // helper type that implements Inputs() returning nil - noInputs struct{} + nullaryOperator struct { + Operator + } // helper type that implements AddColumn() returning an error noColumns struct{} @@ -94,14 +96,14 @@ func PanicHandler(err *error) { } // Inputs implements the Operator interface -func (noInputs) Inputs() []Operator { +func (nullaryOperator) Inputs() []Operator { return nil } // SetInputs implements the Operator interface -func (noInputs) SetInputs(ops []Operator) { +func (nullaryOperator) SetInputs(ops []Operator) { if len(ops) > 0 { - panic("the noInputs operator does not have inputs") + panic("the nullaryOperator operator does not have inputs") } } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 95ebeadaeb7..e894ab433b4 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -32,7 +32,7 @@ import ( // Projection is used when we need to evaluate expressions on the vtgate // It uses the evalengine to accomplish its goal type Projection struct { - Source Operator + unaryOperator // Columns contain the expressions as viewed from the outside of this operator Columns ProjCols @@ -127,8 +127,8 @@ func newProjExprWithInner(ae *sqlparser.AliasedExpr, in sqlparser.Expr) *ProjExp func newAliasedProjection(src Operator) *Projection { return &Projection{ - Source: src, - Columns: AliasedProjections{}, + unaryOperator: newUnaryOp(src), + Columns: AliasedProjections{}, } } @@ -405,20 +405,9 @@ func (po *EvalEngine) expr() {} func (po SubQueryExpression) expr() {} func (p *Projection) Clone(inputs []Operator) Operator { - return &Projection{ - Source: inputs[0], - Columns: p.Columns, // TODO don't think we need to deep clone here - DT: p.DT, - FromAggr: p.FromAggr, - } -} - -func (p *Projection) Inputs() []Operator { - return []Operator{p.Source} -} - -func (p *Projection) SetInputs(operators []Operator) { - p.Source = operators[0] + klone := *p + klone.Source = inputs[0] + return &klone } func (p *Projection) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index f930a7f4f76..5fe0c7773c1 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -112,8 +112,8 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { if pbm, ok := root.(*PercentBasedMirror); ok { pbm.SetInputs([]Operator{ - runRewriters(ctx, pbm.Operator), - runRewriters(ctx.UseMirror(), pbm.Target), + runRewriters(ctx, pbm.Operator()), + runRewriters(ctx.UseMirror(), pbm.Target()), }) } @@ -160,27 +160,15 @@ func pushLockAndComment(l *LockAndComment) (Operator, *ApplyResult) { src.Lock = l.Lock.GetHighestOrderLock(src.Lock) return src, Rewrote("put lock and comment into route") case *SubQueryContainer: - src.Outer = &LockAndComment{ - Source: src.Outer, - Comments: l.Comments, - Lock: l.Lock, - } + src.Outer = newLockAndComment(src.Outer, l.Comments, l.Lock) for _, sq := range src.Inner { - sq.Subquery = &LockAndComment{ - Source: sq.Subquery, - Comments: l.Comments, - Lock: l.Lock, - } + sq.Subquery = newLockAndComment(sq.Subquery, l.Comments, l.Lock) } return src, Rewrote("push lock and comment into subquery container") default: inputs := src.Inputs() for i, op := range inputs { - inputs[i] = &LockAndComment{ - Source: op, - Comments: l.Comments, - Lock: l.Lock, - } + inputs[i] = newLockAndComment(op, l.Comments, l.Lock) } src.SetInputs(inputs) return src, Rewrote("pushed down lock and comments") @@ -436,10 +424,7 @@ func createPushedLimit(ctx *plancontext.PlanningContext, src Operator, orig *Lim pushedLimit.Rowcount = getLimitExpression(ctx, plus) pushedLimit.Offset = nil } - return &Limit{ - Source: src, - AST: pushedLimit, - } + return newLimit(src, pushedLimit, false) } // getLimitExpression is a helper function to simplify an expression using the evalengine @@ -506,11 +491,8 @@ func setUpperLimit(in *Limit) (Operator, *ApplyResult) { return SkipChildren } case *Route: - newSrc := &Limit{ - Source: op.Source, - AST: &sqlparser.Limit{Rowcount: sqlparser.NewArgument(engine.UpperLimitStr)}, - } - op.Source = newSrc + ast := &sqlparser.Limit{Rowcount: sqlparser.NewArgument(engine.UpperLimitStr)} + op.Source = newLimit(op.Source, ast, false) result = result.Merge(Rewrote("push upper limit under route")) return SkipChildren } @@ -762,7 +744,7 @@ func tryPushDistinct(in *Distinct) (Operator, *ApplyResult) { return in, NoRewrite } - src.Source = &Distinct{Source: src.Source} + src.Source = newDistinct(src.Source, nil, false) in.PushedPerformance = true return in, Rewrote("added distinct under route - kept original") @@ -772,14 +754,14 @@ func tryPushDistinct(in *Distinct) (Operator, *ApplyResult) { return src, Rewrote("remove double distinct") case *Union: for i := range src.Sources { - src.Sources[i] = &Distinct{Source: src.Sources[i]} + src.Sources[i] = newDistinct(src.Sources[i], nil, false) } in.PushedPerformance = true return in, Rewrote("push down distinct under union") case JoinOp: - src.SetLHS(&Distinct{Source: src.GetLHS()}) - src.SetRHS(&Distinct{Source: src.GetRHS()}) + src.SetLHS(newDistinct(src.GetLHS(), nil, false)) + src.SetRHS(newDistinct(src.GetRHS(), nil, false)) in.PushedPerformance = true if in.Required { @@ -830,10 +812,7 @@ func tryPushUnion(ctx *plancontext.PlanningContext, op *Union) (Operator, *Apply return result, Rewrote("push union under route") } - return &Distinct{ - Source: result, - Required: true, - }, Rewrote("push union under route") + return newDistinct(result, nil, true), Rewrote("push union under route") } if len(sources) == len(op.Sources) { diff --git a/go/vt/vtgate/planbuilder/operators/querygraph.go b/go/vt/vtgate/planbuilder/operators/querygraph.go index 8e8572f7dfa..98cd9ada64a 100644 --- a/go/vt/vtgate/planbuilder/operators/querygraph.go +++ b/go/vt/vtgate/planbuilder/operators/querygraph.go @@ -42,7 +42,7 @@ type ( // NoDeps contains the predicates that can be evaluated anywhere. NoDeps sqlparser.Expr - noInputs + nullaryOperator noColumns } diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index ebb7dc54765..61474b663d6 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -32,8 +32,7 @@ import ( // RecurseCTE is used to represent a recursive CTE type RecurseCTE struct { - Seed, // used to describe the non-recursive part that initializes the result set - Term Operator // the part that repeatedly applies the recursion, processing the result set + binaryOperator // Def is the CTE definition according to the semantics Def *semantics.CTE @@ -77,36 +76,26 @@ func newRecurse( ctx.AddJoinPredicates(pred.Original, pred.RightExpr) } return &RecurseCTE{ - Def: def, - Seed: seed, - Term: term, - Predicates: predicates, - Horizon: horizon, - LeftID: leftID, - OuterID: outerID, - Distinct: distinct, + binaryOperator: newBinaryOp(seed, term), + Def: def, + Predicates: predicates, + Horizon: horizon, + LeftID: leftID, + OuterID: outerID, + Distinct: distinct, } } func (r *RecurseCTE) Clone(inputs []Operator) Operator { klone := *r - klone.Seed = inputs[0] - klone.Term = inputs[1] + klone.LHS = inputs[0] + klone.RHS = inputs[1] klone.Vars = maps.Clone(r.Vars) klone.Predicates = slices.Clone(r.Predicates) klone.Projections = slices.Clone(r.Projections) return &klone } -func (r *RecurseCTE) Inputs() []Operator { - return []Operator{r.Seed, r.Term} -} - -func (r *RecurseCTE) SetInputs(operators []Operator) { - r.Seed = operators[0] - r.Term = operators[1] -} - func (r *RecurseCTE) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Expr) Operator { return newFilter(r, e) } @@ -114,7 +103,7 @@ func (r *RecurseCTE) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Ex func (r *RecurseCTE) AddColumn(ctx *plancontext.PlanningContext, _, _ bool, expr *sqlparser.AliasedExpr) int { r.makeSureWeHaveTableInfo(ctx) e := semantics.RewriteDerivedTableExpression(expr.Expr, r.MyTableInfo) - offset := r.Seed.FindCol(ctx, e, false) + offset := r.Seed().FindCol(ctx, e, false) if offset == -1 { panic(vterrors.VT13001("CTE column not found")) } @@ -140,8 +129,8 @@ func (r *RecurseCTE) makeSureWeHaveTableInfo(ctx *plancontext.PlanningContext) { } func (r *RecurseCTE) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int { - seed := r.Seed.AddWSColumn(ctx, offset, underRoute) - term := r.Term.AddWSColumn(ctx, offset, underRoute) + seed := r.Seed().AddWSColumn(ctx, offset, underRoute) + term := r.Term().AddWSColumn(ctx, offset, underRoute) if seed != term { panic(vterrors.VT13001("CTE columns don't match")) } @@ -151,15 +140,15 @@ func (r *RecurseCTE) AddWSColumn(ctx *plancontext.PlanningContext, offset int, u func (r *RecurseCTE) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { r.makeSureWeHaveTableInfo(ctx) expr = semantics.RewriteDerivedTableExpression(expr, r.MyTableInfo) - return r.Seed.FindCol(ctx, expr, underRoute) + return r.Seed().FindCol(ctx, expr, underRoute) } func (r *RecurseCTE) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { - return r.Seed.GetColumns(ctx) + return r.Seed().GetColumns(ctx) } func (r *RecurseCTE) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { - return r.Seed.GetSelectExprs(ctx) + return r.Seed().GetSelectExprs(ctx) } func (r *RecurseCTE) ShortDescription() string { @@ -187,7 +176,7 @@ func (r *RecurseCTE) expressions() []*plancontext.RecurseExpression { func (r *RecurseCTE) planOffsets(ctx *plancontext.PlanningContext) Operator { r.Vars = make(map[string]int) - columns := r.Seed.GetColumns(ctx) + columns := r.Seed().GetColumns(ctx) for _, expr := range r.expressions() { outer: for _, lhsExpr := range expr.LeftExprs { @@ -212,3 +201,11 @@ func (r *RecurseCTE) planOffsets(ctx *plancontext.PlanningContext) Operator { func (r *RecurseCTE) introducesTableID() semantics.TableSet { return r.OuterID } + +func (r *RecurseCTE) Seed() Operator { + return r.LHS +} + +func (r *RecurseCTE) Term() Operator { + return r.RHS +} diff --git a/go/vt/vtgate/planbuilder/operators/rewriters.go b/go/vt/vtgate/planbuilder/operators/rewriters.go index a5743e001e1..1864eb952e1 100644 --- a/go/vt/vtgate/planbuilder/operators/rewriters.go +++ b/go/vt/vtgate/planbuilder/operators/rewriters.go @@ -164,6 +164,16 @@ func TopDown( // Swap takes a tree like a->b->c and swaps `a` and `b`, so we end up with b->a->c func Swap(parent, child Operator, message string) (Operator, *ApplyResult) { + unaryParent, isUnary := parent.(*unaryOperator) + if isUnary { + unaryChild, isUnary := child.(*unaryOperator) + if isUnary { + // If both the parent and child are unary operators, we can just swap the sources + unaryParent.Source, unaryChild.Source = unaryChild.Source, unaryParent.Source + return parent, Rewrote(message) + } + } + c := child.Inputs() if len(c) != 1 { panic(vterrors.VT13001("Swap can only be used on single input operators")) @@ -200,34 +210,60 @@ func bottomUp( return root, NoRewrite } - oldInputs := root.Inputs() var anythingChanged *ApplyResult - newInputs := make([]Operator, len(oldInputs)) - childID := rootID - - // noLHSTableSet is used to mark which operators that do not send data from the LHS to the RHS - // It's only UNION at this moment, but this package can't depend on the actual operators, so - // we use this interface to avoid direct dependencies - type noLHSTableSet interface{ NoLHSTableSet() } - - for i, operator := range oldInputs { - // We merge the table set of all the LHS above the current root so that we can - // send it down to the current RHS. - // We don't want to send the LHS table set to the RHS if the root is a UNION. - // Some operators, like SubQuery, can have multiple child operators on the RHS - if _, isUnion := root.(noLHSTableSet); !isUnion && i > 0 { - childID = childID.Merge(resolveID(oldInputs[0])) + + switch root := root.(type) { + case nullaryOperator: + // no inputs, nothing to do + case *unaryOperator: + newSource, changed := bottomUp(root.Source, rootID, resolveID, rewriter, shouldVisit, false) + if DebugOperatorTree && changed.Changed() { + fmt.Println(ToTree(newSource)) } - in, changed := bottomUp(operator, childID, resolveID, rewriter, shouldVisit, false) + anythingChanged = anythingChanged.Merge(changed) + root.Source = newSource + case *binaryOperator: + newLHS, changed := bottomUp(root.LHS, rootID, resolveID, rewriter, shouldVisit, false) if DebugOperatorTree && changed.Changed() { - fmt.Println(ToTree(in)) + fmt.Println(ToTree(newLHS)) } anythingChanged = anythingChanged.Merge(changed) - newInputs[i] = in - } + root.LHS = newLHS + newRHS, changed := bottomUp(root.RHS, rootID, resolveID, rewriter, shouldVisit, false) + if DebugOperatorTree && changed.Changed() { + fmt.Println(ToTree(newRHS)) + } + anythingChanged = anythingChanged.Merge(changed) + root.RHS = newRHS + default: + oldInputs := root.Inputs() + newInputs := make([]Operator, len(oldInputs)) + childID := rootID + + // noLHSTableSet is used to mark which operators that do not send data from the LHS to the RHS + // It's only UNION at this moment, but this package can't depend on the actual operators, so + // we use this interface to avoid direct dependencies + type noLHSTableSet interface{ NoLHSTableSet() } + + for i, operator := range oldInputs { + // We merge the table set of all the LHS above the current root so that we can + // send it down to the current RHS. + // We don't want to send the LHS table set to the RHS if the root is a UNION. + // Some operators, like SubQuery, can have multiple child operators on the RHS + if _, isUnion := root.(noLHSTableSet); !isUnion && i > 0 { + childID = childID.Merge(resolveID(oldInputs[0])) + } + in, changed := bottomUp(operator, childID, resolveID, rewriter, shouldVisit, false) + if DebugOperatorTree && changed.Changed() { + fmt.Println(ToTree(in)) + } + anythingChanged = anythingChanged.Merge(changed) + newInputs[i] = in + } - if anythingChanged.Changed() { - root.SetInputs(newInputs) + if anythingChanged.Changed() { + root.SetInputs(newInputs) + } } newOp, treeIdentity := rewriter(root, rootID, isRoot) @@ -247,14 +283,39 @@ func breakableTopDown( var anythingChanged *ApplyResult - oldInputs := newOp.Inputs() - newInputs := make([]Operator, len(oldInputs)) - for i, oldInput := range oldInputs { - newInputs[i], identity, err = breakableTopDown(oldInput, rewriter) + switch newOp := newOp.(type) { + case nullaryOperator: + // no inputs, nothing to do + case *unaryOperator: + newSource, identity, err := breakableTopDown(newOp.Source, rewriter) + if err != nil { + return nil, NoRewrite, err + } + anythingChanged = anythingChanged.Merge(identity) + newOp.Source = newSource + case *binaryOperator: + newLHS, identity, err := breakableTopDown(newOp.LHS, rewriter) + if err != nil { + return nil, NoRewrite, err + } anythingChanged = anythingChanged.Merge(identity) + newRHS, identity, err := breakableTopDown(newOp.RHS, rewriter) if err != nil { return nil, NoRewrite, err } + anythingChanged = anythingChanged.Merge(identity) + newOp.LHS = newLHS + newOp.RHS = newRHS + default: + oldInputs := newOp.Inputs() + newInputs := make([]Operator, len(oldInputs)) + for i, oldInput := range oldInputs { + newInputs[i], identity, err = breakableTopDown(oldInput, rewriter) + if err != nil { + return nil, NoRewrite, err + } + anythingChanged = anythingChanged.Merge(identity) + } } return newOp, anythingChanged, nil @@ -281,25 +342,39 @@ func topDown( root = newOp } - oldInputs := root.Inputs() - newInputs := make([]Operator, len(oldInputs)) - childID := rootID - - type noLHSTableSet interface{ NoLHSTableSet() } - - for i, operator := range oldInputs { - if _, isUnion := root.(noLHSTableSet); !isUnion && i > 0 { - childID = childID.Merge(resolveID(oldInputs[0])) - } - in, changed := topDown(operator, childID, resolveID, rewriter, shouldVisit, false) + switch newOp := newOp.(type) { + case nullaryOperator: + // no inputs, nothing to do + case *unaryOperator: + newSource, changed := topDown(newOp.Source, rootID, resolveID, rewriter, shouldVisit, false) anythingChanged = anythingChanged.Merge(changed) - newInputs[i] = in - } + newOp.Source = newSource + case *binaryOperator: + newLHS, changed := topDown(newOp.LHS, rootID, resolveID, rewriter, shouldVisit, false) + anythingChanged = anythingChanged.Merge(changed) + newRHS, changed := topDown(newOp.RHS, rootID, resolveID, rewriter, shouldVisit, false) + anythingChanged = anythingChanged.Merge(changed) + newOp.LHS, newOp.RHS = newLHS, newRHS + default: + oldInputs := root.Inputs() + newInputs := make([]Operator, len(oldInputs)) + childID := rootID + + type noLHSTableSet interface{ NoLHSTableSet() } + + for i, operator := range oldInputs { + if _, isUnion := root.(noLHSTableSet); !isUnion && i > 0 { + childID = childID.Merge(resolveID(oldInputs[0])) + } + in, changed := topDown(operator, childID, resolveID, rewriter, shouldVisit, false) + anythingChanged = anythingChanged.Merge(changed) + newInputs[i] = in + } - if anythingChanged != NoRewrite { - root.SetInputs(newInputs) - return root, anythingChanged + if anythingChanged != NoRewrite { + root.SetInputs(newInputs) + } } - return root, NoRewrite + return root, anythingChanged } diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index 25e3f610a04..a8cf8582851 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -33,7 +33,7 @@ import ( type ( Route struct { - Source Operator + unaryOperator // Routes that have been merged into this one. MergedWith []*Route @@ -119,7 +119,7 @@ func UpdateRoutingLogic(ctx *plancontext.PlanningContext, expr sqlparser.Expr, r } nr := &NoneRouting{keyspace: ks} - if isConstantFalse(ctx, expr) { + if b := ctx.IsConstantBool(expr); b != nil && !*b { return nr } @@ -161,39 +161,6 @@ func UpdateRoutingLogic(ctx *plancontext.PlanningContext, expr sqlparser.Expr, r return exit() } -// isConstantFalse checks whether this predicate can be evaluated at plan-time. If it returns `false` or `null`, -// we know that the query will not return anything, and this can be used to produce better plans -func isConstantFalse(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool { - if !ctx.SemTable.RecursiveDeps(expr).IsEmpty() { - // we have column dependencies, so we can be pretty sure - // we won't be able to use the evalengine to check if this is constant false - return false - } - env := ctx.VSchema.Environment() - collation := ctx.VSchema.ConnCollation() - eenv := evalengine.EmptyExpressionEnv(env) - eexpr, err := evalengine.Translate(expr, &evalengine.Config{ - Collation: collation, - Environment: env, - NoCompilation: true, - }) - if err != nil { - return false - } - eres, err := eenv.Evaluate(eexpr) - if err != nil { - return false - } - if eres.Value(collation).IsNull() { - return false - } - b, err := eres.ToBooleanStrict() - if err != nil { - return false - } - return !b -} - // Cost implements the Operator interface func (r *Route) Cost() int { return r.Routing.Cost() @@ -207,16 +174,6 @@ func (r *Route) Clone(inputs []Operator) Operator { return &cloneRoute } -// Inputs implements the Operator interface -func (r *Route) Inputs() []Operator { - return []Operator{r.Source} -} - -// SetInputs implements the Operator interface -func (r *Route) SetInputs(ops []Operator) { - r.Source = ops[0] -} - func createOption( colVindex *vindexes.ColumnVindex, vfunc func(*vindexes.ColumnVindex) vindexes.Vindex, @@ -465,10 +422,10 @@ func createRouteFromVSchemaTable( } } plan := &Route{ - Source: &Table{ + unaryOperator: newUnaryOp(&Table{ QTable: queryTable, VTable: vschemaTable, - }, + }), } // We create the appropriate Routing struct here, depending on the type of table we are dealing with. @@ -722,8 +679,8 @@ func wrapInDerivedProjection( columns = append(columns, sqlparser.NewIdentifierCI(fmt.Sprintf("c%d", i))) } derivedProj := &Projection{ - Source: op, - Columns: AliasedProjections(slice.Map(unionColumns, newProjExpr)), + unaryOperator: newUnaryOp(op), + Columns: AliasedProjections(slice.Map(unionColumns, newProjExpr)), DT: &DerivedTable{ TableID: ctx.SemTable.NewTableId(), Alias: "dt", diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index 6a242649725..90eb16e1562 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -189,8 +189,8 @@ func createInfSchemaRoute(ctx *plancontext.PlanningContext, table *QueryTable) O routing = UpdateRoutingLogic(ctx, pred, routing) } return &Route{ - Source: src, - Routing: routing, + unaryOperator: newUnaryOp(src), + Routing: routing, } } diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index b919bbfaed9..9610a2b10c9 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -232,11 +232,7 @@ var subqueryNotAtTopErr = vterrors.VT12001("unmergable subquery can not be insid func (sq *SubQuery) addLimit() { // for a correlated subquery, we can add a limit 1 to the subquery - sq.Subquery = &Limit{ - Source: sq.Subquery, - AST: &sqlparser.Limit{Rowcount: sqlparser.NewIntLiteral("1")}, - Top: true, - } + sq.Subquery = newLimit(sq.Subquery, &sqlparser.Limit{Rowcount: sqlparser.NewIntLiteral("1")}, true) } func (sq *SubQuery) settleFilter(ctx *plancontext.PlanningContext, outer Operator) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 0893afbeead..a2aca74fb6e 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -633,7 +633,7 @@ func (s *subqueryRouteMerger) merge(ctx *plancontext.PlanningContext, inner, out if !s.subq.TopLevel { // if the subquery we are merging isn't a top level predicate, we can't use it for routing return &Route{ - Source: outer.Source, + unaryOperator: newUnaryOp(outer.Source), MergedWith: mergedWith(inner, outer), Routing: outer.Routing, Ordering: outer.Ordering, @@ -651,7 +651,7 @@ func (s *subqueryRouteMerger) merge(ctx *plancontext.PlanningContext, inner, out src = s.rewriteASTExpression(ctx, inner) } return &Route{ - Source: src, + unaryOperator: newUnaryOp(src), MergedWith: mergedWith(inner, outer), Routing: r, Ordering: s.outer.Ordering, diff --git a/go/vt/vtgate/planbuilder/operators/table.go b/go/vt/vtgate/planbuilder/operators/table.go index 3ecd4982ece..6f221f2337a 100644 --- a/go/vt/vtgate/planbuilder/operators/table.go +++ b/go/vt/vtgate/planbuilder/operators/table.go @@ -33,7 +33,7 @@ type ( VTable *vindexes.Table Columns []*sqlparser.ColName - noInputs + nullaryOperator } ColNameColumns interface { GetColNames() []*sqlparser.ColName diff --git a/go/vt/vtgate/planbuilder/operators/union_merging.go b/go/vt/vtgate/planbuilder/operators/union_merging.go index 20c20673665..000d176b61a 100644 --- a/go/vt/vtgate/planbuilder/operators/union_merging.go +++ b/go/vt/vtgate/planbuilder/operators/union_merging.go @@ -222,9 +222,9 @@ func createMergedUnion( union := newUnion([]Operator{lhsRoute.Source, rhsRoute.Source}, []sqlparser.SelectExprs{lhsExprs, rhsExprs}, cols, distinct) selectExprs := unionSelects(lhsExprs) return &Route{ - Source: union, - MergedWith: []*Route{rhsRoute}, - Routing: routing, + unaryOperator: newUnaryOp(union), + MergedWith: []*Route{rhsRoute}, + Routing: routing, }, selectExprs } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index b4f0a37914e..9844b341670 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -110,11 +110,7 @@ func createOperatorFromUpdate(ctx *plancontext.PlanningContext, updStmt *sqlpars var targetTbl TargetTable op, targetTbl, updClone = createUpdateOperator(ctx, updStmt) - op = &LockAndComment{ - Source: op, - Comments: updStmt.Comments, - Lock: sqlparser.ShareModeLock, - } + op = newLockAndComment(op, updStmt.Comments, sqlparser.ShareModeLock) parentFks = ctx.SemTable.GetParentForeignKeysForTableSet(targetTbl.ID) childFks = ctx.SemTable.GetChildForeignKeysForTableSet(targetTbl.ID) @@ -203,10 +199,7 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up } if upd.Comments != nil { - op = &LockAndComment{ - Source: op, - Comments: upd.Comments, - } + op = newLockAndComment(op, upd.Comments, sqlparser.NoLock) } return op } @@ -399,10 +392,7 @@ func createUpdateOperator(ctx *plancontext.PlanningContext, updStmt *sqlparser.U } if updStmt.Limit != nil { - updOp.Source = &Limit{ - Source: updOp.Source, - AST: updStmt.Limit, - } + updOp.Source = newLimit(updOp.Source, updStmt.Limit, false) } return sqc.getRootOperator(updOp, nil), targetTbl, updClone diff --git a/go/vt/vtgate/planbuilder/operators/vindex.go b/go/vt/vtgate/planbuilder/operators/vindex.go index fd907fdad27..fbfbb6c0ccd 100644 --- a/go/vt/vtgate/planbuilder/operators/vindex.go +++ b/go/vt/vtgate/planbuilder/operators/vindex.go @@ -35,7 +35,7 @@ type ( Columns []*sqlparser.ColName Value sqlparser.Expr - noInputs + nullaryOperator } // VindexTable contains information about the vindex table we want to query diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 79536483970..9cf92a91ddf 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -459,11 +459,7 @@ func benchmarkWorkload(b *testing.B, name string) { testCases := readJSONTests(name + "_cases.json") b.ResetTimer() - for _, version := range plannerVersions { - b.Run(version.String(), func(b *testing.B) { - benchmarkPlanner(b, version, testCases, vschemaWrapper) - }) - } + benchmarkPlanner(b, Gen4, testCases, vschemaWrapper) } func (s *planTestSuite) TestBypassPlanningShardTargetFromFile() { @@ -797,9 +793,6 @@ func BenchmarkPlanner(b *testing.B) { b.Run(filename+"-gen4", func(b *testing.B) { benchmarkPlanner(b, Gen4, testCases, vschema) }) - b.Run(filename+"-gen4left2right", func(b *testing.B) { - benchmarkPlanner(b, Gen4Left2Right, testCases, vschema) - }) } } diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 66be6a4c71d..607ca83aa31 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -76,6 +76,9 @@ type PlanningContext struct { // isMirrored indicates that mirrored tables should be used. isMirrored bool + + emptyEnv *evalengine.ExpressionEnv + constantCfg *evalengine.Config } // CreatePlanningContext initializes a new PlanningContext with the given parameters. @@ -439,21 +442,59 @@ func (ctx *PlanningContext) UseMirror() *PlanningContext { return ctx.mirror } ctx.mirror = &PlanningContext{ - ctx.ReservedVars, - ctx.SemTable, - ctx.VSchema, - map[sqlparser.Expr][]sqlparser.Expr{}, - map[sqlparser.Expr]any{}, - ctx.PlannerVersion, - map[sqlparser.Expr]string{}, - ctx.VerifyAllFKs, - ctx.MergedSubqueries, - ctx.CurrentPhase, - ctx.Statement, - ctx.OuterTables, - ctx.CurrentCTE, - nil, - true, + ReservedVars: ctx.ReservedVars, + SemTable: ctx.SemTable, + VSchema: ctx.VSchema, + joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, + skipPredicates: map[sqlparser.Expr]any{}, + PlannerVersion: ctx.PlannerVersion, + ReservedArguments: map[sqlparser.Expr]string{}, + VerifyAllFKs: ctx.VerifyAllFKs, + MergedSubqueries: ctx.MergedSubqueries, + CurrentPhase: ctx.CurrentPhase, + Statement: ctx.Statement, + OuterTables: ctx.OuterTables, + CurrentCTE: ctx.CurrentCTE, + emptyEnv: ctx.emptyEnv, + isMirrored: true, } return ctx.mirror } + +// IsConstantBool checks whether this predicate can be evaluated at plan-time. +// If it can, it returns the constant value. +func (ctx *PlanningContext) IsConstantBool(expr sqlparser.Expr) *bool { + if !ctx.SemTable.RecursiveDeps(expr).IsEmpty() { + // we have column dependencies, so we can be pretty sure + // we won't be able to use the evalengine to check if this is constant false + return nil + } + env := ctx.VSchema.Environment() + collation := ctx.VSchema.ConnCollation() + if ctx.constantCfg == nil { + ctx.constantCfg = &evalengine.Config{ + Collation: collation, + Environment: env, + NoCompilation: true, + } + } + eexpr, err := evalengine.Translate(expr, ctx.constantCfg) + if ctx.emptyEnv == nil { + ctx.emptyEnv = evalengine.EmptyExpressionEnv(env) + } + if err != nil { + return nil + } + eres, err := ctx.emptyEnv.Evaluate(eexpr) + if err != nil { + return nil + } + if eres.Value(collation).IsNull() { + return nil + } + b, err := eres.ToBooleanStrict() + if err != nil { + return nil + } + return &b +} diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index b60e8812dda..3dc379b9aae 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -789,7 +789,7 @@ "Sharded": true }, "FieldQuery": "select Id from `user` where 1 != 1", - "Query": "select Id from `user` where 1 in ('aa', 'bb')", + "Query": "select Id from `user` where 0", "Table": "`user`" }, "TablesUsed": [ @@ -1251,7 +1251,7 @@ "Sharded": true }, "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select `user`.col from `user` where 1 = 1", + "Query": "select `user`.col from `user`", "Table": "`user`" }, { @@ -1262,7 +1262,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */ and 1 = 1", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index 47f10cd273b..799c9bd4420 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -1518,7 +1518,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and cast('foo' as CHAR) is not null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", "Table": "u_tbl8, u_tbl9" }, { @@ -1594,7 +1594,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and cast('foo' as CHAR) is not null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", "Table": "u_tbl3, u_tbl4" }, { @@ -1606,7 +1606,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (cast('foo' as CHAR) is null or (u_tbl9.col9) not in ((cast('foo' as CHAR)))) limit 1 for share", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (u_tbl9.col9) not in ((cast('foo' as CHAR))) limit 1 for share", "Table": "u_tbl4, u_tbl9" }, { @@ -2532,7 +2532,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where 1 != 1", - "Query": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where u_multicol_tbl1.cola is null and 2 is not null and u_multicol_tbl1.colb is null and u_multicol_tbl2.colc - 2 is not null and not (u_multicol_tbl2.cola, u_multicol_tbl2.colb) <=> (2, u_multicol_tbl2.colc - 2) and u_multicol_tbl2.id = 7 limit 1 for share", + "Query": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where u_multicol_tbl1.cola is null and u_multicol_tbl1.colb is null and u_multicol_tbl2.colc - 2 is not null and not (u_multicol_tbl2.cola, u_multicol_tbl2.colb) <=> (2, u_multicol_tbl2.colc - 2) and u_multicol_tbl2.id = 7 limit 1 for share", "Table": "u_multicol_tbl1, u_multicol_tbl2" }, { @@ -4110,7 +4110,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and cast('foo' as CHAR) is not null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", "Table": "u_tbl8, u_tbl9" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json index 7b525b2dcc9..5464ccbd619 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json @@ -1595,7 +1595,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and cast('foo' as CHAR) is not null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", "Table": "u_tbl8, u_tbl9" }, { @@ -1671,7 +1671,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and cast('foo' as CHAR) is not null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", "Table": "u_tbl3, u_tbl4" }, { @@ -1683,7 +1683,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (cast('foo' as CHAR) is null or (u_tbl9.col9) not in ((cast('foo' as CHAR)))) limit 1 for share", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (u_tbl9.col9) not in ((cast('foo' as CHAR))) limit 1 for share", "Table": "u_tbl4, u_tbl9" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index f06a6a50d45..856e56265ca 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -1874,7 +1874,7 @@ "Sharded": false }, "FieldQuery": "select 42 from dual where 1 != 1", - "Query": "select 42 from dual where false", + "Query": "select 42 from dual where 0", "Table": "dual" }, "TablesUsed": [ diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index ee12765e984..3e53ed0816a 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -24,7 +24,6 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/evalengine" ) type earlyRewriter struct { @@ -48,10 +47,6 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case sqlparser.SelectExprs: return r.handleSelectExprs(cursor, node) - case *sqlparser.OrExpr: - rewriteOrExpr(r.env, cursor, node) - case *sqlparser.AndExpr: - rewriteAndExpr(r.env, cursor, node) case *sqlparser.NotExpr: rewriteNotExpr(cursor, node) case *sqlparser.ComparisonExpr: @@ -854,57 +849,6 @@ func (r *earlyRewriter) rewriteGroupByExpr(node *sqlparser.Literal) (sqlparser.E return realCloneOfColNames(aliasedExpr.Expr, false), nil } -// rewriteOrExpr rewrites OR expressions when the right side is FALSE. -func rewriteOrExpr(env *vtenv.Environment, cursor *sqlparser.Cursor, node *sqlparser.OrExpr) { - newNode := rewriteOrFalse(env, *node) - if newNode != nil { - cursor.ReplaceAndRevisit(newNode) - } -} - -// rewriteAndExpr rewrites AND expressions when either side is TRUE. -func rewriteAndExpr(env *vtenv.Environment, cursor *sqlparser.Cursor, node *sqlparser.AndExpr) { - newNode := rewriteAndTrue(env, *node) - if newNode != nil { - cursor.ReplaceAndRevisit(newNode) - } -} - -func rewriteAndTrue(env *vtenv.Environment, andExpr sqlparser.AndExpr) sqlparser.Expr { - // we are looking for the pattern `WHERE c = 1 AND 1 = 1` - isTrue := func(subExpr sqlparser.Expr) bool { - coll := env.CollationEnv().DefaultConnectionCharset() - evalEnginePred, err := evalengine.Translate(subExpr, &evalengine.Config{ - Environment: env, - Collation: coll, - }) - if err != nil { - return false - } - - env := evalengine.EmptyExpressionEnv(env) - res, err := env.Evaluate(evalEnginePred) - if err != nil { - return false - } - - boolValue, err := res.Value(coll).ToBool() - if err != nil { - return false - } - - return boolValue - } - - if isTrue(andExpr.Left) { - return andExpr.Right - } else if isTrue(andExpr.Right) { - return andExpr.Left - } - - return nil -} - // handleComparisonExpr processes Comparison expressions, specifically for tuples with equal length and EqualOp operator. func handleComparisonExpr(cursor *sqlparser.Cursor, node *sqlparser.ComparisonExpr) error { lft, lftOK := node.Left.(sqlparser.ValTuple) @@ -970,41 +914,6 @@ func realCloneOfColNames(expr sqlparser.Expr, union bool) sqlparser.Expr { }, nil).(sqlparser.Expr) } -func rewriteOrFalse(env *vtenv.Environment, orExpr sqlparser.OrExpr) sqlparser.Expr { - // we are looking for the pattern `WHERE c = 1 OR 1 = 0` - isFalse := func(subExpr sqlparser.Expr) bool { - coll := env.CollationEnv().DefaultConnectionCharset() - evalEnginePred, err := evalengine.Translate(subExpr, &evalengine.Config{ - Environment: env, - Collation: coll, - }) - if err != nil { - return false - } - - env := evalengine.EmptyExpressionEnv(env) - res, err := env.Evaluate(evalEnginePred) - if err != nil { - return false - } - - boolValue, err := res.Value(coll).ToBool() - if err != nil { - return false - } - - return !boolValue - } - - if isFalse(orExpr.Left) { - return orExpr.Right - } else if isFalse(orExpr.Right) { - return orExpr.Left - } - - return nil -} - // rewriteJoinUsing rewrites SQL JOINs that use the USING clause to their equivalent // JOINs with the ON condition. This function finds all the tables that have the // specified columns in the USING clause, constructs an equality predicate for diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index fab8211f74e..4f550d46392 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -905,53 +905,6 @@ func TestOrderByDerivedTable(t *testing.T) { } } -// TestConstantFolding tests that the rewriter is able to do various constant foldings properly. -func TestConstantFolding(t *testing.T) { - ks := &vindexes.Keyspace{ - Name: "main", - Sharded: true, - } - schemaInfo := &FakeSI{ - Tables: map[string]*vindexes.Table{ - "t1": { - Keyspace: ks, - Name: sqlparser.NewIdentifierCS("t1"), - Columns: []vindexes.Column{{ - Name: sqlparser.NewIdentifierCI("a"), - Type: sqltypes.VarChar, - }, { - Name: sqlparser.NewIdentifierCI("b"), - Type: sqltypes.VarChar, - }, { - Name: sqlparser.NewIdentifierCI("c"), - Type: sqltypes.VarChar, - }}, - ColumnListAuthoritative: true, - }, - }, - } - cDB := "db" - tcases := []struct { - sql string - expSQL string - }{{ - sql: "select 1 from t1 where (a, b) in ::fkc_vals and (2 is null or (1 is null or a in (1)))", - expSQL: "select 1 from t1 where (a, b) in ::fkc_vals and a in (1)", - }, { - sql: "select 1 from t1 where (false or (false or a in (1)))", - expSQL: "select 1 from t1 where a in (1)", - }} - for _, tcase := range tcases { - t.Run(tcase.sql, func(t *testing.T) { - ast, err := sqlparser.NewTestParser().Parse(tcase.sql) - require.NoError(t, err) - _, err = Analyze(ast, cDB, schemaInfo) - require.NoError(t, err) - require.Equal(t, tcase.expSQL, sqlparser.String(ast)) - }) - } -} - // TestCTEToDerivedTableRewrite checks that CTEs are correctly rewritten to derived tables func TestCTEToDerivedTableRewrite(t *testing.T) { cDB := "db"