diff --git a/go/vt/vtctl/workflow/framework_test.go b/go/vt/vtctl/workflow/framework_test.go index 1d25aafa75f..b5d0a308261 100644 --- a/go/vt/vtctl/workflow/framework_test.go +++ b/go/vt/vtctl/workflow/framework_test.go @@ -256,6 +256,9 @@ type testTMClient struct { readVReplicationWorkflowRequests map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest primaryPositions map[uint32]string + // Stack of ReadVReplicationWorkflowsResponse to return, in order, for each shard + readVReplicationWorkflowsResponses map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse + env *testEnv // For access to the env config from tmc methods. reverse atomic.Bool // Are we reversing traffic? frozen atomic.Bool // Are the workflows frozen? @@ -267,6 +270,7 @@ func newTestTMClient(env *testEnv) *testTMClient { vrQueries: make(map[int][]*queryResult), createVReplicationWorkflowRequests: make(map[uint32]*tabletmanagerdatapb.CreateVReplicationWorkflowRequest), readVReplicationWorkflowRequests: make(map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest), + readVReplicationWorkflowsResponses: make(map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse), primaryPositions: make(map[uint32]string), env: env, } @@ -285,6 +289,10 @@ func (tmc *testTMClient) CreateVReplicationWorkflow(ctx context.Context, tablet return &tabletmanagerdatapb.CreateVReplicationWorkflowResponse{Result: sqltypes.ResultToProto3(res)}, nil } +func (tmc *testTMClient) GetWorkflowKey(keyspace, shard string) string { + return fmt.Sprintf("%s/%s", keyspace, shard) +} + func (tmc *testTMClient) ReadVReplicationWorkflow(ctx context.Context, tablet *topodatapb.Tablet, req *tabletmanagerdatapb.ReadVReplicationWorkflowRequest) (*tabletmanagerdatapb.ReadVReplicationWorkflowResponse, error) { tmc.mu.Lock() defer tmc.mu.Unlock() @@ -463,6 +471,10 @@ func (tmc *testTMClient) ReadVReplicationWorkflows(ctx context.Context, tablet * tmc.mu.Lock() defer tmc.mu.Unlock() + workflowKey := tmc.GetWorkflowKey(tablet.Keyspace, tablet.Shard) + if resp := tmc.getVReplicationWorkflowsResponse(workflowKey); resp != nil { + return resp, nil + } workflowType := binlogdatapb.VReplicationWorkflowType_MoveTables if len(req.IncludeWorkflows) > 0 { for _, wf := range req.IncludeWorkflows { @@ -494,7 +506,7 @@ func (tmc *testTMClient) ReadVReplicationWorkflows(ctx context.Context, tablet * }, }, }, - Pos: "MySQL56/" + position, + Pos: position, TimeUpdated: protoutil.TimeToProto(time.Now()), TimeHeartbeat: protoutil.TimeToProto(time.Now()), }, @@ -541,6 +553,25 @@ func (tmc *testTMClient) VReplicationWaitForPos(ctx context.Context, tablet *top return nil } +func (tmc *testTMClient) AddVReplicationWorkflowsResponse(key string, resp *tabletmanagerdatapb.ReadVReplicationWorkflowsResponse) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + tmc.readVReplicationWorkflowsResponses[key] = append(tmc.readVReplicationWorkflowsResponses[key], resp) +} + +func (tmc *testTMClient) getVReplicationWorkflowsResponse(key string) *tabletmanagerdatapb.ReadVReplicationWorkflowsResponse { + if len(tmc.readVReplicationWorkflowsResponses) == 0 { + return nil + } + responses, ok := tmc.readVReplicationWorkflowsResponses[key] + if !ok || len(responses) == 0 { + return nil + } + resp := tmc.readVReplicationWorkflowsResponses[key][0] + tmc.readVReplicationWorkflowsResponses[key] = tmc.readVReplicationWorkflowsResponses[key][1:] + return resp +} + // // Utility / helper functions. // diff --git a/go/vt/vtctl/workflow/materializer_env_test.go b/go/vt/vtctl/workflow/materializer_env_test.go index 569651f85ca..aada59c244d 100644 --- a/go/vt/vtctl/workflow/materializer_env_test.go +++ b/go/vt/vtctl/workflow/materializer_env_test.go @@ -61,7 +61,7 @@ type testMaterializerEnv struct { venv *vtenv.Environment } -//---------------------------------------------- +// ---------------------------------------------- // testMaterializerEnv func newTestMaterializerEnv(t *testing.T, ctx context.Context, ms *vtctldatapb.MaterializeSettings, sourceShards, targetShards []string) *testMaterializerEnv { @@ -426,7 +426,7 @@ func (tmc *testMaterializerTMClient) ReadVReplicationWorkflows(ctx context.Conte }, }, }, - Pos: "MySQL56/" + position, + Pos: position, TimeUpdated: protoutil.TimeToProto(time.Now()), TimeHeartbeat: protoutil.TimeToProto(time.Now()), } diff --git a/go/vt/vtctl/workflow/materializer_test.go b/go/vt/vtctl/workflow/materializer_test.go index 51a7d22d5eb..763dd7c04d3 100644 --- a/go/vt/vtctl/workflow/materializer_test.go +++ b/go/vt/vtctl/workflow/materializer_test.go @@ -44,7 +44,7 @@ import ( ) const ( - position = "9d10e6ec-07a0-11ee-ae73-8e53f4cf3083:1-97" + position = "MySQL56/9d10e6ec-07a0-11ee-ae73-8e53f4cf3083:1-97" mzSelectFrozenQuery = "select 1 from _vt.vreplication where db_name='vt_targetks' and message='FROZEN' and workflow_sub_type != 1" mzCheckJournal = "/select val from _vt.resharding_journal where id=" mzGetCopyState = "select distinct table_name from _vt.copy_state cs, _vt.vreplication vr where vr.id = cs.vrepl_id and vr.id = 1" @@ -56,6 +56,14 @@ var ( defaultOnDDL = binlogdatapb.OnDDLAction_IGNORE.String() ) +func gtid(position string) string { + arr := strings.Split(position, "/") + if len(arr) != 2 { + return "" + } + return arr[1] +} + func TestStripForeignKeys(t *testing.T) { tcs := []struct { desc string @@ -577,7 +585,7 @@ func TestMoveTablesDDLFlag(t *testing.T) { sourceShard, err := env.topoServ.GetShardNames(ctx, ms.SourceKeyspace) require.NoError(t, err) want := fmt.Sprintf("shard_streams:{key:\"%s/%s\" value:{streams:{id:1 tablet:{cell:\"%s\" uid:200} source_shard:\"%s/%s\" position:\"%s\" status:\"Running\" info:\"VStream Lag: 0s\"}}} traffic_state:\"Reads Not Switched. Writes Not Switched\"", - ms.TargetKeyspace, targetShard[0], env.cell, ms.SourceKeyspace, sourceShard[0], position) + ms.TargetKeyspace, targetShard[0], env.cell, ms.SourceKeyspace, sourceShard[0], gtid(position)) res, err := env.ws.MoveTablesCreate(ctx, &vtctldatapb.MoveTablesCreateRequest{ Workflow: ms.Workflow, @@ -636,7 +644,7 @@ func TestMoveTablesNoRoutingRules(t *testing.T) { Uid: 200, }, SourceShard: fmt.Sprintf("%s/%s", ms.SourceKeyspace, sourceShard[0]), - Position: position, + Position: gtid(position), Status: binlogdatapb.VReplicationWorkflowState_Running.String(), Info: "VStream Lag: 0s", }, diff --git a/go/vt/vtctl/workflow/mount_test.go b/go/vt/vtctl/workflow/mount_test.go new file mode 100644 index 00000000000..2fec275e4cb --- /dev/null +++ b/go/vt/vtctl/workflow/mount_test.go @@ -0,0 +1,77 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package workflow + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + vtctldatapb "vitess.io/vitess/go/vt/proto/vtctldata" + "vitess.io/vitess/go/vt/topo/memorytopo" + "vitess.io/vitess/go/vt/vtenv" +) + +// TestMount tests various Mount-related methods. +func TestMount(t *testing.T) { + const ( + extCluster = "extcluster" + topoType = "etcd2" + topoServer = "localhost:2379" + topoRoot = "/vitess/global" + ) + ctx := context.Background() + ts := memorytopo.NewServer(ctx, "cell") + tmc := &fakeTMC{} + s := NewServer(vtenv.NewTestEnv(), ts, tmc) + + resp, err := s.MountRegister(ctx, &vtctldatapb.MountRegisterRequest{ + Name: extCluster, + TopoType: topoType, + TopoServer: topoServer, + TopoRoot: topoRoot, + }) + require.NoError(t, err) + require.NotNil(t, resp) + + respList, err := s.MountList(ctx, &vtctldatapb.MountListRequest{}) + require.NoError(t, err) + require.NotNil(t, respList) + require.Equal(t, []string{extCluster}, respList.Names) + + respShow, err := s.MountShow(ctx, &vtctldatapb.MountShowRequest{ + Name: extCluster, + }) + require.NoError(t, err) + require.NotNil(t, respShow) + require.Equal(t, extCluster, respShow.Name) + require.Equal(t, topoType, respShow.TopoType) + require.Equal(t, topoServer, respShow.TopoServer) + require.Equal(t, topoRoot, respShow.TopoRoot) + + respUnregister, err := s.MountUnregister(ctx, &vtctldatapb.MountUnregisterRequest{ + Name: extCluster, + }) + require.NoError(t, err) + require.NotNil(t, respUnregister) + + respList, err = s.MountList(ctx, &vtctldatapb.MountListRequest{}) + require.NoError(t, err) + require.NotNil(t, respList) + require.Nil(t, respList.Names) +} diff --git a/go/vt/vtctl/workflow/resharder_test.go b/go/vt/vtctl/workflow/resharder_test.go index 1bb2f065e0f..6353f36db9f 100644 --- a/go/vt/vtctl/workflow/resharder_test.go +++ b/go/vt/vtctl/workflow/resharder_test.go @@ -84,7 +84,7 @@ func TestReshardCreate(t *testing.T) { { Id: 1, Tablet: &topodatapb.TabletAlias{Cell: defaultCellName, Uid: startingTargetTabletUID}, - SourceShard: "targetks/0", Position: position, Status: "Running", Info: "VStream Lag: 0s", + SourceShard: "targetks/0", Position: gtid(position), Status: "Running", Info: "VStream Lag: 0s", }, }, }, @@ -93,7 +93,7 @@ func TestReshardCreate(t *testing.T) { { Id: 1, Tablet: &topodatapb.TabletAlias{Cell: defaultCellName, Uid: startingTargetTabletUID + tabletUIDStep}, - SourceShard: "targetks/0", Position: position, Status: "Running", Info: "VStream Lag: 0s", + SourceShard: "targetks/0", Position: gtid(position), Status: "Running", Info: "VStream Lag: 0s", }, }, }, diff --git a/go/vt/vtctl/workflow/stream_migrator_test.go b/go/vt/vtctl/workflow/stream_migrator_test.go index 38ae10280f7..5e9c2a79038 100644 --- a/go/vt/vtctl/workflow/stream_migrator_test.go +++ b/go/vt/vtctl/workflow/stream_migrator_test.go @@ -19,17 +19,22 @@ package workflow import ( "context" "encoding/json" + "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/key" + "vitess.io/vitess/go/vt/proto/tabletmanagerdata" "vitess.io/vitess/go/vt/sqlparser" - + "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/vtgate/vindexes" "vitess.io/vitess/go/vt/vttablet/tabletmanager/vreplication" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" vschemapb "vitess.io/vitess/go/vt/proto/vschema" ) @@ -347,3 +352,271 @@ func stringifyVRS(streams []*VReplicationStream) string { b, _ := json.Marshal(converted) return string(b) } + +var testVSchema = &vschemapb.Keyspace{ + Sharded: true, + Vindexes: map[string]*vschemapb.Vindex{ + "xxhash": { + Type: "xxhash", + }, + }, + Tables: map[string]*vschemapb.Table{ + "t1": { + ColumnVindexes: []*vschemapb.ColumnVindex{{ + Columns: []string{"c1"}, + Name: "xxhash", + }}, + }, + "t2": { + ColumnVindexes: []*vschemapb.ColumnVindex{{ + Columns: []string{"c1"}, + Name: "xxhash", + }}, + }, + "ref": { + Type: vindexes.TypeReference, + }, + }, +} + +var ( + commerceKeyspace = &testKeyspace{ + KeyspaceName: "commerce", + ShardNames: []string{"0"}, + } + customerUnshardedKeyspace = &testKeyspace{ + KeyspaceName: "customer", + ShardNames: []string{"0"}, + } + customerShardedKeyspace = &testKeyspace{ + KeyspaceName: "customer", + ShardNames: []string{"-80", "80-"}, + } +) + +type streamMigratorEnv struct { + tenv *testEnv + ts *testTrafficSwitcher + sourceTabletIds []int + targetTabletIds []int +} + +func (env *streamMigratorEnv) close() { + env.tenv.close() +} + +func (env *streamMigratorEnv) addSourceQueries(queries []string) { + for _, id := range env.sourceTabletIds { + for _, q := range queries { + env.tenv.tmc.expectVRQuery(id, q, &sqltypes.Result{}) + } + } +} + +func (env *streamMigratorEnv) addTargetQueries(queries []string) { + for _, id := range env.targetTabletIds { + for _, q := range queries { + env.tenv.tmc.expectVRQuery(id, q, &sqltypes.Result{}) + } + } +} + +func newStreamMigratorEnv(ctx context.Context, t *testing.T, sourceKeyspace, targetKeyspace *testKeyspace) *streamMigratorEnv { + tenv := newTestEnv(t, ctx, "cell1", sourceKeyspace, targetKeyspace) + env := &streamMigratorEnv{tenv: tenv} + + ksschema, err := vindexes.BuildKeyspaceSchema(testVSchema, "ks", sqlparser.NewTestParser()) + require.NoError(t, err, "could not create test keyspace %+v", testVSchema) + sources := make(map[string]*MigrationSource, len(sourceKeyspace.ShardNames)) + targets := make(map[string]*MigrationTarget, len(targetKeyspace.ShardNames)) + for i, shard := range sourceKeyspace.ShardNames { + tablet := tenv.tablets[sourceKeyspace.KeyspaceName][startingSourceTabletUID+(i*tabletUIDStep)] + kr, _ := key.ParseShardingSpec(shard) + sources[shard] = &MigrationSource{ + si: topo.NewShardInfo(sourceKeyspace.KeyspaceName, shard, &topodatapb.Shard{KeyRange: kr[0]}, nil), + primary: &topo.TabletInfo{ + Tablet: tablet, + }, + } + env.sourceTabletIds = append(env.sourceTabletIds, int(tablet.Alias.Uid)) + } + for i, shard := range targetKeyspace.ShardNames { + tablet := tenv.tablets[targetKeyspace.KeyspaceName][startingTargetTabletUID+(i*tabletUIDStep)] + kr, _ := key.ParseShardingSpec(shard) + targets[shard] = &MigrationTarget{ + si: topo.NewShardInfo(targetKeyspace.KeyspaceName, shard, &topodatapb.Shard{KeyRange: kr[0]}, nil), + primary: &topo.TabletInfo{ + Tablet: tablet, + }, + } + env.targetTabletIds = append(env.targetTabletIds, int(tablet.Alias.Uid)) + } + ts := &testTrafficSwitcher{ + trafficSwitcher: trafficSwitcher{ + migrationType: binlogdatapb.MigrationType_SHARDS, + workflow: "wf1", + id: 1, + sources: sources, + targets: targets, + sourceKeyspace: sourceKeyspace.KeyspaceName, + targetKeyspace: targetKeyspace.KeyspaceName, + sourceKSSchema: ksschema, + workflowType: binlogdatapb.VReplicationWorkflowType_Reshard, + ws: tenv.ws, + }, + sourceKeyspaceSchema: ksschema, + } + env.ts = ts + + return env +} + +func addMaterializeWorkflow(t *testing.T, env *streamMigratorEnv, id int32, sourceShard string) { + var wfs tabletmanagerdata.ReadVReplicationWorkflowsResponse + wfName := "wfMat1" + wfs.Workflows = append(wfs.Workflows, &tabletmanagerdata.ReadVReplicationWorkflowResponse{ + Workflow: wfName, + WorkflowType: binlogdatapb.VReplicationWorkflowType_Materialize, + }) + wfs.Workflows[0].Streams = append(wfs.Workflows[0].Streams, &tabletmanagerdata.ReadVReplicationWorkflowResponse_Stream{ + Id: id, + Bls: &binlogdatapb.BinlogSource{ + Keyspace: env.tenv.sourceKeyspace.KeyspaceName, + Shard: sourceShard, + Filter: &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{ + {Match: "t1", Filter: "select * from t1"}, + }, + }, + }, + Pos: position, + State: binlogdatapb.VReplicationWorkflowState_Running, + }) + workflowKey := env.tenv.tmc.GetWorkflowKey(env.tenv.sourceKeyspace.KeyspaceName, sourceShard) + workflowResponses := []*tabletmanagerdata.ReadVReplicationWorkflowsResponse{ + nil, // this is the response for getting stopped workflows + &wfs, &wfs, &wfs, // return the full list for subsequent GetWorkflows calls + } + for _, resp := range workflowResponses { + env.tenv.tmc.AddVReplicationWorkflowsResponse(workflowKey, resp) + } + queries := []string{ + fmt.Sprintf("select distinct vrepl_id from _vt.copy_state where vrepl_id in (%d)", id), + fmt.Sprintf("update _vt.vreplication set state='Stopped', message='for cutover' where id in (%d)", id), + fmt.Sprintf("delete from _vt.vreplication where db_name='vt_%s' and workflow in ('%s')", + env.tenv.sourceKeyspace.KeyspaceName, wfName), + } + env.addSourceQueries(queries) + queries = []string{ + fmt.Sprintf("delete from _vt.vreplication where db_name='vt_%s' and workflow in ('%s')", + env.tenv.sourceKeyspace.KeyspaceName, wfName), + } + env.addTargetQueries(queries) + +} + +func addReferenceWorkflow(t *testing.T, env *streamMigratorEnv, id int32, sourceShard string) { + var wfs tabletmanagerdata.ReadVReplicationWorkflowsResponse + wfName := "wfRef1" + wfs.Workflows = append(wfs.Workflows, &tabletmanagerdata.ReadVReplicationWorkflowResponse{ + Workflow: wfName, + WorkflowType: binlogdatapb.VReplicationWorkflowType_Materialize, + }) + wfs.Workflows[0].Streams = append(wfs.Workflows[0].Streams, &tabletmanagerdata.ReadVReplicationWorkflowResponse_Stream{ + Id: id, + Bls: &binlogdatapb.BinlogSource{ + Keyspace: env.tenv.sourceKeyspace.KeyspaceName, + Shard: sourceShard, + Filter: &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{ + {Match: "ref", Filter: "select * from ref"}, + }, + }, + }, + Pos: position, + State: binlogdatapb.VReplicationWorkflowState_Running, + }) + workflowKey := env.tenv.tmc.GetWorkflowKey(env.tenv.sourceKeyspace.KeyspaceName, sourceShard) + workflowResponses := []*tabletmanagerdata.ReadVReplicationWorkflowsResponse{ + nil, // this is the response for getting stopped workflows + &wfs, &wfs, &wfs, // return the full list for subsequent GetWorkflows calls + } + for _, resp := range workflowResponses { + env.tenv.tmc.AddVReplicationWorkflowsResponse(workflowKey, resp) + } +} + +func TestBuildStreamMigratorOneMaterialize(t *testing.T) { + ctx := context.Background() + env := newStreamMigratorEnv(ctx, t, customerUnshardedKeyspace, customerShardedKeyspace) + defer env.close() + tmc := env.tenv.tmc + + addMaterializeWorkflow(t, env, 100, "0") + + // FIXME: Note: currently it is not optimal: we create two streams for each shard from all the + // shards even if the key ranges don't intersect. TBD + getInsert := func(shard string) string { + s := "/insert into _vt.vreplication.*" + s += fmt.Sprintf("shard:\"-80\".*in_keyrange.*c1.*%s.*", shard) + s += fmt.Sprintf("shard:\"80-\".*in_keyrange.*c1.*%s.*", shard) + return s + } + tmc.expectVRQuery(200, getInsert("-80"), &sqltypes.Result{}) + tmc.expectVRQuery(210, getInsert("80-"), &sqltypes.Result{}) + + sm, err := BuildStreamMigrator(ctx, env.ts, false, sqlparser.NewTestParser()) + require.NoError(t, err) + require.NotNil(t, sm) + require.NotNil(t, sm.streams) + require.Equal(t, 1, len(sm.streams)) + + workflows, err := sm.StopStreams(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(workflows)) + require.NoError(t, sm.MigrateStreams(ctx)) + require.Len(t, sm.templates, 1) + env.addTargetQueries([]string{ + fmt.Sprintf("update _vt.vreplication set state='Running' where db_name='vt_%s' and workflow in ('%s')", + env.tenv.sourceKeyspace.KeyspaceName, "wfMat1"), + }) + require.NoError(t, StreamMigratorFinalize(ctx, env.ts, []string{"wfMat1"})) +} + +func TestBuildStreamMigratorNoStreams(t *testing.T) { + ctx := context.Background() + env := newStreamMigratorEnv(ctx, t, customerUnshardedKeyspace, customerShardedKeyspace) + defer env.close() + + sm, err := BuildStreamMigrator(ctx, env.ts, false, sqlparser.NewTestParser()) + require.NoError(t, err) + require.NotNil(t, sm) + require.NotNil(t, sm.streams) + require.Equal(t, 0, len(sm.streams)) + + workflows, err := sm.StopStreams(ctx) + require.NoError(t, err) + require.Equal(t, 0, len(workflows)) + require.NoError(t, sm.MigrateStreams(ctx)) + require.Len(t, sm.templates, 0) +} + +func TestBuildStreamMigratorRefStream(t *testing.T) { + ctx := context.Background() + env := newStreamMigratorEnv(ctx, t, customerUnshardedKeyspace, customerShardedKeyspace) + defer env.close() + + addReferenceWorkflow(t, env, 100, "0") + + sm, err := BuildStreamMigrator(ctx, env.ts, false, sqlparser.NewTestParser()) + require.NoError(t, err) + require.NotNil(t, sm) + require.NotNil(t, sm.streams) + require.Equal(t, 0, len(sm.streams)) + + workflows, err := sm.StopStreams(ctx) + require.NoError(t, err) + require.Equal(t, 0, len(workflows)) + require.NoError(t, sm.MigrateStreams(ctx)) + require.Len(t, sm.templates, 0) +} diff --git a/go/vt/vtctl/workflow/utils_test.go b/go/vt/vtctl/workflow/utils_test.go index d79c4710b77..b315e1aa991 100644 --- a/go/vt/vtctl/workflow/utils_test.go +++ b/go/vt/vtctl/workflow/utils_test.go @@ -16,12 +16,85 @@ import ( "vitess.io/vitess/go/testfiles" "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/proto/vtctldata" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/etcd2topo" "vitess.io/vitess/go/vt/topo/memorytopo" "vitess.io/vitess/go/vt/topotools" ) +// TestCreateDefaultShardRoutingRules confirms that the default shard routing rules are created correctly for sharded +// and unsharded keyspaces. +func TestCreateDefaultShardRoutingRules(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ks1 := &testKeyspace{ + KeyspaceName: "sourceks", + } + ks2 := &testKeyspace{ + KeyspaceName: "targetks", + } + + type testCase struct { + name string + sourceKeyspace *testKeyspace + targetKeyspace *testKeyspace + shards []string + want map[string]string + } + getExpectedRules := func(sourceKeyspace, targetKeyspace *testKeyspace) map[string]string { + rules := make(map[string]string) + for _, targetShard := range targetKeyspace.ShardNames { + rules[fmt.Sprintf("%s.%s", targetKeyspace.KeyspaceName, targetShard)] = sourceKeyspace.KeyspaceName + } + return rules + + } + testCases := []testCase{ + { + name: "unsharded", + sourceKeyspace: ks1, + targetKeyspace: ks2, + shards: []string{"0"}, + }, + { + name: "sharded", + sourceKeyspace: ks2, + targetKeyspace: ks1, + shards: []string{"-80", "80-"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.sourceKeyspace.ShardNames = tc.shards + tc.targetKeyspace.ShardNames = tc.shards + env := newTestEnv(t, ctx, defaultCellName, tc.sourceKeyspace, tc.targetKeyspace) + defer env.close() + ms := &vtctldata.MaterializeSettings{ + Workflow: "wf1", + SourceKeyspace: tc.sourceKeyspace.KeyspaceName, + TargetKeyspace: tc.targetKeyspace.KeyspaceName, + TableSettings: []*vtctldata.TableMaterializeSettings{ + { + TargetTable: "t1", + SourceExpression: "select * from t1", + }, + }, + Cell: "zone1", + SourceShards: tc.sourceKeyspace.ShardNames, + } + err := createDefaultShardRoutingRules(ctx, ms, env.ts) + require.NoError(t, err) + rules, err := topotools.GetShardRoutingRules(ctx, env.ts) + require.NoError(t, err) + require.Len(t, rules, len(tc.shards)) + want := getExpectedRules(tc.sourceKeyspace, tc.targetKeyspace) + require.EqualValues(t, want, rules) + }) + } +} + // TestUpdateKeyspaceRoutingRule confirms that the keyspace routing rules are updated correctly. func TestUpdateKeyspaceRoutingRule(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/go/vt/vtctl/workflow/vreplication_stream_test.go b/go/vt/vtctl/workflow/vreplication_stream_test.go new file mode 100644 index 00000000000..6269cfa978e --- /dev/null +++ b/go/vt/vtctl/workflow/vreplication_stream_test.go @@ -0,0 +1,52 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package workflow + +import ( + "fmt" + "reflect" + "testing" +) + +// TestVReplicationStreams tests various methods of VReplicationStreams. +func TestVReplicationStreams(t *testing.T) { + var streams VReplicationStreams + for i := 1; i <= 3; i++ { + streams = append(streams, &VReplicationStream{ID: int32(i), Workflow: fmt.Sprintf("workflow%d", i)}) + } + + tests := []struct { + name string + funcUnderTest func(VReplicationStreams) interface{} + expectedResult interface{} + }{ + {"Test IDs", func(s VReplicationStreams) interface{} { return s.IDs() }, []int32{1, 2, 3}}, + {"Test Values", func(s VReplicationStreams) interface{} { return s.Values() }, "(1, 2, 3)"}, + {"Test Workflows", func(s VReplicationStreams) interface{} { return s.Workflows() }, []string{"workflow1", "workflow2", "workflow3"}}, + {"Test Copy", func(s VReplicationStreams) interface{} { return s.Copy() }, streams.Copy()}, + {"Test ToSlice", func(s VReplicationStreams) interface{} { return s.ToSlice() }, []*VReplicationStream{streams[0], streams[1], streams[2]}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.funcUnderTest(streams) + if !reflect.DeepEqual(result, tt.expectedResult) { + t.Errorf("Failed %s: expected %v, got %v", tt.name, tt.expectedResult, result) + } + }) + } +}