From da9c1304572cdd656bfce15eae0a3e339aacc7da Mon Sep 17 00:00:00 2001 From: Matt Lord Date: Sat, 7 Sep 2024 00:12:10 -0400 Subject: [PATCH] Add unit test Signed-off-by: Matt Lord --- go/vt/vtctl/workflow/framework_test.go | 21 +++++++++ go/vt/vtctl/workflow/server_test.go | 59 +++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/go/vt/vtctl/workflow/framework_test.go b/go/vt/vtctl/workflow/framework_test.go index ce58b34a9be..57ccaa0a7d0 100644 --- a/go/vt/vtctl/workflow/framework_test.go +++ b/go/vt/vtctl/workflow/framework_test.go @@ -177,6 +177,7 @@ func (env *testEnv) addTablet(t *testing.T, ctx context.Context, id int, keyspac Shard: shard, KeyRange: &topodatapb.KeyRange{}, Type: tabletType, + Hostname: "localhost", // Without a hostname the RefreshState call is skipped. PortMap: map[string]int32{ "test": int32(id), }, @@ -255,6 +256,7 @@ type testTMClient struct { createVReplicationWorkflowRequests map[uint32]*tabletmanagerdatapb.CreateVReplicationWorkflowRequest readVReplicationWorkflowRequests map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest primaryPositions map[uint32]string + refreshStateErrors map[uint32]error // Stack of ReadVReplicationWorkflowsResponse to return, in order, for each shard readVReplicationWorkflowsResponses map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse @@ -272,6 +274,7 @@ func newTestTMClient(env *testEnv) *testTMClient { readVReplicationWorkflowRequests: make(map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest), readVReplicationWorkflowsResponses: make(map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse), primaryPositions: make(map[uint32]string), + refreshStateErrors: make(map[uint32]error), env: env, } } @@ -553,6 +556,24 @@ func (tmc *testTMClient) VReplicationWaitForPos(ctx context.Context, tablet *top return nil } +func (tmc *testTMClient) SetRefreshStateError(tablet *topodatapb.Tablet, err error) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + if tmc.refreshStateErrors == nil { + tmc.refreshStateErrors = make(map[uint32]error) + } + tmc.refreshStateErrors[tablet.Alias.Uid] = err +} + +func (tmc *testTMClient) RefreshState(ctx context.Context, tablet *topodatapb.Tablet) error { + tmc.mu.Lock() + defer tmc.mu.Unlock() + if tmc.refreshStateErrors == nil { + tmc.refreshStateErrors = make(map[uint32]error) + } + return tmc.refreshStateErrors[tablet.Alias.Uid] +} + func (tmc *testTMClient) AddVReplicationWorkflowsResponse(key string, resp *tabletmanagerdatapb.ReadVReplicationWorkflowsResponse) { tmc.mu.Lock() defer tmc.mu.Unlock() diff --git a/go/vt/vtctl/workflow/server_test.go b/go/vt/vtctl/workflow/server_test.go index 9d23479ad7e..a422db28859 100644 --- a/go/vt/vtctl/workflow/server_test.go +++ b/go/vt/vtctl/workflow/server_test.go @@ -19,6 +19,7 @@ package workflow import ( "context" "encoding/json" + "errors" "fmt" "slices" "sort" @@ -833,6 +834,7 @@ func TestMoveTablesTrafficSwitching(t *testing.T) { name string sourceKeyspace, targetKeyspace *testKeyspace req *vtctldatapb.WorkflowSwitchTrafficRequest + preFunc func(env *testEnv) want *vtctldatapb.WorkflowSwitchTrafficResponse wantErr bool }{ @@ -880,6 +882,55 @@ func TestMoveTablesTrafficSwitching(t *testing.T) { CurrentState: "Reads Not Switched. Writes Not Switched", }, }, + { + name: "forward with tablet refresh error", + sourceKeyspace: &testKeyspace{ + KeyspaceName: sourceKeyspaceName, + ShardNames: []string{"0"}, + }, + targetKeyspace: &testKeyspace{ + KeyspaceName: targetKeyspaceName, + ShardNames: []string{"-80", "80-"}, + }, + req: &vtctldatapb.WorkflowSwitchTrafficRequest{ + Keyspace: targetKeyspaceName, + Workflow: workflowName, + Direction: int32(DirectionForward), + TabletTypes: tabletTypes, + }, + preFunc: func(env *testEnv) { + env.tmc.SetRefreshStateError(env.tablets[sourceKeyspaceName][startingSourceTabletUID], errors.New("tablet refresh error")) + env.tmc.SetRefreshStateError(env.tablets[targetKeyspaceName][startingTargetTabletUID], errors.New("tablet refresh error")) + }, + wantErr: true, + }, + { + name: "forward with tablet refresh error and force", + sourceKeyspace: &testKeyspace{ + KeyspaceName: sourceKeyspaceName, + ShardNames: []string{"0"}, + }, + targetKeyspace: &testKeyspace{ + KeyspaceName: targetKeyspaceName, + ShardNames: []string{"-80", "80-"}, + }, + req: &vtctldatapb.WorkflowSwitchTrafficRequest{ + Keyspace: targetKeyspaceName, + Workflow: workflowName, + Direction: int32(DirectionForward), + TabletTypes: tabletTypes, + Force: true, + }, + preFunc: func(env *testEnv) { + env.tmc.SetRefreshStateError(env.tablets[sourceKeyspaceName][startingSourceTabletUID], errors.New("tablet refresh error")) + env.tmc.SetRefreshStateError(env.tablets[targetKeyspaceName][startingTargetTabletUID], errors.New("tablet refresh error")) + }, + want: &vtctldatapb.WorkflowSwitchTrafficResponse{ + Summary: fmt.Sprintf("SwitchTraffic was successful for workflow %s.%s", targetKeyspaceName, workflowName), + StartState: "Reads Not Switched. Writes Not Switched", + CurrentState: "All Reads Switched. Writes Switched", + }, + }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { @@ -923,11 +974,15 @@ func TestMoveTablesTrafficSwitching(t *testing.T) { env.tmc.expectVRQueryResultOnKeyspaceTablets(tc.targetKeyspace.KeyspaceName, createJournalQR) env.tmc.expectVRQueryResultOnKeyspaceTablets(tc.sourceKeyspace.KeyspaceName, freezeReverseWFQR) } + if tc.preFunc != nil { + tc.preFunc(env) + } got, err := env.ws.WorkflowSwitchTraffic(ctx, tc.req) - if (err != nil) != tc.wantErr { - require.Fail(t, "unexpected error value", "Server.WorkflowSwitchTraffic() error = %v, wantErr %v", err, tc.wantErr) + if tc.wantErr { + require.Error(t, err) return } + require.NoError(t, err) require.Equal(t, tc.want.String(), got.String(), "Server.WorkflowSwitchTraffic() = %v, want %v", got, tc.want) // Confirm that we have the expected routing rules.