Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add foreign key support for insert on duplicate key update #14638

Merged
merged 11 commits into from
Dec 12, 2023
Merged
31 changes: 31 additions & 0 deletions go/test/endtoend/vtgate/foreignkey/fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,37 @@ func TestReplaceWithFK(t *testing.T) {
utils.AssertMatches(t, conn, `select * from u_t2`, `[[INT64(1) NULL] [INT64(2) NULL]]`)
}

// TestInsertWithFKOnDup tests that insertion with on duplicate key update works as expected.
func TestInsertWithFKOnDup(t *testing.T) {
mcmp, closer := start(t)
defer closer()

utils.Exec(t, mcmp.VtConn, "use `uks`")

// insert some data.
mcmp.Exec(`insert into u_t1(id, col1) values (100, 1), (200, 2), (300, 3), (400, 4)`)
mcmp.Exec(`insert into u_t2(id, col2) values (1000, 1), (2000, 2), (3000, 3), (4000, 4)`)

// updating child to an existing value in parent.
mcmp.Exec(`insert into u_t2(id, col2) values (4000, 50) on duplicate key update col2 = 1`)
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) INT64(1)] [INT64(2000) INT64(2)] [INT64(3000) INT64(3)] [INT64(4000) INT64(1)]]`)

// updating parent, value not referred in child.
mcmp.Exec(`insert into u_t1(id, col1) values (400, 50) on duplicate key update col1 = values(col1)`)
mcmp.AssertMatches(`select * from u_t1 order by id`, `[[INT64(100) INT64(1)] [INT64(200) INT64(2)] [INT64(300) INT64(3)] [INT64(400) INT64(50)]]`)
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) INT64(1)] [INT64(2000) INT64(2)] [INT64(3000) INT64(3)] [INT64(4000) INT64(1)]]`)

// updating parent, child updated to null.
mcmp.Exec(`insert into u_t1(id, col1) values (100, 75) on duplicate key update col1 = values(col1)`)
mcmp.AssertMatches(`select * from u_t1 order by id`, `[[INT64(100) INT64(75)] [INT64(200) INT64(2)] [INT64(300) INT64(3)] [INT64(400) INT64(50)]]`)
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) NULL] [INT64(2000) INT64(2)] [INT64(3000) INT64(3)] [INT64(4000) NULL]]`)

// inserting multiple rows in parent, some child rows updated to null.
mcmp.Exec(`insert into u_t1(id, col1) values (100, 42),(600, 2),(300, 24),(200, 2) on duplicate key update col1 = values(col1)`)
mcmp.AssertMatches(`select * from u_t1 order by id`, `[[INT64(100) INT64(42)] [INT64(200) INT64(2)] [INT64(300) INT64(24)] [INT64(400) INT64(50)] [INT64(600) INT64(2)]]`)
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) NULL] [INT64(2000) INT64(2)] [INT64(3000) NULL] [INT64(4000) NULL]]`)
}

// TestDDLFk tests that table is created with fk constraint when foreign_key_checks is off.
func TestDDLFk(t *testing.T) {
mcmp, closer := start(t)
Expand Down
35 changes: 35 additions & 0 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

138 changes: 138 additions & 0 deletions go/vt/vtgate/engine/upsert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
Copyright 2023 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package engine

import (
"context"
"fmt"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

var _ Primitive = (*Upsert)(nil)

// Upsert Primitive will execute the insert primitive first and
// if there is `Duplicate Key` error, it executes the update primitive.
type Upsert struct {
Upserts []upsert

txNeeded
}

type upsert struct {
Insert Primitive
Update Primitive
}

// AddUpsert appends to the Upsert Primitive.
func (u *Upsert) AddUpsert(ins, upd Primitive) {
u.Upserts = append(u.Upserts, upsert{
Insert: ins,
Update: upd,
})
}

// RouteType implements Primitive interface type.
func (u *Upsert) RouteType() string {
return "UPSERT"
}

// GetKeyspaceName implements Primitive interface type.
func (u *Upsert) GetKeyspaceName() string {
if len(u.Upserts) > 0 {
return u.Upserts[0].Insert.GetKeyspaceName()
}
return ""
}

// GetTableName implements Primitive interface type.
func (u *Upsert) GetTableName() string {
if len(u.Upserts) > 0 {
return u.Upserts[0].Insert.GetTableName()
}
return ""
}

// GetFields implements Primitive interface type.
func (u *Upsert) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, vterrors.VT13001("unexpected to receive GetFields call for insert on duplicate key update query")
}

// TryExecute implements Primitive interface type.
func (u *Upsert) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
result := &sqltypes.Result{}
for _, up := range u.Upserts {
qr, err := execOne(ctx, vcursor, bindVars, wantfields, up)
if err != nil {
return nil, err
}
result.RowsAffected += qr.RowsAffected
}
return result, nil
}

func execOne(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, up upsert) (*sqltypes.Result, error) {
insQr, err := vcursor.ExecutePrimitive(ctx, up.Insert, bindVars, wantfields)
if err == nil {
return insQr, nil
}
if vterrors.Code(err) != vtrpcpb.Code_ALREADY_EXISTS {
return nil, err
}
updQr, err := vcursor.ExecutePrimitive(ctx, up.Update, bindVars, wantfields)
if err != nil {
return nil, err
}
// To match mysql, need to report +1 on rows affected if there is any change.
if updQr.RowsAffected > 0 {
updQr.RowsAffected += 1
}
Comment on lines +104 to +107
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that MySQL has this behaviour, but for the life of me, I can't understand why 😆

return updQr, nil
}

// TryStreamExecute implements Primitive interface type.
func (u *Upsert) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
qr, err := u.TryExecute(ctx, vcursor, bindVars, wantfields)
if err != nil {
return err
}
return callback(qr)
}

// Inputs implements Primitive interface type.
func (u *Upsert) Inputs() ([]Primitive, []map[string]any) {
var inputs []Primitive
var inputsMap []map[string]any
for i, up := range u.Upserts {
inputs = append(inputs, up.Insert, up.Update)
inputsMap = append(inputsMap,
map[string]any{inputName: fmt.Sprintf("Insert-%d", i+1)},
map[string]any{inputName: fmt.Sprintf("Update-%d", i+1)})
}
return inputs, inputsMap
}

func (u *Upsert) description() PrimitiveDescription {
return PrimitiveDescription{
OperatorType: "Upsert",
TargetTabletType: topodatapb.TabletType_PRIMARY,
}
}
27 changes: 27 additions & 0 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op operators.Opera
return transformFkVerify(ctx, op)
case *operators.InsertSelection:
return transformInsertionSelection(ctx, op)
case *operators.Upsert:
return transformUpsert(ctx, op)
case *operators.HashJoin:
return transformHashJoin(ctx, op)
case *operators.Sequential:
Expand All @@ -75,6 +77,31 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op operators.Opera
return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToLogicalPlan)", op))
}

func transformUpsert(ctx *plancontext.PlanningContext, op *operators.Upsert) (logicalPlan, error) {
u := &upsert{}
for _, source := range op.Sources {
iLp, uLp, err := transformOneUpsert(ctx, source)
if err != nil {
return nil, err
}
u.insert = append(u.insert, iLp)
u.update = append(u.update, uLp)
}
return u, nil
}

func transformOneUpsert(ctx *plancontext.PlanningContext, source operators.UpsertSource) (iLp, uLp logicalPlan, err error) {
iLp, err = transformToLogicalPlan(ctx, source.Insert)
if err != nil {
return
}
if ins, ok := iLp.(*insert); ok {
ins.eInsert.PreventAutoCommit = true
}
uLp, err = transformToLogicalPlan(ctx, source.Update)
return
}

func transformSequential(ctx *plancontext.PlanningContext, op *operators.Sequential) (logicalPlan, error) {
var lps []logicalPlan
for _, source := range op.Sources {
Expand Down
Loading
Loading