diff --git a/go/test/endtoend/vtgate/queries/dml/dml_test.go b/go/test/endtoend/vtgate/queries/dml/dml_test.go index deca3f01caf..4ed54cad489 100644 --- a/go/test/endtoend/vtgate/queries/dml/dml_test.go +++ b/go/test/endtoend/vtgate/queries/dml/dml_test.go @@ -222,7 +222,7 @@ func TestMultiTableUpdate(t *testing.T) { mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("c")]]`) - // multi table delete + // multi table update qr := mcmp.Exec(`update order_tbl o join oevent_tbl ev on o.oid = ev.oid set ev.ename = 'a' where ev.oid > 3`) assert.EqualValues(t, 1, qr.RowsAffected) @@ -368,3 +368,40 @@ func TestMultiTargetDeleteMore(t *testing.T) { mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(2) VARCHAR("c")] [INT64(3) VARCHAR("a")]]`) } + +// TestMultiTargetUpdate executed multi-target update queries +func TestMultiTargetUpdate(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + + // initial rows + mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)") + mcmp.Exec("insert into oevent_tbl(oid, ename) values (1,'a'), (2,'b'), (3,'a'), (4,'c')") + + // check rows + mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, + `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`) + mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, + `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("c")]]`) + + // multi target update + qr := mcmp.Exec(`update order_tbl o join oevent_tbl ev on o.oid = ev.oid set ev.ename = 'a', o.cust_no = 1 where ev.oid > 3`) + assert.EqualValues(t, 2, qr.RowsAffected) + + // check rows + mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, + `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(1)]]`) + mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, + `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("a")]]`) + + qr = mcmp.Exec(`update order_tbl o, oevent_tbl ev set ev.ename = 'xyz', o.oid = 40 where o.cust_no = ev.oid and ev.ename = 'b'`) + assert.EqualValues(t, 2, qr.RowsAffected) + + // check rows + mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid, region_id`, + `[[INT64(1) INT64(1) INT64(4)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(1)] [INT64(1) INT64(40) INT64(2)]]`) + mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, + `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("xyz")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("a")]]`) +} diff --git a/go/test/endtoend/vtgate/queries/dml/sharded_schema.sql b/go/test/endtoend/vtgate/queries/dml/sharded_schema.sql index 3310724d420..8ddf9250e45 100644 --- a/go/test/endtoend/vtgate/queries/dml/sharded_schema.sql +++ b/go/test/endtoend/vtgate/queries/dml/sharded_schema.sql @@ -25,7 +25,8 @@ create table order_tbl oid bigint, region_id bigint, cust_no bigint unique key, - primary key (oid, region_id) + primary key (oid, region_id), + unique key (oid) ) Engine = InnoDB; create table oid_vdx_tbl diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index b44edc2a9e5..6ac11dd7d73 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -231,11 +231,15 @@ func createUpdateOpWithTarget(ctx *plancontext.PlanningContext, target semantics panic(vterrors.VT13001(err.Error())) } vTbl := ti.GetVindexTable() + tblName, err := ti.Name() + if err != nil { + panic(err) + } var leftComp sqlparser.ValTuple cols := make([]*sqlparser.ColName, 0, len(vTbl.PrimaryKey)) for _, col := range vTbl.PrimaryKey { - colName := sqlparser.NewColNameWithQualifier(col.String(), vTbl.GetTableName()) + colName := sqlparser.NewColNameWithQualifier(col.String(), tblName) cols = append(cols, colName) leftComp = append(leftComp, colName) ctx.SemTable.Recursive[colName] = target diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index 246e9684987..0b77104ee4c 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -6698,8 +6698,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `user`.id, u.col from `user` as u where 1 != 1", - "Query": "select `user`.id, u.col from `user` as u for update", + "FieldQuery": "select u.id, u.col from `user` as u where 1 != 1", + "Query": "select u.id, u.col from `user` as u for update", "Table": "`user`" }, { @@ -6709,8 +6709,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select music.id from music as m where 1 != 1", - "Query": "select music.id from music as m where m.col = :u_col for update", + "FieldQuery": "select m.id from music as m where 1 != 1", + "Query": "select m.id from music as m where m.col = :u_col for update", "Table": "music" } ] @@ -6723,7 +6723,7 @@ "Sharded": true }, "TargetTabletType": "PRIMARY", - "Query": "update ignore `user` as u set u.foo = 21 where `user`.id in ::dml_vals", + "Query": "update ignore `user` as u set u.foo = 21 where u.id in ::dml_vals", "Table": "user", "Values": [ "::dml_vals" @@ -6738,7 +6738,7 @@ "Sharded": true }, "TargetTabletType": "PRIMARY", - "Query": "update ignore music as m set m.bar = 'abc' where music.id in ::dml_vals", + "Query": "update ignore music as m set m.bar = 'abc' where m.id in ::dml_vals", "Table": "music", "Values": [ "::dml_vals"