From 1543fabf163029dbe4a1d00e2bf34d0f30961f12 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Sat, 10 Aug 2024 18:09:52 +0530 Subject: [PATCH] feat: mark the preparedPool shutdown after the connections have been fetched for rolling back Signed-off-by: Manan Gupta --- go/vt/vttablet/tabletserver/tx_engine.go | 4 +++- go/vt/vttablet/tabletserver/tx_prep_pool.go | 16 ++++++++++++++-- go/vt/vttablet/tabletserver/tx_prep_pool_test.go | 12 +++++------- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/go/vt/vttablet/tabletserver/tx_engine.go b/go/vt/vttablet/tabletserver/tx_engine.go index 57c6ff1fd64..33e22e321bc 100644 --- a/go/vt/vttablet/tabletserver/tx_engine.go +++ b/go/vt/vttablet/tabletserver/tx_engine.go @@ -160,6 +160,8 @@ func (te *TxEngine) transition(state txEngineState) { te.txPool.Open(te.env.Config().DB.AppWithDB(), te.env.Config().DB.DbaWithDB(), te.env.Config().DB.AppDebugWithDB()) if te.twopcEnabled && te.state == AcceptingReadAndWrite { + // Set the preparedPool to start accepting connections. + te.preparedPool.shutdown = false // If there are errors, we choose to raise an alert and // continue anyway. Serving traffic is considered more important // than blocking everything for the sake of a few transactions. @@ -442,7 +444,7 @@ func (te *TxEngine) shutdownTransactions() { func (te *TxEngine) rollbackPrepared() { ctx := tabletenv.LocalContext() - for _, conn := range te.preparedPool.FetchAll() { + for _, conn := range te.preparedPool.FetchAllForRollback() { te.txPool.Rollback(ctx, conn) conn.Release(tx.TxRollback) } diff --git a/go/vt/vttablet/tabletserver/tx_prep_pool.go b/go/vt/vttablet/tabletserver/tx_prep_pool.go index 22e0ce295c0..907c0f0f3f9 100644 --- a/go/vt/vttablet/tabletserver/tx_prep_pool.go +++ b/go/vt/vttablet/tabletserver/tx_prep_pool.go @@ -34,6 +34,8 @@ type TxPreparedPool struct { mu sync.Mutex conns map[string]*StatefulConnection reserved map[string]error + // shutdown tells if the prepared pool has been drained and shutdown. + shutdown bool capacity int } @@ -55,6 +57,10 @@ func NewTxPreparedPool(capacity int) *TxPreparedPool { func (pp *TxPreparedPool) Put(c *StatefulConnection, dtid string) error { pp.mu.Lock() defer pp.mu.Unlock() + // If the pool is shutdown, we don't accept new prepared transactions. + if pp.shutdown { + return errors.New("pool is shutdown") + } if _, ok := pp.reserved[dtid]; ok { return errors.New("duplicate DTID in Prepare: " + dtid) } @@ -95,6 +101,11 @@ func (pp *TxPreparedPool) FetchForRollback(dtid string) *StatefulConnection { func (pp *TxPreparedPool) FetchForCommit(dtid string) (*StatefulConnection, error) { pp.mu.Lock() defer pp.mu.Unlock() + // If the pool is shutdown, we don't have any connections to return. + // That however doesn't mean this transaction was committed, it could very well have been rollbacked. + if pp.shutdown { + return nil, errors.New("pool is shutdown") + } if err, ok := pp.reserved[dtid]; ok { return nil, err } @@ -121,11 +132,12 @@ func (pp *TxPreparedPool) Forget(dtid string) { delete(pp.reserved, dtid) } -// FetchAll removes all connections and returns them as a list. +// FetchAllForRollback removes all connections and returns them as a list. // It also forgets all reserved dtids. -func (pp *TxPreparedPool) FetchAll() []*StatefulConnection { +func (pp *TxPreparedPool) FetchAllForRollback() []*StatefulConnection { pp.mu.Lock() defer pp.mu.Unlock() + pp.shutdown = true conns := make([]*StatefulConnection, 0, len(pp.conns)) for _, c := range pp.conns { conns = append(conns, c) diff --git a/go/vt/vttablet/tabletserver/tx_prep_pool_test.go b/go/vt/vttablet/tabletserver/tx_prep_pool_test.go index cd2b5a180c1..d16ee7b9b6e 100644 --- a/go/vt/vttablet/tabletserver/tx_prep_pool_test.go +++ b/go/vt/vttablet/tabletserver/tx_prep_pool_test.go @@ -113,11 +113,9 @@ func TestPrepFetchAll(t *testing.T) { conn2 := &StatefulConnection{} pp.Put(conn1, "aa") pp.Put(conn2, "bb") - got := pp.FetchAll() - if len(got) != 2 { - t.Errorf("FetchAll len: %d, want 2", len(got)) - } - if len(pp.conns) != 0 { - t.Errorf("len(pp.conns): %d, want 0", len(pp.conns)) - } + got := pp.FetchAllForRollback() + require.Len(t, got, 2) + require.Len(t, pp.conns, 0) + _, err := pp.FetchForCommit("aa") + require.ErrorContains(t, err, "pool is shutdown") }