diff --git a/changelog/20.0/20.0.0/summary.md b/changelog/20.0/20.0.0/summary.md index 304ef5a44f0..16863248e68 100644 --- a/changelog/20.0/20.0.0/summary.md +++ b/changelog/20.0/20.0.0/summary.md @@ -9,7 +9,9 @@ - [Vindex Hints](#vindex-hints) - [Update with Limit Support](#update-limit) - [Update with Multi Table Support](#multi-table-update) + - [Update with Multi Target Support](#update-multi-target) - [Delete with Subquery Support](#delete-subquery) + - [Delete with Multi Target Support](#delete-multi-target) - **[Flag changes](#flag-changes)** - [`pprof-http` default change](#pprof-http-default) - [New `healthcheck-dial-concurrency` flag](#healthcheck-dial-concurrency-flag) @@ -59,12 +61,27 @@ Example: `update t1 join t2 on t1.id = t2.id join t3 on t1.col = t3.col set t1.b More details about how it works is available in [MySQL Docs](https://dev.mysql.com/doc/refman/8.0/en/update.html) +#### Update with Multi Target Support + +Support is added for sharded multi table target update. + +Example: `update t1 join t2 on t1.id = t2.id set t1.foo = 'abc', t2.bar = 23` + +More details about how it works is available in [MySQL Docs](https://dev.mysql.com/doc/refman/8.0/en/update.html) + #### Delete with Subquery Support Support is added for sharded table delete with subquery Example: `delete from t1 where id in (select col from t2 where foo = 32 and bar = 43)` +#### Delete with Multi Target Support + +Support is added for sharded multi table target delete. + +Example: `delete t1, t3 from t1 join t2 on t1.id = t2.id join t3 on t1.col = t3.col` + +More details about how it works is available in [MySQL Docs](https://dev.mysql.com/doc/refman/8.0/en/delete.html) ### Flag Changes diff --git a/go/test/endtoend/vtgate/queries/dml/dml_test.go b/go/test/endtoend/vtgate/queries/dml/dml_test.go index deca3f01caf..98db03bee0c 100644 --- a/go/test/endtoend/vtgate/queries/dml/dml_test.go +++ b/go/test/endtoend/vtgate/queries/dml/dml_test.go @@ -54,12 +54,6 @@ func TestMultiTableDelete(t *testing.T) { mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)") mcmp.Exec("insert into oevent_tbl(oid, ename) values (1,'a'), (2,'b'), (3,'a'), (4,'c')") - // check rows - mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, - `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`) - mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, - `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("c")]]`) - // multi table delete qr := mcmp.Exec(`delete o from order_tbl o join oevent_tbl ev where o.oid = ev.oid and ev.ename = 'a'`) assert.EqualValues(t, 2, qr.RowsAffected) @@ -91,12 +85,6 @@ func TestDeleteWithLimit(t *testing.T) { mcmp.Exec("insert into s_tbl(id, num) values (1,10), (2,10), (3,10), (4,20), (5,5), (6,15), (7,17), (8,80)") mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)") - // check rows - mcmp.AssertMatches(`select id, num from s_tbl order by id`, - `[[INT64(1) INT64(10)] [INT64(2) INT64(10)] [INT64(3) INT64(10)] [INT64(4) INT64(20)] [INT64(5) INT64(5)] [INT64(6) INT64(15)] [INT64(7) INT64(17)] [INT64(8) INT64(80)]]`) - mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, - `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`) - // delete with limit qr := mcmp.Exec(`delete from s_tbl order by num, id limit 3`) require.EqualValues(t, 3, qr.RowsAffected) @@ -152,12 +140,6 @@ func TestUpdateWithLimit(t *testing.T) { mcmp.Exec("insert into s_tbl(id, num) values (1,10), (2,10), (3,10), (4,20), (5,5), (6,15), (7,17), (8,80)") mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)") - // check rows - mcmp.AssertMatches(`select id, num from s_tbl order by id`, - `[[INT64(1) INT64(10)] [INT64(2) INT64(10)] [INT64(3) INT64(10)] [INT64(4) INT64(20)] [INT64(5) INT64(5)] [INT64(6) INT64(15)] [INT64(7) INT64(17)] [INT64(8) INT64(80)]]`) - mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, - `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`) - // update with limit qr := mcmp.Exec(`update s_tbl set num = 12 order by num, id limit 3`) require.EqualValues(t, 3, qr.RowsAffected) @@ -216,13 +198,7 @@ func TestMultiTableUpdate(t *testing.T) { mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)") mcmp.Exec("insert into oevent_tbl(oid, ename) values (1,'a'), (2,'b'), (3,'a'), (4,'c')") - // check rows - mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, - `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`) - mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, - `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("c")]]`) - - // multi table delete + // multi table update qr := mcmp.Exec(`update order_tbl o join oevent_tbl ev on o.oid = ev.oid set ev.ename = 'a' where ev.oid > 3`) assert.EqualValues(t, 1, qr.RowsAffected) @@ -253,12 +229,6 @@ func TestDeleteWithSubquery(t *testing.T) { mcmp.Exec("insert into s_tbl(id, num) values (1,10), (2,10), (3,10), (4,20), (5,5), (6,15), (7,17), (8,80)") mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)") - // check rows - mcmp.AssertMatches(`select id, num from s_tbl order by id`, - `[[INT64(1) INT64(10)] [INT64(2) INT64(10)] [INT64(3) INT64(10)] [INT64(4) INT64(20)] [INT64(5) INT64(5)] [INT64(6) INT64(15)] [INT64(7) INT64(17)] [INT64(8) INT64(80)]]`) - mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, - `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`) - // delete with subquery on s_tbl qr := mcmp.Exec(`delete from s_tbl where id in (select oid from order_tbl)`) require.EqualValues(t, 4, qr.RowsAffected) @@ -305,12 +275,6 @@ func TestMultiTargetDelete(t *testing.T) { mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)") mcmp.Exec("insert into oevent_tbl(oid, ename) values (1,'a'), (2,'b'), (3,'a'), (2,'c')") - // check rows - mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, - `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`) - mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, - `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(2) VARCHAR("c")] [INT64(3) VARCHAR("a")]]`) - // multi table delete qr := mcmp.Exec(`delete o, ev from order_tbl o join oevent_tbl ev where o.oid = ev.oid and ev.ename = 'a'`) assert.EqualValues(t, 4, qr.RowsAffected) @@ -368,3 +332,34 @@ func TestMultiTargetDeleteMore(t *testing.T) { mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(2) VARCHAR("c")] [INT64(3) VARCHAR("a")]]`) } + +// TestMultiTargetUpdate executed multi-target update queries +func TestMultiTargetUpdate(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + + // initial rows + mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)") + mcmp.Exec("insert into oevent_tbl(oid, ename) values (1,'a'), (2,'b'), (3,'a'), (4,'c')") + + // multi target update + qr := mcmp.Exec(`update order_tbl o join oevent_tbl ev on o.oid = ev.oid set ev.ename = 'a', o.cust_no = 1 where ev.oid > 3`) + assert.EqualValues(t, 2, qr.RowsAffected) + + // check rows + mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, + `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(1)]]`) + mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, + `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("a")]]`) + + qr = mcmp.Exec(`update order_tbl o, oevent_tbl ev set ev.ename = 'xyz', o.oid = 40 where o.cust_no = ev.oid and ev.ename = 'b'`) + assert.EqualValues(t, 2, qr.RowsAffected) + + // check rows + mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid, region_id`, + `[[INT64(1) INT64(1) INT64(4)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(1)] [INT64(1) INT64(40) INT64(2)]]`) + mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, + `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("xyz")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("a")]]`) +} diff --git a/go/test/endtoend/vtgate/queries/dml/sharded_schema.sql b/go/test/endtoend/vtgate/queries/dml/sharded_schema.sql index 3310724d420..8ddf9250e45 100644 --- a/go/test/endtoend/vtgate/queries/dml/sharded_schema.sql +++ b/go/test/endtoend/vtgate/queries/dml/sharded_schema.sql @@ -25,7 +25,8 @@ create table order_tbl oid bigint, region_id bigint, cust_no bigint unique key, - primary key (oid, region_id) + primary key (oid, region_id), + unique key (oid) ) Engine = InnoDB; create table oid_vdx_tbl diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 2537cf0020f..ddb5251dbb3 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2686,3 +2686,11 @@ func (node *Update) SetWherePredicate(expr Expr) { Expr: expr, } } + +// GetHighestOrderLock returns the higher level lock between the current lock and the new lock +func (lock Lock) GetHighestOrderLock(newLock Lock) Lock { + if newLock > lock { + return newLock + } + return lock +} diff --git a/go/vt/sqlparser/constants.go b/go/vt/sqlparser/constants.go index becaad2a2fe..cba5f7823c1 100644 --- a/go/vt/sqlparser/constants.go +++ b/go/vt/sqlparser/constants.go @@ -522,11 +522,11 @@ const ( // Constants for Enum Type - Lock const ( NoLock Lock = iota - ForUpdateLock ShareModeLock ForShareLock ForShareLockNoWait ForShareLockSkipLocked + ForUpdateLock ForUpdateLockNoWait ForUpdateLockSkipLocked ) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 831dae4ade2..6d2c2317517 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -26,6 +26,7 @@ import ( "vitess.io/vitess/go/slice" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/sysvars" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" @@ -701,6 +702,9 @@ func buildUpdateLogicalPlan( return nil, vterrors.VT12001("Vindex update should have ORDER BY clause when using LIMIT") } } + if upd.VerifyAll { + stmt.SetComments(stmt.GetParsedComments().SetMySQLSetVarValue(sysvars.ForeignKeyChecks, "OFF")) + } edml := createDMLPrimitive(ctx, rb, hints, upd.Target.VTable, generateQuery(stmt), vindexes, vQuery) diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index 8c7703ef096..bac61c51126 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -17,8 +17,6 @@ limitations under the License. package operators import ( - "sort" - "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -35,13 +33,6 @@ type Delete struct { noPredicates } -// delOp stores intermediary value for Delete Operator with the vindexes.Table for ordering. -type delOp struct { - op Operator - vTbl *vindexes.Table - cols []*sqlparser.ColName -} - // Clone implements the Operator interface func (d *Delete) Clone(inputs []Operator) Operator { newD := *d @@ -73,7 +64,7 @@ func (d *Delete) ShortDescription() string { } func createOperatorFromDelete(ctx *plancontext.PlanningContext, deleteStmt *sqlparser.Delete) (op Operator) { - childFks := ctx.SemTable.GetChildForeignKeysForTable(deleteStmt.Targets[0]) + childFks := ctx.SemTable.GetChildForeignKeysForTargets() // We check if delete with input plan is required. DML with input planning is generally // slower, because it does a selection and then creates a delete statement wherein we have to @@ -136,34 +127,17 @@ func createDeleteWithInputOp(ctx *plancontext.PlanningContext, del *sqlparser.De Lock: sqlparser.ForUpdateLock, } - var delOps []delOp - for _, target := range del.Targets { - op := createDeleteOpWithTarget(ctx, target) + var delOps []dmlOp + for _, target := range ctx.SemTable.Targets.Constituents() { + op := createDeleteOpWithTarget(ctx, target, del.Ignore) delOps = append(delOps, op) } - // sort the operator based on sharding vindex type. - // Unsharded < Lookup Vindex < Any - // This is needed to ensure all the rows are deleted from unowned sharding tables first. - // Otherwise, those table rows will be missed from getting deleted as - // the owned table row won't have matching values. - sort.Slice(delOps, func(i, j int) bool { - a, b := delOps[i], delOps[j] - // Get the first Vindex of a and b, if available - aVdx, bVdx := getFirstVindex(a.vTbl), getFirstVindex(b.vTbl) - - // Sort nil Vindexes to the start - if aVdx == nil || bVdx == nil { - return aVdx != nil // true if bVdx is nil and aVdx is not nil - } - - // Among non-nil Vindexes, those that need VCursor come first - return aVdx.NeedsVCursor() && !bVdx.NeedsVCursor() - }) + delOps = sortDmlOps(delOps) // now map the operator and column list. var colsList [][]*sqlparser.ColName - dmls := slice.Map(delOps, func(from delOp) Operator { + dmls := slice.Map(delOps, func(from dmlOp) Operator { colsList = append(colsList, from.cols) for _, col := range from.cols { selectStmt.SelectExprs = append(selectStmt.SelectExprs, aeWrap(col)) @@ -194,9 +168,8 @@ func getFirstVindex(vTbl *vindexes.Table) vindexes.Vindex { return nil } -func createDeleteOpWithTarget(ctx *plancontext.PlanningContext, target sqlparser.TableName) delOp { - ts := ctx.SemTable.Targets[target.Name] - ti, err := ctx.SemTable.TableInfoFor(ts) +func createDeleteOpWithTarget(ctx *plancontext.PlanningContext, target semantics.TableSet, ignore sqlparser.Ignore) dmlOp { + ti, err := ctx.SemTable.TableInfoFor(target) if err != nil { panic(vterrors.VT13001(err.Error())) } @@ -205,14 +178,18 @@ func createDeleteOpWithTarget(ctx *plancontext.PlanningContext, target sqlparser if len(vTbl.PrimaryKey) == 0 { panic(vterrors.VT09015()) } + tblName, err := ti.Name() + if err != nil { + panic(err) + } var leftComp sqlparser.ValTuple cols := make([]*sqlparser.ColName, 0, len(vTbl.PrimaryKey)) for _, col := range vTbl.PrimaryKey { - colName := sqlparser.NewColNameWithQualifier(col.String(), target) + colName := sqlparser.NewColNameWithQualifier(col.String(), tblName) cols = append(cols, colName) leftComp = append(leftComp, colName) - ctx.SemTable.Recursive[colName] = ts + ctx.SemTable.Recursive[colName] = target } // optimize for case when there is only single column on left hand side. var lhs sqlparser.Expr = leftComp @@ -222,11 +199,12 @@ func createDeleteOpWithTarget(ctx *plancontext.PlanningContext, target sqlparser compExpr := sqlparser.NewComparisonExpr(sqlparser.InOp, lhs, sqlparser.ListArg(engine.DmlVals), nil) del := &sqlparser.Delete{ + Ignore: ignore, TableExprs: sqlparser.TableExprs{ti.GetAliasedTableExpr()}, - Targets: sqlparser.TableNames{target}, + Targets: sqlparser.TableNames{tblName}, Where: sqlparser.NewWhere(sqlparser.WhereClause, compExpr), } - return delOp{ + return dmlOp{ createOperatorFromDelete(ctx, del), vTbl, cols, @@ -241,10 +219,9 @@ func createDeleteOperator(ctx *plancontext.PlanningContext, del *sqlparser.Delet op = addWherePredsToSubQueryBuilder(ctx, del.Where.Expr, op, sqc) } - target := del.Targets[0] - tblID, exists := ctx.SemTable.Targets[target.Name] - if !exists { - panic(vterrors.VT13001("delete target table should be part of semantic analyzer")) + tblID, err := ctx.SemTable.GetTargetTableSetForTableName(del.Targets[0]) + if err != nil { + panic(err) } tblInfo, err := ctx.SemTable.TableInfoFor(tblID) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/dml_planning.go b/go/vt/vtgate/planbuilder/operators/dml_planning.go index b8fa172b87c..6d51a33b4aa 100644 --- a/go/vt/vtgate/planbuilder/operators/dml_planning.go +++ b/go/vt/vtgate/planbuilder/operators/dml_planning.go @@ -18,12 +18,10 @@ package operators import ( "fmt" + "sort" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/engine" - "vitess.io/vitess/go/vt/vtgate/evalengine" - "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -41,6 +39,35 @@ type TargetTable struct { Name sqlparser.TableName } +// dmlOp stores intermediary value for Update/Delete Operator with the vindexes.Table for ordering. +type dmlOp struct { + op Operator + vTbl *vindexes.Table + cols []*sqlparser.ColName +} + +// sortDmlOps sort the operator based on sharding vindex type. +// Unsharded < Lookup Vindex < Any +// This is needed to ensure all the rows are deleted from unowned sharding tables first. +// Otherwise, those table rows will be missed from getting deleted as +// the owned table row won't have matching values. +func sortDmlOps(dmlOps []dmlOp) []dmlOp { + sort.Slice(dmlOps, func(i, j int) bool { + a, b := dmlOps[i], dmlOps[j] + // Get the first Vindex of a and b, if available + aVdx, bVdx := getFirstVindex(a.vTbl), getFirstVindex(b.vTbl) + + // Sort nil Vindexes to the start + if aVdx == nil || bVdx == nil { + return aVdx != nil // true if bVdx is nil and aVdx is not nil + } + + // Among non-nil Vindexes, those that need VCursor come first + return aVdx.NeedsVCursor() && !bVdx.NeedsVCursor() + }) + return dmlOps +} + func shortDesc(target TargetTable, ovq *sqlparser.Select) string { ovqString := "" if ovq != nil { @@ -66,113 +93,3 @@ func getVindexInformation(id semantics.TableSet, table *vindexes.Table) *vindexe } return table.ColumnVindexes[0] } - -func createAssignmentExpressions( - ctx *plancontext.PlanningContext, - assignments []SetExpr, - vcol sqlparser.IdentifierCI, - subQueriesArgOnChangedVindex []string, - vindexValueMap map[string]evalengine.Expr, - compExprs []sqlparser.Expr, -) ([]string, []sqlparser.Expr) { - // Searching in order of columns in colvindex. - found := false - for _, assignment := range assignments { - if !vcol.Equal(assignment.Name.Name) { - continue - } - if found { - panic(vterrors.VT03015(assignment.Name.Name)) - } - found = true - pv, err := evalengine.Translate(assignment.Expr.EvalExpr, &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, - Collation: ctx.SemTable.Collation, - Environment: ctx.VSchema.Environment(), - }) - if err != nil { - panic(invalidUpdateExpr(assignment.Name.Name.String(), assignment.Expr.EvalExpr)) - } - - if assignment.Expr.Info != nil { - sqe, ok := assignment.Expr.Info.(SubQueryExpression) - if ok { - for _, sq := range sqe { - subQueriesArgOnChangedVindex = append(subQueriesArgOnChangedVindex, sq.ArgName) - } - } - } - - vindexValueMap[vcol.String()] = pv - compExprs = append(compExprs, sqlparser.NewComparisonExpr(sqlparser.EqualOp, assignment.Name, assignment.Expr.EvalExpr, nil)) - } - return subQueriesArgOnChangedVindex, compExprs -} - -func buildChangedVindexesValues( - ctx *plancontext.PlanningContext, - update *sqlparser.Update, - table *vindexes.Table, - ksidCols []sqlparser.IdentifierCI, - assignments []SetExpr, -) (vv map[string]*engine.VindexValues, ownedVindexQuery *sqlparser.Select, subQueriesArgOnChangedVindex []string) { - changedVindexes := make(map[string]*engine.VindexValues) - selExprs, offset := initialQuery(ksidCols, table) - for i, vindex := range table.ColumnVindexes { - vindexValueMap := make(map[string]evalengine.Expr) - var compExprs []sqlparser.Expr - for _, vcol := range vindex.Columns { - subQueriesArgOnChangedVindex, compExprs = - createAssignmentExpressions(ctx, assignments, vcol, subQueriesArgOnChangedVindex, vindexValueMap, compExprs) - } - if len(vindexValueMap) == 0 { - // Vindex not changing, continue - continue - } - if i == 0 { - panic(vterrors.VT12001(fmt.Sprintf("you cannot UPDATE primary vindex columns; invalid update on vindex: %v", vindex.Name))) - } - if _, ok := vindex.Vindex.(vindexes.Lookup); !ok { - panic(vterrors.VT12001(fmt.Sprintf("you can only UPDATE lookup vindexes; invalid update on vindex: %v", vindex.Name))) - } - - // Checks done, let's actually add the expressions and the vindex map - selExprs = append(selExprs, aeWrap(sqlparser.AndExpressions(compExprs...))) - changedVindexes[vindex.Name] = &engine.VindexValues{ - EvalExprMap: vindexValueMap, - Offset: offset, - } - offset++ - } - if len(changedVindexes) == 0 { - return nil, nil, nil - } - // generate rest of the owned vindex query. - ovq := &sqlparser.Select{ - SelectExprs: selExprs, - OrderBy: update.OrderBy, - Limit: update.Limit, - Lock: sqlparser.ForUpdateLock, - } - return changedVindexes, ovq, subQueriesArgOnChangedVindex -} - -func initialQuery(ksidCols []sqlparser.IdentifierCI, table *vindexes.Table) (sqlparser.SelectExprs, int) { - var selExprs sqlparser.SelectExprs - offset := 0 - for _, col := range ksidCols { - selExprs = append(selExprs, aeWrap(sqlparser.NewColName(col.String()))) - offset++ - } - for _, cv := range table.Owned { - for _, column := range cv.Columns { - selExprs = append(selExprs, aeWrap(sqlparser.NewColName(column.String()))) - offset++ - } - } - return selExprs, offset -} - -func invalidUpdateExpr(upd string, expr sqlparser.Expr) error { - return vterrors.VT12001(fmt.Sprintf("only values are supported; invalid update on column: `%s` with expr: [%s]", upd, sqlparser.String(expr))) -} diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index e31d06122da..f214cb6512e 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -137,7 +137,7 @@ func pushLockAndComment(l *LockAndComment) (Operator, *ApplyResult) { return l, NoRewrite case *Route: src.Comments = l.Comments - src.Lock = l.Lock + src.Lock = l.Lock.GetHighestOrderLock(src.Lock) return src, Rewrote("put lock and comment into route") case *SubQueryContainer: src.Outer = &LockAndComment{ diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 7f97e62f41e..4c559fcf7f7 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -21,12 +21,14 @@ import ( "maps" "slices" + "vitess.io/vitess/go/slice" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/sysvars" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" @@ -44,6 +46,8 @@ type ( // On merging this information will be lost, so subquery merge is blocked. SubQueriesArgOnChangedVindex []string + VerifyAll bool + noColumns noPredicates } @@ -95,14 +99,14 @@ func (u *Update) ShortDescription() string { func createOperatorFromUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update) (op Operator) { errIfUpdateNotSupported(ctx, updStmt) - parentFks := ctx.SemTable.GetParentForeignKeysList() - childFks := ctx.SemTable.GetChildForeignKeysList() + parentFks := ctx.SemTable.GetParentForeignKeysForTargets() + childFks := ctx.SemTable.GetChildForeignKeysForTargets() // We check if dml with input plan is required. DML with input planning is generally // slower, because it does a selection and then creates a update statement wherein we have to // list all the primary key values. - if updateWithInputPlanningRequired(childFks, parentFks, updStmt) { - return updateWithInputPlanningForFk(ctx, updStmt) + if updateWithInputPlanningRequired(ctx, childFks, parentFks, updStmt) { + return createUpdateWithInputOp(ctx, updStmt) } var updClone *sqlparser.Update @@ -123,10 +127,54 @@ func createOperatorFromUpdate(ctx *plancontext.PlanningContext, updStmt *sqlpars return buildFkOperator(ctx, op, updClone, parentFks, childFks, vTbl) } -func updateWithInputPlanningForFk(ctx *plancontext.PlanningContext, upd *sqlparser.Update) Operator { +func updateWithInputPlanningRequired( + ctx *plancontext.PlanningContext, + childFks []vindexes.ChildFKInfo, + parentFks []vindexes.ParentFKInfo, + updateStmt *sqlparser.Update, +) bool { + if isMultiTargetUpdate(ctx, childFks, parentFks, updateStmt) { + return true + } + // If there are no foreign keys, we don't need to use delete with input. + if len(childFks) == 0 && len(parentFks) == 0 { + return false + } + // Limit requires dml with input. + if updateStmt.Limit != nil { + return true + } + return false +} + +func isMultiTargetUpdate(ctx *plancontext.PlanningContext, childFks []vindexes.ChildFKInfo, parentFks []vindexes.ParentFKInfo, updateStmt *sqlparser.Update) bool { + var targetTS semantics.TableSet + for _, ue := range updateStmt.Exprs { + targetTS = targetTS.Merge(ctx.SemTable.DirectDeps(ue.Name)) + } + if targetTS.NumberOfTables() == 1 { + return false + } + + if len(childFks) > 0 || len(parentFks) > 0 { + panic(vterrors.VT12001("multi table update with foreign keys")) + } + + return true +} + +func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Update) (op Operator) { updClone := ctx.SemTable.Clone(upd).(*sqlparser.Update) upd.Limit = nil + var updOps []dmlOp + for _, target := range ctx.SemTable.Targets.Constituents() { + op := createUpdateOpWithTarget(ctx, target, upd) + updOps = append(updOps, op) + } + + updOps = sortDmlOps(updOps) + selectStmt := &sqlparser.Select{ From: updClone.TableExprs, Where: updClone.Where, @@ -135,25 +183,60 @@ func updateWithInputPlanningForFk(ctx *plancontext.PlanningContext, upd *sqlpars Lock: sqlparser.ForUpdateLock, } - ate, isAliasTableExpr := upd.TableExprs[0].(*sqlparser.AliasedTableExpr) - if !isAliasTableExpr { - panic(vterrors.VT12001("update with limit with foreign key constraints using a complex table")) + // now map the operator and column list. + var colsList [][]*sqlparser.ColName + dmls := slice.Map(updOps, func(from dmlOp) Operator { + colsList = append(colsList, from.cols) + for _, col := range from.cols { + selectStmt.SelectExprs = append(selectStmt.SelectExprs, aeWrap(col)) + } + return from.op + }) + + op = &DMLWithInput{ + DML: dmls, + Source: createOperatorFromSelect(ctx, selectStmt), + cols: colsList, + } + + if upd.Comments != nil { + op = &LockAndComment{ + Source: op, + Comments: upd.Comments, + } } - ts := ctx.SemTable.TableSetFor(ate) - ti, err := ctx.SemTable.TableInfoFor(ts) + return op +} + +func createUpdateOpWithTarget(ctx *plancontext.PlanningContext, target semantics.TableSet, updStmt *sqlparser.Update) dmlOp { + var updExprs sqlparser.UpdateExprs + for _, ue := range updStmt.Exprs { + if ctx.SemTable.DirectDeps(ue.Name) == target { + updExprs = append(updExprs, ue) + } + } + + if len(updExprs) == 0 { + panic(vterrors.VT13001("no update expression for the target")) + } + + ti, err := ctx.SemTable.TableInfoFor(target) if err != nil { panic(vterrors.VT13001(err.Error())) } vTbl := ti.GetVindexTable() + tblName, err := ti.Name() + if err != nil { + panic(err) + } var leftComp sqlparser.ValTuple cols := make([]*sqlparser.ColName, 0, len(vTbl.PrimaryKey)) for _, col := range vTbl.PrimaryKey { - colName := sqlparser.NewColNameWithQualifier(col.String(), vTbl.GetTableName()) - selectStmt.SelectExprs = append(selectStmt.SelectExprs, aeWrap(colName)) + colName := sqlparser.NewColNameWithQualifier(col.String(), tblName) cols = append(cols, colName) leftComp = append(leftComp, colName) - ctx.SemTable.Recursive[colName] = ts + ctx.SemTable.Recursive[colName] = target } // optimize for case when there is only single column on left hand side. var lhs sqlparser.Expr = leftComp @@ -162,28 +245,21 @@ func updateWithInputPlanningForFk(ctx *plancontext.PlanningContext, upd *sqlpars } compExpr := sqlparser.NewComparisonExpr(sqlparser.InOp, lhs, sqlparser.ListArg(engine.DmlVals), nil) - upd.Where = sqlparser.NewWhere(sqlparser.WhereClause, compExpr) - return &DMLWithInput{ - DML: []Operator{createOperatorFromUpdate(ctx, upd)}, - Source: createOperatorFromSelect(ctx, selectStmt), - cols: [][]*sqlparser.ColName{cols}, - } -} - -func updateWithInputPlanningRequired(childFks []vindexes.ChildFKInfo, parentFks []vindexes.ParentFKInfo, updateStmt *sqlparser.Update) bool { - // If there are no foreign keys, we don't need to use delete with input. - if len(childFks) == 0 && len(parentFks) == 0 { - return false + upd := &sqlparser.Update{ + Ignore: updStmt.Ignore, + TableExprs: sqlparser.TableExprs{ti.GetAliasedTableExpr()}, + Exprs: updExprs, + Where: sqlparser.NewWhere(sqlparser.WhereClause, compExpr), + OrderBy: updStmt.OrderBy, } - // Limit requires dml with input. - if updateStmt.Limit != nil { - return true + return dmlOp{ + createOperatorFromUpdate(ctx, upd), + vTbl, + cols, } - return false } func errIfUpdateNotSupported(ctx *plancontext.PlanningContext, stmt *sqlparser.Update) { - var vTbl *vindexes.Table for _, ue := range stmt.Exprs { tblInfo, err := ctx.SemTable.TableInfoForExpr(ue.Name) if err != nil { @@ -197,13 +273,6 @@ func errIfUpdateNotSupported(ctx *plancontext.PlanningContext, stmt *sqlparser.U } panic(vterrors.VT03032(tblName)) } - - if vTbl == nil { - vTbl = tblInfo.GetVindexTable() - } - if vTbl != tblInfo.GetVindexTable() { - panic(vterrors.VT12001("multi-table UPDATE statement with multi-target column update")) - } } // Now we check if any of the foreign key columns that are being udpated have dependencies on other updated columns. @@ -281,6 +350,7 @@ func createUpdateOperator(ctx *plancontext.PlanningContext, updStmt *sqlparser.U Assignments: assignments, ChangedVindexValues: cvv, SubQueriesArgOnChangedVindex: subQueriesArgOnChangedVindex, + VerifyAll: ctx.VerifyAllFKs, } if len(updStmt.OrderBy) > 0 { @@ -926,3 +996,113 @@ func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updatedTable *vin return finalExpr } + +func buildChangedVindexesValues( + ctx *plancontext.PlanningContext, + update *sqlparser.Update, + table *vindexes.Table, + ksidCols []sqlparser.IdentifierCI, + assignments []SetExpr, +) (changedVindexes map[string]*engine.VindexValues, ovq *sqlparser.Select, subQueriesArgOnChangedVindex []string) { + changedVindexes = make(map[string]*engine.VindexValues) + selExprs, offset := initialQuery(ksidCols, table) + for i, vindex := range table.ColumnVindexes { + vindexValueMap := make(map[string]evalengine.Expr) + var compExprs []sqlparser.Expr + for _, vcol := range vindex.Columns { + subQueriesArgOnChangedVindex, compExprs = + createAssignmentExpressions(ctx, assignments, vcol, subQueriesArgOnChangedVindex, vindexValueMap, compExprs) + } + if len(vindexValueMap) == 0 { + // Vindex not changing, continue + continue + } + if i == 0 { + panic(vterrors.VT12001(fmt.Sprintf("you cannot UPDATE primary vindex columns; invalid update on vindex: %v", vindex.Name))) + } + if _, ok := vindex.Vindex.(vindexes.Lookup); !ok { + panic(vterrors.VT12001(fmt.Sprintf("you can only UPDATE lookup vindexes; invalid update on vindex: %v", vindex.Name))) + } + + // Checks done, let's actually add the expressions and the vindex map + selExprs = append(selExprs, aeWrap(sqlparser.AndExpressions(compExprs...))) + changedVindexes[vindex.Name] = &engine.VindexValues{ + EvalExprMap: vindexValueMap, + Offset: offset, + } + offset++ + } + if len(changedVindexes) == 0 { + return nil, nil, nil + } + // generate rest of the owned vindex query. + ovq = &sqlparser.Select{ + SelectExprs: selExprs, + OrderBy: update.OrderBy, + Limit: update.Limit, + Lock: sqlparser.ForUpdateLock, + } + return changedVindexes, ovq, subQueriesArgOnChangedVindex +} + +func initialQuery(ksidCols []sqlparser.IdentifierCI, table *vindexes.Table) (sqlparser.SelectExprs, int) { + var selExprs sqlparser.SelectExprs + offset := 0 + for _, col := range ksidCols { + selExprs = append(selExprs, aeWrap(sqlparser.NewColName(col.String()))) + offset++ + } + for _, cv := range table.Owned { + for _, column := range cv.Columns { + selExprs = append(selExprs, aeWrap(sqlparser.NewColName(column.String()))) + offset++ + } + } + return selExprs, offset +} + +func createAssignmentExpressions( + ctx *plancontext.PlanningContext, + assignments []SetExpr, + vcol sqlparser.IdentifierCI, + subQueriesArgOnChangedVindex []string, + vindexValueMap map[string]evalengine.Expr, + compExprs []sqlparser.Expr, +) ([]string, []sqlparser.Expr) { + // Searching in order of columns in colvindex. + found := false + for _, assignment := range assignments { + if !vcol.Equal(assignment.Name.Name) { + continue + } + if found { + panic(vterrors.VT03015(assignment.Name.Name)) + } + found = true + pv, err := evalengine.Translate(assignment.Expr.EvalExpr, &evalengine.Config{ + ResolveType: ctx.SemTable.TypeForExpr, + Collation: ctx.SemTable.Collation, + Environment: ctx.VSchema.Environment(), + }) + if err != nil { + panic(invalidUpdateExpr(assignment.Name.Name.String(), assignment.Expr.EvalExpr)) + } + + if assignment.Expr.Info != nil { + sqe, ok := assignment.Expr.Info.(SubQueryExpression) + if ok { + for _, sq := range sqe { + subQueriesArgOnChangedVindex = append(subQueriesArgOnChangedVindex, sq.ArgName) + } + } + } + + vindexValueMap[vcol.String()] = pv + compExprs = append(compExprs, sqlparser.NewComparisonExpr(sqlparser.EqualOp, assignment.Name, assignment.Expr.EvalExpr, nil)) + } + return subQueriesArgOnChangedVindex, compExprs +} + +func invalidUpdateExpr(upd string, expr sqlparser.Expr) error { + return vterrors.VT12001(fmt.Sprintf("only values are supported; invalid update on column: `%s` with expr: [%s]", upd, sqlparser.String(expr))) +} diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index 7fb3a577729..0b77104ee4c 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -6667,5 +6667,90 @@ "ordering.order_event" ] } + }, + { + "comment": "update with multi table reference with multi target update", + "query": "update ignore user u, music m set u.foo = 21, m.bar = 'abc' where u.col = m.col", + "plan": { + "QueryType": "UPDATE", + "Original": "update ignore user u, music m set u.foo = 21, m.bar = 'abc' where u.col = m.col", + "Instructions": { + "OperatorType": "DMLWithInput", + "TargetTabletType": "PRIMARY", + "Offset": [ + "0:[0]", + "1:[1]" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "u_col": 1 + }, + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.id, u.col from `user` as u where 1 != 1", + "Query": "select u.id, u.col from `user` as u for update", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select m.id from music as m where 1 != 1", + "Query": "select m.id from music as m where m.col = :u_col for update", + "Table": "music" + } + ] + }, + { + "OperatorType": "Update", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "update ignore `user` as u set u.foo = 21 where u.id in ::dml_vals", + "Table": "user", + "Values": [ + "::dml_vals" + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Update", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "update ignore music as m set m.bar = 'abc' where m.id in ::dml_vals", + "Table": "music", + "Values": [ + "::dml_vals" + ], + "Vindex": "music_user_map" + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index bf95af52f1e..68f49f41e64 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -3518,7 +3518,7 @@ "Sharded": false }, "FieldQuery": "select col14 from u_tbl1 where 1 != 1", - "Query": "select /*+ SET_VAR(foreign_key_checks=OFF) */ col14 from u_tbl1 where x = 2 and y = 4 lock in share mode", + "Query": "select col14 from u_tbl1 where x = 2 and y = 4 lock in share mode", "Table": "u_tbl1" }, { @@ -3534,7 +3534,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl1 on u_tbl1.col14 = cast(:__sq1 as SIGNED) where 1 != 1", - "Query": "select /*+ SET_VAR(foreign_key_checks=OFF) */ 1 from u_tbl4 left join u_tbl1 on u_tbl1.col14 = cast(:__sq1 as SIGNED) where u_tbl1.col14 is null and cast(:__sq1 as SIGNED) is not null and not (u_tbl4.col41) <=> (cast(:__sq1 as SIGNED)) and u_tbl4.col4 = 3 limit 1 lock in share mode", + "Query": "select 1 from u_tbl4 left join u_tbl1 on u_tbl1.col14 = cast(:__sq1 as SIGNED) where u_tbl1.col14 is null and cast(:__sq1 as SIGNED) is not null and not (u_tbl4.col41) <=> (cast(:__sq1 as SIGNED)) and u_tbl4.col4 = 3 limit 1 for share", "Table": "u_tbl1, u_tbl4" }, { @@ -3581,7 +3581,7 @@ "Sharded": false }, "FieldQuery": "select foo from u_tbl1 where 1 != 1", - "Query": "select /*+ SET_VAR(foreign_key_checks=OFF) */ foo from u_tbl1 where id = 1 lock in share mode", + "Query": "select foo from u_tbl1 where id = 1 lock in share mode", "Table": "u_tbl1" }, { @@ -3597,7 +3597,7 @@ "Sharded": false }, "FieldQuery": "select u_tbl1.col1 from u_tbl1 where 1 != 1", - "Query": "select /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl1.col1 from u_tbl1 order by id desc lock in share mode", + "Query": "select u_tbl1.col1 from u_tbl1 order by id desc for update", "Table": "u_tbl1" }, { @@ -3617,7 +3617,7 @@ "Sharded": false }, "FieldQuery": "select u_tbl2.col2 from u_tbl2 where 1 != 1", - "Query": "select /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl2.col2 from u_tbl2 where (col2) in ::fkc_vals lock in share mode", + "Query": "select u_tbl2.col2 from u_tbl2 where (col2) in ::fkc_vals for update", "Table": "u_tbl2" }, { @@ -3633,7 +3633,7 @@ "Cols": [ 0 ], - "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl3 set col3 = null where (col3) in ::fkc_vals1 and (cast(:__sq1 as CHAR) is null or (col3) not in ((cast(:__sq1 as CHAR))))", + "Query": "update u_tbl3 set col3 = null where (col3) in ::fkc_vals1 and (cast(:__sq1 as CHAR) is null or (col3) not in ((cast(:__sq1 as CHAR))))", "Table": "u_tbl3" }, { @@ -3667,7 +3667,7 @@ "Sharded": false }, "FieldQuery": "select u_tbl9.col9 from u_tbl9 where 1 != 1", - "Query": "select /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl9.col9 from u_tbl9 where (col9) in ::fkc_vals2 and (cast(:__sq1 as CHAR) is null or (col9) not in ((cast(:__sq1 as CHAR)))) lock in share mode", + "Query": "select u_tbl9.col9 from u_tbl9 where (col9) in ::fkc_vals2 and (cast(:__sq1 as CHAR) is null or (col9) not in ((cast(:__sq1 as CHAR)))) for update nowait", "Table": "u_tbl9" }, { @@ -3683,7 +3683,7 @@ "Cols": [ 0 ], - "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl8 set col8 = null where (col8) in ::fkc_vals3", + "Query": "update u_tbl8 set col8 = null where (col8) in ::fkc_vals3", "Table": "u_tbl8" }, { @@ -3695,7 +3695,7 @@ "Sharded": false }, "TargetTabletType": "PRIMARY", - "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl9 set col9 = null where (col9) in ::fkc_vals2 and (cast(:__sq1 as CHAR) is null or (col9) not in ((cast(:__sq1 as CHAR))))", + "Query": "update u_tbl9 set col9 = null where (col9) in ::fkc_vals2 and (cast(:__sq1 as CHAR) is null or (col9) not in ((cast(:__sq1 as CHAR))))", "Table": "u_tbl9" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json index 7ade2be3954..d6829962f64 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json @@ -814,7 +814,7 @@ "Sharded": false }, "FieldQuery": "select u_tbl2.id from u_tbl2 where 1 != 1", - "Query": "select u_tbl2.id from u_tbl2 limit 2 for update", + "Query": "select /*+ SET_VAR(foreign_key_checks=On) */ u_tbl2.id from u_tbl2 limit 2 for update", "Table": "u_tbl2" }, { @@ -829,7 +829,7 @@ "Sharded": false }, "FieldQuery": "select u_tbl2.col2 from u_tbl2 where 1 != 1", - "Query": "select u_tbl2.col2 from u_tbl2 where u_tbl2.id in ::dml_vals for update", + "Query": "select /*+ SET_VAR(foreign_key_checks=On) */ u_tbl2.col2 from u_tbl2 where u_tbl2.id in ::dml_vals for update", "Table": "u_tbl2" }, { @@ -845,7 +845,7 @@ "Cols": [ 0 ], - "Query": "update /*+ SET_VAR(foreign_key_checks=ON) */ u_tbl3 set col3 = null where (col3) in ::fkc_vals and (col3) not in ((cast('bar' as CHAR)))", + "Query": "update /*+ SET_VAR(foreign_key_checks=On) */ u_tbl3 set col3 = null where (col3) in ::fkc_vals and (col3) not in ((cast('bar' as CHAR)))", "Table": "u_tbl3" }, { @@ -1338,7 +1338,7 @@ "Sharded": false }, "FieldQuery": "select u_tbl2.id from u_tbl2 where 1 != 1", - "Query": "select /*+ SET_VAR(foreign_key_checks=On) */ u_tbl2.id from u_tbl2 limit 2", + "Query": "select /*+ SET_VAR(foreign_key_checks=On) */ u_tbl2.id from u_tbl2 limit 2 for update", "Table": "u_tbl2" }, { @@ -1353,7 +1353,7 @@ "Sharded": false }, "FieldQuery": "select u_tbl2.col2 from u_tbl2 where 1 != 1", - "Query": "select /*+ SET_VAR(foreign_key_checks=On) */ u_tbl2.col2 from u_tbl2 where u_tbl2.id in ::dml_vals", + "Query": "select /*+ SET_VAR(foreign_key_checks=On) */ u_tbl2.col2 from u_tbl2 where u_tbl2.id in ::dml_vals for update", "Table": "u_tbl2" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 887b59c50db..c12430a3df4 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -344,11 +344,6 @@ "query": "SELECT COUNT(DISTINCT col), SUM(DISTINCT id) FROM user", "plan": "VT12001: unsupported: only one DISTINCT aggregation is allowed in a SELECT: sum(distinct id)" }, - { - "comment": "update with multi table reference with multi target update", - "query": "update ignore user u, music m set u.foo = 21, m.bar = 'abc' where u.col = m.col", - "plan": "VT12001: unsupported: multi-table UPDATE statement with multi-target column update" - }, { "comment": "Over clause isn't supported in sharded cases", "query": "SELECT val, CUME_DIST() OVER w, ROW_NUMBER() OVER w, DENSE_RANK() OVER w, PERCENT_RANK() OVER w, RANK() OVER w AS 'cd' FROM user", diff --git a/go/vt/vtgate/planbuilder/update.go b/go/vt/vtgate/planbuilder/update.go index eb21546224c..124eaf87310 100644 --- a/go/vt/vtgate/planbuilder/update.go +++ b/go/vt/vtgate/planbuilder/update.go @@ -19,7 +19,6 @@ package planbuilder import ( querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/sysvars" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators" @@ -51,7 +50,6 @@ func gen4UpdateStmtPlanner( if ctx.SemTable.HasNonLiteralForeignKeyUpdate(updStmt.Exprs) { // Since we are running the query with foreign key checks off, we have to verify all the foreign keys validity on vtgate. ctx.VerifyAllFKs = true - updStmt.SetComments(updStmt.GetParsedComments().SetMySQLSetVarValue(sysvars.ForeignKeyChecks, "OFF")) } // Remove all the foreign keys that don't require any handling. diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index f604f2a4ec7..bfd5f413f80 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -145,7 +145,6 @@ func (a *analyzer) newSemTable( NotUnshardedErr: a.unshardedErr, Recursive: ExprDependencies{}, Direct: ExprDependencies{}, - Targets: map[sqlparser.IdentifierCS]TableSet{}, ColumnEqualities: map[columnName][]sqlparser.Expr{}, ExpandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, columns: map[*sqlparser.Union]sqlparser.SelectExprs{}, diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index 27a34a427f1..deb84538740 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -289,7 +289,8 @@ func TestBindingDelete(t *testing.T) { ts := semTable.TableSetFor(t1) assert.Equal(t, SingleTableSet(0), ts) - actualTs := semTable.Targets[del.Targets[0].Name] + actualTs, err := semTable.GetTargetTableSetForTableName(del.Targets[0]) + require.NoError(t, err) assert.Equal(t, ts, actualTs) }) } diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index f93dd579898..d77811860a7 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -31,7 +31,7 @@ import ( type binder struct { recursive ExprDependencies direct ExprDependencies - targets map[sqlparser.IdentifierCS]TableSet + targets TableSet scoper *scoper tc *tableCollector org originable @@ -47,7 +47,6 @@ func newBinder(scoper *scoper, org originable, tc *tableCollector, typer *typer) return &binder{ recursive: map[sqlparser.Expr]TableSet{}, direct: map[sqlparser.Expr]TableSet{}, - targets: map[sqlparser.IdentifierCS]TableSet{}, scoper: scoper, org: org, tc: tc, @@ -70,11 +69,22 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { return b.bindUnion(node) case sqlparser.TableNames: return b.bindTableNames(cursor, node) + case *sqlparser.UpdateExpr: + return b.bindUpdateExpr(node) default: return nil } } +func (b *binder) bindUpdateExpr(ue *sqlparser.UpdateExpr) error { + ts, ok := b.direct[ue.Name] + if !ok { + return nil + } + b.targets = b.targets.Merge(ts) + return nil +} + func (b *binder) bindTableNames(cursor *sqlparser.Cursor, tables sqlparser.TableNames) error { _, isDelete := cursor.Parent().(*sqlparser.Delete) if !isDelete { @@ -86,7 +96,7 @@ func (b *binder) bindTableNames(cursor *sqlparser.Cursor, tables sqlparser.Table if err != nil { return err } - b.targets[target.Name] = finalDep.direct + b.targets = b.targets.Merge(finalDep.direct) } return nil } diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 91c535ffaff..a0bf0624044 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -119,7 +119,8 @@ type ( // It doesn't recurse inside derived tables to find the original dependencies. Direct ExprDependencies - Targets map[sqlparser.IdentifierCS]TableSet + // Targets contains the TableSet of each table getting modified by the update/delete statement. + Targets TableSet // ColumnEqualities is used for transitive closures (e.g., if a == b and b == c, then a == c). ColumnEqualities map[columnName][]sqlparser.Expr @@ -189,9 +190,17 @@ func (st *SemTable) CopyDependencies(from, to sqlparser.Expr) { } } -// GetChildForeignKeysForTable gets the child foreign keys as a list for the specified table. -func (st *SemTable) GetChildForeignKeysForTable(tableName sqlparser.TableName) []vindexes.ChildFKInfo { - return st.childForeignKeysInvolved[st.Targets[tableName.Name]] +// GetChildForeignKeysForTargets gets the child foreign keys as a list for all the target tables. +func (st *SemTable) GetChildForeignKeysForTargets() (fks []vindexes.ChildFKInfo) { + for _, ts := range st.Targets.Constituents() { + fks = append(fks, st.childForeignKeysInvolved[ts]...) + } + return fks +} + +// GetChildForeignKeysForTableSet gets the child foreign keys as a list for the specified TableSet. +func (st *SemTable) GetChildForeignKeysForTableSet(ts TableSet) []vindexes.ChildFKInfo { + return st.childForeignKeysInvolved[ts] } // GetChildForeignKeysList gets the child foreign keys as a list. @@ -203,6 +212,14 @@ func (st *SemTable) GetChildForeignKeysList() []vindexes.ChildFKInfo { return childFkInfos } +// GetParentForeignKeysForTargets gets the parent foreign keys as a list for all the target tables. +func (st *SemTable) GetParentForeignKeysForTargets() (fks []vindexes.ParentFKInfo) { + for _, ts := range st.Targets.Constituents() { + fks = append(fks, st.parentForeignKeysInvolved[ts]...) + } + return fks +} + // GetParentForeignKeysList gets the parent foreign keys as a list. func (st *SemTable) GetParentForeignKeysList() []vindexes.ParentFKInfo { var parentFkInfos []vindexes.ParentFKInfo @@ -928,6 +945,7 @@ func (st *SemTable) Clone(n sqlparser.SQLNode) sqlparser.SQLNode { }, st.CopySemanticInfo) } +// UpdateChildFKExpr updates the child foreign key expression with the new expression. func (st *SemTable) UpdateChildFKExpr(origUpdExpr *sqlparser.UpdateExpr, newExpr sqlparser.Expr) { for _, exprs := range st.childFkToUpdExprs { for idx, updateExpr := range exprs { @@ -937,3 +955,17 @@ func (st *SemTable) UpdateChildFKExpr(origUpdExpr *sqlparser.UpdateExpr, newExpr } } } + +// GetTargetTableSetForTableName returns the TableSet for the given table name from the target tables. +func (st *SemTable) GetTargetTableSetForTableName(name sqlparser.TableName) (TableSet, error) { + for _, target := range st.Targets.Constituents() { + tbl, err := st.Tables[target.TableOffset()].Name() + if err != nil { + return "", err + } + if tbl.Name == name.Name { + return target, nil + } + } + return "", vterrors.Errorf(vtrpcpb.Code_INTERNAL, "target table '%s' not found", sqlparser.String(name)) +}