diff --git a/go/vt/vtgate/engine/ddl.go b/go/vt/vtgate/engine/ddl.go index d0ac2cb457e..17aa7945537 100644 --- a/go/vt/vtgate/engine/ddl.go +++ b/go/vt/vtgate/engine/ddl.go @@ -95,6 +95,11 @@ func (ddl *DDL) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st return vcursor.ExecutePrimitive(ctx, ddl.NormalDDL, bindVars, wantfields) } + // Commit any open transaction before executing the ddl query. + if err = vcursor.Session().Commit(ctx); err != nil { + return nil, err + } + ddlStrategySetting, err := schema.ParseDDLStrategy(vcursor.Session().GetDDLStrategy()) if err != nil { return nil, err diff --git a/go/vt/vtgate/engine/ddl_test.go b/go/vt/vtgate/engine/ddl_test.go new file mode 100644 index 00000000000..80936b6a913 --- /dev/null +++ b/go/vt/vtgate/engine/ddl_test.go @@ -0,0 +1,84 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/key" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +func TestDDL(t *testing.T) { + ddl := &DDL{ + DDL: &sqlparser.CreateTable{ + Table: sqlparser.NewTableName("a"), + }, + DirectDDLEnabled: true, + OnlineDDL: &OnlineDDL{}, + NormalDDL: &Send{ + Keyspace: &vindexes.Keyspace{ + Name: "ks", + Sharded: true, + }, + TargetDestination: key.DestinationAllShards{}, + Query: "ddl query", + }, + } + + vc := &loggingVCursor{} + _, err := ddl.TryExecute(context.Background(), vc, nil, true) + require.NoError(t, err) + + vc.ExpectLog(t, []string{ + "commit", + "ResolveDestinations ks [] Destinations:DestinationAllShards()", + "ExecuteMultiShard false false", + }) +} + +func TestDDLTempTable(t *testing.T) { + ddl := &DDL{ + CreateTempTable: true, + DDL: &sqlparser.CreateTable{ + Table: sqlparser.NewTableName("a"), + }, + DirectDDLEnabled: true, + OnlineDDL: &OnlineDDL{}, + NormalDDL: &Send{ + Keyspace: &vindexes.Keyspace{ + Name: "ks", + Sharded: true, + }, + TargetDestination: key.DestinationAllShards{}, + Query: "ddl query", + }, + } + + vc := &loggingVCursor{} + _, err := ddl.TryExecute(context.Background(), vc, nil, true) + require.NoError(t, err) + + vc.ExpectLog(t, []string{ + "ResolveDestinations ks [] Destinations:DestinationAllShards()", + "ExecuteMultiShard false false", + }) +} diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 139223d4d09..e7c7bb32e73 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -51,6 +51,10 @@ var _ SessionActions = (*noopVCursor)(nil) type noopVCursor struct { } +func (t *noopVCursor) Commit(ctx context.Context) error { + return nil +} + func (t *noopVCursor) GetUDV(key string) *querypb.BindVariable { // TODO implement me panic("implement me") @@ -156,7 +160,7 @@ func (t *noopVCursor) SetDDLStrategy(strategy string) { } func (t *noopVCursor) GetDDLStrategy() string { - panic("implement me") + return "" } func (t *noopVCursor) SetMigrationContext(migrationContext string) { @@ -389,6 +393,11 @@ type loggingVCursor struct { shardSession []*srvtopo.ResolvedShard } +func (f *loggingVCursor) Commit(_ context.Context) error { + f.log = append(f.log, "commit") + return nil +} + func (f *loggingVCursor) GetUDV(key string) *querypb.BindVariable { // TODO implement me panic("implement me") diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index b5d67c9d994..44654f2850d 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -203,6 +203,8 @@ type ( // InTransaction returns true if the session has already opened transaction or // will start a transaction on the query execution. InTransaction() bool + + Commit(ctx context.Context) error } // Match is used to check if a Primitive matches diff --git a/go/vt/vtgate/planbuilder/ddl.go b/go/vt/vtgate/planbuilder/ddl.go index f366a169d69..e7703630bb3 100644 --- a/go/vt/vtgate/planbuilder/ddl.go +++ b/go/vt/vtgate/planbuilder/ddl.go @@ -44,7 +44,7 @@ func (fk *fkContraint) FkWalk(node sqlparser.SQLNode) (kontinue bool, err error) // and which chooses which of the two to invoke at runtime. func buildGeneralDDLPlan(ctx context.Context, sql string, ddlStatement sqlparser.DDLStatement, reservedVars *sqlparser.ReservedVars, vschema plancontext.VSchema, enableOnlineDDL, enableDirectDDL bool) (*planResult, error) { if vschema.Destination() != nil { - return buildByPassDDLPlan(sql, vschema) + return buildByPassPlan(sql, vschema) } normalDDLPlan, onlineDDLPlan, err := buildDDLPlans(ctx, sql, ddlStatement, reservedVars, vschema, enableOnlineDDL, enableDirectDDL) if err != nil { @@ -79,7 +79,7 @@ func buildGeneralDDLPlan(ctx context.Context, sql string, ddlStatement sqlparser return newPlanResult(eddl, tc.getTables()...), nil } -func buildByPassDDLPlan(sql string, vschema plancontext.VSchema) (*planResult, error) { +func buildByPassPlan(sql string, vschema plancontext.VSchema) (*planResult, error) { keyspace, err := vschema.DefaultKeyspace() if err != nil { return nil, err diff --git a/go/vt/vtgate/planbuilder/show.go b/go/vt/vtgate/planbuilder/show.go index b45ae23bfbc..aba5b1a9016 100644 --- a/go/vt/vtgate/planbuilder/show.go +++ b/go/vt/vtgate/planbuilder/show.go @@ -44,7 +44,7 @@ const ( func buildShowPlan(sql string, stmt *sqlparser.Show, _ *sqlparser.ReservedVars, vschema plancontext.VSchema) (*planResult, error) { if vschema.Destination() != nil { - return buildByPassDDLPlan(sql, vschema) + return buildByPassPlan(sql, vschema) } var prim engine.Primitive diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index 8f65884dba3..7341e3e8c1b 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -978,6 +978,10 @@ func (vc *vcursorImpl) InTransaction() bool { return vc.safeSession.InTransaction() } +func (vc *vcursorImpl) Commit(ctx context.Context) error { + return vc.executor.Commit(ctx, vc.safeSession) +} + // GetDBDDLPluginName implements the VCursor interface func (vc *vcursorImpl) GetDBDDLPluginName() string { return dbDDLPlugin