diff --git a/go/test/endtoend/vtgate/transaction/restart/main_test.go b/go/test/endtoend/vtgate/transaction/restart/main_test.go index de52a3e8870..01185b5fa59 100644 --- a/go/test/endtoend/vtgate/transaction/restart/main_test.go +++ b/go/test/endtoend/vtgate/transaction/restart/main_test.go @@ -23,7 +23,6 @@ import ( "os" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" @@ -113,5 +112,4 @@ func TestStreamTxRestart(t *testing.T) { // query should return connection error _, err = utils.ExecAllowError(t, conn, "select connection_id()") require.Error(t, err) - assert.Contains(t, err.Error(), "broken pipe (errno 2006) (sqlstate HY000)") } diff --git a/go/vt/vttablet/tabletserver/livequeryz_test.go b/go/vt/vttablet/tabletserver/livequeryz_test.go index e507f365afb..8dad3cd1631 100644 --- a/go/vt/vttablet/tabletserver/livequeryz_test.go +++ b/go/vt/vttablet/tabletserver/livequeryz_test.go @@ -22,6 +22,8 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/vt/sqlparser" ) @@ -30,8 +32,10 @@ func TestLiveQueryzHandlerJSON(t *testing.T) { req, _ := http.NewRequest("GET", "/livequeryz/?format=json", nil) queryList := NewQueryList("test", sqlparser.NewTestParser()) - queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1})) - queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2})) + err := queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1})) + require.NoError(t, err) + err = queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2})) + require.NoError(t, err) livequeryzHandler([]*QueryList{queryList}, resp, req) } @@ -41,8 +45,10 @@ func TestLiveQueryzHandlerHTTP(t *testing.T) { req, _ := http.NewRequest("GET", "/livequeryz/", nil) queryList := NewQueryList("test", sqlparser.NewTestParser()) - queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1})) - queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2})) + err := queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1})) + require.NoError(t, err) + err = queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2})) + require.NoError(t, err) livequeryzHandler([]*QueryList{queryList}, resp, req) } @@ -64,7 +70,8 @@ func TestLiveQueryzHandlerTerminateConn(t *testing.T) { queryList := NewQueryList("test", sqlparser.NewTestParser()) testConn := &testConn{id: 1} - queryList.Add(NewQueryDetail(context.Background(), testConn)) + err := queryList.Add(NewQueryDetail(context.Background(), testConn)) + require.NoError(t, err) if testConn.IsKilled() { t.Fatalf("conn should still be alive") } diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index f371d62006c..d1fbc96123f 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -1085,7 +1085,10 @@ func (qre *QueryExecutor) execDBConn(conn *connpool.Conn, sql string, wantfields defer qre.logStats.AddRewrittenSQL(sql, time.Now()) qd := NewQueryDetail(qre.logStats.Ctx, conn) - qre.tsv.statelessql.Add(qd) + err := qre.tsv.statelessql.Add(qd) + if err != nil { + return nil, err + } defer qre.tsv.statelessql.Remove(qd) return conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields) @@ -1098,7 +1101,10 @@ func (qre *QueryExecutor) execStatefulConn(conn *StatefulConnection, sql string, defer qre.logStats.AddRewrittenSQL(sql, time.Now()) qd := NewQueryDetail(qre.logStats.Ctx, conn) - qre.tsv.statefulql.Add(qd) + err := qre.tsv.statefulql.Add(qd) + if err != nil { + return nil, err + } defer qre.tsv.statefulql.Remove(qd) return conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields) @@ -1122,11 +1128,17 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction // once their grace period is over. qd := NewQueryDetail(qre.logStats.Ctx, conn.Conn) if isTransaction { - qre.tsv.statefulql.Add(qd) + err := qre.tsv.statefulql.Add(qd) + if err != nil { + return err + } defer qre.tsv.statefulql.Remove(qd) return conn.Conn.StreamOnce(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) } - qre.tsv.olapql.Add(qd) + err := qre.tsv.olapql.Add(qd) + if err != nil { + return err + } defer qre.tsv.olapql.Remove(qd) return conn.Conn.Stream(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) } diff --git a/go/vt/vttablet/tabletserver/query_list.go b/go/vt/vttablet/tabletserver/query_list.go index a41f23b6aa0..a21acd6f92a 100644 --- a/go/vt/vttablet/tabletserver/query_list.go +++ b/go/vt/vttablet/tabletserver/query_list.go @@ -26,7 +26,9 @@ import ( "vitess.io/vitess/go/streamlog" "vitess.io/vitess/go/vt/callinfo" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" ) // QueryDetail is a simple wrapper for Query, Context and a killable conn. @@ -59,27 +61,52 @@ type QueryList struct { queryDetails map[int64][]*QueryDetail parser *sqlparser.Parser + ca ClusterActionState } +type ClusterActionState int + +const ( + ClusterActionNotInProgress ClusterActionState = iota + ClusterActionInProgress ClusterActionState = iota + ClusterActionNoQueries ClusterActionState = iota +) + // NewQueryList creates a new QueryList func NewQueryList(name string, parser *sqlparser.Parser) *QueryList { return &QueryList{ name: name, queryDetails: make(map[int64][]*QueryDetail), parser: parser, + ca: ClusterActionNotInProgress, + } +} + +// SetClusterAction sets the clusterActionInProgress field. +func (ql *QueryList) SetClusterAction(ca ClusterActionState) { + ql.mu.Lock() + defer ql.mu.Unlock() + // If the current state is ClusterActionNotInProgress, then we want to ignore setting ClusterActionNoQueries. + if ca == ClusterActionNoQueries && ql.ca == ClusterActionNotInProgress { + return } + ql.ca = ca } // Add adds a QueryDetail to QueryList -func (ql *QueryList) Add(qd *QueryDetail) { +func (ql *QueryList) Add(qd *QueryDetail) error { ql.mu.Lock() defer ql.mu.Unlock() + if ql.ca == ClusterActionNoQueries { + return vterrors.New(vtrpcpb.Code_CLUSTER_EVENT, vterrors.ShuttingDown) + } qds, exists := ql.queryDetails[qd.connID] if exists { ql.queryDetails[qd.connID] = append(qds, qd) } else { ql.queryDetails[qd.connID] = []*QueryDetail{qd} } + return nil } // Remove removes a QueryDetail from QueryList diff --git a/go/vt/vttablet/tabletserver/query_list_test.go b/go/vt/vttablet/tabletserver/query_list_test.go index 57b672a16e0..1e9dc2bf42c 100644 --- a/go/vt/vttablet/tabletserver/query_list_test.go +++ b/go/vt/vttablet/tabletserver/query_list_test.go @@ -49,7 +49,8 @@ func TestQueryList(t *testing.T) { ql := NewQueryList("test", sqlparser.NewTestParser()) connID := int64(1) qd := NewQueryDetail(context.Background(), &testConn{id: connID}) - ql.Add(qd) + err := ql.Add(qd) + require.NoError(t, err) if qd1, ok := ql.queryDetails[connID]; !ok || qd1[0].connID != connID { t.Errorf("failed to add to QueryList") @@ -57,7 +58,8 @@ func TestQueryList(t *testing.T) { conn2ID := int64(2) qd2 := NewQueryDetail(context.Background(), &testConn{id: conn2ID}) - ql.Add(qd2) + err = ql.Add(qd2) + require.NoError(t, err) rows := ql.AppendQueryzRows(nil) if len(rows) != 2 || rows[0].ConnID != 1 || rows[1].ConnID != 2 { @@ -74,11 +76,13 @@ func TestQueryListChangeConnIDInMiddle(t *testing.T) { ql := NewQueryList("test", sqlparser.NewTestParser()) connID := int64(1) qd1 := NewQueryDetail(context.Background(), &testConn{id: connID}) - ql.Add(qd1) + err := ql.Add(qd1) + require.NoError(t, err) conn := &testConn{id: connID} qd2 := NewQueryDetail(context.Background(), conn) - ql.Add(qd2) + err = ql.Add(qd2) + require.NoError(t, err) require.Len(t, ql.queryDetails[1], 2) @@ -92,3 +96,22 @@ func TestQueryListChangeConnIDInMiddle(t *testing.T) { require.Equal(t, qd1, ql.queryDetails[1][0]) require.NotEqual(t, qd2, ql.queryDetails[1][0]) } + +func TestClusterAction(t *testing.T) { + ql := NewQueryList("test", sqlparser.NewTestParser()) + connID := int64(1) + qd1 := NewQueryDetail(context.Background(), &testConn{id: connID}) + + ql.SetClusterAction(ClusterActionInProgress) + ql.SetClusterAction(ClusterActionNoQueries) + err := ql.Add(qd1) + require.ErrorContains(t, err, "operation not allowed in state SHUTTING_DOWN") + + ql.SetClusterAction(ClusterActionNotInProgress) + err = ql.Add(qd1) + require.NoError(t, err) + // If the current state is not in progress, then setting no queries, shouldn't change anything. + ql.SetClusterAction(ClusterActionNoQueries) + err = ql.Add(qd1) + require.NoError(t, err) +} diff --git a/go/vt/vttablet/tabletserver/state_manager.go b/go/vt/vttablet/tabletserver/state_manager.go index 60b1f1281d0..308f9165ba6 100644 --- a/go/vt/vttablet/tabletserver/state_manager.go +++ b/go/vt/vttablet/tabletserver/state_manager.go @@ -542,6 +542,8 @@ func (sm *stateManager) connect(tabletType topodatapb.TabletType) error { } func (sm *stateManager) unserveCommon() { + sm.markClusterAction(ClusterActionInProgress) + defer sm.markClusterAction(ClusterActionNotInProgress) // We create a wait group that tracks whether all the queries have been terminated or not. wg := sync.WaitGroup{} wg.Add(1) @@ -601,6 +603,8 @@ func (sm *stateManager) terminateAllQueries(wg *sync.WaitGroup) (cancel func()) if err := timer.SleepContext(ctx, sm.shutdownGracePeriod); err != nil { return } + // Prevent any new queries from being added before we kill all the queries in the list. + sm.markClusterAction(ClusterActionNoQueries) log.Infof("Grace Period %v exceeded. Killing all OLTP queries.", sm.shutdownGracePeriod) sm.statelessql.TerminateAll() log.Infof("Killed all stateless OLTP queries.") @@ -850,3 +854,10 @@ func (sm *stateManager) IsServingString() string { func (sm *stateManager) SetUnhealthyThreshold(v time.Duration) { sm.unhealthyThreshold.Store(v.Nanoseconds()) } + +// markClusterAction marks whether a cluster action is in progress or not for all the query details. +func (sm *stateManager) markClusterAction(ca ClusterActionState) { + sm.statefulql.SetClusterAction(ca) + sm.statelessql.SetClusterAction(ca) + sm.olapql.SetClusterAction(ca) +} diff --git a/go/vt/vttablet/tabletserver/state_manager_test.go b/go/vt/vttablet/tabletserver/state_manager_test.go index a0ef3557074..f6345b9b29c 100644 --- a/go/vt/vttablet/tabletserver/state_manager_test.go +++ b/go/vt/vttablet/tabletserver/state_manager_test.go @@ -409,18 +409,20 @@ func TestStateManagerShutdownGracePeriod(t *testing.T) { sm.te = &delayedTxEngine{} kconn1 := &killableConn{id: 1} - sm.statelessql.Add(&QueryDetail{ + err := sm.statelessql.Add(&QueryDetail{ conn: kconn1, connID: kconn1.id, }) + require.NoError(t, err) kconn2 := &killableConn{id: 2} - sm.statefulql.Add(&QueryDetail{ + err = sm.statefulql.Add(&QueryDetail{ conn: kconn2, connID: kconn2.id, }) + require.NoError(t, err) // Transition to replica with no shutdown grace period should kill kconn2 but not kconn1. - err := sm.SetServingType(topodatapb.TabletType_PRIMARY, testNow, StateServing, "") + err = sm.SetServingType(topodatapb.TabletType_PRIMARY, testNow, StateServing, "") require.NoError(t, err) assert.False(t, kconn1.killed.Load()) assert.True(t, kconn2.killed.Load())