Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect Straight Join in Vitess query planning #15528

Merged
merged 8 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions go/test/endtoend/vtgate/gen4/gen4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,26 +187,6 @@ func TestSubQueriesOnOuterJoinOnCondition(t *testing.T) {
}
}

func TestPlannerWarning(t *testing.T) {
Copy link
Collaborator

@systay systay Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, straight_join no longer produces a warning, and we don't have reliable ways of getting warnings from the planner without turning off schema tracking

mcmp, closer := start(t)
defer closer()

// straight_join query
_ = utils.Exec(t, mcmp.VtConn, `select 1 from t1 straight_join t2 on t1.id = t2.id`)
utils.AssertMatches(t, mcmp.VtConn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`)

// execute same query again.
_ = utils.Exec(t, mcmp.VtConn, `select 1 from t1 straight_join t2 on t1.id = t2.id`)
utils.AssertMatches(t, mcmp.VtConn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`)

// random query to reset the warning.
_ = utils.Exec(t, mcmp.VtConn, `select 1 from t1`)

// execute same query again.
_ = utils.Exec(t, mcmp.VtConn, `select 1 from t1 straight_join t2 on t1.id = t2.id`)
utils.AssertMatches(t, mcmp.VtConn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`)
}

func TestHashJoin(t *testing.T) {
mcmp, closer := start(t)
defer closer()
Expand Down
27 changes: 27 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,30 @@ func TestAlterTableWithView(t *testing.T) {

mcmp.AssertMatches("select * from v1", `[[INT64(1) INT64(1)]]`)
}

// TestStraightJoin tests that Vitess respects the ordering of join in a STRAIGHT JOIN query.
func TestStraightJoin(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (1,0,10), (2,10,10), (3,4,20), (4,30,20), (5,40,10)")
mcmp.Exec(`insert into t1(id1, id2) values (10, 11), (20, 13)`)

mcmp.AssertMatchesNoOrder("select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 join tbl where t1.id1 = tbl.nonunq_col",
`[[INT64(0) INT64(10) INT64(11)] [INT64(10) INT64(10) INT64(11)] [INT64(4) INT64(20) INT64(13)] [INT64(40) INT64(10) INT64(11)] [INT64(30) INT64(20) INT64(13)]]`,
)
// Verify that in a normal join query, vitess joins tbl with t1.
res, err := mcmp.VtConn.ExecuteFetch("vexplain plan select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 join tbl where t1.id1 = tbl.nonunq_col", 100, false)
require.NoError(t, err)
require.Contains(t, fmt.Sprintf("%v", res.Rows), "tbl_t1")

// Test the same query with a straight join
mcmp.AssertMatchesNoOrder("select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 straight_join tbl where t1.id1 = tbl.nonunq_col",
`[[INT64(0) INT64(10) INT64(11)] [INT64(10) INT64(10) INT64(11)] [INT64(4) INT64(20) INT64(13)] [INT64(40) INT64(10) INT64(11)] [INT64(30) INT64(20) INT64(13)]]`,
)
// Verify that in a straight join query, vitess joins t1 with tbl.
res, err = mcmp.VtConn.ExecuteFetch("vexplain plan select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 straight_join tbl where t1.id1 = tbl.nonunq_col", 100, false)
require.NoError(t, err)
require.Contains(t, fmt.Sprintf("%v", res.Rows), "t1_tbl")
}
20 changes: 20 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,26 @@ func (node DatabaseOptionType) ToString() string {
}
}

// IsCommutative returns whether the join type supports rearranging or not.
func (joinType JoinType) IsCommutative() bool {
switch joinType {
case StraightJoinType, LeftJoinType, RightJoinType, NaturalLeftJoinType, NaturalRightJoinType:
return false
default:
return true
}
}

// IsInner returns whether the join type is an inner join or not.
func (joinType JoinType) IsInner() bool {
switch joinType {
case StraightJoinType, NaturalJoinType, NormalJoinType:
return true
default:
return false
}
}

// ToString returns the type as a string
func (ty LockType) ToString() string {
switch ty {
Expand Down
10 changes: 2 additions & 8 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3343,18 +3343,12 @@ func TestGen4SelectStraightJoin(t *testing.T) {
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{
{
Sql: "select u.id from `user` as u, user2 as u2 where u.id = u2.id",
Sql: "select u.id from `user` as u straight_join user2 as u2 on u.id = u2.id",
BindVariables: map[string]*querypb.BindVariable{},
},
}
wantWarnings := []*querypb.QueryWarning{
{
Code: 1235,
Message: "straight join is converted to normal join",
},
}
utils.MustMatch(t, wantQueries, sbc1.Queries)
utils.MustMatch(t, wantWarnings, session.Warnings)
require.Empty(t, session.Warnings)
}

func TestGen4MultiColumnVindexEqual(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ func transformApplyJoinPlan(ctx *plancontext.PlanningContext, n *operators.Apply
return nil, err
}
opCode := engine.InnerJoin
if n.LeftJoin {
if !n.JoinType.IsInner() {
opCode = engine.LeftJoin
}

Expand Down
32 changes: 11 additions & 21 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ var _ FromStatement = (*sqlparser.Select)(nil)
var _ FromStatement = (*sqlparser.Update)(nil)
var _ FromStatement = (*sqlparser.Delete)(nil)

func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser.Expr) {
func (qb *queryBuilder) joinWith(other *queryBuilder, onCondition sqlparser.Expr, joinType sqlparser.JoinType) {
stmt := qb.stmt.(FromStatement)
otherStmt := other.stmt.(FromStatement)

Expand All @@ -222,24 +222,18 @@ func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser
sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...)
}

newFromClause := append(stmt.GetFrom(), otherStmt.GetFrom()...)
stmt.SetFrom(newFromClause)
qb.mergeWhereClauses(stmt, otherStmt)
qb.addPredicate(onCondition)
}

func (qb *queryBuilder) joinOuterWith(other *queryBuilder, onCondition sqlparser.Expr) {
stmt := qb.stmt.(FromStatement)
otherStmt := other.stmt.(FromStatement)

if sel, isSel := stmt.(*sqlparser.Select); isSel {
otherSel := otherStmt.(*sqlparser.Select)
sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...)
var newFromClause []sqlparser.TableExpr
switch joinType {
case sqlparser.NormalJoinType:
newFromClause = append(stmt.GetFrom(), otherStmt.GetFrom()...)
qb.addPredicate(onCondition)
default:
newFromClause = []sqlparser.TableExpr{buildJoin(stmt, otherStmt, onCondition, joinType)}
}

newFromClause := []sqlparser.TableExpr{buildOuterJoin(stmt, otherStmt, onCondition)}
stmt.SetFrom(newFromClause)
qb.mergeWhereClauses(stmt, otherStmt)
}

func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) {
Expand All @@ -254,7 +248,7 @@ func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) {
}
}

func buildOuterJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr) *sqlparser.JoinTableExpr {
func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr {
var lhs sqlparser.TableExpr
fromClause := stmt.GetFrom()
if len(fromClause) == 1 {
Expand All @@ -273,7 +267,7 @@ func buildOuterJoin(stmt FromStatement, otherStmt FromStatement, onCondition sql
return &sqlparser.JoinTableExpr{
LeftExpr: lhs,
RightExpr: rhs,
Join: sqlparser.LeftJoinType,
Join: joinType,
Condition: &sqlparser.JoinCondition{
On: onCondition,
},
Expand Down Expand Up @@ -539,11 +533,7 @@ func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) {

qbR := &queryBuilder{ctx: qb.ctx}
buildQuery(op.RHS, qbR)
if op.LeftJoin {
qb.joinOuterWith(qbR, pred)
} else {
qb.joinInnerWith(qbR, pred)
}
qb.joinWith(qbR, pred, op.JoinType)
}

func buildUnion(op *Union, qb *queryBuilder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ func pushAggregationThroughApplyJoin(ctx *plancontext.PlanningContext, rootAggr
rhs := createJoinPusher(rootAggr, join.RHS)

columns := &applyJoinColumns{}
output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, join.LeftJoin, columns, lhs, rhs)
output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, !join.JoinType.IsInner(), columns, lhs, rhs)
join.JoinColumns = columns
if err != nil {
// if we get this error, we just abort the splitting and fall back on simpler ways of solving the same query
Expand Down
14 changes: 10 additions & 4 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ type (
ApplyJoin struct {
LHS, RHS Operator

// JoinType is permitted to store only 3 of the possible values
// NormalJoinType, StraightJoinType and LeftJoinType.
JoinType sqlparser.JoinType
// LeftJoin will be true in the case of an outer join
LeftJoin bool

Expand Down Expand Up @@ -82,12 +85,12 @@ type (
}
)

func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin {
func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, joinType sqlparser.JoinType) *ApplyJoin {
aj := &ApplyJoin{
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
LeftJoin: leftOuterJoin,
JoinType: joinType,
JoinColumns: &applyJoinColumns{},
JoinPredicates: &applyJoinColumns{},
}
Expand Down Expand Up @@ -139,11 +142,14 @@ func (aj *ApplyJoin) SetRHS(operator Operator) {
}

func (aj *ApplyJoin) MakeInner() {
aj.LeftJoin = false
if aj.IsInner() {
return
}
aj.JoinType = sqlparser.NormalJoinType
}

func (aj *ApplyJoin) IsInner() bool {
return !aj.LeftJoin
return aj.JoinType.IsInner()
}

func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) {
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ func getOperatorFromJoinTableExpr(ctx *plancontext.PlanningContext, tableExpr *s
case sqlparser.NormalJoinType:
return createInnerJoin(ctx, tableExpr, lhs, rhs)
case sqlparser.LeftJoinType, sqlparser.RightJoinType:
return createOuterJoin(tableExpr, lhs, rhs)
return createLeftOuterJoin(ctx, tableExpr, lhs, rhs)
case sqlparser.StraightJoinType:
return createStraightJoin(ctx, tableExpr, lhs, rhs)
default:
panic(vterrors.VT13001("unsupported: %s", tableExpr.Join.ToString()))
}
Expand Down
83 changes: 58 additions & 25 deletions go/vt/vtgate/planbuilder/operators/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
type Join struct {
LHS, RHS Operator
Predicate sqlparser.Expr
LeftJoin bool
// JoinType is permitted to store only 3 of the possible values
// NormalJoinType, StraightJoinType and LeftJoinType.
JoinType sqlparser.JoinType

noColumns
}
Expand All @@ -42,7 +44,7 @@ func (j *Join) Clone(inputs []Operator) Operator {
LHS: inputs[0],
RHS: inputs[1],
Predicate: j.Predicate,
LeftJoin: j.LeftJoin,
JoinType: j.JoinType,
}
}

Expand All @@ -61,8 +63,8 @@ func (j *Join) SetInputs(ops []Operator) {
}

func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult) {
if j.LeftJoin {
// we can't merge outer joins into a single QG
if !j.JoinType.IsCommutative() {
// if we can't move tables around, we can't merge these inputs
return j, NoRewrite
}

Expand All @@ -83,38 +85,52 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult
return newOp, Rewrote("merge querygraphs into a single one")
}

func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
if tableExpr.Join == sqlparser.RightJoinType {
func createStraightJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
// for inner joins we can treat the predicates as filters on top of the join
joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join}

return addJoinPredicates(ctx, join.Condition.On, joinOp)
}

func createLeftOuterJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
// first we switch sides, so we always deal with left outer joins
switch join.Join {
case sqlparser.RightJoinType:
lhs, rhs = rhs, lhs
join.Join = sqlparser.LeftJoinType
case sqlparser.NaturalRightJoinType:
lhs, rhs = rhs, lhs
join.Join = sqlparser.NaturalLeftJoinType
}
subq, _ := getSubQuery(tableExpr.Condition.On)

joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join}

// for outer joins we have to be careful with the predicates we use
var op Operator
subq, _ := getSubQuery(join.Condition.On)
if subq != nil {
panic(vterrors.VT12001("subquery in outer join predicate"))
}
predicate := tableExpr.Condition.On
predicate := join.Condition.On
sqlparser.RemoveKeyspaceInCol(predicate)
return &Join{LHS: lhs, RHS: rhs, LeftJoin: true, Predicate: predicate}
}
joinOp.Predicate = predicate
op = joinOp

func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator {
lqg, lok := LHS.(*QueryGraph)
rqg, rok := RHS.(*QueryGraph)
if lok && rok {
op := &QueryGraph{
Tables: append(lqg.Tables, rqg.Tables...),
innerJoins: append(lqg.innerJoins, rqg.innerJoins...),
NoDeps: ctx.SemTable.AndExpressions(lqg.NoDeps, rqg.NoDeps),
}
return op
}
return &Join{LHS: LHS, RHS: RHS}
return op
}

func createInnerJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
op := createJoin(ctx, lhs, rhs)
return addJoinPredicates(ctx, tableExpr.Condition.On, op)
}

func addJoinPredicates(
ctx *plancontext.PlanningContext,
joinPredicate sqlparser.Expr,
op Operator,
) Operator {
sqc := &SubQueryBuilder{}
outerID := TableID(op)
joinPredicate := tableExpr.Condition.On
sqlparser.RemoveKeyspaceInCol(joinPredicate)
exprs := sqlparser.SplitAndExpression(nil, joinPredicate)
for _, pred := range exprs {
Expand All @@ -127,6 +143,20 @@ func createInnerJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.Join
return sqc.getRootOperator(op, nil)
}

func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator {
lqg, lok := LHS.(*QueryGraph)
rqg, rok := RHS.(*QueryGraph)
if lok && rok {
op := &QueryGraph{
Tables: append(lqg.Tables, rqg.Tables...),
innerJoins: append(lqg.innerJoins, rqg.innerJoins...),
NoDeps: ctx.SemTable.AndExpressions(lqg.NoDeps, rqg.NoDeps),
}
return op
}
return &Join{LHS: LHS, RHS: RHS}
}

func (j *Join) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator {
return AddPredicate(ctx, j, expr, false, newFilterSinglePredicate)
}
Expand All @@ -150,11 +180,14 @@ func (j *Join) SetRHS(operator Operator) {
}

func (j *Join) MakeInner() {
j.LeftJoin = false
if j.IsInner() {
return
}
j.JoinType = sqlparser.NormalJoinType
}

func (j *Join) IsInner() bool {
return !j.LeftJoin
return j.JoinType.IsInner()
}

func (j *Join) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) {
Expand Down
Loading
Loading