diff --git a/go/vt/vtgate/engine/dml_with_input.go b/go/vt/vtgate/engine/dml_with_input.go index a52de30ad19..0974f753cef 100644 --- a/go/vt/vtgate/engine/dml_with_input.go +++ b/go/vt/vtgate/engine/dml_with_input.go @@ -89,6 +89,8 @@ func (dml *DMLWithInput) TryExecute(ctx context.Context, vcursor VCursor, bindVa return res, nil } +// executeLiteralUpdate executes the primitive that can be executed with a single bind variable from the input result. +// The column updated have same value for all rows in the input result. func executeLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int) (*sqltypes.Result, error) { var bv *querypb.BindVariable if len(outputCols) == 1 { @@ -122,6 +124,8 @@ func getBVMulti(rows []sqltypes.Row, offsets []int) *querypb.BindVariable { return bv } +// executeNonLiteralUpdate executes the primitive that needs to be executed per row from the input result. +// The column updated might have different value for each row in the input result. func executeNonLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int, vars map[string]int) (qr *sqltypes.Result, err error) { var res *sqltypes.Result for _, row := range inputRes.Rows { @@ -175,7 +179,7 @@ func (dml *DMLWithInput) description() PrimitiveDescription { if len(vars) == 0 { continue } - bvList = append(bvList, fmt.Sprintf("%d:%v", idx, vars)) + bvList = append(bvList, fmt.Sprintf("%d:[%s]", idx, orderedStringIntMap(vars))) } if len(bvList) > 0 { other["BindVars"] = bvList diff --git a/go/vt/vtgate/engine/plan_description.go b/go/vt/vtgate/engine/plan_description.go index 72220fda460..a8daa25ecd0 100644 --- a/go/vt/vtgate/engine/plan_description.go +++ b/go/vt/vtgate/engine/plan_description.go @@ -21,6 +21,7 @@ import ( "encoding/json" "fmt" "sort" + "strings" "vitess.io/vitess/go/tools/graphviz" "vitess.io/vitess/go/vt/key" @@ -266,3 +267,11 @@ func (m orderedMap) MarshalJSON() ([]byte, error) { buf.WriteString("}") return buf.Bytes(), nil } + +func (m orderedMap) String() string { + var output []string + for _, val := range m { + output = append(output, fmt.Sprintf("%s:%v", val.key, val.val)) + } + return strings.Join(output, " ") +} diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 13b97269c73..4abf319ad08 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -219,31 +219,30 @@ func prepareUpdateExpressionList(ctx *plancontext.PlanningContext, upd *sqlparse // E.g. UPDATE t1 join t2 on t1.col = t2.col SET t1.col = t2.col + 1 where t2.col = 10; // SET t1.col = t2.col + 1 -> SET t1.col = :t2_col + 1 (t2_col is the bindvar column which will be provided from the input) ueMap := make(map[semantics.TableSet]updList) - var dependentCols updList for _, ue := range upd.Exprs { target := ctx.SemTable.DirectDeps(ue.Name) exprDeps := ctx.SemTable.RecursiveDeps(ue.Expr) jc := breakExpressionInLHSandRHS(ctx, ue.Expr, exprDeps.Remove(target)) - updCol := updColumn{ue.Name, jc} - ueMap[target] = append(ueMap[target], updCol) - dependentCols = append(dependentCols, updCol) + ueMap[target] = append(ueMap[target], updColumn{ue.Name, jc}) } // Check if any of the dependent columns are updated in the same query. // This can result in a mismatch of rows on how MySQL interprets it and how Vitess would have updated those rows. // It is safe to fail for those cases. - errIfDependentColumnUpdated(ctx, upd, dependentCols) + errIfDependentColumnUpdated(ctx, upd, ueMap) return ueMap } -func errIfDependentColumnUpdated(ctx *plancontext.PlanningContext, upd *sqlparser.Update, dependentCols updList) { +func errIfDependentColumnUpdated(ctx *plancontext.PlanningContext, upd *sqlparser.Update, ueMap map[semantics.TableSet]updList) { for _, ue := range upd.Exprs { - for _, dc := range dependentCols { - for _, bvExpr := range dc.jc.LHSExprs { - if ctx.SemTable.EqualsExprWithDeps(ue.Name, bvExpr.Expr) { - panic(vterrors.VT12001( - fmt.Sprintf("'%s' column referenced in update expression '%s' is itself updated", sqlparser.String(ue.Name), sqlparser.String(dc.jc.Original)))) + for _, list := range ueMap { + for _, dc := range list { + for _, bvExpr := range dc.jc.LHSExprs { + if ctx.SemTable.EqualsExprWithDeps(ue.Name, bvExpr.Expr) { + panic(vterrors.VT12001( + fmt.Sprintf("'%s' column referenced in update expression '%s' is itself updated", sqlparser.String(ue.Name), sqlparser.String(dc.jc.Original)))) + } } } } diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index 591a41e1fc9..9c2ed1920ee 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -5829,7 +5829,7 @@ "OperatorType": "DMLWithInput", "TargetTabletType": "PRIMARY", "BindVars": [ - "0:map[ue_col:1]" + "0:[ue_col:1]" ], "Offset": [ "0:[0]" @@ -5895,6 +5895,82 @@ ] } }, + { + "comment": "update with multi table join with single target having multiple dependent column update", + "query": "update user as u, user_extra as ue set u.col = ue.foo + ue.bar + u.baz where u.id = ue.id", + "plan": { + "QueryType": "UPDATE", + "Original": "update user as u, user_extra as ue set u.col = ue.foo + ue.bar + u.baz where u.id = ue.id", + "Instructions": { + "OperatorType": "DMLWithInput", + "TargetTabletType": "PRIMARY", + "BindVars": [ + "0:[ue_bar:2 ue_foo:1]" + ], + "Offset": [ + "0:[0]" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "R:0,L:0,L:1", + "JoinVars": { + "ue_id": 2 + }, + "TableName": "user_extra_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select ue.foo, ue.bar, ue.id from user_extra as ue where 1 != 1", + "Query": "select ue.foo, ue.bar, ue.id from user_extra as ue for update", + "Table": "user_extra" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.id from `user` as u where 1 != 1", + "Query": "select u.id from `user` as u where u.id = :ue_id for update", + "Table": "`user`", + "Values": [ + ":ue_id" + ], + "Vindex": "user_index" + } + ] + }, + { + "OperatorType": "Update", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "update `user` as u set u.col = :ue_foo + :ue_bar + u.baz where u.id in ::dml_vals", + "Table": "user", + "Values": [ + "::dml_vals" + ], + "Vindex": "user_index" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, { "comment": "update with multi table join with multi target having dependent column update", "query": "update user, user_extra ue set user.name = ue.id + 'foo', ue.bar = user.baz where user.id = ue.id and user.id = 1", @@ -5905,8 +5981,8 @@ "OperatorType": "DMLWithInput", "TargetTabletType": "PRIMARY", "BindVars": [ - "0:map[ue_id:1]", - "1:map[user_baz:3]" + "0:[ue_id:1]", + "1:[user_baz:3]" ], "Offset": [ "0:[0]",