Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reference Table DML Join Fix #17414

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/cte_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator
}

func tryMergeCTE(ctx *plancontext.PlanningContext, seed, term Operator, in *RecurseCTE) *Route {
seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term)
seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, seed, term)
if seedRoute == nil {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func createDeleteWithInputOp(ctx *plancontext.PlanningContext, del *sqlparser.De
}

var delOps []dmlOp
for _, target := range ctx.SemTable.Targets.Constituents() {
for _, target := range ctx.SemTable.DMLTargets.Constituents() {
op := createDeleteOpWithTarget(ctx, target, del.Ignore)
delOps = append(delOps, op)
}
Expand Down
14 changes: 8 additions & 6 deletions go/vt/vtgate/planbuilder/operators/join_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// If they can be merged, a new operator with the merged routing is returned
// If they cannot be merged, nil is returned.
func (jm *joinMerger) mergeJoinInputs(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr) *Route {
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs)
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs)
if lhsRoute == nil {
return nil
}
Expand Down Expand Up @@ -102,13 +102,13 @@ func mergeAnyShardRoutings(ctx *plancontext.PlanningContext, a, b *AnyShardRouti
}
}

func prepareInputRoutes(lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) {
func prepareInputRoutes(ctx *plancontext.PlanningContext, lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) {
lhsRoute, rhsRoute := operatorsToRoutes(lhs, rhs)
if lhsRoute == nil || rhsRoute == nil {
return nil, nil, nil, nil, 0, 0, false
}

lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(lhsRoute, rhsRoute)
lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(ctx, lhsRoute, rhsRoute)

a, b := getRoutingType(routingA), getRoutingType(routingB)
return lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace
Expand Down Expand Up @@ -159,7 +159,7 @@ func (rt routingType) String() string {

// getRoutesOrAlternates gets the Routings from each Route. If they are from different keyspaces,
// we check if this is a table with alternates in other keyspaces that we can use
func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) {
func getRoutesOrAlternates(ctx *plancontext.PlanningContext, lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) {
routingA := lhsRoute.Routing
routingB := rhsRoute.Routing
sameKeyspace := routingA.Keyspace() == routingB.Keyspace()
Expand All @@ -171,13 +171,15 @@ func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing,
return lhsRoute, rhsRoute, routingA, routingB, sameKeyspace
}

if refA, ok := routingA.(*AnyShardRouting); ok {
if refA, ok := routingA.(*AnyShardRouting); ok &&
!TableID(lhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) {
if altARoute := refA.AlternateInKeyspace(routingB.Keyspace()); altARoute != nil {
return altARoute, rhsRoute, altARoute.Routing, routingB, true
}
}

if refB, ok := routingB.(*AnyShardRouting); ok {
if refB, ok := routingB.(*AnyShardRouting); ok &&
!TableID(rhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) {
if altBRoute := refB.AlternateInKeyspace(routingA.Keyspace()); altBRoute != nil {
return lhsRoute, altBRoute, routingA, altBRoute.Routing, true
}
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ func rewriteOriginalPushedToRHS(ctx *plancontext.PlanningContext, expression sql
// this is used when we push an operator from above the subquery into the outer side of the subquery
func rewriteColNameToArgument(
ctx *plancontext.PlanningContext,
in sqlparser.Expr, // the expression to rewrite
se SubQueryExpression, // the subquery expression we are rewriting
in sqlparser.Expr, // the expression to rewrite
se SubQueryExpression, // the subquery expression we are rewriting
subqueries ...*SubQuery, // the inner subquery operators
) sqlparser.Expr {
// the visitor function that will rewrite the expression tree
Expand Down Expand Up @@ -730,7 +730,7 @@ func mergeSubqueryInputs(ctx *plancontext.PlanningContext, in, out Operator, joi
return nil
}

inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(inRoute, outRoute)
inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(ctx, inRoute, outRoute)
inner, outer := getRoutingType(inRouting), getRoutingType(outRouting)

switch {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/union_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func mergeUnionInputs(
lhsExprs, rhsExprs sqlparser.SelectExprs,
distinct bool,
) (Operator, sqlparser.SelectExprs) {
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs)
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs)
if lhsRoute == nil {
return nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up
ueMap := prepareUpdateExpressionList(ctx, upd)

var updOps []dmlOp
for _, target := range ctx.SemTable.Targets.Constituents() {
for _, target := range ctx.SemTable.DMLTargets.Constituents() {
op := createUpdateOpWithTarget(ctx, upd, target, ueMap[target])
updOps = append(updOps, op)
}
Expand Down Expand Up @@ -308,7 +308,7 @@ func errIfUpdateNotSupported(ctx *plancontext.PlanningContext, stmt *sqlparser.U
}
}

// Now we check if any of the foreign key columns that are being udpated have dependencies on other updated columns.
// Now we check if any of the foreign key columns that are being updated have dependencies on other updated columns.
// This is unsafe, and we currently don't support this in Vitess.
if err := ctx.SemTable.ErrIfFkDependentColumnUpdated(stmt.Exprs); err != nil {
panic(err)
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ func (s *planTestSuite) TestOne() {
s.addPKsProvided(vschema, "user", []string{"user_extra"}, []string{"id", "user_id"})
s.addPKsProvided(vschema, "ordering", []string{"order"}, []string{"oid", "region_id"})
s.addPKsProvided(vschema, "ordering", []string{"order_event"}, []string{"oid", "ename"})
s.addPKsProvided(vschema, "legacy", []string{"users"}, []string{"idUser"})

s.testFile("onecase.json", vw, false)
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/testdata/onecase.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"comment": "Add your test case here for debugging and run go test -run=One.",
"query": "",
"query": "update legacy.users as u join legacy.users2sites as u2s on u2s.idUser = u.idUser inner join sites2024.sites as s on s.idSite = u2s.idSite set u.email = '[email protected]' where s.idSite = 234234234;",
"plan": {
}
}
Expand Down
35 changes: 35 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/vschemas/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,41 @@
]
}
}
},
"sites2024": {
"sharded": true,
"require_explicit_routing": false,
"vindexes": {
"a_standard_hash": {
"type": "xxhash"
}
},
"tables": {
"users": {
"type": "reference",
"source": "legacy.users"
},
"users2sites": {
"type": "reference",
"source": "legacy.users2sites"
},
"sites": {
"column_vindexes": [
{
"column": "idSite",
"name": "a_standard_hash"
}
]
}
}
},
"legacy": {
"require_explicit_routing": false,
"sharded": false,
"tables": {
"users": {},
"users2sites": {}
}
}
}
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (a *analyzer) newSemTable(
Direct: a.binder.direct,
ExprTypes: a.typer.m,
Tables: a.tables.Tables,
Targets: a.binder.targets,
DMLTargets: a.binder.targets,
NotSingleRouteErr: a.notSingleRouteErr,
NotUnshardedErr: a.unshardedErr,
Warning: a.warning,
Expand Down
14 changes: 7 additions & 7 deletions go/vt/vtgate/semantics/semantic_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ type (
// It doesn't recurse inside derived tables to find the original dependencies.
Direct ExprDependencies

// Targets contains the TableSet of each table getting modified by the update/delete statement.
Targets TableSet
// DMLTargets contains the TableSet of each table getting modified by the update/delete statement.
DMLTargets TableSet

// ColumnEqualities is used for transitive closures (e.g., if a == b and b == c, then a == c).
ColumnEqualities map[columnName][]sqlparser.Expr
Expand Down Expand Up @@ -202,15 +202,15 @@ func (st *SemTable) CopyDependencies(from, to sqlparser.Expr) {

// GetChildForeignKeysForTargets gets the child foreign keys as a list for all the target tables.
func (st *SemTable) GetChildForeignKeysForTargets() (fks []vindexes.ChildFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
fks = append(fks, st.childForeignKeysInvolved[ts]...)
}
return fks
}

// GetChildForeignKeysForTableSet gets the child foreign keys as a listfor the TableSet.
func (st *SemTable) GetChildForeignKeysForTableSet(target TableSet) (fks []vindexes.ChildFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
if target.IsSolvedBy(ts) {
fks = append(fks, st.childForeignKeysInvolved[ts]...)
}
Expand Down Expand Up @@ -238,15 +238,15 @@ func (st *SemTable) GetChildForeignKeysList() []vindexes.ChildFKInfo {

// GetParentForeignKeysForTargets gets the parent foreign keys as a list for all the target tables.
func (st *SemTable) GetParentForeignKeysForTargets() (fks []vindexes.ParentFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
fks = append(fks, st.parentForeignKeysInvolved[ts]...)
}
return fks
}

// GetParentForeignKeysForTableSet gets the parent foreign keys as a list for the TableSet.
func (st *SemTable) GetParentForeignKeysForTableSet(target TableSet) (fks []vindexes.ParentFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
if target.IsSolvedBy(ts) {
fks = append(fks, st.parentForeignKeysInvolved[ts]...)
}
Expand Down Expand Up @@ -970,7 +970,7 @@ func (st *SemTable) UpdateChildFKExpr(origUpdExpr *sqlparser.UpdateExpr, newExpr

// GetTargetTableSetForTableName returns the TableSet for the given table name from the target tables.
func (st *SemTable) GetTargetTableSetForTableName(name sqlparser.TableName) (TableSet, error) {
for _, target := range st.Targets.Constituents() {
for _, target := range st.DMLTargets.Constituents() {
tbl, err := st.Tables[target.TableOffset()].Name()
if err != nil {
return "", err
Expand Down
Loading