Skip to content

Commit

Permalink
Prevent adding to query details after unserve common has started (#15684
Browse files Browse the repository at this point in the history
)

Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 authored Apr 11, 2024
1 parent f118ba2 commit 9e40015
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 19 deletions.
2 changes: 0 additions & 2 deletions go/test/endtoend/vtgate/transaction/restart/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql"
Expand Down Expand Up @@ -113,5 +112,4 @@ func TestStreamTxRestart(t *testing.T) {
// query should return connection error
_, err = utils.ExecAllowError(t, conn, "select connection_id()")
require.Error(t, err)
assert.Contains(t, err.Error(), "broken pipe (errno 2006) (sqlstate HY000)")
}
17 changes: 12 additions & 5 deletions go/vt/vttablet/tabletserver/livequeryz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"net/http/httptest"
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/vt/sqlparser"
)

Expand All @@ -30,8 +32,10 @@ func TestLiveQueryzHandlerJSON(t *testing.T) {
req, _ := http.NewRequest("GET", "/livequeryz/?format=json", nil)

queryList := NewQueryList("test", sqlparser.NewTestParser())
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1}))
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2}))
err := queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1}))
require.NoError(t, err)
err = queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2}))
require.NoError(t, err)

livequeryzHandler([]*QueryList{queryList}, resp, req)
}
Expand All @@ -41,8 +45,10 @@ func TestLiveQueryzHandlerHTTP(t *testing.T) {
req, _ := http.NewRequest("GET", "/livequeryz/", nil)

queryList := NewQueryList("test", sqlparser.NewTestParser())
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1}))
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2}))
err := queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1}))
require.NoError(t, err)
err = queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2}))
require.NoError(t, err)

livequeryzHandler([]*QueryList{queryList}, resp, req)
}
Expand All @@ -64,7 +70,8 @@ func TestLiveQueryzHandlerTerminateConn(t *testing.T) {

queryList := NewQueryList("test", sqlparser.NewTestParser())
testConn := &testConn{id: 1}
queryList.Add(NewQueryDetail(context.Background(), testConn))
err := queryList.Add(NewQueryDetail(context.Background(), testConn))
require.NoError(t, err)
if testConn.IsKilled() {
t.Fatalf("conn should still be alive")
}
Expand Down
20 changes: 16 additions & 4 deletions go/vt/vttablet/tabletserver/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,10 @@ func (qre *QueryExecutor) execDBConn(conn *connpool.Conn, sql string, wantfields
defer qre.logStats.AddRewrittenSQL(sql, time.Now())

qd := NewQueryDetail(qre.logStats.Ctx, conn)
qre.tsv.statelessql.Add(qd)
err := qre.tsv.statelessql.Add(qd)
if err != nil {
return nil, err
}
defer qre.tsv.statelessql.Remove(qd)

return conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields)
Expand All @@ -1098,7 +1101,10 @@ func (qre *QueryExecutor) execStatefulConn(conn *StatefulConnection, sql string,
defer qre.logStats.AddRewrittenSQL(sql, time.Now())

qd := NewQueryDetail(qre.logStats.Ctx, conn)
qre.tsv.statefulql.Add(qd)
err := qre.tsv.statefulql.Add(qd)
if err != nil {
return nil, err
}
defer qre.tsv.statefulql.Remove(qd)

return conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields)
Expand All @@ -1122,11 +1128,17 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction
// once their grace period is over.
qd := NewQueryDetail(qre.logStats.Ctx, conn.Conn)
if isTransaction {
qre.tsv.statefulql.Add(qd)
err := qre.tsv.statefulql.Add(qd)
if err != nil {
return err
}
defer qre.tsv.statefulql.Remove(qd)
return conn.Conn.StreamOnce(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options))
}
qre.tsv.olapql.Add(qd)
err := qre.tsv.olapql.Add(qd)
if err != nil {
return err
}
defer qre.tsv.olapql.Remove(qd)
return conn.Conn.Stream(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options))
}
Expand Down
29 changes: 28 additions & 1 deletion go/vt/vttablet/tabletserver/query_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (

"vitess.io/vitess/go/streamlog"
"vitess.io/vitess/go/vt/callinfo"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
)

// QueryDetail is a simple wrapper for Query, Context and a killable conn.
Expand Down Expand Up @@ -59,27 +61,52 @@ type QueryList struct {
queryDetails map[int64][]*QueryDetail

parser *sqlparser.Parser
ca ClusterActionState
}

type ClusterActionState int

const (
ClusterActionNotInProgress ClusterActionState = iota
ClusterActionInProgress ClusterActionState = iota
ClusterActionNoQueries ClusterActionState = iota
)

// NewQueryList creates a new QueryList
func NewQueryList(name string, parser *sqlparser.Parser) *QueryList {
return &QueryList{
name: name,
queryDetails: make(map[int64][]*QueryDetail),
parser: parser,
ca: ClusterActionNotInProgress,
}
}

// SetClusterAction sets the clusterActionInProgress field.
func (ql *QueryList) SetClusterAction(ca ClusterActionState) {
ql.mu.Lock()
defer ql.mu.Unlock()
// If the current state is ClusterActionNotInProgress, then we want to ignore setting ClusterActionNoQueries.
if ca == ClusterActionNoQueries && ql.ca == ClusterActionNotInProgress {
return
}
ql.ca = ca
}

// Add adds a QueryDetail to QueryList
func (ql *QueryList) Add(qd *QueryDetail) {
func (ql *QueryList) Add(qd *QueryDetail) error {
ql.mu.Lock()
defer ql.mu.Unlock()
if ql.ca == ClusterActionNoQueries {
return vterrors.New(vtrpcpb.Code_CLUSTER_EVENT, vterrors.ShuttingDown)
}
qds, exists := ql.queryDetails[qd.connID]
if exists {
ql.queryDetails[qd.connID] = append(qds, qd)
} else {
ql.queryDetails[qd.connID] = []*QueryDetail{qd}
}
return nil
}

// Remove removes a QueryDetail from QueryList
Expand Down
31 changes: 27 additions & 4 deletions go/vt/vttablet/tabletserver/query_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,17 @@ func TestQueryList(t *testing.T) {
ql := NewQueryList("test", sqlparser.NewTestParser())
connID := int64(1)
qd := NewQueryDetail(context.Background(), &testConn{id: connID})
ql.Add(qd)
err := ql.Add(qd)
require.NoError(t, err)

if qd1, ok := ql.queryDetails[connID]; !ok || qd1[0].connID != connID {
t.Errorf("failed to add to QueryList")
}

conn2ID := int64(2)
qd2 := NewQueryDetail(context.Background(), &testConn{id: conn2ID})
ql.Add(qd2)
err = ql.Add(qd2)
require.NoError(t, err)

rows := ql.AppendQueryzRows(nil)
if len(rows) != 2 || rows[0].ConnID != 1 || rows[1].ConnID != 2 {
Expand All @@ -74,11 +76,13 @@ func TestQueryListChangeConnIDInMiddle(t *testing.T) {
ql := NewQueryList("test", sqlparser.NewTestParser())
connID := int64(1)
qd1 := NewQueryDetail(context.Background(), &testConn{id: connID})
ql.Add(qd1)
err := ql.Add(qd1)
require.NoError(t, err)

conn := &testConn{id: connID}
qd2 := NewQueryDetail(context.Background(), conn)
ql.Add(qd2)
err = ql.Add(qd2)
require.NoError(t, err)

require.Len(t, ql.queryDetails[1], 2)

Expand All @@ -92,3 +96,22 @@ func TestQueryListChangeConnIDInMiddle(t *testing.T) {
require.Equal(t, qd1, ql.queryDetails[1][0])
require.NotEqual(t, qd2, ql.queryDetails[1][0])
}

func TestClusterAction(t *testing.T) {
ql := NewQueryList("test", sqlparser.NewTestParser())
connID := int64(1)
qd1 := NewQueryDetail(context.Background(), &testConn{id: connID})

ql.SetClusterAction(ClusterActionInProgress)
ql.SetClusterAction(ClusterActionNoQueries)
err := ql.Add(qd1)
require.ErrorContains(t, err, "operation not allowed in state SHUTTING_DOWN")

ql.SetClusterAction(ClusterActionNotInProgress)
err = ql.Add(qd1)
require.NoError(t, err)
// If the current state is not in progress, then setting no queries, shouldn't change anything.
ql.SetClusterAction(ClusterActionNoQueries)
err = ql.Add(qd1)
require.NoError(t, err)
}
11 changes: 11 additions & 0 deletions go/vt/vttablet/tabletserver/state_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ func (sm *stateManager) connect(tabletType topodatapb.TabletType) error {
}

func (sm *stateManager) unserveCommon() {
sm.markClusterAction(ClusterActionInProgress)
defer sm.markClusterAction(ClusterActionNotInProgress)
// We create a wait group that tracks whether all the queries have been terminated or not.
wg := sync.WaitGroup{}
wg.Add(1)
Expand Down Expand Up @@ -601,6 +603,8 @@ func (sm *stateManager) terminateAllQueries(wg *sync.WaitGroup) (cancel func())
if err := timer.SleepContext(ctx, sm.shutdownGracePeriod); err != nil {
return
}
// Prevent any new queries from being added before we kill all the queries in the list.
sm.markClusterAction(ClusterActionNoQueries)
log.Infof("Grace Period %v exceeded. Killing all OLTP queries.", sm.shutdownGracePeriod)
sm.statelessql.TerminateAll()
log.Infof("Killed all stateless OLTP queries.")
Expand Down Expand Up @@ -850,3 +854,10 @@ func (sm *stateManager) IsServingString() string {
func (sm *stateManager) SetUnhealthyThreshold(v time.Duration) {
sm.unhealthyThreshold.Store(v.Nanoseconds())
}

// markClusterAction marks whether a cluster action is in progress or not for all the query details.
func (sm *stateManager) markClusterAction(ca ClusterActionState) {
sm.statefulql.SetClusterAction(ca)
sm.statelessql.SetClusterAction(ca)
sm.olapql.SetClusterAction(ca)
}
8 changes: 5 additions & 3 deletions go/vt/vttablet/tabletserver/state_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,18 +409,20 @@ func TestStateManagerShutdownGracePeriod(t *testing.T) {

sm.te = &delayedTxEngine{}
kconn1 := &killableConn{id: 1}
sm.statelessql.Add(&QueryDetail{
err := sm.statelessql.Add(&QueryDetail{
conn: kconn1,
connID: kconn1.id,
})
require.NoError(t, err)
kconn2 := &killableConn{id: 2}
sm.statefulql.Add(&QueryDetail{
err = sm.statefulql.Add(&QueryDetail{
conn: kconn2,
connID: kconn2.id,
})
require.NoError(t, err)

// Transition to replica with no shutdown grace period should kill kconn2 but not kconn1.
err := sm.SetServingType(topodatapb.TabletType_PRIMARY, testNow, StateServing, "")
err = sm.SetServingType(topodatapb.TabletType_PRIMARY, testNow, StateServing, "")
require.NoError(t, err)
assert.False(t, kconn1.killed.Load())
assert.True(t, kconn2.killed.Load())
Expand Down

0 comments on commit 9e40015

Please sign in to comment.