From d18e3031682d1a196377a03af107352b78249bcd Mon Sep 17 00:00:00 2001
From: Matt Lord <mattalord@gmail.com>
Date: Wed, 17 Jan 2024 12:37:21 -0500
Subject: [PATCH 01/12] Flakes: De-flake
 TestGatewayBufferingWhenPrimarySwitchesServingState (#14968)

Signed-off-by: Matt Lord <mattalord@gmail.com>
---
 go/vt/vtgate/tabletgateway_flaky_test.go | 31 ++++++++++++++++++------
 1 file changed, 23 insertions(+), 8 deletions(-)

diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go
index f625b5599cd..74e6751162a 100644
--- a/go/vt/vtgate/tabletgateway_flaky_test.go
+++ b/go/vt/vtgate/tabletgateway_flaky_test.go
@@ -22,15 +22,14 @@ import (
 
 	"github.com/stretchr/testify/require"
 
-	"vitess.io/vitess/go/test/utils"
-
 	"vitess.io/vitess/go/mysql/collations"
-
 	"vitess.io/vitess/go/sqltypes"
+	"vitess.io/vitess/go/test/utils"
 	"vitess.io/vitess/go/vt/discovery"
+	"vitess.io/vitess/go/vt/vtgate/buffer"
+
 	querypb "vitess.io/vitess/go/vt/proto/query"
 	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
-	"vitess.io/vitess/go/vt/vtgate/buffer"
 )
 
 // TestGatewayBufferingWhenPrimarySwitchesServingState is used to test that the buffering mechanism buffers the queries when a primary goes to a non serving state and
@@ -64,6 +63,20 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) {
 	// add a primary tabelt which is serving
 	sbc := hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil)
 
+	bufferingWaitTimeout := 60 * time.Second
+	waitForBuffering := func(enabled bool) {
+		timer := time.NewTimer(bufferingWaitTimeout)
+		defer timer.Stop()
+		for _, buffering := tg.kev.PrimaryIsNotServing(ctx, target); buffering != enabled; _, buffering = tg.kev.PrimaryIsNotServing(ctx, target) {
+			select {
+			case <-timer.C:
+				require.Fail(t, "timed out waiting for buffering of enabled: %t", enabled)
+			default:
+			}
+			time.Sleep(10 * time.Millisecond)
+		}
+	}
+
 	// add a result to the sandbox connection
 	sqlResult1 := &sqltypes.Result{
 		Fields: []*querypb.Field{{
@@ -94,6 +107,8 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) {
 	// add another result to the sandbox connection
 	sbc.SetResults([]*sqltypes.Result{sqlResult1})
 
+	waitForBuffering(true)
+
 	// execute the query in a go routine since it should be buffered, and check that it eventually succeed
 	queryChan := make(chan struct{})
 	go func() {
@@ -102,17 +117,17 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) {
 	}()
 
 	// set the serving type for the primary tablet true and broadcast it so that the buffering code registers this change
-	// this should stop the buffering and the query executed in the go routine should work. This should be done with some delay so
-	// that we know that the query was buffered
-	time.Sleep(1 * time.Second)
+	// this should stop the buffering and the query executed in the go routine should work.
 	hc.SetServing(primaryTablet, true)
 	hc.Broadcast(primaryTablet)
 
+	waitForBuffering(false)
+
 	// wait for the query to execute before checking for results
 	select {
 	case <-queryChan:
 		require.NoError(t, err)
-		require.Equal(t, res, sqlResult1)
+		require.Equal(t, sqlResult1, res)
 	case <-time.After(15 * time.Second):
 		t.Fatalf("timed out waiting for query to execute")
 	}

From 5d2abca4402390c5f55bfac1baf0c359d7a68820 Mon Sep 17 00:00:00 2001
From: Matt Lord <mattalord@gmail.com>
Date: Thu, 25 Apr 2024 12:58:32 -0400
Subject: [PATCH 02/12] VReplication: Improve query buffering behavior during
 MoveTables traffic switching (#15701)

Signed-off-by: Matt Lord <mattalord@gmail.com>
---
 go/test/endtoend/vreplication/cluster_test.go |   5 +-
 go/test/endtoend/vreplication/helper_test.go  | 127 ++++++-----
 .../vreplication/movetables_buffering_test.go |   9 +-
 .../vreplication/partial_movetables_test.go   |   6 +
 go/vt/discovery/keyspace_events.go            | 208 +++++++++++-------
 go/vt/discovery/keyspace_events_test.go       |  86 +++++++-
 go/vt/srvtopo/watch_srvvschema.go             |   3 +-
 go/vt/topo/etcd2topo/watch.go                 |  20 +-
 go/vt/vtctl/workflow/server.go                |   8 +-
 go/vt/vtctl/workflow/switcher.go              |   4 +-
 go/vt/vtctl/workflow/switcher_dry_run.go      |   5 +-
 go/vt/vtctl/workflow/switcher_interface.go    |   2 +-
 go/vt/vtctl/workflow/traffic_switcher.go      |   7 +-
 go/vt/vtgate/buffer/buffer.go                 |   4 +
 go/vt/vtgate/buffer/flags.go                  |   3 +
 go/vt/vtgate/buffer/shard_buffer.go           |   7 +-
 go/vt/vtgate/executor_test.go                 |   2 +-
 go/vt/vtgate/executor_vschema_ddl_test.go     |  27 +--
 go/vt/vtgate/plan_execute.go                  |  98 ++++++---
 19 files changed, 423 insertions(+), 208 deletions(-)

diff --git a/go/test/endtoend/vreplication/cluster_test.go b/go/test/endtoend/vreplication/cluster_test.go
index af93ac40726..74923cf6a6c 100644
--- a/go/test/endtoend/vreplication/cluster_test.go
+++ b/go/test/endtoend/vreplication/cluster_test.go
@@ -54,8 +54,9 @@ var (
 	sidecarDBIdentifier   = sqlparser.String(sqlparser.NewIdentifierCS(sidecarDBName))
 	mainClusterConfig     *ClusterConfig
 	externalClusterConfig *ClusterConfig
-	extraVTGateArgs       = []string{"--tablet_refresh_interval", "10ms", "--enable_buffer", "--buffer_window", loadTestBufferingWindowDurationStr,
-		"--buffer_size", "100000", "--buffer_min_time_between_failovers", "0s", "--buffer_max_failover_duration", loadTestBufferingWindowDurationStr}
+	extraVTGateArgs       = []string{"--tablet_refresh_interval", "10ms", "--enable_buffer", "--buffer_window", loadTestBufferingWindowDuration.String(),
+		"--buffer_size", "250000", "--buffer_min_time_between_failovers", "1s", "--buffer_max_failover_duration", loadTestBufferingWindowDuration.String(),
+		"--buffer_drain_concurrency", "10"}
 	extraVtctldArgs = []string{"--remote_operation_timeout", "600s", "--topo_etcd_lease_ttl", "120"}
 	// This variable can be used within specific tests to alter vttablet behavior
 	extraVTTabletArgs = []string{}
diff --git a/go/test/endtoend/vreplication/helper_test.go b/go/test/endtoend/vreplication/helper_test.go
index 445a15f7767..8c96d06226a 100644
--- a/go/test/endtoend/vreplication/helper_test.go
+++ b/go/test/endtoend/vreplication/helper_test.go
@@ -18,16 +18,17 @@ package vreplication
 
 import (
 	"context"
-	"crypto/rand"
 	"encoding/hex"
 	"encoding/json"
 	"fmt"
 	"io"
+	"math/rand"
 	"net/http"
 	"os/exec"
 	"regexp"
 	"sort"
 	"strings"
+	"sync"
 	"sync/atomic"
 	"testing"
 	"time"
@@ -75,9 +76,10 @@ func execQuery(t *testing.T, conn *mysql.Conn, query string) *sqltypes.Result {
 
 func getConnection(t *testing.T, hostname string, port int) *mysql.Conn {
 	vtParams := mysql.ConnParams{
-		Host:  hostname,
-		Port:  port,
-		Uname: "vt_dba",
+		Host:             hostname,
+		Port:             port,
+		Uname:            "vt_dba",
+		ConnectTimeoutMs: 1000,
 	}
 	ctx := context.Background()
 	conn, err := mysql.Connect(ctx, &vtParams)
@@ -714,31 +716,35 @@ func isBinlogRowImageNoBlob(t *testing.T, tablet *cluster.VttabletProcess) bool
 }
 
 const (
-	loadTestBufferingWindowDurationStr = "30s"
-	loadTestPostBufferingInsertWindow  = 60 * time.Second // should be greater than loadTestBufferingWindowDurationStr
-	loadTestWaitForCancel              = 30 * time.Second
-	loadTestWaitBetweenQueries         = 2 * time.Millisecond
+	loadTestBufferingWindowDuration = 10 * time.Second
+	loadTestAvgWaitBetweenQueries   = 500 * time.Microsecond
+	loadTestDefaultConnections      = 100
 )
 
 type loadGenerator struct {
-	t      *testing.T
-	vc     *VitessCluster
-	ctx    context.Context
-	cancel context.CancelFunc
+	t           *testing.T
+	vc          *VitessCluster
+	ctx         context.Context
+	cancel      context.CancelFunc
+	connections int
+	wg          sync.WaitGroup
 }
 
 func newLoadGenerator(t *testing.T, vc *VitessCluster) *loadGenerator {
 	return &loadGenerator{
-		t:  t,
-		vc: vc,
+		t:           t,
+		vc:          vc,
+		connections: loadTestDefaultConnections,
 	}
 }
 
 func (lg *loadGenerator) stop() {
-	time.Sleep(loadTestPostBufferingInsertWindow) // wait for buffering to stop and additional records to be inserted by startLoad after traffic is switched
+	// Wait for buffering to stop and additional records to be inserted by start
+	// after traffic is switched.
+	time.Sleep(loadTestBufferingWindowDuration * 2)
 	log.Infof("Canceling load")
 	lg.cancel()
-	time.Sleep(loadTestWaitForCancel) // wait for cancel to take effect
+	lg.wg.Wait()
 	log.Flush()
 
 }
@@ -746,62 +752,77 @@ func (lg *loadGenerator) stop() {
 func (lg *loadGenerator) start() {
 	t := lg.t
 	lg.ctx, lg.cancel = context.WithCancel(context.Background())
+	var connectionCount atomic.Int64
 
 	var id int64
-	log.Infof("startLoad: starting")
+	log.Infof("loadGenerator: starting")
 	queryTemplate := "insert into loadtest(id, name) values (%d, 'name-%d')"
 	var totalQueries, successfulQueries int64
 	var deniedErrors, ambiguousErrors, reshardedErrors, tableNotFoundErrors, otherErrors int64
+	lg.wg.Add(1)
 	defer func() {
-
-		log.Infof("startLoad: totalQueries: %d, successfulQueries: %d, deniedErrors: %d, ambiguousErrors: %d, reshardedErrors: %d, tableNotFoundErrors: %d, otherErrors: %d",
+		defer lg.wg.Done()
+		log.Infof("loadGenerator: totalQueries: %d, successfulQueries: %d, deniedErrors: %d, ambiguousErrors: %d, reshardedErrors: %d, tableNotFoundErrors: %d, otherErrors: %d",
 			totalQueries, successfulQueries, deniedErrors, ambiguousErrors, reshardedErrors, tableNotFoundErrors, otherErrors)
 	}()
-	logOnce := true
 	for {
 		select {
 		case <-lg.ctx.Done():
-			log.Infof("startLoad: context cancelled")
-			log.Infof("startLoad: deniedErrors: %d, ambiguousErrors: %d, reshardedErrors: %d, tableNotFoundErrors: %d, otherErrors: %d",
+			log.Infof("loadGenerator: context cancelled")
+			log.Infof("loadGenerator: deniedErrors: %d, ambiguousErrors: %d, reshardedErrors: %d, tableNotFoundErrors: %d, otherErrors: %d",
 				deniedErrors, ambiguousErrors, reshardedErrors, tableNotFoundErrors, otherErrors)
 			require.Equal(t, int64(0), deniedErrors)
 			require.Equal(t, int64(0), otherErrors)
+			require.Equal(t, int64(0), reshardedErrors)
 			require.Equal(t, totalQueries, successfulQueries)
 			return
 		default:
-			go func() {
-				conn := vc.GetVTGateConn(t)
-				defer conn.Close()
-				atomic.AddInt64(&id, 1)
-				query := fmt.Sprintf(queryTemplate, id, id)
-				_, err := conn.ExecuteFetch(query, 1, false)
-				atomic.AddInt64(&totalQueries, 1)
-				if err != nil {
-					sqlErr := err.(*sqlerror.SQLError)
-					if strings.Contains(strings.ToLower(err.Error()), "denied tables") {
-						log.Infof("startLoad: denied tables error executing query: %d:%v", sqlErr.Number(), err)
-						atomic.AddInt64(&deniedErrors, 1)
-					} else if strings.Contains(strings.ToLower(err.Error()), "ambiguous") {
-						// this can happen when a second keyspace is setup with the same tables, but there are no routing rules
-						// set yet by MoveTables. So we ignore these errors.
-						atomic.AddInt64(&ambiguousErrors, 1)
-					} else if strings.Contains(strings.ToLower(err.Error()), "current keyspace is being resharded") {
-						atomic.AddInt64(&reshardedErrors, 1)
-					} else if strings.Contains(strings.ToLower(err.Error()), "not found") {
-						atomic.AddInt64(&tableNotFoundErrors, 1)
-					} else {
-						if logOnce {
-							log.Infof("startLoad: error executing query: %d:%v", sqlErr.Number(), err)
-							logOnce = false
+			if int(connectionCount.Load()) < lg.connections {
+				connectionCount.Add(1)
+				lg.wg.Add(1)
+				go func() {
+					defer lg.wg.Done()
+					defer connectionCount.Add(-1)
+					conn := vc.GetVTGateConn(t)
+					defer conn.Close()
+					for {
+						select {
+						case <-lg.ctx.Done():
+							return
+						default:
+						}
+						newID := atomic.AddInt64(&id, 1)
+						query := fmt.Sprintf(queryTemplate, newID, newID)
+						_, err := conn.ExecuteFetch(query, 1, false)
+						atomic.AddInt64(&totalQueries, 1)
+						if err != nil {
+							sqlErr := err.(*sqlerror.SQLError)
+							if strings.Contains(strings.ToLower(err.Error()), "denied tables") {
+								if debugMode {
+									t.Logf("loadGenerator: denied tables error executing query: %d:%v", sqlErr.Number(), err)
+								}
+								atomic.AddInt64(&deniedErrors, 1)
+							} else if strings.Contains(strings.ToLower(err.Error()), "ambiguous") {
+								// This can happen when a second keyspace is setup with the same tables, but
+								// there are no routing rules set yet by MoveTables. So we ignore these errors.
+								atomic.AddInt64(&ambiguousErrors, 1)
+							} else if strings.Contains(strings.ToLower(err.Error()), "current keyspace is being resharded") {
+								atomic.AddInt64(&reshardedErrors, 1)
+							} else if strings.Contains(strings.ToLower(err.Error()), "not found") {
+								atomic.AddInt64(&tableNotFoundErrors, 1)
+							} else {
+								if debugMode {
+									t.Logf("loadGenerator: error executing query: %d:%v", sqlErr.Number(), err)
+								}
+								atomic.AddInt64(&otherErrors, 1)
+							}
+						} else {
+							atomic.AddInt64(&successfulQueries, 1)
 						}
-						atomic.AddInt64(&otherErrors, 1)
+						time.Sleep(time.Duration(int64(float64(loadTestAvgWaitBetweenQueries.Microseconds()) * rand.Float64())))
 					}
-					time.Sleep(loadTestWaitBetweenQueries)
-				} else {
-					atomic.AddInt64(&successfulQueries, 1)
-				}
-			}()
-			time.Sleep(loadTestWaitBetweenQueries)
+				}()
+			}
 		}
 	}
 }
diff --git a/go/test/endtoend/vreplication/movetables_buffering_test.go b/go/test/endtoend/vreplication/movetables_buffering_test.go
index 3171c35e35b..9b9b1c69163 100644
--- a/go/test/endtoend/vreplication/movetables_buffering_test.go
+++ b/go/test/endtoend/vreplication/movetables_buffering_test.go
@@ -2,6 +2,7 @@ package vreplication
 
 import (
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/require"
 
@@ -34,8 +35,12 @@ func TestMoveTablesBuffering(t *testing.T) {
 	catchup(t, targetTab2, workflowName, "MoveTables")
 	vdiffSideBySide(t, ksWorkflow, "")
 	waitForLowLag(t, "customer", workflowName)
-	tstWorkflowSwitchReads(t, "", "")
-	tstWorkflowSwitchWrites(t)
+	for i := 0; i < 10; i++ {
+		tstWorkflowSwitchReadsAndWrites(t)
+		time.Sleep(loadTestBufferingWindowDuration + 1*time.Second)
+		tstWorkflowReverseReadsAndWrites(t)
+		time.Sleep(loadTestBufferingWindowDuration + 1*time.Second)
+	}
 	log.Infof("SwitchWrites done")
 	lg.stop()
 
diff --git a/go/test/endtoend/vreplication/partial_movetables_test.go b/go/test/endtoend/vreplication/partial_movetables_test.go
index 40bb585495e..62d9a0dcafa 100644
--- a/go/test/endtoend/vreplication/partial_movetables_test.go
+++ b/go/test/endtoend/vreplication/partial_movetables_test.go
@@ -20,6 +20,7 @@ import (
 	"fmt"
 	"strings"
 	"testing"
+	"time"
 
 	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
 
@@ -64,10 +65,12 @@ func testCancel(t *testing.T) {
 	mt.SwitchReadsAndWrites()
 	checkDenyList(targetKeyspace, false)
 	checkDenyList(sourceKeyspace, true)
+	time.Sleep(loadTestBufferingWindowDuration + 1*time.Second)
 
 	mt.ReverseReadsAndWrites()
 	checkDenyList(targetKeyspace, true)
 	checkDenyList(sourceKeyspace, false)
+	time.Sleep(loadTestBufferingWindowDuration + 1*time.Second)
 
 	mt.Cancel()
 	checkDenyList(targetKeyspace, false)
@@ -108,6 +111,7 @@ func TestPartialMoveTablesBasic(t *testing.T) {
 	// sharded customer keyspace.
 	createMoveTablesWorkflow(t, "customer,loadtest,customer2")
 	tstWorkflowSwitchReadsAndWrites(t)
+	time.Sleep(loadTestBufferingWindowDuration + 1*time.Second)
 	tstWorkflowComplete(t)
 
 	emptyGlobalRoutingRules := "{}\n"
@@ -218,6 +222,7 @@ func TestPartialMoveTablesBasic(t *testing.T) {
 	expectedSwitchOutput := fmt.Sprintf("SwitchTraffic was successful for workflow %s.%s\n\nStart State: Reads Not Switched. Writes Not Switched\nCurrent State: Reads partially switched, for shards: %s. Writes partially switched, for shards: %s\n\n",
 		targetKs, wfName, shard, shard)
 	require.Equal(t, expectedSwitchOutput, lastOutput)
+	time.Sleep(loadTestBufferingWindowDuration + 1*time.Second)
 
 	// Confirm global routing rules -- everything should still be routed
 	// to the source side, customer, globally.
@@ -296,6 +301,7 @@ func TestPartialMoveTablesBasic(t *testing.T) {
 	expectedSwitchOutput = fmt.Sprintf("SwitchTraffic was successful for workflow %s.%s\n\nStart State: Reads partially switched, for shards: 80-. Writes partially switched, for shards: 80-\nCurrent State: All Reads Switched. All Writes Switched\n\n",
 		targetKs, wfName)
 	require.Equal(t, expectedSwitchOutput, lastOutput)
+	time.Sleep(loadTestBufferingWindowDuration + 1*time.Second)
 
 	// Confirm global routing rules: everything should still be routed
 	// to the source side, customer, globally.
diff --git a/go/vt/discovery/keyspace_events.go b/go/vt/discovery/keyspace_events.go
index 014284ed5ee..9fa457c1589 100644
--- a/go/vt/discovery/keyspace_events.go
+++ b/go/vt/discovery/keyspace_events.go
@@ -21,6 +21,7 @@ import (
 	"fmt"
 	"sync"
 
+	"golang.org/x/sync/errgroup"
 	"google.golang.org/protobuf/proto"
 
 	"vitess.io/vitess/go/vt/key"
@@ -93,18 +94,8 @@ func NewKeyspaceEventWatcher(ctx context.Context, topoServer srvtopo.Server, hc
 	return kew
 }
 
-type MoveTablesStatus int
-
-const (
-	MoveTablesUnknown MoveTablesStatus = iota
-	// MoveTablesSwitching is set when the write traffic is the middle of being switched from the source to the target
-	MoveTablesSwitching
-	// MoveTablesSwitched is set when write traffic has been completely switched to the target
-	MoveTablesSwitched
-)
-
 // keyspaceState is the internal state for all the keyspaces that the KEW is
-// currently watching
+// currently watching.
 type keyspaceState struct {
 	kew      *KeyspaceEventWatcher
 	keyspace string
@@ -120,7 +111,7 @@ type keyspaceState struct {
 	moveTablesState *MoveTablesState
 }
 
-// Format prints the internal state for this keyspace for debug purposes
+// Format prints the internal state for this keyspace for debug purposes.
 func (kss *keyspaceState) Format(f fmt.State, verb rune) {
 	kss.mu.Lock()
 	defer kss.mu.Unlock()
@@ -137,9 +128,9 @@ func (kss *keyspaceState) Format(f fmt.State, verb rune) {
 	fmt.Fprintf(f, "]\n")
 }
 
-// beingResharded returns whether this keyspace is thought to be in the middle of a resharding
-// operation. currentShard is the name of the shard that belongs to this keyspace and which
-// we are trying to access. currentShard can _only_ be a primary shard.
+// beingResharded returns whether this keyspace is thought to be in the middle of a
+// resharding operation. currentShard is the name of the shard that belongs to this
+// keyspace and which we are trying to access. currentShard can _only_ be a primary shard.
 func (kss *keyspaceState) beingResharded(currentShard string) bool {
 	kss.mu.Lock()
 	defer kss.mu.Unlock()
@@ -179,11 +170,19 @@ type shardState struct {
 	currentPrimary       *topodatapb.TabletAlias
 }
 
-// Subscribe returns a channel that will receive any KeyspaceEvents for all keyspaces in the current cell
+// Subscribe returns a channel that will receive any KeyspaceEvents for all keyspaces in the
+// current cell.
 func (kew *KeyspaceEventWatcher) Subscribe() chan *KeyspaceEvent {
 	kew.subsMu.Lock()
 	defer kew.subsMu.Unlock()
-	c := make(chan *KeyspaceEvent, 2)
+	// Use a decent size buffer to:
+	// 1. Avoid blocking the KEW
+	// 2. While not losing/missing any events
+	// 3. And processing them in the order received
+	// TODO: do we care about intermediate events?
+	// If not, then we could instead e.g. pull the first/oldest event
+	// from the channel, discard it, and add the current/latest.
+	c := make(chan *KeyspaceEvent, 10)
 	kew.subs[c] = struct{}{}
 	return c
 }
@@ -195,14 +194,11 @@ func (kew *KeyspaceEventWatcher) Unsubscribe(c chan *KeyspaceEvent) {
 	delete(kew.subs, c)
 }
 
-func (kew *KeyspaceEventWatcher) broadcast(th *KeyspaceEvent) {
+func (kew *KeyspaceEventWatcher) broadcast(ev *KeyspaceEvent) {
 	kew.subsMu.Lock()
 	defer kew.subsMu.Unlock()
 	for c := range kew.subs {
-		select {
-		case c <- th:
-		default:
-		}
+		c <- ev
 	}
 }
 
@@ -240,7 +236,8 @@ func (kew *KeyspaceEventWatcher) run(ctx context.Context) {
 }
 
 // ensureConsistentLocked checks if the current keyspace has recovered from an availability
-// event, and if so, returns information about the availability event to all subscribers
+// event, and if so, returns information about the availability event to all subscribers.
+// Note: you MUST be holding the ks.mu when calling this function.
 func (kss *keyspaceState) ensureConsistentLocked() {
 	// if this keyspace is consistent, there's no ongoing availability event
 	if kss.consistent {
@@ -285,7 +282,8 @@ func (kss *keyspaceState) ensureConsistentLocked() {
 		}
 	}
 
-	// clone the current moveTablesState, if any, to handle race conditions where it can get updated while we're broadcasting
+	// Clone the current moveTablesState, if any, to handle race conditions where it can get
+	// updated while we're broadcasting.
 	var moveTablesState MoveTablesState
 	if kss.moveTablesState != nil {
 		moveTablesState = *kss.moveTablesState
@@ -312,8 +310,8 @@ func (kss *keyspaceState) ensureConsistentLocked() {
 			Serving: sstate.serving,
 		})
 
-		log.Infof("keyspace event resolved: %s/%s is now consistent (serving: %v)",
-			sstate.target.Keyspace, sstate.target.Keyspace,
+		log.Infof("keyspace event resolved: %s is now consistent (serving: %t)",
+			topoproto.KeyspaceShardString(sstate.target.Keyspace, sstate.target.Shard),
 			sstate.serving,
 		)
 
@@ -325,9 +323,10 @@ func (kss *keyspaceState) ensureConsistentLocked() {
 	kss.kew.broadcast(ksevent)
 }
 
-// onHealthCheck is the callback that updates this keyspace with event data from the HealthCheck stream.
-// the HealthCheck stream applies to all the keyspaces in the cluster and emits TabletHealth events to our
-// parent KeyspaceWatcher, which will mux them into their corresponding keyspaceState
+// onHealthCheck is the callback that updates this keyspace with event data from the HealthCheck
+// stream. The HealthCheck stream applies to all the keyspaces in the cluster and emits
+// TabletHealth events to our parent KeyspaceWatcher, which will mux them into their
+// corresponding keyspaceState.
 func (kss *keyspaceState) onHealthCheck(th *TabletHealth) {
 	// we only care about health events on the primary
 	if th.Target.TabletType != topodatapb.TabletType_PRIMARY {
@@ -371,6 +370,17 @@ func (kss *keyspaceState) onHealthCheck(th *TabletHealth) {
 	kss.ensureConsistentLocked()
 }
 
+type MoveTablesStatus int
+
+const (
+	MoveTablesUnknown MoveTablesStatus = iota
+	// MoveTablesSwitching is set when the write traffic is the middle of being switched from
+	// the source to the target.
+	MoveTablesSwitching
+	// MoveTablesSwitched is set when write traffic has been completely switched to the target.
+	MoveTablesSwitched
+)
+
 type MoveTablesType int
 
 const (
@@ -384,33 +394,66 @@ type MoveTablesState struct {
 	State MoveTablesStatus
 }
 
+func (mts MoveTablesState) String() string {
+	var typ, state string
+	switch mts.Typ {
+	case MoveTablesRegular:
+		typ = "Regular"
+	case MoveTablesShardByShard:
+		typ = "ShardByShard"
+	default:
+		typ = "None"
+	}
+	switch mts.State {
+	case MoveTablesSwitching:
+		state = "Switching"
+	case MoveTablesSwitched:
+		state = "Switched"
+	default:
+		state = "Unknown"
+	}
+	return fmt.Sprintf("{Type: %s, State: %s}", typ, state)
+}
+
 func (kss *keyspaceState) getMoveTablesStatus(vs *vschemapb.SrvVSchema) (*MoveTablesState, error) {
 	mtState := &MoveTablesState{
 		Typ:   MoveTablesNone,
 		State: MoveTablesUnknown,
 	}
 
-	// if there are no routing rules defined, then movetables is not in progress, exit early
+	// If there are no routing rules defined, then movetables is not in progress, exit early.
 	if len(vs.GetRoutingRules().GetRules()) == 0 && len(vs.GetShardRoutingRules().GetRules()) == 0 {
 		return mtState, nil
 	}
 
 	shortCtx, cancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout)
 	defer cancel()
-	ts, _ := kss.kew.ts.GetTopoServer()
-
-	// collect all current shard information from the topo
+	ts, err := kss.kew.ts.GetTopoServer()
+	if err != nil {
+		return mtState, err
+	}
+	// Collect all current shard information from the topo.
 	var shardInfos []*topo.ShardInfo
+	mu := sync.Mutex{}
+	eg, ectx := errgroup.WithContext(shortCtx)
 	for _, sstate := range kss.shards {
-		si, err := ts.GetShard(shortCtx, kss.keyspace, sstate.target.Shard)
-		if err != nil {
-			return nil, err
-		}
-		shardInfos = append(shardInfos, si)
+		eg.Go(func() error {
+			si, err := ts.GetShard(ectx, kss.keyspace, sstate.target.Shard)
+			if err != nil {
+				return err
+			}
+			mu.Lock()
+			defer mu.Unlock()
+			shardInfos = append(shardInfos, si)
+			return nil
+		})
+	}
+	if err := eg.Wait(); err != nil {
+		return mtState, err
 	}
 
-	// check if any shard has denied tables and if so, record one of these to check where it currently points to
-	// using the (shard) routing rules
+	// Check if any shard has denied tables and if so, record one of these to check where it
+	// currently points to using the (shard) routing rules.
 	var shardsWithDeniedTables []string
 	var oneDeniedTable string
 	for _, si := range shardInfos {
@@ -425,11 +468,11 @@ func (kss *keyspaceState) getMoveTablesStatus(vs *vschemapb.SrvVSchema) (*MoveTa
 		return mtState, nil
 	}
 
-	// check if a shard by shard migration is in progress and if so detect if it has been switched
-	isPartialTables := vs.ShardRoutingRules != nil && len(vs.ShardRoutingRules.Rules) > 0
+	// Check if a shard by shard migration is in progress and if so detect if it has been switched.
+	isPartialTables := vs.GetShardRoutingRules() != nil && len(vs.GetShardRoutingRules().GetRules()) > 0
 
 	if isPartialTables {
-		srr := topotools.GetShardRoutingRulesMap(vs.ShardRoutingRules)
+		srr := topotools.GetShardRoutingRulesMap(vs.GetShardRoutingRules())
 		mtState.Typ = MoveTablesShardByShard
 		mtState.State = MoveTablesSwitched
 		for _, shard := range shardsWithDeniedTables {
@@ -440,31 +483,32 @@ func (kss *keyspaceState) getMoveTablesStatus(vs *vschemapb.SrvVSchema) (*MoveTa
 				break
 			}
 		}
-		log.Infof("getMoveTablesStatus: keyspace %s declaring partial move tables %v", kss.keyspace, mtState)
+		log.Infof("getMoveTablesStatus: keyspace %s declaring partial move tables %s", kss.keyspace, mtState.String())
 		return mtState, nil
 	}
 
-	// it wasn't a shard by shard migration, but since we have denied tables it must be a regular MoveTables
+	// It wasn't a shard by shard migration, but since we have denied tables it must be a
+	// regular MoveTables.
 	mtState.Typ = MoveTablesRegular
 	mtState.State = MoveTablesSwitching
-	rr := topotools.GetRoutingRulesMap(vs.RoutingRules)
+	rr := topotools.GetRoutingRulesMap(vs.GetRoutingRules())
 	if rr != nil {
 		r, ok := rr[oneDeniedTable]
-		// if a rule exists for the table and points to the target keyspace, writes have been switched
+		// If a rule exists for the table and points to the target keyspace, writes have been switched.
 		if ok && len(r) > 0 && r[0] != fmt.Sprintf("%s.%s", kss.keyspace, oneDeniedTable) {
 			mtState.State = MoveTablesSwitched
 			log.Infof("onSrvKeyspace::  keyspace %s writes have been switched for table %s, rule %v", kss.keyspace, oneDeniedTable, r[0])
 		}
 	}
-	log.Infof("getMoveTablesStatus: keyspace %s declaring regular move tables %v", kss.keyspace, mtState)
+	log.Infof("getMoveTablesStatus: keyspace %s declaring regular move tables %s", kss.keyspace, mtState.String())
 
 	return mtState, nil
 }
 
-// onSrvKeyspace is the callback that updates this keyspace with fresh topology data from our topology server.
-// this callback is called from a Watcher in the topo server whenever a change to the topology for this keyspace
-// occurs. this watcher is dedicated to this keyspace, and will only yield topology metadata changes for as
-// long as we're interested on this keyspace.
+// onSrvKeyspace is the callback that updates this keyspace with fresh topology data from our
+// topology server. this callback is called from a Watcher in the topo server whenever a change to
+// the topology for this keyspace occurs. This watcher is dedicated to this keyspace, and will
+// only yield topology metadata changes for as long as we're interested on this keyspace.
 func (kss *keyspaceState) onSrvKeyspace(newKeyspace *topodatapb.SrvKeyspace, newError error) bool {
 	kss.mu.Lock()
 	defer kss.mu.Unlock()
@@ -478,23 +522,25 @@ func (kss *keyspaceState) onSrvKeyspace(newKeyspace *topodatapb.SrvKeyspace, new
 		return false
 	}
 
-	// if there's another kind of error while watching this keyspace, we assume it's temporary and related
-	// to the topology server, not to the keyspace itself. we'll keep waiting for more topology events.
+	// If there's another kind of error while watching this keyspace, we assume it's temporary and
+	// related to the topology server, not to the keyspace itself. we'll keep waiting for more
+	// topology events.
 	if newError != nil {
 		kss.lastError = newError
 		log.Errorf("error while watching keyspace %q: %v", kss.keyspace, newError)
 		return true
 	}
 
-	// if the topology metadata for our keyspace is identical to the last one we saw there's nothing to do
-	// here. this is a side-effect of the way ETCD watchers work.
+	// If the topology metadata for our keyspace is identical to the last one we saw there's nothing to
+	// do here. this is a side-effect of the way ETCD watchers work.
 	if proto.Equal(kss.lastKeyspace, newKeyspace) {
 		// no changes
 		return true
 	}
 
-	// we only mark this keyspace as inconsistent if there has been a topology change in the PRIMARY for
-	// this keyspace, but we store the topology metadata for both primary and replicas for future-proofing.
+	// we only mark this keyspace as inconsistent if there has been a topology change in the PRIMARY
+	// for this keyspace, but we store the topology metadata for both primary and replicas for
+	// future-proofing.
 	var oldPrimary, newPrimary *topodatapb.SrvKeyspace_KeyspacePartition
 	if kss.lastKeyspace != nil {
 		oldPrimary = topoproto.SrvKeyspaceGetPartition(kss.lastKeyspace, topodatapb.TabletType_PRIMARY)
@@ -525,20 +571,24 @@ func (kss *keyspaceState) isServing() bool {
 
 // onSrvVSchema is called from a Watcher in the topo server whenever the SrvVSchema is updated by Vitess.
 // For the purposes here, we are interested in updates to the RoutingRules or ShardRoutingRules.
-// In addition, the traffic switcher updates SrvVSchema when the DeniedTables attributes in a Shard record is
-// modified.
+// In addition, the traffic switcher updates SrvVSchema when the DeniedTables attributes in a Shard
+// record is modified.
 func (kss *keyspaceState) onSrvVSchema(vs *vschemapb.SrvVSchema, err error) bool {
-	// the vschema can be nil if the server is currently shutting down
+	// The vschema can be nil if the server is currently shutting down.
 	if vs == nil {
 		return true
 	}
 
 	kss.mu.Lock()
 	defer kss.mu.Unlock()
-	kss.moveTablesState, _ = kss.getMoveTablesStatus(vs)
+	var kerr error
+	if kss.moveTablesState, kerr = kss.getMoveTablesStatus(vs); err != nil {
+		log.Errorf("onSrvVSchema: keyspace %s failed to get move tables status: %v", kss.keyspace, kerr)
+	}
 	if kss.moveTablesState != nil && kss.moveTablesState.Typ != MoveTablesNone {
-		// mark the keyspace as inconsistent. ensureConsistentLocked() checks if the workflow is switched,
-		// and if so, it will send an event to the buffering subscribers to indicate that buffering can be stopped.
+		// Mark the keyspace as inconsistent. ensureConsistentLocked() checks if the workflow is
+		// switched, and if so, it will send an event to the buffering subscribers to indicate that
+		// buffering can be stopped.
 		kss.consistent = false
 		kss.ensureConsistentLocked()
 	}
@@ -560,8 +610,9 @@ func newKeyspaceState(ctx context.Context, kew *KeyspaceEventWatcher, cell, keys
 	return kss
 }
 
-// processHealthCheck is the callback that is called by the global HealthCheck stream that was initiated
-// by this KeyspaceEventWatcher. it redirects the TabletHealth event to the corresponding keyspaceState
+// processHealthCheck is the callback that is called by the global HealthCheck stream that was
+// initiated by this KeyspaceEventWatcher. It redirects the TabletHealth event to the
+// corresponding keyspaceState.
 func (kew *KeyspaceEventWatcher) processHealthCheck(ctx context.Context, th *TabletHealth) {
 	kss := kew.getKeyspaceStatus(ctx, th.Target.Keyspace)
 	if kss == nil {
@@ -571,8 +622,8 @@ func (kew *KeyspaceEventWatcher) processHealthCheck(ctx context.Context, th *Tab
 	kss.onHealthCheck(th)
 }
 
-// getKeyspaceStatus returns the keyspaceState object for the corresponding keyspace, allocating it
-// if we've never seen the keyspace before.
+// getKeyspaceStatus returns the keyspaceState object for the corresponding keyspace, allocating
+// it if we've never seen the keyspace before.
 func (kew *KeyspaceEventWatcher) getKeyspaceStatus(ctx context.Context, keyspace string) *keyspaceState {
 	kew.mu.Lock()
 	defer kew.mu.Unlock()
@@ -612,15 +663,15 @@ func (kew *KeyspaceEventWatcher) TargetIsBeingResharded(ctx context.Context, tar
 }
 
 // PrimaryIsNotServing checks if the reason why the given target is not accessible right now is
-// that the primary tablet for that shard is not serving. This is possible during a Planned Reparent Shard
-// operation. Just as the operation completes, a new primary will be elected, and it will send its own healthcheck
-// stating that it is serving. We should buffer requests until that point.
-// There are use cases where people do not run with a Primary server at all, so we must verify that
-// we only start buffering when a primary was present, and it went not serving.
-// The shard state keeps track of the current primary and the last externally reparented time, which we can use
-// to determine that there was a serving primary which now became non serving. This is only possible in a DemotePrimary
-// RPC which are only called from ERS and PRS. So buffering will stop when these operations succeed.
-// We return the tablet alias of the primary if it is serving.
+// that the primary tablet for that shard is not serving. This is possible during a Planned
+// Reparent Shard operation. Just as the operation completes, a new primary will be elected, and
+// it will send its own healthcheck stating that it is serving. We should buffer requests until
+// that point. There are use cases where people do not run with a Primary server at all, so we must
+// verify that we only start buffering when a primary was present, and it went not serving.
+// The shard state keeps track of the current primary and the last externally reparented time, which
+// we can use to determine that there was a serving primary which now became non serving. This is
+// only possible in a DemotePrimary RPC which are only called from ERS and PRS. So buffering will
+// stop when these operations succeed. We return the tablet alias of the primary if it is serving.
 func (kew *KeyspaceEventWatcher) PrimaryIsNotServing(ctx context.Context, target *querypb.Target) (*topodatapb.TabletAlias, bool) {
 	if target.TabletType != topodatapb.TabletType_PRIMARY {
 		return nil, false
@@ -632,7 +683,8 @@ func (kew *KeyspaceEventWatcher) PrimaryIsNotServing(ctx context.Context, target
 	ks.mu.Lock()
 	defer ks.mu.Unlock()
 	if state, ok := ks.shards[target.Shard]; ok {
-		// If the primary tablet was present then externallyReparented will be non-zero and currentPrimary will be not nil
+		// If the primary tablet was present then externallyReparented will be non-zero and
+		// currentPrimary will be not nil.
 		return state.currentPrimary, !state.serving && !ks.consistent && state.externallyReparented != 0 && state.currentPrimary != nil
 	}
 	return nil, false
diff --git a/go/vt/discovery/keyspace_events_test.go b/go/vt/discovery/keyspace_events_test.go
index 43af4bf49de..e9406ff1de2 100644
--- a/go/vt/discovery/keyspace_events_test.go
+++ b/go/vt/discovery/keyspace_events_test.go
@@ -19,6 +19,8 @@ package discovery
 import (
 	"context"
 	"encoding/hex"
+	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 
@@ -60,6 +62,67 @@ func TestSrvKeyspaceWithNilNewKeyspace(t *testing.T) {
 	require.True(t, kss.onSrvKeyspace(nil, nil))
 }
 
+// TestKeyspaceEventConcurrency confirms that the keyspace event watcher
+// does not fail to broadcast received keyspace events to subscribers.
+// This verifies that no events are lost when there's a high number of
+// concurrent keyspace events.
+func TestKeyspaceEventConcurrency(t *testing.T) {
+	cell := "cell1"
+	factory := faketopo.NewFakeTopoFactory()
+	factory.AddCell(cell)
+	sts := &fakeTopoServer{}
+	hc := NewFakeHealthCheck(make(chan *TabletHealth))
+	defer hc.Close()
+	kew := &KeyspaceEventWatcher{
+		hc:        hc,
+		ts:        sts,
+		localCell: cell,
+		keyspaces: make(map[string]*keyspaceState),
+		subs:      make(map[chan *KeyspaceEvent]struct{}),
+	}
+
+	// Subscribe to the watcher's broadcasted keyspace events.
+	receiver := kew.Subscribe()
+
+	updates := atomic.Uint32{}
+	updates.Store(0)
+	wg := sync.WaitGroup{}
+	concurrency := 100
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+	defer cancel()
+	go func() {
+		for {
+			select {
+			case <-ctx.Done():
+				return
+			case <-receiver:
+				updates.Add(1)
+			}
+		}
+	}()
+	// Start up concurent go-routines that will broadcast keyspace events.
+	for i := 1; i <= concurrency; i++ {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			kew.broadcast(&KeyspaceEvent{})
+		}()
+	}
+	wg.Wait()
+	for {
+		select {
+		case <-ctx.Done():
+			require.Equal(t, concurrency, int(updates.Load()), "expected %d updates, got %d", concurrency, updates.Load())
+			return
+		default:
+			if int(updates.Load()) == concurrency { // Pass
+				cancel()
+				return
+			}
+		}
+	}
+}
+
 // TestKeyspaceEventTypes confirms that the keyspace event watcher determines
 // that the unavailability event is caused by the correct scenario. We should
 // consider it to be caused by a resharding operation when the following
@@ -309,6 +372,26 @@ func (f *fakeTopoServer) GetSrvKeyspace(ctx context.Context, cell, keyspace stri
 	return ks, nil
 }
 
+// GetSrvVSchema returns the SrvVSchema for a cell.
+func (f *fakeTopoServer) GetSrvVSchema(ctx context.Context, cell string) (*vschemapb.SrvVSchema, error) {
+	vs := &vschemapb.SrvVSchema{
+		Keyspaces: map[string]*vschemapb.Keyspace{
+			"ks1": {
+				Sharded: true,
+			},
+		},
+		RoutingRules: &vschemapb.RoutingRules{
+			Rules: []*vschemapb.RoutingRule{
+				{
+					FromTable: "db1.t1",
+					ToTables:  []string{"db1.t1"},
+				},
+			},
+		},
+	}
+	return vs, nil
+}
+
 func (f *fakeTopoServer) WatchSrvKeyspace(ctx context.Context, cell, keyspace string, callback func(*topodatapb.SrvKeyspace, error) bool) {
 	ks, err := f.GetSrvKeyspace(ctx, cell, keyspace)
 	callback(ks, err)
@@ -318,5 +401,6 @@ func (f *fakeTopoServer) WatchSrvKeyspace(ctx context.Context, cell, keyspace st
 // the provided cell.  It will call the callback when
 // a new value or an error occurs.
 func (f *fakeTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) {
-
+	sv, err := f.GetSrvVSchema(ctx, cell)
+	callback(sv, err)
 }
diff --git a/go/vt/srvtopo/watch_srvvschema.go b/go/vt/srvtopo/watch_srvvschema.go
index 1b5536e623d..c758211375d 100644
--- a/go/vt/srvtopo/watch_srvvschema.go
+++ b/go/vt/srvtopo/watch_srvvschema.go
@@ -21,8 +21,9 @@ import (
 	"time"
 
 	"vitess.io/vitess/go/stats"
-	vschemapb "vitess.io/vitess/go/vt/proto/vschema"
 	"vitess.io/vitess/go/vt/topo"
+
+	vschemapb "vitess.io/vitess/go/vt/proto/vschema"
 )
 
 type SrvVSchemaWatcher struct {
diff --git a/go/vt/topo/etcd2topo/watch.go b/go/vt/topo/etcd2topo/watch.go
index cdc9be44b21..2fc58d437ff 100644
--- a/go/vt/topo/etcd2topo/watch.go
+++ b/go/vt/topo/etcd2topo/watch.go
@@ -51,7 +51,7 @@ func (s *Server) Watch(ctx context.Context, filePath string) (*topo.WatchData, <
 	}
 	wd := &topo.WatchData{
 		Contents: initial.Kvs[0].Value,
-		Version:  EtcdVersion(initial.Kvs[0].ModRevision),
+		Version:  EtcdVersion(initial.Kvs[0].Version),
 	}
 
 	// Create an outer context that will be canceled on return and will cancel all inner watches.
@@ -76,7 +76,7 @@ func (s *Server) Watch(ctx context.Context, filePath string) (*topo.WatchData, <
 		defer close(notifications)
 		defer outerCancel()
 
-		var currVersion = initial.Header.Revision
+		var rev = initial.Header.Revision
 		var watchRetries int
 		for {
 			select {
@@ -107,9 +107,9 @@ func (s *Server) Watch(ctx context.Context, filePath string) (*topo.WatchData, <
 					// Cancel inner context on retry and create new one.
 					watchCancel()
 					watchCtx, watchCancel = context.WithCancel(ctx)
-					newWatcher := s.cli.Watch(watchCtx, nodePath, clientv3.WithRev(currVersion))
+					newWatcher := s.cli.Watch(watchCtx, nodePath, clientv3.WithRev(rev))
 					if newWatcher == nil {
-						log.Warningf("watch %v failed and get a nil channel returned, currVersion: %v", nodePath, currVersion)
+						log.Warningf("watch %v failed and get a nil channel returned, rev: %v", nodePath, rev)
 					} else {
 						watcher = newWatcher
 					}
@@ -126,7 +126,7 @@ func (s *Server) Watch(ctx context.Context, filePath string) (*topo.WatchData, <
 					return
 				}
 
-				currVersion = wresp.Header.GetRevision()
+				rev = wresp.Header.GetRevision()
 
 				for _, ev := range wresp.Events {
 					switch ev.Type {
@@ -174,7 +174,7 @@ func (s *Server) WatchRecursive(ctx context.Context, dirpath string) ([]*topo.Wa
 		var wd topo.WatchDataRecursive
 		wd.Path = string(kv.Key)
 		wd.Contents = kv.Value
-		wd.Version = EtcdVersion(initial.Kvs[0].ModRevision)
+		wd.Version = EtcdVersion(initial.Kvs[0].Version)
 		initialwd = append(initialwd, &wd)
 	}
 
@@ -200,7 +200,7 @@ func (s *Server) WatchRecursive(ctx context.Context, dirpath string) ([]*topo.Wa
 		defer close(notifications)
 		defer outerCancel()
 
-		var currVersion = initial.Header.Revision
+		var rev = initial.Header.Revision
 		var watchRetries int
 		for {
 			select {
@@ -228,9 +228,9 @@ func (s *Server) WatchRecursive(ctx context.Context, dirpath string) ([]*topo.Wa
 					watchCancel()
 					watchCtx, watchCancel = context.WithCancel(ctx)
 
-					newWatcher := s.cli.Watch(watchCtx, nodePath, clientv3.WithRev(currVersion), clientv3.WithPrefix())
+					newWatcher := s.cli.Watch(watchCtx, nodePath, clientv3.WithRev(rev), clientv3.WithPrefix())
 					if newWatcher == nil {
-						log.Warningf("watch %v failed and get a nil channel returned, currVersion: %v", nodePath, currVersion)
+						log.Warningf("watch %v failed and get a nil channel returned, rev: %v", nodePath, rev)
 					} else {
 						watcher = newWatcher
 					}
@@ -247,7 +247,7 @@ func (s *Server) WatchRecursive(ctx context.Context, dirpath string) ([]*topo.Wa
 					return
 				}
 
-				currVersion = wresp.Header.GetRevision()
+				rev = wresp.Header.GetRevision()
 
 				for _, ev := range wresp.Events {
 					switch ev.Type {
diff --git a/go/vt/vtctl/workflow/server.go b/go/vt/vtctl/workflow/server.go
index 43d1f1a2b05..3f6297d345f 100644
--- a/go/vt/vtctl/workflow/server.go
+++ b/go/vt/vtctl/workflow/server.go
@@ -2891,7 +2891,9 @@ func (s *Server) WorkflowSwitchTraffic(ctx context.Context, req *vtctldatapb.Wor
 		return nil, err
 	}
 	if hasReplica || hasRdonly {
-		if rdDryRunResults, err = s.switchReads(ctx, req, ts, startState, timeout, false, direction); err != nil {
+		// If we're going to switch writes immediately after then we don't need to
+		// rebuild the SrvVSchema here as we will do it after switching writes.
+		if rdDryRunResults, err = s.switchReads(ctx, req, ts, startState, !hasPrimary /* rebuildSrvVSchema */, direction); err != nil {
 			return nil, err
 		}
 		log.Infof("Switch Reads done for workflow %s.%s", req.Keyspace, req.Workflow)
@@ -2945,7 +2947,7 @@ func (s *Server) WorkflowSwitchTraffic(ctx context.Context, req *vtctldatapb.Wor
 }
 
 // switchReads is a generic way of switching read traffic for a workflow.
-func (s *Server) switchReads(ctx context.Context, req *vtctldatapb.WorkflowSwitchTrafficRequest, ts *trafficSwitcher, state *State, timeout time.Duration, cancel bool, direction TrafficSwitchDirection) (*[]string, error) {
+func (s *Server) switchReads(ctx context.Context, req *vtctldatapb.WorkflowSwitchTrafficRequest, ts *trafficSwitcher, state *State, rebuildSrvVSchema bool, direction TrafficSwitchDirection) (*[]string, error) {
 	var roTabletTypes []topodatapb.TabletType
 	// When we are switching all traffic we also get the primary tablet type, which we need to
 	// filter out for switching reads.
@@ -3032,7 +3034,7 @@ func (s *Server) switchReads(ctx context.Context, req *vtctldatapb.WorkflowSwitc
 	if ts.MigrationType() == binlogdatapb.MigrationType_TABLES {
 		if ts.isPartialMigration {
 			ts.Logger().Infof("Partial migration, skipping switchTableReads as traffic is all or nothing per shard and overridden for reads AND writes in the ShardRoutingRule created when switching writes.")
-		} else if err := sw.switchTableReads(ctx, req.Cells, roTabletTypes, direction); err != nil {
+		} else if err := sw.switchTableReads(ctx, req.Cells, roTabletTypes, rebuildSrvVSchema, direction); err != nil {
 			return handleError("failed to switch read traffic for the tables", err)
 		}
 		return sw.logs(), nil
diff --git a/go/vt/vtctl/workflow/switcher.go b/go/vt/vtctl/workflow/switcher.go
index 0cbdce164dc..d7690458439 100644
--- a/go/vt/vtctl/workflow/switcher.go
+++ b/go/vt/vtctl/workflow/switcher.go
@@ -66,8 +66,8 @@ func (r *switcher) switchShardReads(ctx context.Context, cells []string, servedT
 	return r.ts.switchShardReads(ctx, cells, servedTypes, direction)
 }
 
-func (r *switcher) switchTableReads(ctx context.Context, cells []string, servedTypes []topodatapb.TabletType, direction TrafficSwitchDirection) error {
-	return r.ts.switchTableReads(ctx, cells, servedTypes, direction)
+func (r *switcher) switchTableReads(ctx context.Context, cells []string, servedTypes []topodatapb.TabletType, rebuildSrvVSchema bool, direction TrafficSwitchDirection) error {
+	return r.ts.switchTableReads(ctx, cells, servedTypes, rebuildSrvVSchema, direction)
 }
 
 func (r *switcher) startReverseVReplication(ctx context.Context) error {
diff --git a/go/vt/vtctl/workflow/switcher_dry_run.go b/go/vt/vtctl/workflow/switcher_dry_run.go
index 1c8a05e00c2..60213194071 100644
--- a/go/vt/vtctl/workflow/switcher_dry_run.go
+++ b/go/vt/vtctl/workflow/switcher_dry_run.go
@@ -76,7 +76,7 @@ func (dr *switcherDryRun) switchShardReads(ctx context.Context, cells []string,
 	return nil
 }
 
-func (dr *switcherDryRun) switchTableReads(ctx context.Context, cells []string, servedTypes []topodatapb.TabletType, direction TrafficSwitchDirection) error {
+func (dr *switcherDryRun) switchTableReads(ctx context.Context, cells []string, servedTypes []topodatapb.TabletType, rebuildSrvVSchema bool, direction TrafficSwitchDirection) error {
 	ks := dr.ts.TargetKeyspaceName()
 	if direction == DirectionBackward {
 		ks = dr.ts.SourceKeyspaceName()
@@ -88,6 +88,9 @@ func (dr *switcherDryRun) switchTableReads(ctx context.Context, cells []string,
 	tables := strings.Join(dr.ts.Tables(), ",")
 	dr.drLog.Logf("Switch reads for tables [%s] to keyspace %s for tablet types [%s]", tables, ks, strings.Join(tabletTypes, ","))
 	dr.drLog.Logf("Routing rules for tables [%s] will be updated", tables)
+	if rebuildSrvVSchema {
+		dr.drLog.Logf("Serving VSchema will be rebuilt for the %s keyspace", ks)
+	}
 	return nil
 }
 
diff --git a/go/vt/vtctl/workflow/switcher_interface.go b/go/vt/vtctl/workflow/switcher_interface.go
index 8d0f9e847be..9af9ff49f2f 100644
--- a/go/vt/vtctl/workflow/switcher_interface.go
+++ b/go/vt/vtctl/workflow/switcher_interface.go
@@ -36,7 +36,7 @@ type iswitcher interface {
 	changeRouting(ctx context.Context) error
 	streamMigraterfinalize(ctx context.Context, ts *trafficSwitcher, workflows []string) error
 	startReverseVReplication(ctx context.Context) error
-	switchTableReads(ctx context.Context, cells []string, servedType []topodatapb.TabletType, direction TrafficSwitchDirection) error
+	switchTableReads(ctx context.Context, cells []string, servedType []topodatapb.TabletType, rebuildSrvVSchema bool, direction TrafficSwitchDirection) error
 	switchShardReads(ctx context.Context, cells []string, servedType []topodatapb.TabletType, direction TrafficSwitchDirection) error
 	validateWorkflowHasCompleted(ctx context.Context) error
 	removeSourceTables(ctx context.Context, removalType TableRemovalType) error
diff --git a/go/vt/vtctl/workflow/traffic_switcher.go b/go/vt/vtctl/workflow/traffic_switcher.go
index a102c3162f3..66bae0942d0 100644
--- a/go/vt/vtctl/workflow/traffic_switcher.go
+++ b/go/vt/vtctl/workflow/traffic_switcher.go
@@ -573,7 +573,7 @@ func (ts *trafficSwitcher) switchShardReads(ctx context.Context, cells []string,
 	return nil
 }
 
-func (ts *trafficSwitcher) switchTableReads(ctx context.Context, cells []string, servedTypes []topodatapb.TabletType, direction TrafficSwitchDirection) error {
+func (ts *trafficSwitcher) switchTableReads(ctx context.Context, cells []string, servedTypes []topodatapb.TabletType, rebuildSrvVSchema bool, direction TrafficSwitchDirection) error {
 	log.Infof("switchTableReads: cells: %s, tablet types: %+v, direction %d", strings.Join(cells, ","), servedTypes, direction)
 	rules, err := topotools.GetRoutingRules(ctx, ts.TopoServer())
 	if err != nil {
@@ -605,7 +605,10 @@ func (ts *trafficSwitcher) switchTableReads(ctx context.Context, cells []string,
 	if err := topotools.SaveRoutingRules(ctx, ts.TopoServer(), rules); err != nil {
 		return err
 	}
-	return ts.TopoServer().RebuildSrvVSchema(ctx, cells)
+	if rebuildSrvVSchema {
+		return ts.TopoServer().RebuildSrvVSchema(ctx, cells)
+	}
+	return nil
 }
 
 func (ts *trafficSwitcher) startReverseVReplication(ctx context.Context) error {
diff --git a/go/vt/vtgate/buffer/buffer.go b/go/vt/vtgate/buffer/buffer.go
index 622bb03b082..260fb272544 100644
--- a/go/vt/vtgate/buffer/buffer.go
+++ b/go/vt/vtgate/buffer/buffer.go
@@ -164,6 +164,10 @@ func New(cfg *Config) *Buffer {
 	}
 }
 
+func (b *Buffer) GetConfig() *Config {
+	return b.config
+}
+
 // WaitForFailoverEnd blocks until a pending buffering due to a failover for
 // keyspace/shard is over.
 // If there is no ongoing failover, "err" is checked. If it's caused by a
diff --git a/go/vt/vtgate/buffer/flags.go b/go/vt/vtgate/buffer/flags.go
index a17cc09ccc3..19f30c64ed8 100644
--- a/go/vt/vtgate/buffer/flags.go
+++ b/go/vt/vtgate/buffer/flags.go
@@ -70,6 +70,9 @@ func verifyFlags() error {
 	if bufferSize < 1 {
 		return fmt.Errorf("--buffer_size must be >= 1 (specified value: %d)", bufferSize)
 	}
+	if bufferMinTimeBetweenFailovers < 1*time.Second {
+		return fmt.Errorf("--buffer_min_time_between_failovers must be >= 1s (specified value: %v)", bufferMinTimeBetweenFailovers)
+	}
 
 	if bufferDrainConcurrency < 1 {
 		return fmt.Errorf("--buffer_drain_concurrency must be >= 1 (specified value: %d)", bufferDrainConcurrency)
diff --git a/go/vt/vtgate/buffer/shard_buffer.go b/go/vt/vtgate/buffer/shard_buffer.go
index ae33aabb399..a58b86f670a 100644
--- a/go/vt/vtgate/buffer/shard_buffer.go
+++ b/go/vt/vtgate/buffer/shard_buffer.go
@@ -24,14 +24,13 @@ import (
 	"time"
 
 	"vitess.io/vitess/go/vt/discovery"
-
-	"vitess.io/vitess/go/vt/vtgate/errorsanitizer"
-
 	"vitess.io/vitess/go/vt/log"
 	"vitess.io/vitess/go/vt/logutil"
-	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
 	"vitess.io/vitess/go/vt/topo/topoproto"
 	"vitess.io/vitess/go/vt/vterrors"
+	"vitess.io/vitess/go/vt/vtgate/errorsanitizer"
+
+	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
 )
 
 // bufferState represents the different states a shardBuffer object can be in.
diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go
index 2d3a82f0b9b..319c2507d13 100644
--- a/go/vt/vtgate/executor_test.go
+++ b/go/vt/vtgate/executor_test.go
@@ -1408,7 +1408,7 @@ func TestExecutorAlterVSchemaKeyspace(t *testing.T) {
 	session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true})
 
 	vschemaUpdates := make(chan *vschemapb.SrvVSchema, 2)
-	executor.serv.WatchSrvVSchema(ctx, "aa", func(vschema *vschemapb.SrvVSchema, err error) bool {
+	executor.serv.WatchSrvVSchema(ctx, executor.cell, func(vschema *vschemapb.SrvVSchema, err error) bool {
 		vschemaUpdates <- vschema
 		return true
 	})
diff --git a/go/vt/vtgate/executor_vschema_ddl_test.go b/go/vt/vtgate/executor_vschema_ddl_test.go
index 1c2813a33c4..1c912ed0d62 100644
--- a/go/vt/vtgate/executor_vschema_ddl_test.go
+++ b/go/vt/vtgate/executor_vschema_ddl_test.go
@@ -17,26 +17,23 @@ limitations under the License.
 package vtgate
 
 import (
-	"context"
 	"reflect"
 	"slices"
 	"testing"
 	"time"
 
-	"vitess.io/vitess/go/test/utils"
-
-	"vitess.io/vitess/go/vt/callerid"
-	querypb "vitess.io/vitess/go/vt/proto/query"
-	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"vitess.io/vitess/go/sqltypes"
+	"vitess.io/vitess/go/test/utils"
+	"vitess.io/vitess/go/vt/callerid"
 	"vitess.io/vitess/go/vt/vtgate/vschemaacl"
 
-	"github.com/stretchr/testify/assert"
-	"github.com/stretchr/testify/require"
-
+	querypb "vitess.io/vitess/go/vt/proto/query"
 	vschemapb "vitess.io/vitess/go/vt/proto/vschema"
 	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
+	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
 )
 
 func waitForVindex(t *testing.T, ks, name string, watch chan *vschemapb.SrvVSchema, executor *Executor) (*vschemapb.SrvVSchema, *vschemapb.Vindex) {
@@ -426,9 +423,7 @@ func TestExecutorDropSequenceDDL(t *testing.T) {
 	_, err = executor.Execute(ctx, nil, "TestExecute", session, stmt, nil)
 	require.NoError(t, err)
 
-	ctxWithTimeout, cancel := context.WithTimeout(ctx, 5*time.Second)
-	defer cancel()
-	if !waitForNewerVSchema(ctxWithTimeout, executor, ts) {
+	if !waitForNewerVSchema(ctx, executor, ts, 5*time.Second) {
 		t.Fatalf("vschema did not drop the sequene 'test_seq'")
 	}
 
@@ -464,9 +459,7 @@ func TestExecutorDropAutoIncDDL(t *testing.T) {
 	stmt = "alter vschema on test_table add auto_increment id using `db-name`.`test_seq`"
 	_, err = executor.Execute(ctx, nil, "TestExecute", session, stmt, nil)
 	require.NoError(t, err)
-	ctxWithTimeout, cancel := context.WithTimeout(ctx, 5*time.Second)
-	defer cancel()
-	if !waitForNewerVSchema(ctxWithTimeout, executor, ts) {
+	if !waitForNewerVSchema(ctx, executor, ts, 5*time.Second) {
 		t.Fatalf("vschema did not update with auto_increment for 'test_table'")
 	}
 	ts = executor.VSchema().GetCreated()
@@ -480,9 +473,7 @@ func TestExecutorDropAutoIncDDL(t *testing.T) {
 	_, err = executor.Execute(ctx, nil, "TestExecute", session, stmt, nil)
 	require.NoError(t, err)
 
-	ctxWithTimeout, cancel2 := context.WithTimeout(ctx, 5*time.Second)
-	defer cancel2()
-	if !waitForNewerVSchema(ctxWithTimeout, executor, ts) {
+	if !waitForNewerVSchema(ctx, executor, ts, 5*time.Second) {
 		t.Fatalf("vschema did not drop the auto_increment for 'test_table'")
 	}
 	if executor.vm.GetCurrentSrvVschema().Keyspaces[ks].Tables["test_table"].AutoIncrement != nil {
diff --git a/go/vt/vtgate/plan_execute.go b/go/vt/vtgate/plan_execute.go
index 5d2414ac275..73ad109ef37 100644
--- a/go/vt/vtgate/plan_execute.go
+++ b/go/vt/vtgate/plan_execute.go
@@ -24,20 +24,20 @@ import (
 
 	"vitess.io/vitess/go/sqltypes"
 	"vitess.io/vitess/go/vt/log"
-	querypb "vitess.io/vitess/go/vt/proto/query"
-	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
 	"vitess.io/vitess/go/vt/sqlparser"
 	"vitess.io/vitess/go/vt/vterrors"
 	"vitess.io/vitess/go/vt/vtgate/engine"
 	"vitess.io/vitess/go/vt/vtgate/logstats"
 	"vitess.io/vitess/go/vt/vtgate/vtgateservice"
+
+	querypb "vitess.io/vitess/go/vt/proto/query"
+	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
 )
 
 type planExec func(ctx context.Context, plan *engine.Plan, vc *vcursorImpl, bindVars map[string]*querypb.BindVariable, startTime time.Time) error
 type txResult func(sqlparser.StatementType, *sqltypes.Result) error
 
-func waitForNewerVSchema(ctx context.Context, e *Executor, lastVSchemaCreated time.Time) bool {
-	timeout := 30 * time.Second
+func waitForNewerVSchema(ctx context.Context, e *Executor, lastVSchemaCreated time.Time, timeout time.Duration) bool {
 	pollingInterval := 10 * time.Millisecond
 	waitCtx, cancel := context.WithTimeout(ctx, timeout)
 	ticker := time.NewTicker(pollingInterval)
@@ -48,7 +48,7 @@ func waitForNewerVSchema(ctx context.Context, e *Executor, lastVSchemaCreated ti
 		case <-waitCtx.Done():
 			return false
 		case <-ticker.C:
-			if e.VSchema().GetCreated().After(lastVSchemaCreated) {
+			if e.VSchema() != nil && e.VSchema().GetCreated().After(lastVSchemaCreated) {
 				return true
 			}
 		}
@@ -64,11 +64,11 @@ func (e *Executor) newExecute(
 	logStats *logstats.LogStats,
 	execPlan planExec, // used when there is a plan to execute
 	recResult txResult, // used when it's something simple like begin/commit/rollback/savepoint
-) error {
-	// 1: Prepare before planning and execution
+) (err error) {
+	// 1: Prepare before planning and execution.
 
 	// Start an implicit transaction if necessary.
-	err := e.startTxIfNecessary(ctx, safeSession)
+	err = e.startTxIfNecessary(ctx, safeSession)
 	if err != nil {
 		return err
 	}
@@ -79,21 +79,35 @@ func (e *Executor) newExecute(
 
 	query, comments := sqlparser.SplitMarginComments(sql)
 
-	// 2: Parse and Validate query
+	// 2: Parse and Validate query.
 	stmt, reservedVars, err := parseAndValidateQuery(query)
 	if err != nil {
 		return err
 	}
 
-	var lastVSchemaCreated time.Time
-	vs := e.VSchema()
-	lastVSchemaCreated = vs.GetCreated()
+	var (
+		vs                 = e.VSchema()
+		lastVSchemaCreated = vs.GetCreated()
+		result             *sqltypes.Result
+		plan               *engine.Plan
+	)
+
 	for try := 0; try < MaxBufferingRetries; try++ {
-		if try > 0 && !vs.GetCreated().After(lastVSchemaCreated) {
-			// There is a race due to which the executor's vschema may not have been updated yet.
-			// Without a wait we fail non-deterministically since the previous vschema will not have the updated routing rules
-			if waitForNewerVSchema(ctx, e, lastVSchemaCreated) {
+		if try > 0 && !vs.GetCreated().After(lastVSchemaCreated) { // We need to wait for a vschema update
+			// Without a wait we fail non-deterministically since the previous vschema will not have
+			// the updated routing rules.
+			// We retry MaxBufferingRetries-1 (2) times before giving up. How long we wait before each retry
+			// -- IF we don't see a newer vschema come in -- affects how long we retry in total and how quickly
+			// we retry the query and (should) succeed when the traffic switch fails or we otherwise hit the
+			// max buffer failover time without resolving the keyspace event and marking it as consistent.
+			// This calculation attemps to ensure that we retry at a sensible interval and number of times
+			// based on the buffering configuration. This way we should be able to perform the max retries
+			// within the given window of time for most queries and we should not end up waiting too long
+			// after the traffic switch fails or the buffer window has ended, retrying old queries.
+			timeout := e.resolver.scatterConn.gateway.buffer.GetConfig().MaxFailoverDuration / (MaxBufferingRetries - 1)
+			if waitForNewerVSchema(ctx, e, lastVSchemaCreated, timeout) {
 				vs = e.VSchema()
+				lastVSchemaCreated = vs.GetCreated()
 			}
 		}
 
@@ -102,16 +116,13 @@ func (e *Executor) newExecute(
 			return err
 		}
 
-		// 3: Create a plan for the query
+		// 3: Create a plan for the query.
 		// If we are retrying, it is likely that the routing rules have changed and hence we need to
 		// replan the query since the target keyspace of the resolved shards may have changed as a
-		// result of MoveTables. So we cannot reuse the plan from the first try.
-		// When buffering ends, many queries might be getting planned at the same time. Ideally we
-		// should be able to reuse plans once the first drained query has been planned. For now, we
-		// punt on this and choose not to prematurely optimize since it is not clear how much caching
-		// will help and if it will result in hard-to-track edge cases.
-
-		var plan *engine.Plan
+		// result of MoveTables SwitchTraffic which does a RebuildSrvVSchema which in turn causes
+		// the vtgate to clear the cached plans when processing the new serving vschema.
+		// When buffering ends, many queries might be getting planned at the same time and we then
+		// take full advatange of the cached plan.
 		plan, err = e.getPlan(ctx, vcursor, query, stmt, comments, bindVars, reservedVars, e.normalize, logStats)
 		execStart := e.logPlanningFinished(logStats, plan)
 
@@ -124,12 +135,12 @@ func (e *Executor) newExecute(
 			safeSession.ClearWarnings()
 		}
 
-		// add any warnings that the planner wants to add
+		// Add any warnings that the planner wants to add.
 		for _, warning := range plan.Warnings {
 			safeSession.RecordWarning(warning)
 		}
 
-		result, err := e.handleTransactions(ctx, mysqlCtx, safeSession, plan, logStats, vcursor, stmt)
+		result, err = e.handleTransactions(ctx, mysqlCtx, safeSession, plan, logStats, vcursor, stmt)
 		if err != nil {
 			return err
 		}
@@ -137,14 +148,14 @@ func (e *Executor) newExecute(
 			return recResult(plan.Type, result)
 		}
 
-		// 4: Prepare for execution
+		// 4: Prepare for execution.
 		err = e.addNeededBindVars(vcursor, plan.BindVarNeeds, bindVars, safeSession)
 		if err != nil {
 			logStats.Error = err
 			return err
 		}
 
-		// 5: Execute the plan and retry if needed
+		// 5: Execute the plan.
 		if plan.Instructions.NeedsTransaction() {
 			err = e.insideTransaction(ctx, safeSession, logStats,
 				func() error {
@@ -158,10 +169,39 @@ func (e *Executor) newExecute(
 			return err
 		}
 
+		// 6: Retry if needed.
 		rootCause := vterrors.RootCause(err)
 		if rootCause != nil && strings.Contains(rootCause.Error(), "enforce denied tables") {
 			log.V(2).Infof("Retry: %d, will retry query %s due to %v", try, query, err)
-			lastVSchemaCreated = vs.GetCreated()
+			if try == 0 { // We are going to retry at least once
+				defer func() {
+					// Prevent any plan cache pollution from queries planned against the wrong keyspace during a MoveTables
+					// traffic switching operation.
+					if err != nil { // The error we're checking here is the return value from the newExecute function
+						cause := vterrors.RootCause(err)
+						if cause != nil && strings.Contains(cause.Error(), "enforce denied tables") {
+							// The executor's VSchemaManager clears the plan cache when it receives a new vschema via its
+							// SrvVSchema watcher (it calls executor.SaveVSchema() in its watch's subscriber callback). This
+							// happens concurrently with the KeyspaceEventWatcher also receiving the new vschema in its
+							// SrvVSchema watcher and in its subscriber callback processing it (which includes getting info
+							// on all shards from the topo), and eventually determining that the keyspace is consistent and
+							// ending the buffering window. So there's race with query retries such that a query could be
+							// planned against the wrong side just as the keyspace event is getting resolved and the buffers
+							// drained. Then that bad plan is the cached plan for the query until you do another
+							// topo.RebuildSrvVSchema/vtctldclient RebuildVSchemaGraph which then causes the VSchemaManager
+							// to clear the plan cache. It's essentially a race between the two SrvVSchema watchers and the
+							// work they do when a new one is received. If we DID a retry AND the last time we retried
+							// still encountered the error, we know that the plan used was 1) not valid/correct and going to
+							// the wrong side of the traffic switch as it failed with the denied tables error and 2) it will
+							// remain the plan in the cache if we do not clear the plans after it was added to to the cache.
+							// So here we clear the plan cache in order to prevent this scenario where the bad plan is
+							// cached indefinitely and re-used after the buffering window ends and the keyspace event is
+							// resolved.
+							e.ClearPlans()
+						}
+					}
+				}()
+			}
 			continue
 		}
 

From 286e1eaad114171f938ce6569c72f9dc0768dc62 Mon Sep 17 00:00:00 2001
From: Tim Vaillancourt <tim@timvaillancourt.com>
Date: Tue, 27 Feb 2024 19:18:23 +0100
Subject: [PATCH 03/12] Filter by keyspace earlier in `tabletgateway`s
 `WaitForTablets(...)` (#15347)

Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>
---
 go/vt/discovery/healthcheck.go      |  22 ------
 go/vt/discovery/healthcheck_test.go |  21 ------
 go/vt/srvtopo/discover.go           |  19 +++---
 go/vt/srvtopo/discover_test.go      | 100 ++++++++++++++++------------
 go/vt/vtgate/tabletgateway.go       |   2 +-
 5 files changed, 70 insertions(+), 94 deletions(-)

diff --git a/go/vt/discovery/healthcheck.go b/go/vt/discovery/healthcheck.go
index cb1bbc1b4dd..c3a4d473947 100644
--- a/go/vt/discovery/healthcheck.go
+++ b/go/vt/discovery/healthcheck.go
@@ -746,30 +746,8 @@ func (hc *HealthCheckImpl) WaitForAllServingTablets(ctx context.Context, targets
 	return hc.waitForTablets(ctx, targets, true)
 }
 
-// FilterTargetsByKeyspaces only returns the targets that are part of the provided keyspaces
-func FilterTargetsByKeyspaces(keyspaces []string, targets []*query.Target) []*query.Target {
-	filteredTargets := make([]*query.Target, 0)
-
-	// Keep them all if there are no keyspaces to watch
-	if len(KeyspacesToWatch) == 0 {
-		return append(filteredTargets, targets...)
-	}
-
-	// Let's remove from the target shards that are not in the keyspaceToWatch list.
-	for _, target := range targets {
-		for _, keyspaceToWatch := range keyspaces {
-			if target.Keyspace == keyspaceToWatch {
-				filteredTargets = append(filteredTargets, target)
-			}
-		}
-	}
-	return filteredTargets
-}
-
 // waitForTablets is the internal method that polls for tablets.
 func (hc *HealthCheckImpl) waitForTablets(ctx context.Context, targets []*query.Target, requireServing bool) error {
-	targets = FilterTargetsByKeyspaces(KeyspacesToWatch, targets)
-
 	for {
 		// We nil targets as we find them.
 		allPresent := true
diff --git a/go/vt/discovery/healthcheck_test.go b/go/vt/discovery/healthcheck_test.go
index 6193159b66c..eaa78305a62 100644
--- a/go/vt/discovery/healthcheck_test.go
+++ b/go/vt/discovery/healthcheck_test.go
@@ -672,27 +672,6 @@ func TestWaitForAllServingTablets(t *testing.T) {
 
 	err = hc.WaitForAllServingTablets(ctx, targets)
 	assert.NotNil(t, err, "error should not be nil (there are no tablets on this keyspace")
-
-	targets = []*querypb.Target{
-
-		{
-			Keyspace:   tablet.Keyspace,
-			Shard:      tablet.Shard,
-			TabletType: tablet.Type,
-		},
-		{
-			Keyspace:   "newkeyspace",
-			Shard:      tablet.Shard,
-			TabletType: tablet.Type,
-		},
-	}
-
-	KeyspacesToWatch = []string{tablet.Keyspace}
-
-	err = hc.WaitForAllServingTablets(ctx, targets)
-	assert.Nil(t, err, "error should be nil. Keyspace with no tablets is filtered")
-
-	KeyspacesToWatch = []string{}
 }
 
 // TestRemoveTablet tests the behavior when a tablet goes away.
diff --git a/go/vt/srvtopo/discover.go b/go/vt/srvtopo/discover.go
index 91aaea9daf6..2997dc42e21 100644
--- a/go/vt/srvtopo/discover.go
+++ b/go/vt/srvtopo/discover.go
@@ -29,20 +29,23 @@ import (
 	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
 )
 
-// FindAllTargets goes through all serving shards in the topology
-// for the provided tablet types. It returns one Target object per
-// keyspace / shard / matching TabletType.
-func FindAllTargets(ctx context.Context, ts Server, cell string, tabletTypes []topodatapb.TabletType) ([]*querypb.Target, error) {
-	ksNames, err := ts.GetSrvKeyspaceNames(ctx, cell, true)
-	if err != nil {
-		return nil, err
+// FindAllTargets goes through all serving shards in the topology for the provided keyspaces
+// and tablet types. If no keyspaces are provided all available keyspaces in the topo are
+// fetched. It returns one Target object per keyspace/shard/matching TabletType.
+func FindAllTargets(ctx context.Context, ts Server, cell string, keyspaces []string, tabletTypes []topodatapb.TabletType) ([]*querypb.Target, error) {
+	var err error
+	if len(keyspaces) == 0 {
+		keyspaces, err = ts.GetSrvKeyspaceNames(ctx, cell, true)
+		if err != nil {
+			return nil, err
+		}
 	}
 
 	var targets []*querypb.Target
 	var wg sync.WaitGroup
 	var mu sync.Mutex
 	var errRecorder concurrency.AllErrorRecorder
-	for _, ksName := range ksNames {
+	for _, ksName := range keyspaces {
 		wg.Add(1)
 		go func(keyspace string) {
 			defer wg.Done()
diff --git a/go/vt/srvtopo/discover_test.go b/go/vt/srvtopo/discover_test.go
index ca4774a1b84..3f730bba3d3 100644
--- a/go/vt/srvtopo/discover_test.go
+++ b/go/vt/srvtopo/discover_test.go
@@ -18,11 +18,12 @@ package srvtopo
 
 import (
 	"context"
-	"reflect"
 	"sort"
 	"testing"
 	"time"
 
+	"github.com/stretchr/testify/assert"
+
 	"vitess.io/vitess/go/vt/topo/memorytopo"
 
 	querypb "vitess.io/vitess/go/vt/proto/query"
@@ -62,16 +63,12 @@ func TestFindAllTargets(t *testing.T) {
 	rs := NewResilientServer(ctx, ts, "TestFindAllKeyspaceShards")
 
 	// No keyspace / shards.
-	ks, err := FindAllTargets(ctx, rs, "cell1", []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
-	if err != nil {
-		t.Errorf("unexpected error: %v", err)
-	}
-	if len(ks) > 0 {
-		t.Errorf("why did I get anything? %v", ks)
-	}
+	ks, err := FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
+	assert.NoError(t, err)
+	assert.Len(t, ks, 0)
 
 	// Add one.
-	if err := ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace", &topodatapb.SrvKeyspace{
+	assert.NoError(t, ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace", &topodatapb.SrvKeyspace{
 		Partitions: []*topodatapb.SrvKeyspace_KeyspacePartition{
 			{
 				ServedType: topodatapb.TabletType_PRIMARY,
@@ -82,28 +79,34 @@ func TestFindAllTargets(t *testing.T) {
 				},
 			},
 		},
-	}); err != nil {
-		t.Fatalf("can't add srvKeyspace: %v", err)
-	}
+	}))
 
 	// Get it.
-	ks, err = FindAllTargets(ctx, rs, "cell1", []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
-	if err != nil {
-		t.Errorf("unexpected error: %v", err)
-	}
-	if !reflect.DeepEqual(ks, []*querypb.Target{
+	ks, err = FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
+	assert.NoError(t, err)
+	assert.EqualValues(t, []*querypb.Target{
 		{
 			Cell:       "cell1",
 			Keyspace:   "test_keyspace",
 			Shard:      "test_shard0",
 			TabletType: topodatapb.TabletType_PRIMARY,
 		},
-	}) {
-		t.Errorf("got wrong value: %v", ks)
-	}
+	}, ks)
+
+	// Get any keyspace.
+	ks, err = FindAllTargets(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
+	assert.NoError(t, err)
+	assert.EqualValues(t, []*querypb.Target{
+		{
+			Cell:       "cell1",
+			Keyspace:   "test_keyspace",
+			Shard:      "test_shard0",
+			TabletType: topodatapb.TabletType_PRIMARY,
+		},
+	}, ks)
 
 	// Add another one.
-	if err := ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace2", &topodatapb.SrvKeyspace{
+	assert.NoError(t, ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace2", &topodatapb.SrvKeyspace{
 		Partitions: []*topodatapb.SrvKeyspace_KeyspacePartition{
 			{
 				ServedType: topodatapb.TabletType_PRIMARY,
@@ -122,17 +125,13 @@ func TestFindAllTargets(t *testing.T) {
 				},
 			},
 		},
-	}); err != nil {
-		t.Fatalf("can't add srvKeyspace: %v", err)
-	}
+	}))
 
-	// Get it for all types.
-	ks, err = FindAllTargets(ctx, rs, "cell1", []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
-	if err != nil {
-		t.Errorf("unexpected error: %v", err)
-	}
+	// Get it for any keyspace, all types.
+	ks, err = FindAllTargets(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
+	assert.NoError(t, err)
 	sort.Sort(TargetArray(ks))
-	if !reflect.DeepEqual(ks, []*querypb.Target{
+	assert.EqualValues(t, []*querypb.Target{
 		{
 			Cell:       "cell1",
 			Keyspace:   "test_keyspace",
@@ -151,23 +150,40 @@ func TestFindAllTargets(t *testing.T) {
 			Shard:      "test_shard2",
 			TabletType: topodatapb.TabletType_REPLICA,
 		},
-	}) {
-		t.Errorf("got wrong value: %v", ks)
-	}
+	}, ks)
 
-	// Only get the REPLICA targets.
-	ks, err = FindAllTargets(ctx, rs, "cell1", []topodatapb.TabletType{topodatapb.TabletType_REPLICA})
-	if err != nil {
-		t.Errorf("unexpected error: %v", err)
-	}
-	if !reflect.DeepEqual(ks, []*querypb.Target{
+	// Only get 1 keyspace for all types.
+	ks, err = FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace2"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
+	assert.NoError(t, err)
+	assert.EqualValues(t, []*querypb.Target{
+		{
+			Cell:       "cell1",
+			Keyspace:   "test_keyspace2",
+			Shard:      "test_shard1",
+			TabletType: topodatapb.TabletType_PRIMARY,
+		},
 		{
 			Cell:       "cell1",
 			Keyspace:   "test_keyspace2",
 			Shard:      "test_shard2",
 			TabletType: topodatapb.TabletType_REPLICA,
 		},
-	}) {
-		t.Errorf("got wrong value: %v", ks)
-	}
+	}, ks)
+
+	// Only get the REPLICA targets for any keyspace.
+	ks, err = FindAllTargets(ctx, rs, "cell1", []string{}, []topodatapb.TabletType{topodatapb.TabletType_REPLICA})
+	assert.NoError(t, err)
+	assert.Equal(t, []*querypb.Target{
+		{
+			Cell:       "cell1",
+			Keyspace:   "test_keyspace2",
+			Shard:      "test_shard2",
+			TabletType: topodatapb.TabletType_REPLICA,
+		},
+	}, ks)
+
+	// Get non-existent keyspace.
+	ks, err = FindAllTargets(ctx, rs, "cell1", []string{"doesnt-exist"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
+	assert.NoError(t, err)
+	assert.Len(t, ks, 0)
 }
diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go
index de63da87907..084a5059fd8 100644
--- a/go/vt/vtgate/tabletgateway.go
+++ b/go/vt/vtgate/tabletgateway.go
@@ -191,7 +191,7 @@ func (gw *TabletGateway) WaitForTablets(ctx context.Context, tabletTypesToWait [
 	}
 
 	// Finds the targets to look for.
-	targets, err := srvtopo.FindAllTargets(ctx, gw.srvTopoServer, gw.localCell, tabletTypesToWait)
+	targets, err := srvtopo.FindAllTargets(ctx, gw.srvTopoServer, gw.localCell, discovery.KeyspacesToWatch, tabletTypesToWait)
 	if err != nil {
 		return err
 	}

From 7d79407b198031feccc88feb0e5c94c36c8cd4e3 Mon Sep 17 00:00:00 2001
From: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com>
Date: Thu, 29 Aug 2024 08:53:29 +0530
Subject: [PATCH 04/12] Fix race condition that prevents queries from being
 buffered after vtgate startup (#16655)

Signed-off-by: Manan Gupta <manan@planetscale.com>
---
 go/vt/discovery/keyspace_events.go       |  90 +++++++++++++++--
 go/vt/discovery/keyspace_events_test.go  | 119 +++++++++++++++++++----
 go/vt/srvtopo/discover.go                |  14 +--
 go/vt/srvtopo/discover_test.go           |  41 +++++---
 go/vt/vtgate/tabletgateway.go            |  28 ++++--
 go/vt/vtgate/tabletgateway_flaky_test.go |  10 +-
 go/vt/vtgate/tabletgateway_test.go       |  56 +++++++++++
 7 files changed, 297 insertions(+), 61 deletions(-)

diff --git a/go/vt/discovery/keyspace_events.go b/go/vt/discovery/keyspace_events.go
index 9fa457c1589..036d4f3ad14 100644
--- a/go/vt/discovery/keyspace_events.go
+++ b/go/vt/discovery/keyspace_events.go
@@ -19,7 +19,9 @@ package discovery
 import (
 	"context"
 	"fmt"
+	"slices"
 	"sync"
+	"time"
 
 	"golang.org/x/sync/errgroup"
 	"google.golang.org/protobuf/proto"
@@ -37,6 +39,11 @@ import (
 	vschemapb "vitess.io/vitess/go/vt/proto/vschema"
 )
 
+var (
+	// waitConsistentKeyspacesCheck is the amount of time to wait for between checks to verify the keyspace is consistent.
+	waitConsistentKeyspacesCheck = 100 * time.Millisecond
+)
+
 // KeyspaceEventWatcher is an auxiliary watcher that watches all availability incidents
 // for all keyspaces in a Vitess cell and notifies listeners when the events have been resolved.
 // Right now this is capable of detecting the end of failovers, both planned and unplanned,
@@ -662,29 +669,53 @@ func (kew *KeyspaceEventWatcher) TargetIsBeingResharded(ctx context.Context, tar
 	return ks.beingResharded(target.Shard)
 }
 
-// PrimaryIsNotServing checks if the reason why the given target is not accessible right now is
-// that the primary tablet for that shard is not serving. This is possible during a Planned
-// Reparent Shard operation. Just as the operation completes, a new primary will be elected, and
+// ShouldStartBufferingForTarget checks if we should be starting buffering for the given target.
+// We check the following things before we start buffering -
+//  1. The shard must have a primary.
+//  2. The primary must be non-serving.
+//  3. The keyspace must be marked inconsistent.
+//
+// This buffering is meant to kick in during a Planned Reparent Shard operation.
+// As part of that operation the old primary will become non-serving. At that point
+// this code should return true to start buffering requests.
+// Just as the PRS operation completes, a new primary will be elected, and
 // it will send its own healthcheck stating that it is serving. We should buffer requests until
-// that point. There are use cases where people do not run with a Primary server at all, so we must
+// that point.
+//
+// There are use cases where people do not run with a Primary server at all, so we must
 // verify that we only start buffering when a primary was present, and it went not serving.
 // The shard state keeps track of the current primary and the last externally reparented time, which
 // we can use to determine that there was a serving primary which now became non serving. This is
 // only possible in a DemotePrimary RPC which are only called from ERS and PRS. So buffering will
-// stop when these operations succeed. We return the tablet alias of the primary if it is serving.
-func (kew *KeyspaceEventWatcher) PrimaryIsNotServing(ctx context.Context, target *querypb.Target) (*topodatapb.TabletAlias, bool) {
+// stop when these operations succeed. We also return the tablet alias of the primary if it is serving.
+func (kew *KeyspaceEventWatcher) ShouldStartBufferingForTarget(ctx context.Context, target *querypb.Target) (*topodatapb.TabletAlias, bool) {
 	if target.TabletType != topodatapb.TabletType_PRIMARY {
+		// We don't support buffering for any target tablet type other than the primary.
 		return nil, false
 	}
 	ks := kew.getKeyspaceStatus(ctx, target.Keyspace)
 	if ks == nil {
+		// If the keyspace status is nil, then the keyspace must be deleted.
+		// The user query is trying to access a keyspace that has been deleted.
+		// There is no reason to buffer this query.
 		return nil, false
 	}
 	ks.mu.Lock()
 	defer ks.mu.Unlock()
 	if state, ok := ks.shards[target.Shard]; ok {
-		// If the primary tablet was present then externallyReparented will be non-zero and
-		// currentPrimary will be not nil.
+		// As described in the function comment, we only want to start buffering when all the following conditions are met -
+		// 1. The shard must have a primary. We check this by checking the currentPrimary and externallyReparented fields being non-empty.
+		//    They are set the first time the shard registers an update from a serving primary and are never cleared out after that.
+		//    If the user has configured vtgates to wait for the primary tablet healthchecks before starting query service, this condition
+		//    will always be true.
+		// 2. The primary must be non-serving. We check this by checking the serving field in the shard state.
+		// 	  When a primary becomes non-serving, it also marks the keyspace inconsistent. So the next check is only added
+		//    for being defensive against any bugs.
+		// 3. The keyspace must be marked inconsistent. We check this by checking the consistent field in the keyspace state.
+		//
+		// The reason we need all the three checks is that we want to be very defensive in when we start buffering.
+		// We don't want to start buffering when we don't know for sure if the primary
+		// is not serving and we will receive an update that stops buffering soon.
 		return state.currentPrimary, !state.serving && !ks.consistent && state.externallyReparented != 0 && state.currentPrimary != nil
 	}
 	return nil, false
@@ -703,3 +734,46 @@ func (kew *KeyspaceEventWatcher) GetServingKeyspaces() []string {
 	}
 	return servingKeyspaces
 }
+
+// WaitForConsistentKeyspaces waits for the given set of keyspaces to be marked consistent.
+func (kew *KeyspaceEventWatcher) WaitForConsistentKeyspaces(ctx context.Context, ksList []string) error {
+	// We don't want to change the original keyspace list that we receive so we clone it
+	// before we empty it elements down below.
+	keyspaces := slices.Clone(ksList)
+	for {
+		// We empty keyspaces as we find them to be consistent.
+		allConsistent := true
+		for i, ks := range keyspaces {
+			if ks == "" {
+				continue
+			}
+
+			// Get the keyspace status and see it is consistent yet or not.
+			kss := kew.getKeyspaceStatus(ctx, ks)
+			// If kss is nil, then it must be deleted. In that case too it is fine for us to consider
+			// it consistent since the keyspace has been deleted.
+			if kss == nil || kss.consistent {
+				keyspaces[i] = ""
+			} else {
+				allConsistent = false
+			}
+		}
+
+		if allConsistent {
+			// all the keyspaces are consistent.
+			return nil
+		}
+
+		// Unblock after the sleep or when the context has expired.
+		select {
+		case <-ctx.Done():
+			for _, ks := range keyspaces {
+				if ks != "" {
+					log.Infof("keyspace %v didn't become consistent", ks)
+				}
+			}
+			return ctx.Err()
+		case <-time.After(waitConsistentKeyspacesCheck):
+		}
+	}
+}
diff --git a/go/vt/discovery/keyspace_events_test.go b/go/vt/discovery/keyspace_events_test.go
index e9406ff1de2..1a4c473e7cb 100644
--- a/go/vt/discovery/keyspace_events_test.go
+++ b/go/vt/discovery/keyspace_events_test.go
@@ -155,11 +155,11 @@ func TestKeyspaceEventTypes(t *testing.T) {
 	kew := NewKeyspaceEventWatcher(ctx, ts2, hc, cell)
 
 	type testCase struct {
-		name                    string
-		kss                     *keyspaceState
-		shardToCheck            string
-		expectResharding        bool
-		expectPrimaryNotServing bool
+		name               string
+		kss                *keyspaceState
+		shardToCheck       string
+		expectResharding   bool
+		expectShouldBuffer bool
 	}
 
 	testCases := []testCase{
@@ -196,9 +196,9 @@ func TestKeyspaceEventTypes(t *testing.T) {
 				},
 				consistent: false,
 			},
-			shardToCheck:            "-",
-			expectResharding:        true,
-			expectPrimaryNotServing: false,
+			shardToCheck:       "-",
+			expectResharding:   true,
+			expectShouldBuffer: false,
 		},
 		{
 			name: "two to four resharding in progress",
@@ -257,9 +257,9 @@ func TestKeyspaceEventTypes(t *testing.T) {
 				},
 				consistent: false,
 			},
-			shardToCheck:            "-80",
-			expectResharding:        true,
-			expectPrimaryNotServing: false,
+			shardToCheck:       "-80",
+			expectResharding:   true,
+			expectShouldBuffer: false,
 		},
 		{
 			name: "unsharded primary not serving",
@@ -283,9 +283,9 @@ func TestKeyspaceEventTypes(t *testing.T) {
 				},
 				consistent: false,
 			},
-			shardToCheck:            "-",
-			expectResharding:        false,
-			expectPrimaryNotServing: true,
+			shardToCheck:       "-",
+			expectResharding:   false,
+			expectShouldBuffer: true,
 		},
 		{
 			name: "sharded primary not serving",
@@ -317,9 +317,9 @@ func TestKeyspaceEventTypes(t *testing.T) {
 				},
 				consistent: false,
 			},
-			shardToCheck:            "-80",
-			expectResharding:        false,
-			expectPrimaryNotServing: true,
+			shardToCheck:       "-80",
+			expectResharding:   false,
+			expectShouldBuffer: true,
 		},
 	}
 
@@ -334,8 +334,89 @@ func TestKeyspaceEventTypes(t *testing.T) {
 			resharding := kew.TargetIsBeingResharded(ctx, tc.kss.shards[tc.shardToCheck].target)
 			require.Equal(t, resharding, tc.expectResharding, "TargetIsBeingResharded should return %t", tc.expectResharding)
 
-			_, primaryDown := kew.PrimaryIsNotServing(ctx, tc.kss.shards[tc.shardToCheck].target)
-			require.Equal(t, primaryDown, tc.expectPrimaryNotServing, "PrimaryIsNotServing should return %t", tc.expectPrimaryNotServing)
+			_, shouldBuffer := kew.ShouldStartBufferingForTarget(ctx, tc.kss.shards[tc.shardToCheck].target)
+			require.Equal(t, shouldBuffer, tc.expectShouldBuffer, "ShouldStartBufferingForTarget should return %t", tc.expectShouldBuffer)
+		})
+	}
+}
+
+// TestWaitForConsistentKeyspaces tests the behaviour of WaitForConsistent for different scenarios.
+func TestWaitForConsistentKeyspaces(t *testing.T) {
+	testcases := []struct {
+		name        string
+		ksMap       map[string]*keyspaceState
+		ksList      []string
+		errExpected string
+	}{
+		{
+			name:   "Empty keyspace list",
+			ksList: nil,
+			ksMap: map[string]*keyspaceState{
+				"ks1": {},
+			},
+			errExpected: "",
+		},
+		{
+			name:   "All keyspaces consistent",
+			ksList: []string{"ks1", "ks2"},
+			ksMap: map[string]*keyspaceState{
+				"ks1": {
+					consistent: true,
+				},
+				"ks2": {
+					consistent: true,
+				},
+			},
+			errExpected: "",
+		},
+		{
+			name:   "One keyspace inconsistent",
+			ksList: []string{"ks1", "ks2"},
+			ksMap: map[string]*keyspaceState{
+				"ks1": {
+					consistent: true,
+				},
+				"ks2": {
+					consistent: false,
+				},
+			},
+			errExpected: "context canceled",
+		},
+		{
+			name:   "One deleted keyspace - consistent",
+			ksList: []string{"ks1", "ks2"},
+			ksMap: map[string]*keyspaceState{
+				"ks1": {
+					consistent: true,
+				},
+				"ks2": {
+					deleted: true,
+				},
+			},
+			errExpected: "",
+		},
+	}
+
+	for _, tt := range testcases {
+		t.Run(tt.name, func(t *testing.T) {
+			// We create a cancelable context and immediately cancel it.
+			// We don't want the unit tests to wait, so we only test the first
+			// iteration of whether the keyspace event watcher returns
+			// that the keyspaces are consistent or not.
+			ctx, cancel := context.WithCancel(context.Background())
+			cancel()
+			kew := KeyspaceEventWatcher{
+				keyspaces: tt.ksMap,
+				mu:        sync.Mutex{},
+				ts:        &fakeTopoServer{},
+			}
+			err := kew.WaitForConsistentKeyspaces(ctx, tt.ksList)
+			if tt.errExpected != "" {
+				require.ErrorContains(t, err, tt.errExpected)
+			} else {
+				require.NoError(t, err)
+			}
+
 		})
 	}
 }
diff --git a/go/vt/srvtopo/discover.go b/go/vt/srvtopo/discover.go
index 2997dc42e21..2b020e89887 100644
--- a/go/vt/srvtopo/discover.go
+++ b/go/vt/srvtopo/discover.go
@@ -17,9 +17,8 @@ limitations under the License.
 package srvtopo
 
 import (
-	"sync"
-
 	"context"
+	"sync"
 
 	"vitess.io/vitess/go/vt/concurrency"
 	"vitess.io/vitess/go/vt/log"
@@ -29,15 +28,16 @@ import (
 	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
 )
 
-// FindAllTargets goes through all serving shards in the topology for the provided keyspaces
+// FindAllTargetsAndKeyspaces goes through all serving shards in the topology for the provided keyspaces
 // and tablet types. If no keyspaces are provided all available keyspaces in the topo are
 // fetched. It returns one Target object per keyspace/shard/matching TabletType.
-func FindAllTargets(ctx context.Context, ts Server, cell string, keyspaces []string, tabletTypes []topodatapb.TabletType) ([]*querypb.Target, error) {
+// It also returns all the keyspaces that it found.
+func FindAllTargetsAndKeyspaces(ctx context.Context, ts Server, cell string, keyspaces []string, tabletTypes []topodatapb.TabletType) ([]*querypb.Target, []string, error) {
 	var err error
 	if len(keyspaces) == 0 {
 		keyspaces, err = ts.GetSrvKeyspaceNames(ctx, cell, true)
 		if err != nil {
-			return nil, err
+			return nil, nil, err
 		}
 	}
 
@@ -95,8 +95,8 @@ func FindAllTargets(ctx context.Context, ts Server, cell string, keyspaces []str
 	}
 	wg.Wait()
 	if errRecorder.HasErrors() {
-		return nil, errRecorder.Error()
+		return nil, nil, errRecorder.Error()
 	}
 
-	return targets, nil
+	return targets, keyspaces, nil
 }
diff --git a/go/vt/srvtopo/discover_test.go b/go/vt/srvtopo/discover_test.go
index 3f730bba3d3..0232bce7a65 100644
--- a/go/vt/srvtopo/discover_test.go
+++ b/go/vt/srvtopo/discover_test.go
@@ -48,7 +48,7 @@ func (a TargetArray) Less(i, j int) bool {
 	return a[i].TabletType < a[j].TabletType
 }
 
-func TestFindAllTargets(t *testing.T) {
+func TestFindAllTargetsAndKeyspaces(t *testing.T) {
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 	ts := memorytopo.NewServer(ctx, "cell1", "cell2")
@@ -63,9 +63,10 @@ func TestFindAllTargets(t *testing.T) {
 	rs := NewResilientServer(ctx, ts, "TestFindAllKeyspaceShards")
 
 	// No keyspace / shards.
-	ks, err := FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
+	targets, ksList, err := FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
 	assert.NoError(t, err)
-	assert.Len(t, ks, 0)
+	assert.Len(t, targets, 0)
+	assert.EqualValues(t, []string{"test_keyspace"}, ksList)
 
 	// Add one.
 	assert.NoError(t, ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace", &topodatapb.SrvKeyspace{
@@ -82,7 +83,7 @@ func TestFindAllTargets(t *testing.T) {
 	}))
 
 	// Get it.
-	ks, err = FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
+	targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
 	assert.NoError(t, err)
 	assert.EqualValues(t, []*querypb.Target{
 		{
@@ -91,10 +92,11 @@ func TestFindAllTargets(t *testing.T) {
 			Shard:      "test_shard0",
 			TabletType: topodatapb.TabletType_PRIMARY,
 		},
-	}, ks)
+	}, targets)
+	assert.EqualValues(t, []string{"test_keyspace"}, ksList)
 
 	// Get any keyspace.
-	ks, err = FindAllTargets(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
+	targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY})
 	assert.NoError(t, err)
 	assert.EqualValues(t, []*querypb.Target{
 		{
@@ -103,7 +105,8 @@ func TestFindAllTargets(t *testing.T) {
 			Shard:      "test_shard0",
 			TabletType: topodatapb.TabletType_PRIMARY,
 		},
-	}, ks)
+	}, targets)
+	assert.EqualValues(t, []string{"test_keyspace"}, ksList)
 
 	// Add another one.
 	assert.NoError(t, ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace2", &topodatapb.SrvKeyspace{
@@ -128,9 +131,9 @@ func TestFindAllTargets(t *testing.T) {
 	}))
 
 	// Get it for any keyspace, all types.
-	ks, err = FindAllTargets(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
+	targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
 	assert.NoError(t, err)
-	sort.Sort(TargetArray(ks))
+	sort.Sort(TargetArray(targets))
 	assert.EqualValues(t, []*querypb.Target{
 		{
 			Cell:       "cell1",
@@ -150,10 +153,12 @@ func TestFindAllTargets(t *testing.T) {
 			Shard:      "test_shard2",
 			TabletType: topodatapb.TabletType_REPLICA,
 		},
-	}, ks)
+	}, targets)
+	sort.Strings(ksList)
+	assert.EqualValues(t, []string{"test_keyspace", "test_keyspace2"}, ksList)
 
 	// Only get 1 keyspace for all types.
-	ks, err = FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace2"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
+	targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{"test_keyspace2"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
 	assert.NoError(t, err)
 	assert.EqualValues(t, []*querypb.Target{
 		{
@@ -168,10 +173,11 @@ func TestFindAllTargets(t *testing.T) {
 			Shard:      "test_shard2",
 			TabletType: topodatapb.TabletType_REPLICA,
 		},
-	}, ks)
+	}, targets)
+	assert.EqualValues(t, []string{"test_keyspace2"}, ksList)
 
 	// Only get the REPLICA targets for any keyspace.
-	ks, err = FindAllTargets(ctx, rs, "cell1", []string{}, []topodatapb.TabletType{topodatapb.TabletType_REPLICA})
+	targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{}, []topodatapb.TabletType{topodatapb.TabletType_REPLICA})
 	assert.NoError(t, err)
 	assert.Equal(t, []*querypb.Target{
 		{
@@ -180,10 +186,13 @@ func TestFindAllTargets(t *testing.T) {
 			Shard:      "test_shard2",
 			TabletType: topodatapb.TabletType_REPLICA,
 		},
-	}, ks)
+	}, targets)
+	sort.Strings(ksList)
+	assert.EqualValues(t, []string{"test_keyspace", "test_keyspace2"}, ksList)
 
 	// Get non-existent keyspace.
-	ks, err = FindAllTargets(ctx, rs, "cell1", []string{"doesnt-exist"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
+	targets, ksList, err = FindAllTargetsAndKeyspaces(ctx, rs, "cell1", []string{"doesnt-exist"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA})
 	assert.NoError(t, err)
-	assert.Len(t, ks, 0)
+	assert.Len(t, targets, 0)
+	assert.EqualValues(t, []string{"doesnt-exist"}, ksList)
 }
diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go
index 084a5059fd8..21087fe5370 100644
--- a/go/vt/vtgate/tabletgateway.go
+++ b/go/vt/vtgate/tabletgateway.go
@@ -191,11 +191,24 @@ func (gw *TabletGateway) WaitForTablets(ctx context.Context, tabletTypesToWait [
 	}
 
 	// Finds the targets to look for.
-	targets, err := srvtopo.FindAllTargets(ctx, gw.srvTopoServer, gw.localCell, discovery.KeyspacesToWatch, tabletTypesToWait)
+	targets, keyspaces, err := srvtopo.FindAllTargetsAndKeyspaces(ctx, gw.srvTopoServer, gw.localCell, discovery.KeyspacesToWatch, tabletTypesToWait)
 	if err != nil {
 		return err
 	}
-	return gw.hc.WaitForAllServingTablets(ctx, targets)
+	err = gw.hc.WaitForAllServingTablets(ctx, targets)
+	if err != nil {
+		return err
+	}
+	// After having waited for all serving tablets. We should also wait for the keyspace event watcher to have seen
+	// the updates and marked all the keyspaces as consistent (if we want to wait for primary tablets).
+	// Otherwise, we could be in a situation where even though the healthchecks have arrived, the keyspace event watcher hasn't finished processing them.
+	// So, if a primary tablet goes non-serving (because of a PRS or some other reason), we won't be able to start buffering.
+	// Waiting for the keyspaces to become consistent ensures that all the primary tablets for all the shards should be serving as seen by the keyspace event watcher
+	// and any disruption from now on, will make sure we start buffering properly.
+	if topoproto.IsTypeInList(topodatapb.TabletType_PRIMARY, tabletTypesToWait) && gw.kev != nil {
+		return gw.kev.WaitForConsistentKeyspaces(ctx, keyspaces)
+	}
+	return nil
 }
 
 // Close shuts down underlying connections.
@@ -282,18 +295,21 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target,
 		if len(tablets) == 0 {
 			// if we have a keyspace event watcher, check if the reason why our primary is not available is that it's currently being resharded
 			// or if a reparent operation is in progress.
-			if kev := gw.kev; kev != nil {
+			// We only check for whether reshard is ongoing or primary is serving or not, only if the target is primary. We don't want to buffer
+			// replica queries, so it doesn't make any sense to check for resharding or reparenting in that case.
+			if kev := gw.kev; kev != nil && target.TabletType == topodatapb.TabletType_PRIMARY {
 				if kev.TargetIsBeingResharded(ctx, target) {
 					log.V(2).Infof("current keyspace is being resharded, retrying: %s: %s", target.Keyspace, debug.Stack())
 					err = vterrors.Errorf(vtrpcpb.Code_CLUSTER_EVENT, buffer.ClusterEventReshardingInProgress)
 					continue
 				}
-				primary, notServing := kev.PrimaryIsNotServing(ctx, target)
-				if notServing {
+				primary, shouldBuffer := kev.ShouldStartBufferingForTarget(ctx, target)
+				if shouldBuffer {
 					err = vterrors.Errorf(vtrpcpb.Code_CLUSTER_EVENT, buffer.ClusterEventReparentInProgress)
 					continue
 				}
-				// if primary is serving, but we initially found no tablet, we're in an inconsistent state
+				// if the keyspace event manager doesn't think we should buffer queries, and also sees a primary tablet,
+				// but we initially found no tablet, we're in an inconsistent state
 				// we then retry the entire loop
 				if primary != nil {
 					err = vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "inconsistent state detected, primary is serving but initially found no available tablet")
diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go
index 74e6751162a..fbca19ecbad 100644
--- a/go/vt/vtgate/tabletgateway_flaky_test.go
+++ b/go/vt/vtgate/tabletgateway_flaky_test.go
@@ -67,7 +67,7 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) {
 	waitForBuffering := func(enabled bool) {
 		timer := time.NewTimer(bufferingWaitTimeout)
 		defer timer.Stop()
-		for _, buffering := tg.kev.PrimaryIsNotServing(ctx, target); buffering != enabled; _, buffering = tg.kev.PrimaryIsNotServing(ctx, target) {
+		for _, buffering := tg.kev.ShouldStartBufferingForTarget(ctx, target); buffering != enabled; _, buffering = tg.kev.ShouldStartBufferingForTarget(ctx, target) {
 			select {
 			case <-timer.C:
 				require.Fail(t, "timed out waiting for buffering of enabled: %t", enabled)
@@ -213,8 +213,8 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) {
 	hc.Broadcast(primaryTablet)
 
 	require.Len(t, tg.hc.GetHealthyTabletStats(target), 0, "GetHealthyTabletStats has tablets even though it shouldn't")
-	_, isNotServing := tg.kev.PrimaryIsNotServing(ctx, target)
-	require.True(t, isNotServing)
+	_, shouldStartBuffering := tg.kev.ShouldStartBufferingForTarget(ctx, target)
+	require.True(t, shouldStartBuffering)
 
 	// add a result to the sandbox connection of the new primary
 	sbcReplica.SetResults([]*sqltypes.Result{sqlResult1})
@@ -244,8 +244,8 @@ outer:
 		case <-timeout:
 			require.Fail(t, "timed out - could not verify the new primary")
 		case <-time.After(10 * time.Millisecond):
-			newPrimary, notServing := tg.kev.PrimaryIsNotServing(ctx, target)
-			if newPrimary != nil && newPrimary.Uid == 1 && !notServing {
+			newPrimary, shouldBuffer := tg.kev.ShouldStartBufferingForTarget(ctx, target)
+			if newPrimary != nil && newPrimary.Uid == 1 && !shouldBuffer {
 				break outer
 			}
 		}
diff --git a/go/vt/vtgate/tabletgateway_test.go b/go/vt/vtgate/tabletgateway_test.go
index 32d18dcc9ab..fc86ab358c8 100644
--- a/go/vt/vtgate/tabletgateway_test.go
+++ b/go/vt/vtgate/tabletgateway_test.go
@@ -26,6 +26,7 @@ import (
 	"github.com/stretchr/testify/require"
 
 	"vitess.io/vitess/go/test/utils"
+	"vitess.io/vitess/go/vt/vttablet/queryservice"
 
 	"vitess.io/vitess/go/sqltypes"
 	"vitess.io/vitess/go/vt/discovery"
@@ -298,3 +299,58 @@ func verifyShardErrors(t *testing.T, err error, wantErrors []string, wantCode vt
 	}
 	require.Equal(t, vterrors.Code(err), wantCode, "wanted error code: %s, got: %v", wantCode, vterrors.Code(err))
 }
+
+// TestWithRetry tests the functionality of withRetry function in different circumstances.
+func TestWithRetry(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	tg := NewTabletGateway(ctx, discovery.NewFakeHealthCheck(nil), &fakeTopoServer{}, "cell")
+	tg.kev = discovery.NewKeyspaceEventWatcher(ctx, tg.srvTopoServer, tg.hc, tg.localCell)
+	defer func() {
+		cancel()
+		tg.Close(ctx)
+	}()
+
+	testcases := []struct {
+		name          string
+		target        *querypb.Target
+		inTransaction bool
+		inner         func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error)
+		expectedErr   string
+	}{
+		{
+			name: "Transaction on a replica",
+			target: &querypb.Target{
+				Keyspace:   "ks",
+				Shard:      "0",
+				TabletType: topodatapb.TabletType_REPLICA,
+			},
+			inTransaction: true,
+			inner: func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error) {
+				return false, nil
+			},
+			expectedErr: "tabletGateway's query service can only be used for non-transactional queries on replicas",
+		}, {
+			name: "No replica tablets available",
+			target: &querypb.Target{
+				Keyspace:   "ks",
+				Shard:      "0",
+				TabletType: topodatapb.TabletType_REPLICA,
+			},
+			inTransaction: false,
+			inner: func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error) {
+				return false, nil
+			},
+			expectedErr: `target: ks.0.replica: no healthy tablet available for 'keyspace:"ks" shard:"0" tablet_type:REPLICA'`,
+		},
+	}
+	for _, tt := range testcases {
+		t.Run(tt.name, func(t *testing.T) {
+			err := tg.withRetry(ctx, tt.target, nil, "", tt.inTransaction, tt.inner)
+			if tt.expectedErr == "" {
+				require.NoError(t, err)
+			} else {
+				require.ErrorContains(t, err, tt.expectedErr)
+			}
+		})
+	}
+}

From e24e5d97cbe6574f57f8bd7f511856811a52c172 Mon Sep 17 00:00:00 2001
From: Manan Gupta <manan@planetscale.com>
Date: Sat, 19 Oct 2024 12:03:12 +0530
Subject: [PATCH 05/12] test: add a test for premature buffering

Signed-off-by: Manan Gupta <manan@planetscale.com>
Signed-off-by: Arthur Schreiber <arthurschreiber@github.com>
---
 .../reparent/newfeaturetest/reparent_test.go  | 67 +++++++++++++++++++
 go/test/endtoend/reparent/utils/utils.go      | 52 +++++++++++++-
 2 files changed, 118 insertions(+), 1 deletion(-)

diff --git a/go/test/endtoend/reparent/newfeaturetest/reparent_test.go b/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
index d5f37dc8604..7782afecb21 100644
--- a/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
+++ b/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
@@ -19,10 +19,14 @@ package newfeaturetest
 import (
 	"context"
 	"fmt"
+	"math/rand/v2"
+	"sync"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/require"
 
+	"vitess.io/vitess/go/mysql"
 	"vitess.io/vitess/go/test/endtoend/cluster"
 	"vitess.io/vitess/go/test/endtoend/reparent/utils"
 )
@@ -146,3 +150,66 @@ func TestChangeTypeWithoutSemiSync(t *testing.T) {
 	err = clusterInstance.VtctlclientProcess.ExecuteCommand("ChangeTabletType", replica.Alias, "replica")
 	require.NoError(t, err)
 }
+
+func TestSimultaneousPRS(t *testing.T) {
+	defer cluster.PanicHandler(t)
+	clusterInstance := utils.SetupShardedReparentCluster(t, "semi_sync")
+	defer utils.TeardownCluster(clusterInstance)
+
+	// Start by reparenting all the shards to the first tablet.
+	keyspace := clusterInstance.Keyspaces[0]
+	shards := keyspace.Shards
+	for _, shard := range shards {
+		err := clusterInstance.VtctldClientProcess.PlannedReparentShard(keyspace.Name, shard.Name, shard.Vttablets[0].Alias)
+		require.NoError(t, err)
+	}
+
+	rowCount := 1000
+	vtParams := clusterInstance.GetVTParams(keyspace.Name)
+	conn, err := mysql.Connect(context.Background(), &vtParams)
+	require.NoError(t, err)
+	// Now, we need to insert some data into the cluster.
+	for i := 1; i <= rowCount; i++ {
+		_, err = conn.ExecuteFetch(utils.GetInsertQuery(i), 0, false)
+		require.NoError(t, err)
+	}
+
+	// Now we start a goroutine that continues to read the data until we've finished the test.
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	go func() {
+		tick := time.NewTicker(100 * time.Millisecond)
+		defer tick.Stop()
+		for {
+			select {
+			case <-ctx.Done():
+				return
+			case <-tick.C:
+				go func() {
+					conn, err := mysql.Connect(context.Background(), &vtParams)
+					if err != nil {
+						return
+					}
+					// We're running queries every 100 millisecond and verifying the results are all correct.
+					res, err := conn.ExecuteFetch(utils.GetSelectionQuery(), rowCount+10, false)
+					require.NoError(t, err)
+					require.Len(t, res.Rows, rowCount)
+				}()
+			}
+		}
+	}()
+
+	// Now, we run go routines to run PRS calls on all the shards simultaneously.
+	wg := sync.WaitGroup{}
+	for _, shard := range shards {
+		wg.Add(1)
+		go func() {
+			time.Sleep(time.Second * time.Duration(rand.IntN(6)))
+			defer wg.Done()
+			err := clusterInstance.VtctldClientProcess.PlannedReparentShard(keyspace.Name, shard.Name, shard.Vttablets[1].Alias)
+			require.NoError(t, err)
+		}()
+	}
+	wg.Wait()
+	cancel()
+}
diff --git a/go/test/endtoend/reparent/utils/utils.go b/go/test/endtoend/reparent/utils/utils.go
index 0f48f2b3fa8..4532724f8b8 100644
--- a/go/test/endtoend/reparent/utils/utils.go
+++ b/go/test/endtoend/reparent/utils/utils.go
@@ -54,7 +54,7 @@ var (
 	id bigint,
 	msg varchar(64),
 	primary key (id)
-	) Engine=InnoDB	
+	) Engine=InnoDB
 `
 	cell1                  = "zone1"
 	cell2                  = "zone2"
@@ -75,6 +75,56 @@ func SetupRangeBasedCluster(ctx context.Context, t *testing.T) *cluster.LocalPro
 	return setupCluster(ctx, t, ShardName, []string{cell1}, []int{2}, "semi_sync")
 }
 
+// SetupShardedReparentCluster is used to setup a sharded cluster for testing
+func SetupShardedReparentCluster(t *testing.T, durability string) *cluster.LocalProcessCluster {
+	clusterInstance := cluster.NewCluster(cell1, Hostname)
+	// Start topo server
+	err := clusterInstance.StartTopo()
+	require.NoError(t, err)
+
+	clusterInstance.VtTabletExtraArgs = append(clusterInstance.VtTabletExtraArgs,
+		"--lock_tables_timeout", "5s",
+		"--track_schema_versions=true",
+		"--queryserver_enable_online_ddl=false")
+	clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs,
+		"--enable_buffer",
+		// Long timeout in case failover is slow.
+		"--buffer_window", "10m",
+		"--buffer_max_failover_duration", "10m",
+		"--buffer_min_time_between_failovers", "20m",
+	)
+
+	// Start keyspace
+	keyspace := &cluster.Keyspace{
+		Name:      KeyspaceName,
+		SchemaSQL: sqlSchema,
+		VSchema:   `{"sharded": true, "vindexes": {"hash_index": {"type": "hash"}}, "tables": {"vt_insert_test": {"column_vindexes": [{"column": "id", "name": "hash_index"}]}}}`,
+	}
+	err = clusterInstance.StartKeyspace(*keyspace, []string{"-40", "40-80", "80-"}, 2, false)
+	require.NoError(t, err)
+
+	if clusterInstance.VtctlMajorVersion >= 14 {
+		clusterInstance.VtctldClientProcess = *cluster.VtctldClientProcessInstance("localhost", clusterInstance.VtctldProcess.GrpcPort, clusterInstance.TmpDirectory)
+		out, err := clusterInstance.VtctldClientProcess.ExecuteCommandWithOutput("SetKeyspaceDurabilityPolicy", KeyspaceName, fmt.Sprintf("--durability-policy=%s", durability))
+		require.NoError(t, err, out)
+	}
+
+	// Start Vtgate
+	err = clusterInstance.StartVtgate()
+	require.NoError(t, err)
+	return clusterInstance
+}
+
+// GetInsertQuery returns a built insert query to insert a row.
+func GetInsertQuery(idx int) string {
+	return fmt.Sprintf(insertSQL, idx, idx)
+}
+
+// GetSelectionQuery returns a built selection query read the data.
+func GetSelectionQuery() string {
+	return `select * from vt_insert_test`
+}
+
 // TeardownCluster is used to teardown the reparent cluster. When
 // run in a CI environment -- which is considered true when the
 // "CI" env variable is set to "true" -- the teardown also removes

From 01124f2ceccf15d65e829b3b8894dc868fff7070 Mon Sep 17 00:00:00 2001
From: Manan Gupta <manan@planetscale.com>
Date: Mon, 21 Oct 2024 16:24:31 +0530
Subject: [PATCH 06/12] test: update test to reproduce the buffering problem

Signed-off-by: Manan Gupta <manan@planetscale.com>
Signed-off-by: Arthur Schreiber <arthurschreiber@github.com>
---
 .../reparent/newfeaturetest/reparent_test.go  | 80 +++++++++----------
 1 file changed, 36 insertions(+), 44 deletions(-)

diff --git a/go/test/endtoend/reparent/newfeaturetest/reparent_test.go b/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
index 7782afecb21..dc3f64f0964 100644
--- a/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
+++ b/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
@@ -19,10 +19,8 @@ package newfeaturetest
 import (
 	"context"
 	"fmt"
-	"math/rand/v2"
 	"sync"
 	"testing"
-	"time"
 
 	"github.com/stretchr/testify/require"
 
@@ -151,11 +149,17 @@ func TestChangeTypeWithoutSemiSync(t *testing.T) {
 	require.NoError(t, err)
 }
 
-func TestSimultaneousPRS(t *testing.T) {
+func TestBufferingWithMultipleDisruptions(t *testing.T) {
 	defer cluster.PanicHandler(t)
 	clusterInstance := utils.SetupShardedReparentCluster(t, "semi_sync")
 	defer utils.TeardownCluster(clusterInstance)
 
+	// Stop all VTOrc instances, so that they don't interfere with the test.
+	for _, vtorc := range clusterInstance.VTOrcProcesses {
+		err := vtorc.TearDown()
+		require.NoError(t, err)
+	}
+
 	// Start by reparenting all the shards to the first tablet.
 	keyspace := clusterInstance.Keyspaces[0]
 	shards := keyspace.Shards
@@ -164,52 +168,40 @@ func TestSimultaneousPRS(t *testing.T) {
 		require.NoError(t, err)
 	}
 
-	rowCount := 1000
-	vtParams := clusterInstance.GetVTParams(keyspace.Name)
-	conn, err := mysql.Connect(context.Background(), &vtParams)
-	require.NoError(t, err)
-	// Now, we need to insert some data into the cluster.
-	for i := 1; i <= rowCount; i++ {
-		_, err = conn.ExecuteFetch(utils.GetInsertQuery(i), 0, false)
-		require.NoError(t, err)
-	}
+	// We simulate start of external reparent or a PRS where the healthcheck update from the tablet gets lost in transit
+	// to vtgate by just setting the primary read only. This is also why we needed to shutdown all VTOrcs, so that they don't
+	// fix this.
+	//utils.RunSQL(context.Background(), t, "set global read_only=1", shards[0].Vttablets[0])
+	//utils.RunSQL(context.Background(), t, "set global read_only=1", shards[1].Vttablets[0])
 
-	// Now we start a goroutine that continues to read the data until we've finished the test.
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	go func() {
-		tick := time.NewTicker(100 * time.Millisecond)
-		defer tick.Stop()
-		for {
-			select {
-			case <-ctx.Done():
-				return
-			case <-tick.C:
-				go func() {
-					conn, err := mysql.Connect(context.Background(), &vtParams)
-					if err != nil {
-						return
-					}
-					// We're running queries every 100 millisecond and verifying the results are all correct.
-					res, err := conn.ExecuteFetch(utils.GetSelectionQuery(), rowCount+10, false)
-					require.NoError(t, err)
-					require.Len(t, res.Rows, rowCount)
-				}()
-			}
-		}
-	}()
-
-	// Now, we run go routines to run PRS calls on all the shards simultaneously.
 	wg := sync.WaitGroup{}
-	for _, shard := range shards {
+	rowCount := 10
+	vtParams := clusterInstance.GetVTParams(keyspace.Name)
+	// We now spawn writes for a bunch of go routines.
+	// The ones going to shard 1 and shard 2 should block, since
+	// they're in the midst of a reparenting operation (as seen by the buffering code).
+	for i := 1; i <= rowCount; i++ {
 		wg.Add(1)
-		go func() {
-			time.Sleep(time.Second * time.Duration(rand.IntN(6)))
+		go func(i int) {
 			defer wg.Done()
-			err := clusterInstance.VtctldClientProcess.PlannedReparentShard(keyspace.Name, shard.Name, shard.Vttablets[1].Alias)
+			conn, err := mysql.Connect(context.Background(), &vtParams)
+			if err != nil {
+				return
+			}
+			_, err = conn.ExecuteFetch(utils.GetInsertQuery(i), 0, false)
 			require.NoError(t, err)
-		}()
+		}(i)
 	}
+
+	// Now, run a PRS call on the last shard. This shouldn't unbuffer the queries that are buffered for shards 1 and 2
+	// since the disruption on the two shards hasn't stopped.
+	err := clusterInstance.VtctldClientProcess.PlannedReparentShard(keyspace.Name, shards[2].Name, shards[2].Vttablets[1].Alias)
+	require.NoError(t, err)
+	// We wait a second just to make sure the PRS changes are processed by the buffering logic in vtgate.
+	//time.Sleep(1 * time.Second)
+	// Finally, we'll now simulate the 2 shards being healthy again by setting them back to read-write.
+	//utils.RunSQL(context.Background(), t, "set global read_only=0", shards[0].Vttablets[0])
+	//utils.RunSQL(context.Background(), t, "set global read_only=0", shards[1].Vttablets[0])
+	// Wait for all the writes to have succeeded.
 	wg.Wait()
-	cancel()
 }

From 21a8ad2fe6c761a27b3307f8efefda1f6db2a3f3 Mon Sep 17 00:00:00 2001
From: Manan Gupta <manan@planetscale.com>
Date: Mon, 21 Oct 2024 21:46:40 +0530
Subject: [PATCH 07/12] test: uncomment lines that were accidentally commented

Signed-off-by: Manan Gupta <manan@planetscale.com>
---
 .../reparent/newfeaturetest/reparent_test.go         | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/go/test/endtoend/reparent/newfeaturetest/reparent_test.go b/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
index dc3f64f0964..e52229dc84c 100644
--- a/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
+++ b/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
@@ -21,6 +21,7 @@ import (
 	"fmt"
 	"sync"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/require"
 
@@ -171,8 +172,8 @@ func TestBufferingWithMultipleDisruptions(t *testing.T) {
 	// We simulate start of external reparent or a PRS where the healthcheck update from the tablet gets lost in transit
 	// to vtgate by just setting the primary read only. This is also why we needed to shutdown all VTOrcs, so that they don't
 	// fix this.
-	//utils.RunSQL(context.Background(), t, "set global read_only=1", shards[0].Vttablets[0])
-	//utils.RunSQL(context.Background(), t, "set global read_only=1", shards[1].Vttablets[0])
+	utils.RunSQL(context.Background(), t, "set global read_only=1", shards[0].Vttablets[0])
+	utils.RunSQL(context.Background(), t, "set global read_only=1", shards[1].Vttablets[0])
 
 	wg := sync.WaitGroup{}
 	rowCount := 10
@@ -188,6 +189,7 @@ func TestBufferingWithMultipleDisruptions(t *testing.T) {
 			if err != nil {
 				return
 			}
+			defer conn.Close()
 			_, err = conn.ExecuteFetch(utils.GetInsertQuery(i), 0, false)
 			require.NoError(t, err)
 		}(i)
@@ -198,10 +200,10 @@ func TestBufferingWithMultipleDisruptions(t *testing.T) {
 	err := clusterInstance.VtctldClientProcess.PlannedReparentShard(keyspace.Name, shards[2].Name, shards[2].Vttablets[1].Alias)
 	require.NoError(t, err)
 	// We wait a second just to make sure the PRS changes are processed by the buffering logic in vtgate.
-	//time.Sleep(1 * time.Second)
+	time.Sleep(1 * time.Second)
 	// Finally, we'll now simulate the 2 shards being healthy again by setting them back to read-write.
-	//utils.RunSQL(context.Background(), t, "set global read_only=0", shards[0].Vttablets[0])
-	//utils.RunSQL(context.Background(), t, "set global read_only=0", shards[1].Vttablets[0])
+	utils.RunSQL(context.Background(), t, "set global read_only=0", shards[0].Vttablets[0])
+	utils.RunSQL(context.Background(), t, "set global read_only=0", shards[1].Vttablets[0])
 	// Wait for all the writes to have succeeded.
 	wg.Wait()
 }

From 0c87ea72caecefecc0cf6e4d825877b80eba1336 Mon Sep 17 00:00:00 2001
From: Manan Gupta <manan@planetscale.com>
Date: Fri, 25 Oct 2024 15:38:54 +0530
Subject: [PATCH 08/12] feat: fix the problem by adding a new field in the
 shard state

Signed-off-by: Manan Gupta <manan@planetscale.com>
---
 .../reparent/newfeaturetest/reparent_test.go  |  8 ++-
 go/test/endtoend/reparent/utils/utils.go      |  2 +
 go/vt/discovery/keyspace_events.go            | 67 ++++++++++++++++++-
 go/vt/vtgate/buffer/buffer.go                 | 16 ++++-
 go/vt/vtgate/buffer/buffer_helper_test.go     |  2 +-
 go/vt/vtgate/buffer/buffer_test.go            | 22 +++---
 go/vt/vtgate/buffer/shard_buffer.go           | 20 +++++-
 go/vt/vtgate/buffer/variables_test.go         |  2 +-
 go/vt/vtgate/tabletgateway.go                 |  2 +-
 9 files changed, 116 insertions(+), 25 deletions(-)

diff --git a/go/test/endtoend/reparent/newfeaturetest/reparent_test.go b/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
index e52229dc84c..44732ba12d0 100644
--- a/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
+++ b/go/test/endtoend/reparent/newfeaturetest/reparent_test.go
@@ -201,9 +201,11 @@ func TestBufferingWithMultipleDisruptions(t *testing.T) {
 	require.NoError(t, err)
 	// We wait a second just to make sure the PRS changes are processed by the buffering logic in vtgate.
 	time.Sleep(1 * time.Second)
-	// Finally, we'll now simulate the 2 shards being healthy again by setting them back to read-write.
-	utils.RunSQL(context.Background(), t, "set global read_only=0", shards[0].Vttablets[0])
-	utils.RunSQL(context.Background(), t, "set global read_only=0", shards[1].Vttablets[0])
+	// Finally, we'll now make the 2 shards healthy again by running PRS.
+	err = clusterInstance.VtctldClientProcess.PlannedReparentShard(keyspace.Name, shards[0].Name, shards[0].Vttablets[1].Alias)
+	require.NoError(t, err)
+	err = clusterInstance.VtctldClientProcess.PlannedReparentShard(keyspace.Name, shards[1].Name, shards[1].Vttablets[1].Alias)
+	require.NoError(t, err)
 	// Wait for all the writes to have succeeded.
 	wg.Wait()
 }
diff --git a/go/test/endtoend/reparent/utils/utils.go b/go/test/endtoend/reparent/utils/utils.go
index 4532724f8b8..c5ddf75d667 100644
--- a/go/test/endtoend/reparent/utils/utils.go
+++ b/go/test/endtoend/reparent/utils/utils.go
@@ -84,6 +84,8 @@ func SetupShardedReparentCluster(t *testing.T, durability string) *cluster.Local
 
 	clusterInstance.VtTabletExtraArgs = append(clusterInstance.VtTabletExtraArgs,
 		"--lock_tables_timeout", "5s",
+		// Fast health checks help find corner cases.
+		"--health_check_interval", "1s",
 		"--track_schema_versions=true",
 		"--queryserver_enable_online_ddl=false")
 	clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs,
diff --git a/go/vt/discovery/keyspace_events.go b/go/vt/discovery/keyspace_events.go
index 036d4f3ad14..8332e99679b 100644
--- a/go/vt/discovery/keyspace_events.go
+++ b/go/vt/discovery/keyspace_events.go
@@ -171,8 +171,12 @@ func (kss *keyspaceState) beingResharded(currentShard string) bool {
 }
 
 type shardState struct {
-	target               *querypb.Target
-	serving              bool
+	target  *querypb.Target
+	serving bool
+	// waitForReparent is used to tell the keyspace event watcher
+	// that this shard should be marked serving only after a reparent
+	// operation has succeeded.
+	waitForReparent      bool
 	externallyReparented int64
 	currentPrimary       *topodatapb.TabletAlias
 }
@@ -361,8 +365,32 @@ func (kss *keyspaceState) onHealthCheck(th *TabletHealth) {
 	// if the shard went from serving to not serving, or the other way around, the keyspace
 	// is undergoing an availability event
 	if sstate.serving != th.Serving {
-		sstate.serving = th.Serving
 		kss.consistent = false
+		switch {
+		case th.Serving && sstate.waitForReparent:
+			// While waiting for a reparent, if we receive a serving primary,
+			// we should check if the primary term start time is greater than the externally reparented time.
+			// We mark the shard serving only if it is. This is required so that we don't prematurely stop
+			// buffering for PRS, or TabletExternallyReparented, after seeing a serving healthcheck from the
+			// same old primary tablet that has already been turned read-only.
+			if th.PrimaryTermStartTime > sstate.externallyReparented {
+				sstate.waitForReparent = false
+				sstate.serving = true
+			}
+		case th.Serving && !sstate.waitForReparent:
+			sstate.serving = true
+		case !th.Serving:
+			sstate.serving = false
+			// Once we have seen a non-serving primary healthcheck, there is no need for us to explicitly wait
+			// for a reparent to happen. We use waitForReparent to ensure that we don't prematurely stop
+			// buffering when we receive a serving healthcheck from the primary that is being demoted.
+			// However, if we receive a non-serving check, then we know that we won't receive any more serving
+			// healthchecks anymore until reparent finishes. Specifically, this helps us when PRS fails, but
+			// stops gracefully because the new candidate couldn't get caught up in time. In this case, we promote
+			// the previous primary back. Without turning off waitForReparent here, we wouldn't be able to stop
+			// buffering for that case.
+			sstate.waitForReparent = false
+		}
 	}
 
 	// if the primary for this shard has been externally reparented, we're undergoing a failover,
@@ -777,3 +805,36 @@ func (kew *KeyspaceEventWatcher) WaitForConsistentKeyspaces(ctx context.Context,
 		}
 	}
 }
+
+// MarkShardNotServing marks the given shard not serving.
+// We use this when we start buffering for a given shard. This helps
+// coordinate between the sharding logic and the keyspace event watcher.
+// We take in a boolean as well to tell us whether this error is because
+// a reparent is ongoing. If it is, we also mark the shard to wait for a reparent.
+// The return argument is whether the shard was found and marked not serving successfully or not.
+func (kew *KeyspaceEventWatcher) MarkShardNotServing(ctx context.Context, keyspace string, shard string, isReparentErr bool) bool {
+	kss := kew.getKeyspaceStatus(ctx, keyspace)
+	if kss == nil {
+		// Only happens if the keyspace was deleted.
+		return false
+	}
+	kss.mu.Lock()
+	defer kss.mu.Unlock()
+	sstate := kss.shards[shard]
+	if sstate == nil {
+		// This only happens if the shard is deleted, or if
+		// the keyspace event watcher hasn't seen the shard at all.
+		return false
+	}
+	// Mark the keyspace inconsistent and the shard not serving.
+	kss.consistent = false
+	sstate.serving = false
+	if isReparentErr {
+		// If the error was triggered because a reparent operation has started.
+		// We mark the shard to wait for a reparent to finish before marking it serving.
+		// This is required to prevent premature stopping of buffering if we receive
+		// a serving healthcheck from a primary that is being demoted.
+		sstate.waitForReparent = true
+	}
+	return true
+}
diff --git a/go/vt/vtgate/buffer/buffer.go b/go/vt/vtgate/buffer/buffer.go
index 260fb272544..126a3a8826d 100644
--- a/go/vt/vtgate/buffer/buffer.go
+++ b/go/vt/vtgate/buffer/buffer.go
@@ -94,6 +94,18 @@ func CausedByFailover(err error) bool {
 	return isFailover
 }
 
+// IsErrorDueToReparenting is a stronger check than CausedByFailover, meant to return
+// if the failure is caused because of a reparent.
+func IsErrorDueToReparenting(err error) bool {
+	if vterrors.Code(err) != vtrpcpb.Code_CLUSTER_EVENT {
+		return false
+	}
+	if strings.Contains(err.Error(), ClusterEventReshardingInProgress) {
+		return false
+	}
+	return true
+}
+
 // for debugging purposes
 func getReason(err error) string {
 	for _, ce := range ClusterEvents {
@@ -175,7 +187,7 @@ func (b *Buffer) GetConfig() *Config {
 // It returns an error if buffering failed (e.g. buffer full).
 // If it does not return an error, it may return a RetryDoneFunc which must be
 // called after the request was retried.
-func (b *Buffer) WaitForFailoverEnd(ctx context.Context, keyspace, shard string, err error) (RetryDoneFunc, error) {
+func (b *Buffer) WaitForFailoverEnd(ctx context.Context, keyspace, shard string, kev *discovery.KeyspaceEventWatcher, err error) (RetryDoneFunc, error) {
 	// If an err is given, it must be related to a failover.
 	// We never buffer requests with other errors.
 	if err != nil && !CausedByFailover(err) {
@@ -192,7 +204,7 @@ func (b *Buffer) WaitForFailoverEnd(ctx context.Context, keyspace, shard string,
 		requestsSkipped.Add([]string{keyspace, shard, skippedDisabled}, 1)
 		return nil, nil
 	}
-	return sb.waitForFailoverEnd(ctx, keyspace, shard, err)
+	return sb.waitForFailoverEnd(ctx, keyspace, shard, kev, err)
 }
 
 func (b *Buffer) HandleKeyspaceEvent(ksevent *discovery.KeyspaceEvent) {
diff --git a/go/vt/vtgate/buffer/buffer_helper_test.go b/go/vt/vtgate/buffer/buffer_helper_test.go
index 2deb460fc39..1276f0cd751 100644
--- a/go/vt/vtgate/buffer/buffer_helper_test.go
+++ b/go/vt/vtgate/buffer/buffer_helper_test.go
@@ -50,7 +50,7 @@ func issueRequestAndBlockRetry(ctx context.Context, t *testing.T, b *Buffer, err
 	bufferingStopped := make(chan error)
 
 	go func() {
-		retryDone, err := b.WaitForFailoverEnd(ctx, keyspace, shard, failoverErr)
+		retryDone, err := b.WaitForFailoverEnd(ctx, keyspace, shard, nil, failoverErr)
 		if err != nil {
 			bufferingStopped <- err
 		}
diff --git a/go/vt/vtgate/buffer/buffer_test.go b/go/vt/vtgate/buffer/buffer_test.go
index 7f32364d57f..2973ef2dfb9 100644
--- a/go/vt/vtgate/buffer/buffer_test.go
+++ b/go/vt/vtgate/buffer/buffer_test.go
@@ -107,7 +107,7 @@ func testBuffering1(t *testing.T, fail failover) {
 	}
 
 	// Subsequent requests with errors not related to the failover are not buffered.
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nonFailoverErr); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil, nonFailoverErr); err != nil || retryDone != nil {
 		t.Fatalf("requests with non-failover errors must never be buffered. err: %v retryDone: %v", err, retryDone)
 	}
 
@@ -155,7 +155,7 @@ func testBuffering1(t *testing.T, fail failover) {
 	}
 
 	// Second failover: Buffering is skipped because last failover is too recent.
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, failoverErr); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil, failoverErr); err != nil || retryDone != nil {
 		t.Fatalf("subsequent failovers must be skipped due to -buffer_min_time_between_failovers setting. err: %v retryDone: %v", err, retryDone)
 	}
 	if got, want := requestsSkipped.Counts()[statsKeyJoinedLastFailoverTooRecent], int64(1); got != want {
@@ -213,7 +213,7 @@ func testDryRun1(t *testing.T, fail failover) {
 	b := New(cfg)
 
 	// Request does not get buffered.
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, failoverErr); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil, failoverErr); err != nil || retryDone != nil {
 		t.Fatalf("requests must not be buffered during dry-run. err: %v retryDone: %v", err, retryDone)
 	}
 	// But the internal state changes though.
@@ -259,10 +259,10 @@ func testPassthrough1(t *testing.T, fail failover) {
 
 	b := New(cfg)
 
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil, nil); err != nil || retryDone != nil {
 		t.Fatalf("requests with no error must never be buffered. err: %v retryDone: %v", err, retryDone)
 	}
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nonFailoverErr); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil, nonFailoverErr); err != nil || retryDone != nil {
 		t.Fatalf("requests with non-failover errors must never be buffered. err: %v retryDone: %v", err, retryDone)
 	}
 
@@ -298,7 +298,7 @@ func testLastReparentTooRecentBufferingSkipped1(t *testing.T, fail failover) {
 	now = now.Add(1 * time.Second)
 	fail(b, newPrimary, keyspace, shard, now)
 
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, failoverErr); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil, failoverErr); err != nil || retryDone != nil {
 		t.Fatalf("requests where the failover end was recently detected before the start must not be buffered. err: %v retryDone: %v", err, retryDone)
 	}
 	if err := waitForPoolSlots(b, cfg.Size); err != nil {
@@ -395,10 +395,10 @@ func testPassthroughDuringDrain1(t *testing.T, fail failover) {
 	}
 
 	// Requests during the drain will be passed through and not buffered.
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil, nil); err != nil || retryDone != nil {
 		t.Fatalf("requests with no error must not be buffered during a drain. err: %v retryDone: %v", err, retryDone)
 	}
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, failoverErr); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, shard, nil, failoverErr); err != nil || retryDone != nil {
 		t.Fatalf("requests with failover errors must not be buffered during a drain. err: %v retryDone: %v", err, retryDone)
 	}
 
@@ -430,7 +430,7 @@ func testPassthroughIgnoredKeyspaceOrShard1(t *testing.T, fail failover) {
 	b := New(cfg)
 
 	ignoredKeyspace := "ignored_ks"
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), ignoredKeyspace, shard, failoverErr); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), ignoredKeyspace, shard, nil, failoverErr); err != nil || retryDone != nil {
 		t.Fatalf("requests for ignored keyspaces must not be buffered. err: %v retryDone: %v", err, retryDone)
 	}
 	statsKeyJoined := strings.Join([]string{ignoredKeyspace, shard, skippedDisabled}, ".")
@@ -439,7 +439,7 @@ func testPassthroughIgnoredKeyspaceOrShard1(t *testing.T, fail failover) {
 	}
 
 	ignoredShard := "ff-"
-	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, ignoredShard, failoverErr); err != nil || retryDone != nil {
+	if retryDone, err := b.WaitForFailoverEnd(context.Background(), keyspace, ignoredShard, nil, failoverErr); err != nil || retryDone != nil {
 		t.Fatalf("requests for ignored shards must not be buffered. err: %v retryDone: %v", err, retryDone)
 	}
 	if err := waitForPoolSlots(b, cfg.Size); err != nil {
@@ -621,7 +621,7 @@ func testEvictionNotPossible1(t *testing.T, fail failover) {
 
 	// Newer requests of the second failover cannot evict anything because
 	// they have no entries buffered.
-	retryDone, bufferErr := b.WaitForFailoverEnd(context.Background(), keyspace, shard2, failoverErr)
+	retryDone, bufferErr := b.WaitForFailoverEnd(context.Background(), keyspace, shard2, nil, failoverErr)
 	if bufferErr == nil || retryDone != nil {
 		t.Fatalf("buffer should have returned an error because it's full: err: %v retryDone: %v", bufferErr, retryDone)
 	}
diff --git a/go/vt/vtgate/buffer/shard_buffer.go b/go/vt/vtgate/buffer/shard_buffer.go
index a58b86f670a..34fa919cfac 100644
--- a/go/vt/vtgate/buffer/shard_buffer.go
+++ b/go/vt/vtgate/buffer/shard_buffer.go
@@ -136,7 +136,7 @@ func (sb *shardBuffer) disabled() bool {
 	return sb.mode == bufferModeDisabled
 }
 
-func (sb *shardBuffer) waitForFailoverEnd(ctx context.Context, keyspace, shard string, err error) (RetryDoneFunc, error) {
+func (sb *shardBuffer) waitForFailoverEnd(ctx context.Context, keyspace, shard string, kev *discovery.KeyspaceEventWatcher, err error) (RetryDoneFunc, error) {
 	// We assume if err != nil then it's always caused by a failover.
 	// Other errors must be filtered at higher layers.
 	failoverDetected := err != nil
@@ -210,7 +210,11 @@ func (sb *shardBuffer) waitForFailoverEnd(ctx context.Context, keyspace, shard s
 			return nil, nil
 		}
 
-		sb.startBufferingLocked(err)
+		// Try to start buffering. If we're unsuccessful, then we exit early.
+		if !sb.startBufferingLocked(ctx, kev, err) {
+			sb.mu.Unlock()
+			return nil, nil
+		}
 	}
 
 	if sb.mode == bufferModeDryRun {
@@ -254,7 +258,16 @@ func (sb *shardBuffer) shouldBufferLocked(failoverDetected bool) bool {
 	panic("BUG: All possible states must be covered by the switch expression above.")
 }
 
-func (sb *shardBuffer) startBufferingLocked(err error) {
+func (sb *shardBuffer) startBufferingLocked(ctx context.Context, kev *discovery.KeyspaceEventWatcher, err error) bool {
+	if kev != nil {
+		if !kev.MarkShardNotServing(ctx, sb.keyspace, sb.shard, IsErrorDueToReparenting(err)) {
+			// We failed to mark the shard as not serving. Do not buffer the request.
+			// This can happen if the keyspace has been deleted or if the keyspace even watcher
+			// hasn't yet seen the shard. Keyspace event watcher might not stop buffering for this
+			// request at all until it times out. It's better to not buffer this request.
+			return false
+		}
+	}
 	// Reset monitoring data from previous failover.
 	lastRequestsInFlightMax.Set(sb.statsKey, 0)
 	lastRequestsDryRunMax.Set(sb.statsKey, 0)
@@ -280,6 +293,7 @@ func (sb *shardBuffer) startBufferingLocked(err error) {
 		sb.buf.config.MaxFailoverDuration,
 		errorsanitizer.NormalizeError(err.Error()),
 	)
+	return true
 }
 
 // logErrorIfStateNotLocked logs an error if the current state is not "state".
diff --git a/go/vt/vtgate/buffer/variables_test.go b/go/vt/vtgate/buffer/variables_test.go
index a0640bde9e4..30d2426c639 100644
--- a/go/vt/vtgate/buffer/variables_test.go
+++ b/go/vt/vtgate/buffer/variables_test.go
@@ -51,7 +51,7 @@ func TestVariablesAreInitialized(t *testing.T) {
 	// Create a new buffer and make a call which will create the shardBuffer object.
 	// After that, the variables should be initialized for that shard.
 	b := New(NewDefaultConfig())
-	_, err := b.WaitForFailoverEnd(context.Background(), "init_test", "0", nil /* err */)
+	_, err := b.WaitForFailoverEnd(context.Background(), "init_test", "0", nil, nil)
 	if err != nil {
 		t.Fatalf("buffer should just passthrough and not return an error: %v", err)
 	}
diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go
index 21087fe5370..2a7be6595bb 100644
--- a/go/vt/vtgate/tabletgateway.go
+++ b/go/vt/vtgate/tabletgateway.go
@@ -273,7 +273,7 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target,
 		// b) no transaction was created yet.
 		if gw.buffer != nil && !bufferedOnce && !inTransaction && target.TabletType == topodatapb.TabletType_PRIMARY {
 			// The next call blocks if we should buffer during a failover.
-			retryDone, bufferErr := gw.buffer.WaitForFailoverEnd(ctx, target.Keyspace, target.Shard, err)
+			retryDone, bufferErr := gw.buffer.WaitForFailoverEnd(ctx, target.Keyspace, target.Shard, gw.kev, err)
 
 			// Request may have been buffered.
 			if retryDone != nil {

From 517f5192b9f28502700deb1e754c81505b08f5e9 Mon Sep 17 00:00:00 2001
From: Manan Gupta <manan@planetscale.com>
Date: Fri, 25 Oct 2024 16:23:09 +0530
Subject: [PATCH 09/12] test: fix test to also send a higher primary timestamp

Signed-off-by: Manan Gupta <manan@planetscale.com>
---
 go/vt/discovery/fake_healthcheck.go      | 15 +++++++++++++++
 go/vt/vtgate/tabletgateway_flaky_test.go |  3 ++-
 2 files changed, 17 insertions(+), 1 deletion(-)

diff --git a/go/vt/discovery/fake_healthcheck.go b/go/vt/discovery/fake_healthcheck.go
index cb959902c19..823ad4e5503 100644
--- a/go/vt/discovery/fake_healthcheck.go
+++ b/go/vt/discovery/fake_healthcheck.go
@@ -172,6 +172,21 @@ func (fhc *FakeHealthCheck) SetTabletType(tablet *topodatapb.Tablet, tabletType
 	item.ts.Target.TabletType = tabletType
 }
 
+// SetPrimaryTimestamp sets the primary timestamp for the given tablet
+func (fhc *FakeHealthCheck) SetPrimaryTimestamp(tablet *topodatapb.Tablet, timestamp int64) {
+	if fhc.ch == nil {
+		return
+	}
+	fhc.mu.Lock()
+	defer fhc.mu.Unlock()
+	key := TabletToMapKey(tablet)
+	item, isPresent := fhc.items[key]
+	if !isPresent {
+		return
+	}
+	item.ts.PrimaryTermStartTime = timestamp
+}
+
 // Unsubscribe is not implemented.
 func (fhc *FakeHealthCheck) Unsubscribe(c chan *TabletHealth) {
 }
diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go
index fbca19ecbad..55b559ef3ca 100644
--- a/go/vt/vtgate/tabletgateway_flaky_test.go
+++ b/go/vt/vtgate/tabletgateway_flaky_test.go
@@ -234,6 +234,7 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) {
 	hc.SetTabletType(primaryTablet, topodatapb.TabletType_REPLICA)
 	hc.Broadcast(primaryTablet)
 	hc.SetTabletType(replicaTablet, topodatapb.TabletType_PRIMARY)
+	hc.SetPrimaryTimestamp(replicaTablet, 100) // We set a higher timestamp than before to simulate a PRS.
 	hc.SetServing(replicaTablet, true)
 	hc.Broadcast(replicaTablet)
 
@@ -245,7 +246,7 @@ outer:
 			require.Fail(t, "timed out - could not verify the new primary")
 		case <-time.After(10 * time.Millisecond):
 			newPrimary, shouldBuffer := tg.kev.ShouldStartBufferingForTarget(ctx, target)
-			if newPrimary != nil && newPrimary.Uid == 1 && !shouldBuffer {
+			if newPrimary != nil && newPrimary.Uid == replicaTablet.Alias.Uid && !shouldBuffer {
 				break outer
 			}
 		}

From d77fae7d2e42416279ec7db72b02bfeb2998703c Mon Sep 17 00:00:00 2001
From: Manan Gupta <manan@planetscale.com>
Date: Fri, 25 Oct 2024 16:54:12 +0530
Subject: [PATCH 10/12] feat: add tests for processing of health checks and
 also fix a bug

Signed-off-by: Manan Gupta <manan@planetscale.com>
---
 go/vt/discovery/keyspace_events.go      |  20 ++-
 go/vt/discovery/keyspace_events_test.go | 228 ++++++++++++++++++++++++
 2 files changed, 239 insertions(+), 9 deletions(-)

diff --git a/go/vt/discovery/keyspace_events.go b/go/vt/discovery/keyspace_events.go
index 8332e99679b..4a82b173e9b 100644
--- a/go/vt/discovery/keyspace_events.go
+++ b/go/vt/discovery/keyspace_events.go
@@ -381,17 +381,19 @@ func (kss *keyspaceState) onHealthCheck(th *TabletHealth) {
 			sstate.serving = true
 		case !th.Serving:
 			sstate.serving = false
-			// Once we have seen a non-serving primary healthcheck, there is no need for us to explicitly wait
-			// for a reparent to happen. We use waitForReparent to ensure that we don't prematurely stop
-			// buffering when we receive a serving healthcheck from the primary that is being demoted.
-			// However, if we receive a non-serving check, then we know that we won't receive any more serving
-			// healthchecks anymore until reparent finishes. Specifically, this helps us when PRS fails, but
-			// stops gracefully because the new candidate couldn't get caught up in time. In this case, we promote
-			// the previous primary back. Without turning off waitForReparent here, we wouldn't be able to stop
-			// buffering for that case.
-			sstate.waitForReparent = false
 		}
 	}
+	if !th.Serving {
+		// Once we have seen a non-serving primary healthcheck, there is no need for us to explicitly wait
+		// for a reparent to happen. We use waitForReparent to ensure that we don't prematurely stop
+		// buffering when we receive a serving healthcheck from the primary that is being demoted.
+		// However, if we receive a non-serving check, then we know that we won't receive any more serving
+		// healthchecks anymore until reparent finishes. Specifically, this helps us when PRS fails, but
+		// stops gracefully because the new candidate couldn't get caught up in time. In this case, we promote
+		// the previous primary back. Without turning off waitForReparent here, we wouldn't be able to stop
+		// buffering for that case.
+		sstate.waitForReparent = false
+	}
 
 	// if the primary for this shard has been externally reparented, we're undergoing a failover,
 	// which is considered an availability event. update this shard to point it to the new tablet
diff --git a/go/vt/discovery/keyspace_events_test.go b/go/vt/discovery/keyspace_events_test.go
index 1a4c473e7cb..87ecac3355d 100644
--- a/go/vt/discovery/keyspace_events_test.go
+++ b/go/vt/discovery/keyspace_events_test.go
@@ -421,6 +421,234 @@ func TestWaitForConsistentKeyspaces(t *testing.T) {
 	}
 }
 
+func TestOnHealthCheck(t *testing.T) {
+	testcases := []struct {
+		name                     string
+		ss                       *shardState
+		th                       *TabletHealth
+		wantServing              bool
+		wantWaitForReparent      bool
+		wantExternallyReparented int64
+		wantUID                  uint32
+	}{
+		{
+			name: "Non primary tablet health ignored",
+			ss: &shardState{
+				serving:              false,
+				waitForReparent:      false,
+				externallyReparented: 10,
+				currentPrimary: &topodatapb.TabletAlias{
+					Cell: testCell,
+					Uid:  1,
+				},
+			},
+			th: &TabletHealth{
+				Target: &querypb.Target{
+					TabletType: topodatapb.TabletType_REPLICA,
+				},
+				Serving: true,
+			},
+			wantServing:              false,
+			wantWaitForReparent:      false,
+			wantExternallyReparented: 10,
+			wantUID:                  1,
+		}, {
+			name: "Serving primary seen in non-serving shard",
+			ss: &shardState{
+				serving:              false,
+				waitForReparent:      false,
+				externallyReparented: 10,
+				currentPrimary: &topodatapb.TabletAlias{
+					Cell: testCell,
+					Uid:  1,
+				},
+			},
+			th: &TabletHealth{
+				Target: &querypb.Target{
+					TabletType: topodatapb.TabletType_PRIMARY,
+				},
+				Serving:              true,
+				PrimaryTermStartTime: 20,
+				Tablet: &topodatapb.Tablet{
+					Alias: &topodatapb.TabletAlias{
+						Cell: testCell,
+						Uid:  2,
+					},
+				},
+			},
+			wantServing:              true,
+			wantWaitForReparent:      false,
+			wantExternallyReparented: 20,
+			wantUID:                  2,
+		}, {
+			name: "New serving primary seen while waiting for reparent",
+			ss: &shardState{
+				serving:              false,
+				waitForReparent:      true,
+				externallyReparented: 10,
+				currentPrimary: &topodatapb.TabletAlias{
+					Cell: testCell,
+					Uid:  1,
+				},
+			},
+			th: &TabletHealth{
+				Target: &querypb.Target{
+					TabletType: topodatapb.TabletType_PRIMARY,
+				},
+				Serving:              true,
+				PrimaryTermStartTime: 20,
+				Tablet: &topodatapb.Tablet{
+					Alias: &topodatapb.TabletAlias{
+						Cell: testCell,
+						Uid:  2,
+					},
+				},
+			},
+			wantServing:              true,
+			wantWaitForReparent:      false,
+			wantExternallyReparented: 20,
+			wantUID:                  2,
+		}, {
+			name: "Old serving primary seen while waiting for reparent",
+			ss: &shardState{
+				serving:              false,
+				waitForReparent:      true,
+				externallyReparented: 10,
+				currentPrimary: &topodatapb.TabletAlias{
+					Cell: testCell,
+					Uid:  1,
+				},
+			},
+			th: &TabletHealth{
+				Target: &querypb.Target{
+					TabletType: topodatapb.TabletType_PRIMARY,
+				},
+				Serving:              true,
+				PrimaryTermStartTime: 10,
+				Tablet: &topodatapb.Tablet{
+					Alias: &topodatapb.TabletAlias{
+						Cell: testCell,
+						Uid:  1,
+					},
+				},
+			},
+			wantServing:              false,
+			wantWaitForReparent:      true,
+			wantExternallyReparented: 10,
+			wantUID:                  1,
+		}, {
+			name: "Old non-serving primary seen while waiting for reparent",
+			ss: &shardState{
+				serving:              false,
+				waitForReparent:      true,
+				externallyReparented: 10,
+				currentPrimary: &topodatapb.TabletAlias{
+					Cell: testCell,
+					Uid:  1,
+				},
+			},
+			th: &TabletHealth{
+				Target: &querypb.Target{
+					TabletType: topodatapb.TabletType_PRIMARY,
+				},
+				Serving:              false,
+				PrimaryTermStartTime: 10,
+				Tablet: &topodatapb.Tablet{
+					Alias: &topodatapb.TabletAlias{
+						Cell: testCell,
+						Uid:  1,
+					},
+				},
+			},
+			wantServing:              false,
+			wantWaitForReparent:      false,
+			wantExternallyReparented: 10,
+			wantUID:                  1,
+		}, {
+			name: "New serving primary while already serving",
+			ss: &shardState{
+				serving:              true,
+				waitForReparent:      false,
+				externallyReparented: 10,
+				currentPrimary: &topodatapb.TabletAlias{
+					Cell: testCell,
+					Uid:  1,
+				},
+			},
+			th: &TabletHealth{
+				Target: &querypb.Target{
+					TabletType: topodatapb.TabletType_PRIMARY,
+				},
+				Serving:              true,
+				PrimaryTermStartTime: 20,
+				Tablet: &topodatapb.Tablet{
+					Alias: &topodatapb.TabletAlias{
+						Cell: testCell,
+						Uid:  2,
+					},
+				},
+			},
+			wantServing:              true,
+			wantWaitForReparent:      false,
+			wantExternallyReparented: 20,
+			wantUID:                  2,
+		}, {
+			name: "Primary goes non serving",
+			ss: &shardState{
+				serving:              true,
+				waitForReparent:      false,
+				externallyReparented: 10,
+				currentPrimary: &topodatapb.TabletAlias{
+					Cell: testCell,
+					Uid:  1,
+				},
+			},
+			th: &TabletHealth{
+				Target: &querypb.Target{
+					TabletType: topodatapb.TabletType_PRIMARY,
+				},
+				Serving:              false,
+				PrimaryTermStartTime: 10,
+				Tablet: &topodatapb.Tablet{
+					Alias: &topodatapb.TabletAlias{
+						Cell: testCell,
+						Uid:  1,
+					},
+				},
+			},
+			wantServing:              false,
+			wantWaitForReparent:      false,
+			wantExternallyReparented: 10,
+			wantUID:                  1,
+		},
+	}
+
+	ksName := "ks"
+	shard := "-80"
+	kss := &keyspaceState{
+		mu:       sync.Mutex{},
+		keyspace: ksName,
+		shards:   make(map[string]*shardState),
+	}
+	// Adding this so that we don't run any topo calls from ensureConsistentLocked.
+	kss.moveTablesState = &MoveTablesState{
+		Typ:   MoveTablesRegular,
+		State: MoveTablesSwitching,
+	}
+	for _, tt := range testcases {
+		t.Run(tt.name, func(t *testing.T) {
+			kss.shards[shard] = tt.ss
+			tt.th.Target.Keyspace = ksName
+			tt.th.Target.Shard = shard
+			kss.onHealthCheck(tt.th)
+			require.Equal(t, tt.wantServing, tt.ss.serving)
+			require.Equal(t, tt.wantWaitForReparent, tt.ss.waitForReparent)
+			require.Equal(t, tt.wantExternallyReparented, tt.ss.externallyReparented)
+			require.Equal(t, tt.wantUID, tt.ss.currentPrimary.Uid)
+		})
+	}
+}
+
 type fakeTopoServer struct {
 }
 

From b0e2f6435aa818244047e08dfd95d5de4a29530a Mon Sep 17 00:00:00 2001
From: Manan Gupta <manan@planetscale.com>
Date: Fri, 25 Oct 2024 17:00:53 +0530
Subject: [PATCH 11/12] test: add more unit tests

Signed-off-by: Manan Gupta <manan@planetscale.com>
Signed-off-by: Arthur Schreiber <arthurschreiber@github.com>
---
 go/vt/vtgate/buffer/buffer_test.go | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/go/vt/vtgate/buffer/buffer_test.go b/go/vt/vtgate/buffer/buffer_test.go
index 2973ef2dfb9..0ce05ab1f38 100644
--- a/go/vt/vtgate/buffer/buffer_test.go
+++ b/go/vt/vtgate/buffer/buffer_test.go
@@ -23,6 +23,8 @@ import (
 	"testing"
 	"time"
 
+	"github.com/stretchr/testify/assert"
+
 	"vitess.io/vitess/go/vt/topo/topoproto"
 	"vitess.io/vitess/go/vt/vterrors"
 
@@ -68,6 +70,32 @@ var (
 	}
 )
 
+func TestIsErrorDueToReparenting(t *testing.T) {
+	testcases := []struct {
+		err  error
+		want bool
+	}{
+		{
+			err:  vterrors.Errorf(vtrpcpb.Code_CLUSTER_EVENT, ClusterEventReshardingInProgress),
+			want: false,
+		},
+		{
+			err:  vterrors.Errorf(vtrpcpb.Code_CLUSTER_EVENT, ClusterEventReparentInProgress),
+			want: true,
+		},
+		{
+			err:  vterrors.Errorf(vtrpcpb.Code_CLUSTER_EVENT, "The MySQL server is running with the --super-read-only option"),
+			want: true,
+		},
+	}
+	for _, tt := range testcases {
+		t.Run(tt.err.Error(), func(t *testing.T) {
+			got := IsErrorDueToReparenting(tt.err)
+			assert.Equal(t, tt.want, got)
+		})
+	}
+}
+
 func TestBuffering(t *testing.T) {
 	testAllImplementations(t, testBuffering1)
 }

From fd8c3cbcc2191340aea5320085d522047978b8da Mon Sep 17 00:00:00 2001
From: Arthur Schreiber <arthurschreiber@github.com>
Date: Fri, 8 Nov 2024 01:55:57 +0000
Subject: [PATCH 12/12] Appease the linter.

Signed-off-by: Arthur Schreiber <arthurschreiber@github.com>
---
 go/vt/discovery/keyspace_events.go | 1 +
 1 file changed, 1 insertion(+)

diff --git a/go/vt/discovery/keyspace_events.go b/go/vt/discovery/keyspace_events.go
index 4a82b173e9b..24bb47be585 100644
--- a/go/vt/discovery/keyspace_events.go
+++ b/go/vt/discovery/keyspace_events.go
@@ -474,6 +474,7 @@ func (kss *keyspaceState) getMoveTablesStatus(vs *vschemapb.SrvVSchema) (*MoveTa
 	mu := sync.Mutex{}
 	eg, ectx := errgroup.WithContext(shortCtx)
 	for _, sstate := range kss.shards {
+		sstate := sstate
 		eg.Go(func() error {
 			si, err := ts.GetShard(ectx, kss.keyspace, sstate.target.Shard)
 			if err != nil {