diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index c434d07c38f..8bdfd8735a0 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -78,22 +78,28 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op operators.Opera } func transformUpsert(ctx *plancontext.PlanningContext, op *operators.Upsert) (logicalPlan, error) { - ip, err := transformToLogicalPlan(ctx, op.InsertOp) + u := &upsert{} + for _, source := range op.Sources { + iLp, uLp, err := transformOneUpsert(ctx, source) + if err != nil { + return nil, err + } + u.insert = append(u.insert, iLp) + u.update = append(u.update, uLp) + } + return u, nil +} + +func transformOneUpsert(ctx *plancontext.PlanningContext, source operators.UpsertSource) (iLp, uLp logicalPlan, err error) { + iLp, err = transformToLogicalPlan(ctx, source.Insert) if err != nil { - return nil, err + return } - if ins, ok := ip.(*insert); ok { + if ins, ok := iLp.(*insert); ok { ins.eInsert.PreventAutoCommit = true } - up, err := transformToLogicalPlan(ctx, op.UpdateOp) - if err != nil { - return nil, err - } - - return &upsert{ - insert: ip, - update: up, - }, nil + uLp, err = transformToLogicalPlan(ctx, source.Update) + return } func transformSequential(ctx *plancontext.PlanningContext, op *operators.Sequential) (logicalPlan, error) { diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 41e3b489d7d..275849a37f8 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -163,23 +163,18 @@ func checkAndCreateInsertOperator(ctx *plancontext.PlanningContext, ins *sqlpars panic(vterrors.VT12001("REPLACE INTO with foreign keys")) } if len(ins.OnDup) > 0 { - row := getSingleRowOrError(ins) - return createUpsertOperator(ctx, ins, insOp, row, vTbl) + rows := getRowsOrError(ins) + return createUpsertOperator(ctx, ins, insOp, rows, vTbl) } } return insOp } -func getSingleRowOrError(ins *sqlparser.Insert) sqlparser.ValTuple { - switch rows := ins.Rows.(type) { - case sqlparser.SelectStatement: - panic(vterrors.VT12001("ON DUPLICATE KEY UPDATE with foreign keys with select statement")) - case sqlparser.Values: - if len(rows) == 1 { - return rows[0] - } +func getRowsOrError(ins *sqlparser.Insert) sqlparser.Values { + if rows, ok := ins.Rows.(sqlparser.Values); ok { + return rows } - panic(vterrors.VT12001("ON DUPLICATE KEY UPDATE with foreign keys with multiple rows")) + panic(vterrors.VT12001("ON DUPLICATE KEY UPDATE with foreign keys with select statement")) } func getWhereCondExpr(compExprs []*sqlparser.ComparisonExpr) sqlparser.Expr { diff --git a/go/vt/vtgate/planbuilder/operators/upsert.go b/go/vt/vtgate/planbuilder/operators/upsert.go index 5f2b9778742..8cd3cd9a521 100644 --- a/go/vt/vtgate/planbuilder/operators/upsert.go +++ b/go/vt/vtgate/planbuilder/operators/upsert.go @@ -27,27 +27,43 @@ var _ Operator = (*Upsert)(nil) // Upsert represents an insert on duplicate key operation on a table. type Upsert struct { - InsertOp Operator - UpdateOp Operator + Sources []UpsertSource noColumns noPredicates } +type UpsertSource struct { + Insert Operator + Update Operator +} + func (u *Upsert) Clone(inputs []Operator) Operator { - return &Upsert{ - InsertOp: inputs[0], - UpdateOp: inputs[1], + up := &Upsert{} + up.setInputs(inputs) + return up +} + +func (u *Upsert) setInputs(inputs []Operator) { + for i := 0; i < len(inputs); i += 2 { + u.Sources = append(u.Sources, UpsertSource{ + Insert: inputs[i], + Update: inputs[i+1], + }) } } func (u *Upsert) Inputs() []Operator { - return []Operator{u.InsertOp, u.UpdateOp} + var inputs []Operator + for _, source := range u.Sources { + inputs = append(inputs, source.Insert, source.Update) + } + return inputs } -func (u *Upsert) SetInputs(operators []Operator) { - u.InsertOp = operators[0] - u.UpdateOp = operators[1] +func (u *Upsert) SetInputs(inputs []Operator) { + u.Sources = nil + u.setInputs(inputs) } func (u *Upsert) ShortDescription() string { @@ -58,7 +74,7 @@ func (u *Upsert) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { return nil } -func createUpsertOperator(ctx *plancontext.PlanningContext, ins *sqlparser.Insert, insOp Operator, row sqlparser.ValTuple, vTbl *vindexes.Table) Operator { +func createUpsertOperator(ctx *plancontext.PlanningContext, ins *sqlparser.Insert, insOp Operator, rows sqlparser.Values, vTbl *vindexes.Table) Operator { if len(vTbl.UniqueKeys) != 0 { panic(vterrors.VT12001("ON DUPLICATE KEY UPDATE with foreign keys with unique keys")) } @@ -70,56 +86,61 @@ func createUpsertOperator(ctx *plancontext.PlanningContext, ins *sqlparser.Inser return insOp } - var whereExpr sqlparser.Expr - for _, pIdx := range pIndexes { - var expr sqlparser.Expr - if pIdx.idx == -1 { - expr = pIdx.def - } else { - expr = row[pIdx.idx] + upsert := &Upsert{} + for _, row := range rows { + var whereExpr sqlparser.Expr + for _, pIdx := range pIndexes { + var expr sqlparser.Expr + if pIdx.idx == -1 { + expr = pIdx.def + } else { + expr = row[pIdx.idx] + } + equalExpr := sqlparser.NewComparisonExpr(sqlparser.EqualOp, sqlparser.NewColName(pIdx.col.String()), expr, nil) + if whereExpr == nil { + whereExpr = equalExpr + } else { + whereExpr = sqlparser.AndExpressions(whereExpr, equalExpr) + } } - equalExpr := sqlparser.NewComparisonExpr(sqlparser.EqualOp, sqlparser.NewColName(pIdx.col.String()), expr, nil) - if whereExpr == nil { - whereExpr = equalExpr - } else { - whereExpr = sqlparser.AndExpressions(whereExpr, equalExpr) + + var updExprs sqlparser.UpdateExprs + for _, ue := range ins.OnDup { + expr := sqlparser.CopyOnRewrite(ue.Expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { + vfExpr, ok := cursor.Node().(*sqlparser.ValuesFuncExpr) + if !ok { + return + } + idx := ins.Columns.FindColumn(vfExpr.Name.Name) + if idx == -1 { + panic(vterrors.VT03014(sqlparser.String(vfExpr.Name), "field list")) + } + cursor.Replace(row[idx]) + }, nil).(sqlparser.Expr) + updExprs = append(updExprs, &sqlparser.UpdateExpr{ + Name: ue.Name, + Expr: expr, + }) } - } - var updExprs sqlparser.UpdateExprs - for _, ue := range ins.OnDup { - expr := sqlparser.CopyOnRewrite(ue.Expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { - vfExpr, ok := cursor.Node().(*sqlparser.ValuesFuncExpr) - if !ok { - return - } - idx := ins.Columns.FindColumn(vfExpr.Name.Name) - if idx == -1 { - panic(vterrors.VT03014(sqlparser.String(vfExpr.Name), "field list")) - } - cursor.Replace(row[idx]) - }, nil).(sqlparser.Expr) - updExprs = append(updExprs, &sqlparser.UpdateExpr{ - Name: ue.Name, - Expr: expr, + upd := &sqlparser.Update{ + Comments: ins.Comments, + TableExprs: sqlparser.TableExprs{ins.Table}, + Exprs: updExprs, + Where: sqlparser.NewWhere(sqlparser.WhereClause, whereExpr), + } + updOp := createOpFromStmt(ctx, upd, false, "") + + // replan insert statement without on duplicate key update. + newInsert := sqlparser.CloneRefOfInsert(ins) + newInsert.OnDup = nil + newInsert.Rows = sqlparser.Values{row} + insOp = createOpFromStmt(ctx, newInsert, false, "") + upsert.Sources = append(upsert.Sources, UpsertSource{ + Insert: insOp, + Update: updOp, }) } - upd := &sqlparser.Update{ - Comments: ins.Comments, - TableExprs: sqlparser.TableExprs{ins.Table}, - Exprs: updExprs, - Where: sqlparser.NewWhere(sqlparser.WhereClause, whereExpr), - } - updOp := createOpFromStmt(ctx, upd, false, "") - - // replan insert statement without on duplicate key update. - ins = sqlparser.CloneRefOfInsert(ins) - ins.OnDup = nil - insOp = createOpFromStmt(ctx, ins, false, "") - - return &Upsert{ - InsertOp: insOp, - UpdateOp: updOp, - } + return upsert } diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index 9e59d96b797..3ab6ef96118 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -2834,5 +2834,200 @@ "unsharded_fk_allow.u_tbl9" ] } + }, + { + "comment": "insert with on duplicate key update with multiple rows", + "query": "insert into u_tbl2 (id, col2) values (:v1, :v2),(:v3, :v4), (:v5, :v6) on duplicate key update col2 = values(col2)", + "plan": { + "QueryType": "INSERT", + "Original": "insert into u_tbl2 (id, col2) values (:v1, :v2),(:v3, :v4), (:v5, :v6) on duplicate key update col2 = values(col2)", + "Instructions": { + "OperatorType": "Upsert", + "TargetTabletType": "PRIMARY", + "Inputs": [ + { + "InputName": "Insert-1", + "OperatorType": "Insert", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "NoAutoCommit": true, + "Query": "insert into u_tbl2(id, col2) values (:v1, :v2)", + "TableName": "u_tbl2" + }, + { + "InputName": "Update-1", + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col2 from u_tbl2 where 1 != 1", + "Query": "select col2 from u_tbl2 where id = :v1 for update nowait", + "Table": "u_tbl2" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "BvName": "fkc_vals", + "Cols": [ + 0 + ], + "Query": "update u_tbl3 set col3 = null where (col3) in ::fkc_vals and (cast(:v2 as CHAR) is null or (col3) not in ((cast(:v2 as CHAR))))", + "Table": "u_tbl3" + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update u_tbl2 set col2 = :v2 where id = :v1", + "Table": "u_tbl2" + } + ] + }, + { + "InputName": "Insert-2", + "OperatorType": "Insert", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "NoAutoCommit": true, + "Query": "insert into u_tbl2(id, col2) values (:v3, :v4)", + "TableName": "u_tbl2" + }, + { + "InputName": "Update-2", + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col2 from u_tbl2 where 1 != 1", + "Query": "select col2 from u_tbl2 where id = :v3 for update nowait", + "Table": "u_tbl2" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "BvName": "fkc_vals1", + "Cols": [ + 0 + ], + "Query": "update u_tbl3 set col3 = null where (col3) in ::fkc_vals1 and (cast(:v4 as CHAR) is null or (col3) not in ((cast(:v4 as CHAR))))", + "Table": "u_tbl3" + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update u_tbl2 set col2 = :v4 where id = :v3", + "Table": "u_tbl2" + } + ] + }, + { + "InputName": "Insert-3", + "OperatorType": "Insert", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "NoAutoCommit": true, + "Query": "insert into u_tbl2(id, col2) values (:v5, :v6)", + "TableName": "u_tbl2" + }, + { + "InputName": "Update-3", + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col2 from u_tbl2 where 1 != 1", + "Query": "select col2 from u_tbl2 where id = :v5 for update nowait", + "Table": "u_tbl2" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "BvName": "fkc_vals2", + "Cols": [ + 0 + ], + "Query": "update u_tbl3 set col3 = null where (col3) in ::fkc_vals2 and (cast(:v6 as CHAR) is null or (col3) not in ((cast(:v6 as CHAR))))", + "Table": "u_tbl3" + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update u_tbl2 set col2 = :v6 where id = :v5", + "Table": "u_tbl2" + } + ] + } + ] + }, + "TablesUsed": [ + "unsharded_fk_allow.u_tbl2", + "unsharded_fk_allow.u_tbl3" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/upsert.go b/go/vt/vtgate/planbuilder/upsert.go index 097ac3c13fc..cd9c127635c 100644 --- a/go/vt/vtgate/planbuilder/upsert.go +++ b/go/vt/vtgate/planbuilder/upsert.go @@ -21,8 +21,8 @@ import ( ) type upsert struct { - insert logicalPlan - update logicalPlan + insert []logicalPlan + update []logicalPlan } var _ logicalPlan = (*upsert)(nil) @@ -30,6 +30,8 @@ var _ logicalPlan = (*upsert)(nil) // Primitive implements the logicalPlan interface func (u *upsert) Primitive() engine.Primitive { up := &engine.Upsert{} - up.AddUpsert(u.insert.Primitive(), u.update.Primitive()) + for i := 0; i < len(u.insert); i++ { + up.AddUpsert(u.insert[i].Primitive(), u.update[i].Primitive()) + } return up }