diff --git a/go/vt/discovery/keyspace_events.go b/go/vt/discovery/keyspace_events.go index 9fa457c1589..036d4f3ad14 100644 --- a/go/vt/discovery/keyspace_events.go +++ b/go/vt/discovery/keyspace_events.go @@ -19,7 +19,9 @@ package discovery import ( "context" "fmt" + "slices" "sync" + "time" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" @@ -37,6 +39,11 @@ import ( vschemapb "vitess.io/vitess/go/vt/proto/vschema" ) +var ( + // waitConsistentKeyspacesCheck is the amount of time to wait for between checks to verify the keyspace is consistent. + waitConsistentKeyspacesCheck = 100 * time.Millisecond +) + // KeyspaceEventWatcher is an auxiliary watcher that watches all availability incidents // for all keyspaces in a Vitess cell and notifies listeners when the events have been resolved. // Right now this is capable of detecting the end of failovers, both planned and unplanned, @@ -662,29 +669,53 @@ func (kew *KeyspaceEventWatcher) TargetIsBeingResharded(ctx context.Context, tar return ks.beingResharded(target.Shard) } -// PrimaryIsNotServing checks if the reason why the given target is not accessible right now is -// that the primary tablet for that shard is not serving. This is possible during a Planned -// Reparent Shard operation. Just as the operation completes, a new primary will be elected, and +// ShouldStartBufferingForTarget checks if we should be starting buffering for the given target. +// We check the following things before we start buffering - +// 1. The shard must have a primary. +// 2. The primary must be non-serving. +// 3. The keyspace must be marked inconsistent. +// +// This buffering is meant to kick in during a Planned Reparent Shard operation. +// As part of that operation the old primary will become non-serving. At that point +// this code should return true to start buffering requests. +// Just as the PRS operation completes, a new primary will be elected, and // it will send its own healthcheck stating that it is serving. We should buffer requests until -// that point. There are use cases where people do not run with a Primary server at all, so we must +// that point. +// +// There are use cases where people do not run with a Primary server at all, so we must // verify that we only start buffering when a primary was present, and it went not serving. // The shard state keeps track of the current primary and the last externally reparented time, which // we can use to determine that there was a serving primary which now became non serving. This is // only possible in a DemotePrimary RPC which are only called from ERS and PRS. So buffering will -// stop when these operations succeed. We return the tablet alias of the primary if it is serving. -func (kew *KeyspaceEventWatcher) PrimaryIsNotServing(ctx context.Context, target *querypb.Target) (*topodatapb.TabletAlias, bool) { +// stop when these operations succeed. We also return the tablet alias of the primary if it is serving. +func (kew *KeyspaceEventWatcher) ShouldStartBufferingForTarget(ctx context.Context, target *querypb.Target) (*topodatapb.TabletAlias, bool) { if target.TabletType != topodatapb.TabletType_PRIMARY { + // We don't support buffering for any target tablet type other than the primary. return nil, false } ks := kew.getKeyspaceStatus(ctx, target.Keyspace) if ks == nil { + // If the keyspace status is nil, then the keyspace must be deleted. + // The user query is trying to access a keyspace that has been deleted. + // There is no reason to buffer this query. return nil, false } ks.mu.Lock() defer ks.mu.Unlock() if state, ok := ks.shards[target.Shard]; ok { - // If the primary tablet was present then externallyReparented will be non-zero and - // currentPrimary will be not nil. + // As described in the function comment, we only want to start buffering when all the following conditions are met - + // 1. The shard must have a primary. We check this by checking the currentPrimary and externallyReparented fields being non-empty. + // They are set the first time the shard registers an update from a serving primary and are never cleared out after that. + // If the user has configured vtgates to wait for the primary tablet healthchecks before starting query service, this condition + // will always be true. + // 2. The primary must be non-serving. We check this by checking the serving field in the shard state. + // When a primary becomes non-serving, it also marks the keyspace inconsistent. So the next check is only added + // for being defensive against any bugs. + // 3. The keyspace must be marked inconsistent. We check this by checking the consistent field in the keyspace state. + // + // The reason we need all the three checks is that we want to be very defensive in when we start buffering. + // We don't want to start buffering when we don't know for sure if the primary + // is not serving and we will receive an update that stops buffering soon. return state.currentPrimary, !state.serving && !ks.consistent && state.externallyReparented != 0 && state.currentPrimary != nil } return nil, false @@ -703,3 +734,46 @@ func (kew *KeyspaceEventWatcher) GetServingKeyspaces() []string { } return servingKeyspaces } + +// WaitForConsistentKeyspaces waits for the given set of keyspaces to be marked consistent. +func (kew *KeyspaceEventWatcher) WaitForConsistentKeyspaces(ctx context.Context, ksList []string) error { + // We don't want to change the original keyspace list that we receive so we clone it + // before we empty it elements down below. + keyspaces := slices.Clone(ksList) + for { + // We empty keyspaces as we find them to be consistent. + allConsistent := true + for i, ks := range keyspaces { + if ks == "" { + continue + } + + // Get the keyspace status and see it is consistent yet or not. + kss := kew.getKeyspaceStatus(ctx, ks) + // If kss is nil, then it must be deleted. In that case too it is fine for us to consider + // it consistent since the keyspace has been deleted. + if kss == nil || kss.consistent { + keyspaces[i] = "" + } else { + allConsistent = false + } + } + + if allConsistent { + // all the keyspaces are consistent. + return nil + } + + // Unblock after the sleep or when the context has expired. + select { + case <-ctx.Done(): + for _, ks := range keyspaces { + if ks != "" { + log.Infof("keyspace %v didn't become consistent", ks) + } + } + return ctx.Err() + case <-time.After(waitConsistentKeyspacesCheck): + } + } +} diff --git a/go/vt/discovery/keyspace_events_test.go b/go/vt/discovery/keyspace_events_test.go index e9406ff1de2..1a4c473e7cb 100644 --- a/go/vt/discovery/keyspace_events_test.go +++ b/go/vt/discovery/keyspace_events_test.go @@ -155,11 +155,11 @@ func TestKeyspaceEventTypes(t *testing.T) { kew := NewKeyspaceEventWatcher(ctx, ts2, hc, cell) type testCase struct { - name string - kss *keyspaceState - shardToCheck string - expectResharding bool - expectPrimaryNotServing bool + name string + kss *keyspaceState + shardToCheck string + expectResharding bool + expectShouldBuffer bool } testCases := []testCase{ @@ -196,9 +196,9 @@ func TestKeyspaceEventTypes(t *testing.T) { }, consistent: false, }, - shardToCheck: "-", - expectResharding: true, - expectPrimaryNotServing: false, + shardToCheck: "-", + expectResharding: true, + expectShouldBuffer: false, }, { name: "two to four resharding in progress", @@ -257,9 +257,9 @@ func TestKeyspaceEventTypes(t *testing.T) { }, consistent: false, }, - shardToCheck: "-80", - expectResharding: true, - expectPrimaryNotServing: false, + shardToCheck: "-80", + expectResharding: true, + expectShouldBuffer: false, }, { name: "unsharded primary not serving", @@ -283,9 +283,9 @@ func TestKeyspaceEventTypes(t *testing.T) { }, consistent: false, }, - shardToCheck: "-", - expectResharding: false, - expectPrimaryNotServing: true, + shardToCheck: "-", + expectResharding: false, + expectShouldBuffer: true, }, { name: "sharded primary not serving", @@ -317,9 +317,9 @@ func TestKeyspaceEventTypes(t *testing.T) { }, consistent: false, }, - shardToCheck: "-80", - expectResharding: false, - expectPrimaryNotServing: true, + shardToCheck: "-80", + expectResharding: false, + expectShouldBuffer: true, }, } @@ -334,8 +334,89 @@ func TestKeyspaceEventTypes(t *testing.T) { resharding := kew.TargetIsBeingResharded(ctx, tc.kss.shards[tc.shardToCheck].target) require.Equal(t, resharding, tc.expectResharding, "TargetIsBeingResharded should return %t", tc.expectResharding) - _, primaryDown := kew.PrimaryIsNotServing(ctx, tc.kss.shards[tc.shardToCheck].target) - require.Equal(t, primaryDown, tc.expectPrimaryNotServing, "PrimaryIsNotServing should return %t", tc.expectPrimaryNotServing) + _, shouldBuffer := kew.ShouldStartBufferingForTarget(ctx, tc.kss.shards[tc.shardToCheck].target) + require.Equal(t, shouldBuffer, tc.expectShouldBuffer, "ShouldStartBufferingForTarget should return %t", tc.expectShouldBuffer) + }) + } +} + +// TestWaitForConsistentKeyspaces tests the behaviour of WaitForConsistent for different scenarios. +func TestWaitForConsistentKeyspaces(t *testing.T) { + testcases := []struct { + name string + ksMap map[string]*keyspaceState + ksList []string + errExpected string + }{ + { + name: "Empty keyspace list", + ksList: nil, + ksMap: map[string]*keyspaceState{ + "ks1": {}, + }, + errExpected: "", + }, + { + name: "All keyspaces consistent", + ksList: []string{"ks1", "ks2"}, + ksMap: map[string]*keyspaceState{ + "ks1": { + consistent: true, + }, + "ks2": { + consistent: true, + }, + }, + errExpected: "", + }, + { + name: "One keyspace inconsistent", + ksList: []string{"ks1", "ks2"}, + ksMap: map[string]*keyspaceState{ + "ks1": { + consistent: true, + }, + "ks2": { + consistent: false, + }, + }, + errExpected: "context canceled", + }, + { + name: "One deleted keyspace - consistent", + ksList: []string{"ks1", "ks2"}, + ksMap: map[string]*keyspaceState{ + "ks1": { + consistent: true, + }, + "ks2": { + deleted: true, + }, + }, + errExpected: "", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + // We create a cancelable context and immediately cancel it. + // We don't want the unit tests to wait, so we only test the first + // iteration of whether the keyspace event watcher returns + // that the keyspaces are consistent or not. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + kew := KeyspaceEventWatcher{ + keyspaces: tt.ksMap, + mu: sync.Mutex{}, + ts: &fakeTopoServer{}, + } + err := kew.WaitForConsistentKeyspaces(ctx, tt.ksList) + if tt.errExpected != "" { + require.ErrorContains(t, err, tt.errExpected) + } else { + require.NoError(t, err) + } + }) } } diff --git a/go/vt/srvtopo/discover.go b/go/vt/srvtopo/discover.go index 2997dc42e21..2b020e89887 100644 --- a/go/vt/srvtopo/discover.go +++ b/go/vt/srvtopo/discover.go @@ -17,9 +17,8 @@ limitations under the License. package srvtopo import ( - "sync" - "context" + "sync" "vitess.io/vitess/go/vt/concurrency" "vitess.io/vitess/go/vt/log" @@ -29,15 +28,16 @@ import ( topodatapb "vitess.io/vitess/go/vt/proto/topodata" ) -// FindAllTargets goes through all serving shards in the topology for the provided keyspaces +// FindAllTargetsAndKeyspaces goes through all serving shards in the topology for the provided keyspaces // and tablet types. If no keyspaces are provided all available keyspaces in the topo are // fetched. It returns one Target object per keyspace/shard/matching TabletType. -func FindAllTargets(ctx context.Context, ts Server, cell string, keyspaces []string, tabletTypes []topodatapb.TabletType) ([]*querypb.Target, error) { +// It also returns all the keyspaces that it found. +func FindAllTargetsAndKeyspaces(ctx context.Context, ts Server, cell string, keyspaces []string, tabletTypes []topodatapb.TabletType) ([]*querypb.Target, []string, error) { var err error if len(keyspaces) == 0 { keyspaces, err = ts.GetSrvKeyspaceNames(ctx, cell, true) if err != nil { - return nil, err + return nil, nil, err } } @@ -95,8 +95,8 @@ func FindAllTargets(ctx context.Context, ts Server, cell string, keyspaces []str } wg.Wait() if errRecorder.HasErrors() { - return nil, errRecorder.Error() + return nil, nil, errRecorder.Error() } - return targets, nil + return targets, keyspaces, nil } diff --git a/go/vt/srvtopo/discover_test.go b/go/vt/srvtopo/discover_test.go index 3f730bba3d3..0232bce7a65 100644 --- a/go/vt/srvtopo/discover_test.go +++ b/go/vt/srvtopo/discover_test.go @@ -48,7 +48,7 @@ func (a TargetArray) Less(i, j int) bool { return a[i].TabletType < a[j].TabletType } -func TestFindAllTargets(t *testing.T) { +func TestFindAllTargetsAndKeyspaces(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() ts := memorytopo.NewServer(ctx, "cell1", "cell2") @@ -63,9 +63,10 @@ func TestFindAllTargets(t *testing.T) { rs := NewResilientServer(ctx, ts, "TestFindAllKeyspaceShards") // No keyspace / shards. - ks, err := FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) + targets, ksList, err := FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) assert.NoError(t, err) - assert.Len(t, ks, 0) + assert.Len(t, targets, 0) + assert.EqualValues(t, []string{"test_keyspace"}, ksList) // Add one. assert.NoError(t, ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace", &topodatapb.SrvKeyspace{ @@ -82,7 +83,7 @@ func TestFindAllTargets(t *testing.T) { })) // Get it. - ks, err = FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) + targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) assert.NoError(t, err) assert.EqualValues(t, []*querypb.Target{ { @@ -91,10 +92,11 @@ func TestFindAllTargets(t *testing.T) { Shard: "test_shard0", TabletType: topodatapb.TabletType_PRIMARY, }, - }, ks) + }, targets) + assert.EqualValues(t, []string{"test_keyspace"}, ksList) // Get any keyspace. - ks, err = FindAllTargets(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) + targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) assert.NoError(t, err) assert.EqualValues(t, []*querypb.Target{ { @@ -103,7 +105,8 @@ func TestFindAllTargets(t *testing.T) { Shard: "test_shard0", TabletType: topodatapb.TabletType_PRIMARY, }, - }, ks) + }, targets) + assert.EqualValues(t, []string{"test_keyspace"}, ksList) // Add another one. assert.NoError(t, ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace2", &topodatapb.SrvKeyspace{ @@ -128,9 +131,9 @@ func TestFindAllTargets(t *testing.T) { })) // Get it for any keyspace, all types. - ks, err = FindAllTargets(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) + targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) assert.NoError(t, err) - sort.Sort(TargetArray(ks)) + sort.Sort(TargetArray(targets)) assert.EqualValues(t, []*querypb.Target{ { Cell: "cell1", @@ -150,10 +153,12 @@ func TestFindAllTargets(t *testing.T) { Shard: "test_shard2", TabletType: topodatapb.TabletType_REPLICA, }, - }, ks) + }, targets) + sort.Strings(ksList) + assert.EqualValues(t, []string{"test_keyspace", "test_keyspace2"}, ksList) // Only get 1 keyspace for all types. - ks, err = FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace2"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) + targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{"test_keyspace2"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) assert.NoError(t, err) assert.EqualValues(t, []*querypb.Target{ { @@ -168,10 +173,11 @@ func TestFindAllTargets(t *testing.T) { Shard: "test_shard2", TabletType: topodatapb.TabletType_REPLICA, }, - }, ks) + }, targets) + assert.EqualValues(t, []string{"test_keyspace2"}, ksList) // Only get the REPLICA targets for any keyspace. - ks, err = FindAllTargets(ctx, rs, "cell1", []string{}, []topodatapb.TabletType{topodatapb.TabletType_REPLICA}) + targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{}, []topodatapb.TabletType{topodatapb.TabletType_REPLICA}) assert.NoError(t, err) assert.Equal(t, []*querypb.Target{ { @@ -180,10 +186,13 @@ func TestFindAllTargets(t *testing.T) { Shard: "test_shard2", TabletType: topodatapb.TabletType_REPLICA, }, - }, ks) + }, targets) + sort.Strings(ksList) + assert.EqualValues(t, []string{"test_keyspace", "test_keyspace2"}, ksList) // Get non-existent keyspace. - ks, err = FindAllTargets(ctx, rs, "cell1", []string{"doesnt-exist"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) + targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{"doesnt-exist"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) assert.NoError(t, err) - assert.Len(t, ks, 0) + assert.Len(t, targets, 0) + assert.EqualValues(t, []string{"doesnt-exist"}, ksList) } diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 084a5059fd8..21087fe5370 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -191,11 +191,24 @@ func (gw *TabletGateway) WaitForTablets(ctx context.Context, tabletTypesToWait [ } // Finds the targets to look for. - targets, err := srvtopo.FindAllTargets(ctx, gw.srvTopoServer, gw.localCell, discovery.KeyspacesToWatch, tabletTypesToWait) + targets, keyspaces, err := srvtopo.FindAllTargetsAndKeyspaces(ctx, gw.srvTopoServer, gw.localCell, discovery.KeyspacesToWatch, tabletTypesToWait) if err != nil { return err } - return gw.hc.WaitForAllServingTablets(ctx, targets) + err = gw.hc.WaitForAllServingTablets(ctx, targets) + if err != nil { + return err + } + // After having waited for all serving tablets. We should also wait for the keyspace event watcher to have seen + // the updates and marked all the keyspaces as consistent (if we want to wait for primary tablets). + // Otherwise, we could be in a situation where even though the healthchecks have arrived, the keyspace event watcher hasn't finished processing them. + // So, if a primary tablet goes non-serving (because of a PRS or some other reason), we won't be able to start buffering. + // Waiting for the keyspaces to become consistent ensures that all the primary tablets for all the shards should be serving as seen by the keyspace event watcher + // and any disruption from now on, will make sure we start buffering properly. + if topoproto.IsTypeInList(topodatapb.TabletType_PRIMARY, tabletTypesToWait) && gw.kev != nil { + return gw.kev.WaitForConsistentKeyspaces(ctx, keyspaces) + } + return nil } // Close shuts down underlying connections. @@ -282,18 +295,21 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, if len(tablets) == 0 { // if we have a keyspace event watcher, check if the reason why our primary is not available is that it's currently being resharded // or if a reparent operation is in progress. - if kev := gw.kev; kev != nil { + // We only check for whether reshard is ongoing or primary is serving or not, only if the target is primary. We don't want to buffer + // replica queries, so it doesn't make any sense to check for resharding or reparenting in that case. + if kev := gw.kev; kev != nil && target.TabletType == topodatapb.TabletType_PRIMARY { if kev.TargetIsBeingResharded(ctx, target) { log.V(2).Infof("current keyspace is being resharded, retrying: %s: %s", target.Keyspace, debug.Stack()) err = vterrors.Errorf(vtrpcpb.Code_CLUSTER_EVENT, buffer.ClusterEventReshardingInProgress) continue } - primary, notServing := kev.PrimaryIsNotServing(ctx, target) - if notServing { + primary, shouldBuffer := kev.ShouldStartBufferingForTarget(ctx, target) + if shouldBuffer { err = vterrors.Errorf(vtrpcpb.Code_CLUSTER_EVENT, buffer.ClusterEventReparentInProgress) continue } - // if primary is serving, but we initially found no tablet, we're in an inconsistent state + // if the keyspace event manager doesn't think we should buffer queries, and also sees a primary tablet, + // but we initially found no tablet, we're in an inconsistent state // we then retry the entire loop if primary != nil { err = vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "inconsistent state detected, primary is serving but initially found no available tablet") diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go index 74e6751162a..fbca19ecbad 100644 --- a/go/vt/vtgate/tabletgateway_flaky_test.go +++ b/go/vt/vtgate/tabletgateway_flaky_test.go @@ -67,7 +67,7 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) { waitForBuffering := func(enabled bool) { timer := time.NewTimer(bufferingWaitTimeout) defer timer.Stop() - for _, buffering := tg.kev.PrimaryIsNotServing(ctx, target); buffering != enabled; _, buffering = tg.kev.PrimaryIsNotServing(ctx, target) { + for _, buffering := tg.kev.ShouldStartBufferingForTarget(ctx, target); buffering != enabled; _, buffering = tg.kev.ShouldStartBufferingForTarget(ctx, target) { select { case <-timer.C: require.Fail(t, "timed out waiting for buffering of enabled: %t", enabled) @@ -213,8 +213,8 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) { hc.Broadcast(primaryTablet) require.Len(t, tg.hc.GetHealthyTabletStats(target), 0, "GetHealthyTabletStats has tablets even though it shouldn't") - _, isNotServing := tg.kev.PrimaryIsNotServing(ctx, target) - require.True(t, isNotServing) + _, shouldStartBuffering := tg.kev.ShouldStartBufferingForTarget(ctx, target) + require.True(t, shouldStartBuffering) // add a result to the sandbox connection of the new primary sbcReplica.SetResults([]*sqltypes.Result{sqlResult1}) @@ -244,8 +244,8 @@ outer: case <-timeout: require.Fail(t, "timed out - could not verify the new primary") case <-time.After(10 * time.Millisecond): - newPrimary, notServing := tg.kev.PrimaryIsNotServing(ctx, target) - if newPrimary != nil && newPrimary.Uid == 1 && !notServing { + newPrimary, shouldBuffer := tg.kev.ShouldStartBufferingForTarget(ctx, target) + if newPrimary != nil && newPrimary.Uid == 1 && !shouldBuffer { break outer } } diff --git a/go/vt/vtgate/tabletgateway_test.go b/go/vt/vtgate/tabletgateway_test.go index 32d18dcc9ab..fc86ab358c8 100644 --- a/go/vt/vtgate/tabletgateway_test.go +++ b/go/vt/vtgate/tabletgateway_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/test/utils" + "vitess.io/vitess/go/vt/vttablet/queryservice" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/discovery" @@ -298,3 +299,58 @@ func verifyShardErrors(t *testing.T, err error, wantErrors []string, wantCode vt } require.Equal(t, vterrors.Code(err), wantCode, "wanted error code: %s, got: %v", wantCode, vterrors.Code(err)) } + +// TestWithRetry tests the functionality of withRetry function in different circumstances. +func TestWithRetry(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + tg := NewTabletGateway(ctx, discovery.NewFakeHealthCheck(nil), &fakeTopoServer{}, "cell") + tg.kev = discovery.NewKeyspaceEventWatcher(ctx, tg.srvTopoServer, tg.hc, tg.localCell) + defer func() { + cancel() + tg.Close(ctx) + }() + + testcases := []struct { + name string + target *querypb.Target + inTransaction bool + inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error) + expectedErr string + }{ + { + name: "Transaction on a replica", + target: &querypb.Target{ + Keyspace: "ks", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + }, + inTransaction: true, + inner: func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error) { + return false, nil + }, + expectedErr: "tabletGateway's query service can only be used for non-transactional queries on replicas", + }, { + name: "No replica tablets available", + target: &querypb.Target{ + Keyspace: "ks", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + }, + inTransaction: false, + inner: func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error) { + return false, nil + }, + expectedErr: `target: ks.0.replica: no healthy tablet available for 'keyspace:"ks" shard:"0" tablet_type:REPLICA'`, + }, + } + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + err := tg.withRetry(ctx, tt.target, nil, "", tt.inTransaction, tt.inner) + if tt.expectedErr == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tt.expectedErr) + } + }) + } +}