diff --git a/go/vt/vttablet/tabletserver/livequeryz_test.go b/go/vt/vttablet/tabletserver/livequeryz_test.go index e507f365afb..18ce01c2273 100644 --- a/go/vt/vttablet/tabletserver/livequeryz_test.go +++ b/go/vt/vttablet/tabletserver/livequeryz_test.go @@ -30,8 +30,8 @@ func TestLiveQueryzHandlerJSON(t *testing.T) { req, _ := http.NewRequest("GET", "/livequeryz/?format=json", nil) queryList := NewQueryList("test", sqlparser.NewTestParser()) - queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1})) - queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2})) + _ = queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1})) + _ = queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2})) livequeryzHandler([]*QueryList{queryList}, resp, req) } @@ -41,8 +41,8 @@ func TestLiveQueryzHandlerHTTP(t *testing.T) { req, _ := http.NewRequest("GET", "/livequeryz/", nil) queryList := NewQueryList("test", sqlparser.NewTestParser()) - queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1})) - queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2})) + _ = queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1})) + _ = queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2})) livequeryzHandler([]*QueryList{queryList}, resp, req) } @@ -64,7 +64,7 @@ func TestLiveQueryzHandlerTerminateConn(t *testing.T) { queryList := NewQueryList("test", sqlparser.NewTestParser()) testConn := &testConn{id: 1} - queryList.Add(NewQueryDetail(context.Background(), testConn)) + _ = queryList.Add(NewQueryDetail(context.Background(), testConn)) if testConn.IsKilled() { t.Fatalf("conn should still be alive") } diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index f371d62006c..d1fbc96123f 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -1085,7 +1085,10 @@ func (qre *QueryExecutor) execDBConn(conn *connpool.Conn, sql string, wantfields defer qre.logStats.AddRewrittenSQL(sql, time.Now()) qd := NewQueryDetail(qre.logStats.Ctx, conn) - qre.tsv.statelessql.Add(qd) + err := qre.tsv.statelessql.Add(qd) + if err != nil { + return nil, err + } defer qre.tsv.statelessql.Remove(qd) return conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields) @@ -1098,7 +1101,10 @@ func (qre *QueryExecutor) execStatefulConn(conn *StatefulConnection, sql string, defer qre.logStats.AddRewrittenSQL(sql, time.Now()) qd := NewQueryDetail(qre.logStats.Ctx, conn) - qre.tsv.statefulql.Add(qd) + err := qre.tsv.statefulql.Add(qd) + if err != nil { + return nil, err + } defer qre.tsv.statefulql.Remove(qd) return conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields) @@ -1122,11 +1128,17 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction // once their grace period is over. qd := NewQueryDetail(qre.logStats.Ctx, conn.Conn) if isTransaction { - qre.tsv.statefulql.Add(qd) + err := qre.tsv.statefulql.Add(qd) + if err != nil { + return err + } defer qre.tsv.statefulql.Remove(qd) return conn.Conn.StreamOnce(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) } - qre.tsv.olapql.Add(qd) + err := qre.tsv.olapql.Add(qd) + if err != nil { + return err + } defer qre.tsv.olapql.Remove(qd) return conn.Conn.Stream(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) } diff --git a/go/vt/vttablet/tabletserver/query_list.go b/go/vt/vttablet/tabletserver/query_list.go index a41f23b6aa0..3ccf13418a7 100644 --- a/go/vt/vttablet/tabletserver/query_list.go +++ b/go/vt/vttablet/tabletserver/query_list.go @@ -26,7 +26,9 @@ import ( "vitess.io/vitess/go/streamlog" "vitess.io/vitess/go/vt/callinfo" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" ) // QueryDetail is a simple wrapper for Query, Context and a killable conn. @@ -58,28 +60,41 @@ type QueryList struct { // and remove appropriately. queryDetails map[int64][]*QueryDetail - parser *sqlparser.Parser + parser *sqlparser.Parser + clusterActionInProgress bool } // NewQueryList creates a new QueryList func NewQueryList(name string, parser *sqlparser.Parser) *QueryList { return &QueryList{ - name: name, - queryDetails: make(map[int64][]*QueryDetail), - parser: parser, + name: name, + queryDetails: make(map[int64][]*QueryDetail), + parser: parser, + clusterActionInProgress: false, } } +// SetClusterAction sets the clusterActionInProgress field. +func (ql *QueryList) SetClusterAction(inProgress bool) { + ql.mu.Lock() + defer ql.mu.Unlock() + ql.clusterActionInProgress = inProgress +} + // Add adds a QueryDetail to QueryList -func (ql *QueryList) Add(qd *QueryDetail) { +func (ql *QueryList) Add(qd *QueryDetail) error { ql.mu.Lock() defer ql.mu.Unlock() + if ql.clusterActionInProgress { + return vterrors.New(vtrpcpb.Code_CLUSTER_EVENT, vterrors.ShuttingDown) + } qds, exists := ql.queryDetails[qd.connID] if exists { ql.queryDetails[qd.connID] = append(qds, qd) } else { ql.queryDetails[qd.connID] = []*QueryDetail{qd} } + return nil } // Remove removes a QueryDetail from QueryList diff --git a/go/vt/vttablet/tabletserver/query_list_test.go b/go/vt/vttablet/tabletserver/query_list_test.go index 57b672a16e0..bb1b47db4af 100644 --- a/go/vt/vttablet/tabletserver/query_list_test.go +++ b/go/vt/vttablet/tabletserver/query_list_test.go @@ -49,7 +49,8 @@ func TestQueryList(t *testing.T) { ql := NewQueryList("test", sqlparser.NewTestParser()) connID := int64(1) qd := NewQueryDetail(context.Background(), &testConn{id: connID}) - ql.Add(qd) + err := ql.Add(qd) + require.NoError(t, err) if qd1, ok := ql.queryDetails[connID]; !ok || qd1[0].connID != connID { t.Errorf("failed to add to QueryList") @@ -57,7 +58,8 @@ func TestQueryList(t *testing.T) { conn2ID := int64(2) qd2 := NewQueryDetail(context.Background(), &testConn{id: conn2ID}) - ql.Add(qd2) + err = ql.Add(qd2) + require.NoError(t, err) rows := ql.AppendQueryzRows(nil) if len(rows) != 2 || rows[0].ConnID != 1 || rows[1].ConnID != 2 { @@ -74,11 +76,13 @@ func TestQueryListChangeConnIDInMiddle(t *testing.T) { ql := NewQueryList("test", sqlparser.NewTestParser()) connID := int64(1) qd1 := NewQueryDetail(context.Background(), &testConn{id: connID}) - ql.Add(qd1) + err := ql.Add(qd1) + require.NoError(t, err) conn := &testConn{id: connID} qd2 := NewQueryDetail(context.Background(), conn) - ql.Add(qd2) + err = ql.Add(qd2) + require.NoError(t, err) require.Len(t, ql.queryDetails[1], 2) @@ -92,3 +96,17 @@ func TestQueryListChangeConnIDInMiddle(t *testing.T) { require.Equal(t, qd1, ql.queryDetails[1][0]) require.NotEqual(t, qd2, ql.queryDetails[1][0]) } + +func TestClusterAction(t *testing.T) { + ql := NewQueryList("test", sqlparser.NewTestParser()) + connID := int64(1) + qd1 := NewQueryDetail(context.Background(), &testConn{id: connID}) + + ql.SetClusterAction(true) + err := ql.Add(qd1) + require.ErrorContains(t, err, "operation not allowed in state SHUTTING_DOWN") + + ql.SetClusterAction(false) + err = ql.Add(qd1) + require.NoError(t, err) +} diff --git a/go/vt/vttablet/tabletserver/state_manager.go b/go/vt/vttablet/tabletserver/state_manager.go index 60b1f1281d0..af2da48f75d 100644 --- a/go/vt/vttablet/tabletserver/state_manager.go +++ b/go/vt/vttablet/tabletserver/state_manager.go @@ -542,6 +542,8 @@ func (sm *stateManager) connect(tabletType topodatapb.TabletType) error { } func (sm *stateManager) unserveCommon() { + sm.markClusterAction(true) + defer sm.markClusterAction(false) // We create a wait group that tracks whether all the queries have been terminated or not. wg := sync.WaitGroup{} wg.Add(1) @@ -850,3 +852,10 @@ func (sm *stateManager) IsServingString() string { func (sm *stateManager) SetUnhealthyThreshold(v time.Duration) { sm.unhealthyThreshold.Store(v.Nanoseconds()) } + +// markClusterAction marks whether a cluster action is in progress or not for all the query details. +func (sm *stateManager) markClusterAction(inProgress bool) { + sm.statefulql.SetClusterAction(inProgress) + sm.statelessql.SetClusterAction(inProgress) + sm.olapql.SetClusterAction(inProgress) +} diff --git a/go/vt/vttablet/tabletserver/state_manager_test.go b/go/vt/vttablet/tabletserver/state_manager_test.go index a0ef3557074..f6345b9b29c 100644 --- a/go/vt/vttablet/tabletserver/state_manager_test.go +++ b/go/vt/vttablet/tabletserver/state_manager_test.go @@ -409,18 +409,20 @@ func TestStateManagerShutdownGracePeriod(t *testing.T) { sm.te = &delayedTxEngine{} kconn1 := &killableConn{id: 1} - sm.statelessql.Add(&QueryDetail{ + err := sm.statelessql.Add(&QueryDetail{ conn: kconn1, connID: kconn1.id, }) + require.NoError(t, err) kconn2 := &killableConn{id: 2} - sm.statefulql.Add(&QueryDetail{ + err = sm.statefulql.Add(&QueryDetail{ conn: kconn2, connID: kconn2.id, }) + require.NoError(t, err) // Transition to replica with no shutdown grace period should kill kconn2 but not kconn1. - err := sm.SetServingType(topodatapb.TabletType_PRIMARY, testNow, StateServing, "") + err = sm.SetServingType(topodatapb.TabletType_PRIMARY, testNow, StateServing, "") require.NoError(t, err) assert.False(t, kconn1.killed.Load()) assert.True(t, kconn2.killed.Load())