Skip to content

Commit

Permalink
feat: respect straight joins in planning and pass them to MySQL too
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 committed Mar 20, 2024
1 parent d20f3c5 commit bc713aa
Show file tree
Hide file tree
Showing 14 changed files with 148 additions and 65 deletions.
20 changes: 20 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,26 @@ func (node DatabaseOptionType) ToString() string {
}
}

// IsCommutative returns whether the join type supports rearranging or not.
func (joinType JoinType) IsCommutative() bool {
switch joinType {
case StraightJoinType, LeftJoinType, RightJoinType, NaturalLeftJoinType, NaturalRightJoinType:
return false
default:
return true
}
}

// IsInner returns whether the join type is an inner join or not.
func (joinType JoinType) IsInner() bool {
switch joinType {
case StraightJoinType, NaturalJoinType, NormalJoinType:
return true
default:
return false
}
}

// ToString returns the type as a string
func (ty LockType) ToString() string {
switch ty {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ func transformApplyJoinPlan(ctx *plancontext.PlanningContext, n *operators.Apply
return nil, err
}
opCode := engine.InnerJoin
if n.LeftJoin {
if !n.JoinType.IsInner() {
opCode = engine.LeftJoin
}

Expand Down
36 changes: 14 additions & 22 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ var _ FromStatement = (*sqlparser.Select)(nil)
var _ FromStatement = (*sqlparser.Update)(nil)
var _ FromStatement = (*sqlparser.Delete)(nil)

func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser.Expr) {
func (qb *queryBuilder) joinWith(other *queryBuilder, onCondition sqlparser.Expr, joinType sqlparser.JoinType) {
stmt := qb.stmt.(FromStatement)
otherStmt := other.stmt.(FromStatement)

Expand All @@ -222,24 +222,20 @@ func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser
sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...)
}

newFromClause := append(stmt.GetFrom(), otherStmt.GetFrom()...)
stmt.SetFrom(newFromClause)
qb.mergeWhereClauses(stmt, otherStmt)
qb.addPredicate(onCondition)
}

func (qb *queryBuilder) joinOuterWith(other *queryBuilder, onCondition sqlparser.Expr) {
stmt := qb.stmt.(FromStatement)
otherStmt := other.stmt.(FromStatement)

if sel, isSel := stmt.(*sqlparser.Select); isSel {
otherSel := otherStmt.(*sqlparser.Select)
sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...)
var newFromClause []sqlparser.TableExpr
switch joinType {
case sqlparser.NormalJoinType:
newFromClause = append(stmt.GetFrom(), otherStmt.GetFrom()...)
default:
newFromClause = []sqlparser.TableExpr{buildJoin(stmt, otherStmt, onCondition, joinType)}
}

newFromClause := []sqlparser.TableExpr{buildOuterJoin(stmt, otherStmt, onCondition)}
stmt.SetFrom(newFromClause)
qb.mergeWhereClauses(stmt, otherStmt)

if joinType == sqlparser.NormalJoinType {
qb.addPredicate(onCondition)
}
}

func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) {
Expand All @@ -254,7 +250,7 @@ func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) {
}
}

func buildOuterJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr) *sqlparser.JoinTableExpr {
func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr {
var lhs sqlparser.TableExpr
fromClause := stmt.GetFrom()
if len(fromClause) == 1 {
Expand All @@ -273,7 +269,7 @@ func buildOuterJoin(stmt FromStatement, otherStmt FromStatement, onCondition sql
return &sqlparser.JoinTableExpr{
LeftExpr: lhs,
RightExpr: rhs,
Join: sqlparser.LeftJoinType,
Join: joinType,
Condition: &sqlparser.JoinCondition{
On: onCondition,
},
Expand Down Expand Up @@ -539,11 +535,7 @@ func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) {

qbR := &queryBuilder{ctx: qb.ctx}
buildQuery(op.RHS, qbR)
if op.LeftJoin {
qb.joinOuterWith(qbR, pred)
} else {
qb.joinInnerWith(qbR, pred)
}
qb.joinWith(qbR, pred, op.JoinType)
}

func buildUnion(op *Union, qb *queryBuilder) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ func pushAggregationThroughApplyJoin(ctx *plancontext.PlanningContext, rootAggr
rhs := createJoinPusher(rootAggr, join.RHS)

columns := &applyJoinColumns{}
output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, join.LeftJoin, columns, lhs, rhs)
output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, !join.JoinType.IsInner(), columns, lhs, rhs)
join.JoinColumns = columns
if err != nil {
// if we get this error, we just abort the splitting and fall back on simpler ways of solving the same query
Expand Down
14 changes: 10 additions & 4 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ type (
ApplyJoin struct {
LHS, RHS Operator

// JoinType is permitted to store only 3 of the possible values
// NormalJoinType, StraightJoinType and LeftJoinType.
JoinType sqlparser.JoinType
// LeftJoin will be true in the case of an outer join
LeftJoin bool

Expand Down Expand Up @@ -82,12 +85,12 @@ type (
}
)

func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin {
func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, joinType sqlparser.JoinType) *ApplyJoin {
aj := &ApplyJoin{
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
LeftJoin: leftOuterJoin,
JoinType: joinType,
JoinColumns: &applyJoinColumns{},
JoinPredicates: &applyJoinColumns{},
}
Expand Down Expand Up @@ -139,11 +142,14 @@ func (aj *ApplyJoin) SetRHS(operator Operator) {
}

func (aj *ApplyJoin) MakeInner() {
aj.LeftJoin = false
if aj.IsInner() {
panic(vterrors.VT13001("Convert an already inner join"))
}
aj.JoinType = sqlparser.NormalJoinType
}

func (aj *ApplyJoin) IsInner() bool {
return !aj.LeftJoin
return aj.JoinType.IsInner()
}

func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ func getOperatorFromJoinTableExpr(ctx *plancontext.PlanningContext, tableExpr *s
switch tableExpr.Join {
case sqlparser.NormalJoinType:
return createInnerJoin(ctx, tableExpr, lhs, rhs)
case sqlparser.LeftJoinType, sqlparser.RightJoinType:
return createOuterJoin(tableExpr, lhs, rhs)
case sqlparser.LeftJoinType, sqlparser.RightJoinType, sqlparser.StraightJoinType:
return createLeftAndStraightJoin(tableExpr, lhs, rhs)
default:
panic(vterrors.VT13001("unsupported: %s", tableExpr.Join.ToString()))
}
Expand Down
20 changes: 13 additions & 7 deletions go/vt/vtgate/planbuilder/operators/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
type Join struct {
LHS, RHS Operator
Predicate sqlparser.Expr
LeftJoin bool
// JoinType is permitted to store only 3 of the possible values
// NormalJoinType, StraightJoinType and LeftJoinType.
JoinType sqlparser.JoinType

noColumns
}
Expand All @@ -42,7 +44,7 @@ func (j *Join) Clone(inputs []Operator) Operator {
LHS: inputs[0],
RHS: inputs[1],
Predicate: j.Predicate,
LeftJoin: j.LeftJoin,
JoinType: j.JoinType,
}
}

Expand All @@ -61,7 +63,7 @@ func (j *Join) SetInputs(ops []Operator) {
}

func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult) {
if j.LeftJoin {
if !j.JoinType.IsCommutative() {
// we can't merge outer joins into a single QG
return j, NoRewrite
}
Expand All @@ -83,17 +85,18 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult
return newOp, Rewrote("merge querygraphs into a single one")
}

func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
func createLeftAndStraightJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
if tableExpr.Join == sqlparser.RightJoinType {
lhs, rhs = rhs, lhs
tableExpr.Join = sqlparser.LeftJoinType
}
subq, _ := getSubQuery(tableExpr.Condition.On)
if subq != nil {
panic(vterrors.VT12001("subquery in outer join predicate"))
}
predicate := tableExpr.Condition.On
sqlparser.RemoveKeyspaceInCol(predicate)
return &Join{LHS: lhs, RHS: rhs, LeftJoin: true, Predicate: predicate}
return &Join{LHS: lhs, RHS: rhs, JoinType: tableExpr.Join, Predicate: predicate}
}

func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator {
Expand Down Expand Up @@ -150,11 +153,14 @@ func (j *Join) SetRHS(operator Operator) {
}

func (j *Join) MakeInner() {
j.LeftJoin = false
if j.IsInner() {
panic(vterrors.VT13001("Convert an already inner join"))
}
j.JoinType = sqlparser.NormalJoinType
}

func (j *Join) IsInner() bool {
return !j.LeftJoin
return j.JoinType.IsInner()
}

func (j *Join) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) {
Expand Down
10 changes: 6 additions & 4 deletions go/vt/vtgate/planbuilder/operators/join_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ type (

joinMerger struct {
predicates []sqlparser.Expr
innerJoin bool
// joinType is permitted to store only 3 of the possible values
// NormalJoinType, StraightJoinType and LeftJoinType.
joinType sqlparser.JoinType
}

routingType int
Expand Down Expand Up @@ -176,10 +178,10 @@ func getRoutingType(r Routing) routingType {
panic(fmt.Sprintf("switch should be exhaustive, got %T", r))
}

func newJoinMerge(predicates []sqlparser.Expr, innerJoin bool) merger {
func newJoinMerge(predicates []sqlparser.Expr, joinType sqlparser.JoinType) merger {
return &joinMerger{
predicates: predicates,
innerJoin: innerJoin,
joinType: joinType,
}
}

Expand All @@ -203,7 +205,7 @@ func mergeShardedRouting(r1 *ShardedRouting, r2 *ShardedRouting) *ShardedRouting
}

func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) *ApplyJoin {
return NewApplyJoin(ctx, op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin)
return NewApplyJoin(ctx, op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), jm.joinType)
}

func (jm *joinMerger) merge(ctx *plancontext.PlanningContext, op1, op2 *Route, r Routing) *Route {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/projection_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func pushProjectionInApplyJoin(
src *ApplyJoin,
) (Operator, *ApplyResult) {
ap, err := p.GetAliasedProjections()
if src.LeftJoin || err != nil {
if !src.IsInner() || err != nil {
// we can't push down expression evaluation to the rhs if we are not sure if it will even be executed
return p, NoRewrite
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func canPushLeft(ctx *plancontext.PlanningContext, aj *ApplyJoin, order []OrderB

func isOuterTable(op Operator, ts semantics.TableSet) bool {
aj, ok := op.(*ApplyJoin)
if ok && aj.LeftJoin && TableID(aj.RHS).IsOverlapping(ts) {
if ok && !aj.IsInner() && TableID(aj.RHS).IsOverlapping(ts) {
return true
}

Expand Down
18 changes: 9 additions & 9 deletions go/vt/vtgate/planbuilder/operators/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func pushDerived(ctx *plancontext.PlanningContext, op *Horizon) (Operator, *Appl
}

func optimizeJoin(ctx *plancontext.PlanningContext, op *Join) (Operator, *ApplyResult) {
return mergeOrJoin(ctx, op.LHS, op.RHS, sqlparser.SplitAndExpression(nil, op.Predicate), !op.LeftJoin)
return mergeOrJoin(ctx, op.LHS, op.RHS, sqlparser.SplitAndExpression(nil, op.Predicate), op.JoinType)
}

func optimizeQueryGraph(ctx *plancontext.PlanningContext, op *QueryGraph) (result Operator, changed *ApplyResult) {
Expand Down Expand Up @@ -147,7 +147,7 @@ func leftToRightSolve(ctx *plancontext.PlanningContext, qg *QueryGraph) Operator
continue
}
joinPredicates := qg.GetPredicates(TableID(acc), TableID(plan))
acc, _ = mergeOrJoin(ctx, acc, plan, joinPredicates, true)
acc, _ = mergeOrJoin(ctx, acc, plan, joinPredicates, sqlparser.NormalJoinType)
}

return acc
Expand Down Expand Up @@ -262,7 +262,7 @@ func getJoinFor(ctx *plancontext.PlanningContext, cm opCacheMap, lhs, rhs Operat
return cachedPlan
}

join, _ := mergeOrJoin(ctx, lhs, rhs, joinPredicates, true)
join, _ := mergeOrJoin(ctx, lhs, rhs, joinPredicates, sqlparser.NormalJoinType)
cm[solves] = join
return join
}
Expand All @@ -283,29 +283,29 @@ func requiresSwitchingSides(ctx *plancontext.PlanningContext, op Operator) (requ
return
}

func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr, inner bool) (Operator, *ApplyResult) {
newPlan := mergeJoinInputs(ctx, lhs, rhs, joinPredicates, newJoinMerge(joinPredicates, inner))
func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr, joinType sqlparser.JoinType) (Operator, *ApplyResult) {
newPlan := mergeJoinInputs(ctx, lhs, rhs, joinPredicates, newJoinMerge(joinPredicates, joinType))
if newPlan != nil {
return newPlan, Rewrote("merge routes into single operator")
}

if len(joinPredicates) > 0 && requiresSwitchingSides(ctx, rhs) {
if !inner || requiresSwitchingSides(ctx, lhs) {
if !joinType.IsCommutative() || requiresSwitchingSides(ctx, lhs) {
// we can't switch sides, so let's see if we can use a HashJoin to solve it
join := NewHashJoin(lhs, rhs, !inner)
join := NewHashJoin(lhs, rhs, !joinType.IsInner())
for _, pred := range joinPredicates {
join.AddJoinPredicate(ctx, pred)
}
ctx.SemTable.QuerySignature.HashJoin = true
return join, Rewrote("use a hash join because we have LIMIT on the LHS")
}

join := NewApplyJoin(ctx, Clone(rhs), Clone(lhs), nil, !inner)
join := NewApplyJoin(ctx, Clone(rhs), Clone(lhs), nil, joinType)
newOp := pushJoinPredicates(ctx, joinPredicates, join)
return newOp, Rewrote("logical join to applyJoin, switching side because LIMIT")
}

join := NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, !inner)
join := NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, joinType)
newOp := pushJoinPredicates(ctx, joinPredicates, join)
return newOp, Rewrote("logical join to applyJoin ")
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func tryPushSubQueryInJoin(
return outer, Rewrote("push subquery into LHS of join")
}

if outer.LeftJoin || len(inner.Predicates) == 0 {
if !outer.IsInner() || len(inner.Predicates) == 0 {
// we can't push any filters on the RHS of an outer join, and
// we don't want to push uncorrelated subqueries to the RHS of a join
return nil, NoRewrite
Expand Down Expand Up @@ -278,7 +278,7 @@ func extractLHSExpr(

// tryMergeWithRHS attempts to merge a subquery with the RHS of a join
func tryMergeWithRHS(ctx *plancontext.PlanningContext, inner *SubQuery, outer *ApplyJoin) (Operator, *ApplyResult) {
if outer.LeftJoin {
if !outer.IsInner() {
return nil, nil
}
// both sides need to be routes
Expand Down
Loading

0 comments on commit bc713aa

Please sign in to comment.