Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PRS from being blocked because of misbehaving clients #15339

Merged
merged 13 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions go/test/endtoend/tabletgateway/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package healthcheck

import (
"flag"
"fmt"
"os"
"testing"

Expand All @@ -26,11 +27,12 @@ import (
)

var (
clusterInstance *cluster.LocalProcessCluster
vtParams mysql.ConnParams
keyspaceName = "commerce"
cell = "zone1"
sqlSchema = `create table product(
clusterInstance *cluster.LocalProcessCluster
vtParams mysql.ConnParams
keyspaceName = "commerce"
vtgateGrpcAddress string
cell = "zone1"
sqlSchema = `create table product(
sku varbinary(128),
description varbinary(128),
price bigint,
Expand Down Expand Up @@ -64,7 +66,7 @@ func TestMain(m *testing.M) {

exitCode := func() int {
clusterInstance = cluster.NewCluster(cell, "localhost")
clusterInstance.VtTabletExtraArgs = []string{"--health_check_interval", "1s"}
clusterInstance.VtTabletExtraArgs = []string{"--health_check_interval", "1s", "--shutdown_grace_period", "3s"}
defer clusterInstance.Teardown()

// Start topo server
Expand Down Expand Up @@ -96,6 +98,7 @@ func TestMain(m *testing.M) {
Host: clusterInstance.Hostname,
Port: clusterInstance.VtgateMySQLPort,
}
vtgateGrpcAddress = fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateGrpcPort)
return m.Run()
}()
os.Exit(exitCode)
Expand Down
44 changes: 40 additions & 4 deletions go/test/endtoend/tabletgateway/vtgate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ import (
"testing"
"time"

"vitess.io/vitess/go/test/endtoend/utils"
"vitess.io/vitess/go/vt/proto/topodata"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/topodata"
)

func TestVtgateHealthCheck(t *testing.T) {
Expand Down Expand Up @@ -247,6 +246,43 @@ func TestReplicaTransactions(t *testing.T) {
assert.Equal(t, `[[INT64(1) VARCHAR("email1")] [INT64(2) VARCHAR("email2")]]`, fmt.Sprintf("%v", qr4.Rows), "we are not able to reconnect after restart")
}

// TestStreamingRPCStuck tests that StreamExecute calls don't get stuck on the vttablets if a client stop reading from a stream.
func TestStreamingRPCStuck(t *testing.T) {
defer cluster.PanicHandler(t)
ctx := context.Background()
vtConn, err := mysql.Connect(ctx, &vtParams)
require.NoError(t, err)
defer vtConn.Close()

// We want the table to have enough rows such that a streaming call returns multiple packets.
// Therefore, we insert one row and keep doubling it.
utils.Exec(t, vtConn, "insert into customer(email) values('testemail')")
for i := 0; i < 15; i++ {
// Double the number of rows in customer table.
utils.Exec(t, vtConn, "insert into customer (email) select email from customer")
}

// Connect to vtgate and run a streaming query.
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "test_user", "")
require.NoError(t, err)
stream, err := vtgateConn.Session("", &querypb.ExecuteOptions{}).StreamExecute(ctx, "select * from customer", map[string]*querypb.BindVariable{})
require.NoError(t, err)

// We read packets until we see the first set of results. This ensures that the stream is working.
for {
res, err := stream.Recv()
require.NoError(t, err)
if res != nil && len(res.Rows) > 0 {
break
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
}
}

// We simulate a misbehaving client that doesn't read from the stream anymore.
// This however shouldn't block PlannedReparentShard calls.
err = clusterInstance.VtctldClientProcess.PlannedReparentShard(keyspaceName, "0", clusterInstance.Keyspaces[0].Shards[0].Vttablets[1].Alias)
require.NoError(t, err)
}

func getMapFromJSON(JSON map[string]any, key string) map[string]any {
result := make(map[string]any)
object := reflect.ValueOf(JSON[key])
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction

// Add query detail object into QueryExecutor TableServer list w.r.t if it is a transactional or not. Previously we were adding it
// to olapql list regardless but that resulted in problems, where long-running stream queries which can be stateful (or transactional)
// weren't getting cleaned up during unserveCommon>handleShutdownGracePeriod in state_manager.go.
// weren't getting cleaned up during unserveCommon>terminateAllQueries in state_manager.go.
// This change will ensure that long-running streaming stateful queries get gracefully shutdown during ServingTypeChange
// once their grace period is over.
qd := NewQueryDetail(qre.logStats.Ctx, conn.Conn)
Expand Down
83 changes: 83 additions & 0 deletions go/vt/vttablet/tabletserver/requests_waiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
Copyright 2024 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package tabletserver

import "sync"

// requestsWaiter is used to wait for requests. It stores the count of the requests pending,
// and also the number of waiters currently waiting. It has a mutex as well to protects its fields.
type requestsWaiter struct {
mu sync.Mutex
wg sync.WaitGroup
// waitCounter is the number of goroutines that are waiting for wg to be empty.
// If this value is greater than zero, then we have to ensure that we don't Add to the requests
// to avoid any panics in the wait.
waitCounter int
// counter is the count of the number of outstanding requests.
counter int
}

// newRequestsWaiter creates a new requestsWaiter.
func newRequestsWaiter() *requestsWaiter {
return &requestsWaiter{
mu: sync.Mutex{},
wg: sync.WaitGroup{},
waitCounter: 0,
counter: 0,
}
}
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved

// Add adds to the requestsWaiter.
func (r *requestsWaiter) Add(val int) {
r.mu.Lock()
defer r.mu.Unlock()
r.counter += val
r.wg.Add(val)
}

// Done subtracts 1 from the requestsWaiter.
func (r *requestsWaiter) Done() {
r.Add(-1)
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
}

// addToWaitCounter adds to the waitCounter while being protected by a mutex.
func (r *requestsWaiter) addToWaitCounter(val int) {
r.mu.Lock()
defer r.mu.Unlock()
r.waitCounter += val
}

// WaitToBeEmpty waits for requests to be empty. It also increments and decrements the waitCounter as required.
func (r *requestsWaiter) WaitToBeEmpty() {
r.addToWaitCounter(1)
r.wg.Wait()
r.addToWaitCounter(-1)
}

// GetWaiterCount gets the number of go routines currently waiting on the wait group.
func (r *requestsWaiter) GetWaiterCount() int {
r.mu.Lock()
defer r.mu.Unlock()
return r.waitCounter
}

// GetOutstandingRequestsCount gets the number of requests outstanding.
func (r *requestsWaiter) GetOutstandingRequestsCount() int {
r.mu.Lock()
defer r.mu.Unlock()
return r.counter
}
72 changes: 40 additions & 32 deletions go/vt/vttablet/tabletserver/state_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,8 @@
alsoAllow []topodatapb.TabletType
reason string
transitionErr error
// requestsWaitCounter is the number of goroutines that are waiting for requests to be empty.
// If this value is greater than zero, then we have to ensure that we don't Add to the requests
// to avoid any panics in the wait.
requestsWaitCounter int

requests sync.WaitGroup
rw *requestsWaiter

// QueryList does not have an Open or Close.
statelessql *QueryList
Expand Down Expand Up @@ -358,20 +354,6 @@
}()
}

// addRequestsWaitCounter adds to the requestsWaitCounter while being protected by a mutex.
func (sm *stateManager) addRequestsWaitCounter(val int) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.requestsWaitCounter += val
}

// waitForRequestsToBeEmpty waits for requests to be empty. It also increments and decrements the requestsWaitCounter as required.
func (sm *stateManager) waitForRequestsToBeEmpty() {
sm.addRequestsWaitCounter(1)
sm.requests.Wait()
sm.addRequestsWaitCounter(-1)
}

func (sm *stateManager) setWantState(stateWanted servingState) {
sm.mu.Lock()
defer sm.mu.Unlock()
Expand Down Expand Up @@ -410,9 +392,9 @@
}

shuttingDown := sm.wantState != StateServing
// If requestsWaitCounter is not zero, then there are go-routines blocked on waiting for requests to be empty.
// If wait counter for the requests is not zero, then there are go-routines blocked on waiting for requests to be empty.
// We cannot allow adding to the requests to prevent any panics from happening.
if (shuttingDown && !allowOnShutdown) || sm.requestsWaitCounter > 0 {
if (shuttingDown && !allowOnShutdown) || sm.rw.GetWaiterCount() > 0 {
// This specific error string needs to be returned for vtgate buffering to work.
return vterrors.New(vtrpcpb.Code_CLUSTER_EVENT, vterrors.ShuttingDown)
}
Expand All @@ -421,13 +403,13 @@
if err != nil {
return err
}
sm.requests.Add(1)
sm.rw.Add(1)
return nil
}

// EndRequest unregisters the current request (a waitgroup) as done.
func (sm *stateManager) EndRequest() {
sm.requests.Done()
sm.rw.Done()
}

// VerifyTarget allows requests to be executed even in non-serving state.
Expand Down Expand Up @@ -507,7 +489,7 @@
func (sm *stateManager) serveNonPrimary(wantTabletType topodatapb.TabletType) error {
// We are likely transitioning from primary. We have to honor
// the shutdown grace period.
cancel := sm.handleShutdownGracePeriod()
cancel := sm.terminateAllQueries(nil)
defer cancel()

sm.ddle.Close()
Expand Down Expand Up @@ -560,9 +542,12 @@
}

func (sm *stateManager) unserveCommon() {
// We create a wait group that tracks whether all the queries have been terminated or not.
wg := sync.WaitGroup{}
wg.Add(1)
log.Infof("Started execution of unserveCommon")
cancel := sm.handleShutdownGracePeriod()
log.Infof("Finished execution of handleShutdownGracePeriod")
cancel := sm.terminateAllQueries(&wg)
log.Infof("Finished execution of terminateAllQueries")
defer cancel()

log.Infof("Started online ddl executor close")
Expand All @@ -580,22 +565,45 @@
log.Info("Finished Killing all OLAP queries. Started tracker close")
sm.tracker.Close()
log.Infof("Finished tracker close. Started wait for requests")
sm.waitForRequestsToBeEmpty()
log.Infof("Finished wait for requests. Finished execution of unserveCommon")
sm.handleShutdownGracePeriod(&wg)
log.Infof("Finished handling grace period. Finished execution of unserveCommon")
}

// handleShutdownGracePeriod checks if we have shutdwonGracePeriod specified.
// If its not, then we have to wait for all the requests to be empty.
// Otherwise, we only wait for all the queries against MySQL to be terminated.
func (sm *stateManager) handleShutdownGracePeriod(wg *sync.WaitGroup) {
// If there is no shutdown grace period specified, then we should wait for all the requests to be empty.
if sm.shutdownGracePeriod == 0 {
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
sm.rw.WaitToBeEmpty()
} else {
// We quickly check if the requests are empty or not.
// If they are, then we don't need to wait for the shutdown to complete.
count := sm.rw.GetOutstandingRequestsCount()
if count == 0 {
return
}
// Otherwise, we should wait for all olap queries to be killed.
// We don't need to wait for requests to be empty since we have ensured all the queries against MySQL have been killed.
wg.Wait()

Check warning on line 588 in go/vt/vttablet/tabletserver/state_manager.go

View check run for this annotation

Codecov / codecov/patch

go/vt/vttablet/tabletserver/state_manager.go#L588

Added line #L588 was not covered by tests
}
}

func (sm *stateManager) handleShutdownGracePeriod() (cancel func()) {
func (sm *stateManager) terminateAllQueries(wg *sync.WaitGroup) (cancel func()) {
if sm.shutdownGracePeriod == 0 {
return func() {}
}
ctx, cancel := context.WithCancel(context.TODO())
go func() {
if wg != nil {
defer wg.Done()
}
if err := timer.SleepContext(ctx, sm.shutdownGracePeriod); err != nil {
return
}
log.Infof("Grace Period %v exceeded. Killing all OLTP queries.", sm.shutdownGracePeriod)
sm.statelessql.TerminateAll()
log.Infof("Killed all stateful OLTP queries.")
log.Infof("Killed all stateless OLTP queries.")
sm.statefulql.TerminateAll()
log.Infof("Killed all OLTP queries.")
}()
Expand Down Expand Up @@ -645,7 +653,7 @@
log.Infof("TabletServer transition: %v -> %v for tablet %s:%s/%s",
sm.stateStringLocked(sm.target.TabletType, sm.state), sm.stateStringLocked(tabletType, state),
sm.target.Cell, sm.target.Keyspace, sm.target.Shard)
sm.handleGracePeriod(tabletType)
sm.handleTransitionGracePeriod(tabletType)
sm.target.TabletType = tabletType
if sm.state == StateNotConnected {
// If we're transitioning out of StateNotConnected, we have
Expand All @@ -664,7 +672,7 @@
return fmt.Sprintf("%v: %v, %v", tabletType, state, sm.ptsTimestamp.Local().Format("Jan 2, 2006 at 15:04:05 (MST)"))
}

func (sm *stateManager) handleGracePeriod(tabletType topodatapb.TabletType) {
func (sm *stateManager) handleTransitionGracePeriod(tabletType topodatapb.TabletType) {
if tabletType != topodatapb.TabletType_PRIMARY {
// We allow serving of previous type only for a primary transition.
sm.alsoAllow = nil
Expand Down
3 changes: 2 additions & 1 deletion go/vt/vttablet/tabletserver/state_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ func TestPanicInWait(t *testing.T) {

// Simulate going to a not serving state and calling unserveCommon that waits on requests.
sm.wantState = StateNotServing
sm.waitForRequestsToBeEmpty()
sm.rw.WaitToBeEmpty()
}

func verifySubcomponent(t *testing.T, order int64, component any, state testState) {
Expand Down Expand Up @@ -752,6 +752,7 @@ func newTestStateManager(t *testing.T) *stateManager {
ddle: &testOnlineDDLExecutor{},
throttler: &testLagThrottler{},
tableGC: &testTableGC{},
rw: newRequestsWaiter(),
}
sm.Init(env, &querypb.Target{})
sm.hs.InitDBConfig(&querypb.Target{}, dbconfigs.New(fakesqldb.New(t).ConnParams()))
Expand Down
1 change: 1 addition & 0 deletions go/vt/vttablet/tabletserver/tabletserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ func NewTabletServer(ctx context.Context, env *vtenv.Environment, name string, c
ddle: tsv.onlineDDLExecutor,
throttler: tsv.lagThrottler,
tableGC: tsv.tableGC,
rw: newRequestsWaiter(),
}

tsv.exporter.NewGaugeFunc("TabletState", "Tablet server state", func() int64 { return int64(tsv.sm.State()) })
Expand Down
Loading