diff --git a/go/vt/vtctl/workflow/framework_test.go b/go/vt/vtctl/workflow/framework_test.go index 249ff07cf41..d84d1236075 100644 --- a/go/vt/vtctl/workflow/framework_test.go +++ b/go/vt/vtctl/workflow/framework_test.go @@ -271,6 +271,7 @@ type testTMClient struct { vrQueries map[int][]*queryResult createVReplicationWorkflowRequests map[uint32]*createVReplicationWorkflowRequestResponse readVReplicationWorkflowRequests map[uint32]*readVReplicationWorkflowRequestResponse + updateVReplicationWorklowsRequests map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest applySchemaRequests map[uint32]*applySchemaRequestResponse primaryPositions map[uint32]string vdiffRequests map[uint32]*vdiffRequestResponse @@ -294,6 +295,7 @@ func newTestTMClient(env *testEnv) *testTMClient { vrQueries: make(map[int][]*queryResult), createVReplicationWorkflowRequests: make(map[uint32]*createVReplicationWorkflowRequestResponse), readVReplicationWorkflowRequests: make(map[uint32]*readVReplicationWorkflowRequestResponse), + updateVReplicationWorklowsRequests: make(map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest), applySchemaRequests: make(map[uint32]*applySchemaRequestResponse), readVReplicationWorkflowsResponses: make(map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse), primaryPositions: make(map[uint32]string), @@ -677,6 +679,19 @@ func (tmc *testTMClient) UpdateVReplicationWorkflow(ctx context.Context, tablet }, nil } +func (tmc *testTMClient) UpdateVReplicationWorkflows(ctx context.Context, tablet *topodatapb.Tablet, req *tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest) (*tabletmanagerdatapb.UpdateVReplicationWorkflowsResponse, error) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + if expect := tmc.updateVReplicationWorklowsRequests[tablet.Alias.Uid]; expect != nil { + if !proto.Equal(expect, req) { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected ReadVReplicationWorkflow request on tablet %s: got %+v, want %+v", + topoproto.TabletAliasString(tablet.Alias), req, expect) + } + } + delete(tmc.updateVReplicationWorklowsRequests, tablet.Alias.Uid) + return nil, nil +} + func (tmc *testTMClient) ValidateVReplicationPermissions(ctx context.Context, tablet *topodatapb.Tablet, req *tabletmanagerdatapb.ValidateVReplicationPermissionsRequest) (*tabletmanagerdatapb.ValidateVReplicationPermissionsResponse, error) { return &tabletmanagerdatapb.ValidateVReplicationPermissionsResponse{ User: "vt_filtered", @@ -736,6 +751,12 @@ func (tmc *testTMClient) AddVReplicationWorkflowsResponse(key string, resp *tabl tmc.readVReplicationWorkflowsResponses[key] = append(tmc.readVReplicationWorkflowsResponses[key], resp) } +func (tmc *testTMClient) AddUpdateVReplicationRequests(tabletUID uint32, req *tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + tmc.updateVReplicationWorklowsRequests[tabletUID] = req +} + func (tmc *testTMClient) getVReplicationWorkflowsResponse(key string) *tabletmanagerdatapb.ReadVReplicationWorkflowsResponse { if len(tmc.readVReplicationWorkflowsResponses) == 0 { return nil diff --git a/go/vt/vtctl/workflow/resharder_test.go b/go/vt/vtctl/workflow/resharder_test.go index 6353f36db9f..2123f1fdcc8 100644 --- a/go/vt/vtctl/workflow/resharder_test.go +++ b/go/vt/vtctl/workflow/resharder_test.go @@ -22,14 +22,20 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + "vitess.io/vitess/go/ptr" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/topoproto" + "vitess.io/vitess/go/vt/vtgate/vindexes" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vschemapb "vitess.io/vitess/go/vt/proto/vschema" vtctldatapb "vitess.io/vitess/go/vt/proto/vtctldata" ) @@ -65,6 +71,8 @@ func TestReshardCreate(t *testing.T) { sourceKeyspace, targetKeyspace *testKeyspace preFunc func(env *testEnv) want *vtctldatapb.WorkflowStatusResponse + updateVReplicationRequest *tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest + autoStart bool wantErr string }{ { @@ -77,6 +85,11 @@ func TestReshardCreate(t *testing.T) { KeyspaceName: targetKeyspaceName, ShardNames: []string{"-80", "80-"}, }, + autoStart: true, + updateVReplicationRequest: &tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest{ + AllWorkflows: true, + State: ptr.Of(binlogdatapb.VReplicationWorkflowState_Running), + }, want: &vtctldatapb.WorkflowStatusResponse{ ShardStreams: map[string]*vtctldatapb.WorkflowStatusResponse_ShardStreams{ "targetks/-80": { @@ -137,6 +150,7 @@ func TestReshardCreate(t *testing.T) { SourceShards: tc.sourceKeyspace.ShardNames, TargetShards: tc.targetKeyspace.ShardNames, Cells: []string{env.cell}, + AutoStart: tc.autoStart, } for i := range tc.sourceKeyspace.ShardNames { @@ -172,6 +186,9 @@ func TestReshardCreate(t *testing.T) { "select vrepl_id, table_name, lastpk from _vt.copy_state where vrepl_id in (1) and id in (select max(id) from _vt.copy_state where vrepl_id in (1) group by vrepl_id, table_name)", &sqltypes.Result{}, ) + if tc.updateVReplicationRequest != nil { + env.tmc.AddUpdateVReplicationRequests(uint32(tabletUID), tc.updateVReplicationRequest) + } } if tc.preFunc != nil { @@ -187,6 +204,299 @@ func TestReshardCreate(t *testing.T) { if tc.want != nil { require.Equal(t, tc.want, res) } + + // Expect updateVReplicationWorklowsRequests to be empty, + // if AutoStart is enabled. This is because we delete the specific + // key from the map in the testTMC, once updateVReplicationWorklows() + // with the expected request is called. + if tc.autoStart { + assert.Len(t, env.tmc.updateVReplicationWorklowsRequests, 0) + } + }) + } +} + +func TestReadRefStreams(t *testing.T) { + ctx := context.Background() + + sourceKeyspace := &testKeyspace{ + KeyspaceName: "sourceKeyspace", + ShardNames: []string{"-"}, + } + targetKeyspace := &testKeyspace{ + KeyspaceName: "targetKeyspace", + ShardNames: []string{"-"}, + } + + env := newTestEnv(t, ctx, defaultCellName, sourceKeyspace, targetKeyspace) + + s1, err := env.ts.UpdateShardFields(ctx, targetKeyspace.KeyspaceName, "-", func(si *topo.ShardInfo) error { + return nil + }) + require.NoError(t, err) + + sourceTablet, ok := env.tablets[sourceKeyspace.KeyspaceName][100] + require.True(t, ok) + + env.tmc.schema = map[string]*tabletmanagerdatapb.SchemaDefinition{ + "t1": {}, + } + + rules := make([]*binlogdatapb.Rule, len(env.tmc.schema)) + for i, table := range maps.Keys(env.tmc.schema) { + rules[i] = &binlogdatapb.Rule{ + Match: table, + Filter: fmt.Sprintf("select * from %s", table), + } + } + + refKey := fmt.Sprintf("wf:%s:-", sourceKeyspace.KeyspaceName) + + testCases := []struct { + name string + addVReplicationWorkflowsResponse *tabletmanagerdatapb.ReadVReplicationWorkflowsResponse + preRefStreams map[string]*refStream + wantRefStreamKeys []string + wantErr bool + errContains string + }{ + { + name: "error for unnamed workflow", + addVReplicationWorkflowsResponse: &tabletmanagerdatapb.ReadVReplicationWorkflowsResponse{ + Workflows: []*tabletmanagerdatapb.ReadVReplicationWorkflowResponse{ + { + Workflow: "", + WorkflowType: binlogdatapb.VReplicationWorkflowType_Reshard, + }, + }, + }, + wantErr: true, + }, + { + name: "populate ref streams", + addVReplicationWorkflowsResponse: &tabletmanagerdatapb.ReadVReplicationWorkflowsResponse{ + Workflows: []*tabletmanagerdatapb.ReadVReplicationWorkflowResponse{ + { + Workflow: "wf", + WorkflowType: binlogdatapb.VReplicationWorkflowType_Reshard, + Streams: []*tabletmanagerdatapb.ReadVReplicationWorkflowResponse_Stream{ + { + + Bls: &binlogdatapb.BinlogSource{ + Keyspace: sourceKeyspace.KeyspaceName, + Shard: "-", + Tables: maps.Keys(env.tmc.schema), + Filter: &binlogdatapb.Filter{ + Rules: rules, + }, + }, + }, + }, + }, + }, + }, + wantRefStreamKeys: []string{refKey}, + }, + { + name: "mismatched streams with empty map", + preRefStreams: map[string]*refStream{}, + addVReplicationWorkflowsResponse: &tabletmanagerdatapb.ReadVReplicationWorkflowsResponse{ + Workflows: []*tabletmanagerdatapb.ReadVReplicationWorkflowResponse{ + { + Workflow: "wf", + WorkflowType: binlogdatapb.VReplicationWorkflowType_Reshard, + Streams: []*tabletmanagerdatapb.ReadVReplicationWorkflowResponse_Stream{ + { + + Bls: &binlogdatapb.BinlogSource{ + Keyspace: sourceKeyspace.KeyspaceName, + Shard: "-", + Tables: maps.Keys(env.tmc.schema), + Filter: &binlogdatapb.Filter{ + Rules: rules, + }, + }, + }, + }, + }, + }, + }, + wantErr: true, + errContains: "mismatch", + }, + { + name: "mismatched streams", + preRefStreams: map[string]*refStream{ + refKey: nil, + "nonexisting": nil, + }, + addVReplicationWorkflowsResponse: &tabletmanagerdatapb.ReadVReplicationWorkflowsResponse{ + Workflows: []*tabletmanagerdatapb.ReadVReplicationWorkflowResponse{ + { + Workflow: "wf", + WorkflowType: binlogdatapb.VReplicationWorkflowType_Reshard, + Streams: []*tabletmanagerdatapb.ReadVReplicationWorkflowResponse_Stream{ + { + + Bls: &binlogdatapb.BinlogSource{ + Keyspace: sourceKeyspace.KeyspaceName, + Shard: "-", + Tables: maps.Keys(env.tmc.schema), + Filter: &binlogdatapb.Filter{ + Rules: rules, + }, + }, + }, + }, + }, + }, + }, + wantErr: true, + errContains: "mismatch", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rs := &resharder{ + s: env.ws, + keyspace: targetKeyspace.KeyspaceName, + sourceShards: []*topo.ShardInfo{s1}, + sourcePrimaries: map[string]*topo.TabletInfo{ + "-": { + Tablet: sourceTablet, + }, + }, + workflow: "wf", + vschema: &vschemapb.Keyspace{ + Tables: map[string]*vschemapb.Table{ + "t1": { + Type: vindexes.TypeReference, + }, + }, + }, + refStreams: tc.preRefStreams, + } + + workflowKey := env.tmc.GetWorkflowKey(sourceKeyspace.KeyspaceName, "-") + + env.tmc.AddVReplicationWorkflowsResponse(workflowKey, tc.addVReplicationWorkflowsResponse) + + err := rs.readRefStreams(ctx) + if !tc.wantErr { + assert.NoError(t, err) + for _, rk := range tc.wantRefStreamKeys { + assert.Contains(t, rs.refStreams, rk) + } + return + } + + assert.Error(t, err) + assert.ErrorContains(t, err, tc.errContains) + }) + } +} + +func TestBlsIsReference(t *testing.T) { + testCases := []struct { + name string + bls *binlogdatapb.BinlogSource + tables map[string]*vschemapb.Table + expected bool + wantErr bool + errContains string + }{ + { + name: "all references", + bls: &binlogdatapb.BinlogSource{ + Filter: &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{ + {Match: "ref_table1"}, + {Match: "ref_table2"}, + }, + }, + }, + tables: map[string]*vschemapb.Table{ + "ref_table1": {Type: vindexes.TypeReference}, + "ref_table2": {Type: vindexes.TypeReference}, + }, + expected: true, + }, + { + name: "all sharded", + bls: &binlogdatapb.BinlogSource{ + Filter: &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{ + {Match: "sharded_table1"}, + {Match: "sharded_table2"}, + }, + }, + }, + tables: map[string]*vschemapb.Table{ + "sharded_table1": {Type: vindexes.TypeTable}, + "sharded_table2": {Type: vindexes.TypeTable}, + }, + expected: false, + }, + { + name: "mixed reference and sharded tables", + bls: &binlogdatapb.BinlogSource{ + Filter: &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{ + {Match: "ref_table"}, + {Match: "sharded_table"}, + }, + }, + }, + tables: map[string]*vschemapb.Table{ + "ref_table": {Type: vindexes.TypeReference}, + "sharded_table": {Type: vindexes.TypeTable}, + }, + wantErr: true, + }, + { + name: "rule table not found in vschema", + bls: &binlogdatapb.BinlogSource{ + Filter: &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{ + {Match: "unknown_table"}, + }, + }, + }, + tables: map[string]*vschemapb.Table{}, + wantErr: true, + errContains: "unknown_table", + }, + { + name: "internal operation table ignored", + bls: &binlogdatapb.BinlogSource{ + Filter: &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{ + {Match: "_vt_hld_6ace8bcef73211ea87e9f875a4d24e90_20200915120410_"}, + }, + }, + }, + tables: map[string]*vschemapb.Table{}, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rs := &resharder{ + vschema: &vschemapb.Keyspace{ + Tables: tc.tables, + }, + } + + result, err := rs.blsIsReference(tc.bls) + + if tc.wantErr { + assert.ErrorContains(t, err, tc.errContains) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } }) } }