From bb7ec106d378d5e974ade0d51bff612c58edb9c3 Mon Sep 17 00:00:00 2001 From: Callum Waters Date: Wed, 8 Jan 2025 15:38:56 +0100 Subject: [PATCH] fix: mempool locking mechanism in v1 and cat (#1582) This is residual of https://github.com/celestiaorg/celestia-core/pull/1553 The problem is now even more subtle. Because the mempool mutexes weren't over both CheckTx and the actual adding of the transaction to the mempool we occasionally hit a situation as follows: - CheckTx a tx with nonce 2 - Before saving it to the store, collect all transactions and recheck tx. This excludes the last tx with nonce 2, thus the nonce in the state machine is still 1 - Save the tx to the pool - New tx comes in with nonce 3. The application is at 1 so rejects it expecting the next to be 2. This PR fixes this pattern, however this won't be watertight for the CAT pool until we can order both on gas fee (priority) and nonce. --------- Co-authored-by: Rootul Patel --- mempool/cat/pool.go | 81 +++++++++++++------------------- mempool/cat/store.go | 50 +++++++++++++++++++- mempool/cat/store_test.go | 79 +++++++++++++++++++++++++++++++ mempool/v1/mempool.go | 99 +++++++++++++++++++-------------------- 4 files changed, 210 insertions(+), 99 deletions(-) diff --git a/mempool/cat/pool.go b/mempool/cat/pool.go index 7fa57b7993..cc9d479337 100644 --- a/mempool/cat/pool.go +++ b/mempool/cat/pool.go @@ -51,7 +51,7 @@ type TxPool struct { metrics *mempool.Metrics // these values are modified once per height - updateMtx sync.Mutex + mtx sync.Mutex notifiedTxsAvailable bool txsAvailable chan struct{} // one value sent per height when mempool is not empty preCheckFn mempool.PreCheckFunc @@ -127,11 +127,15 @@ func WithMetrics(metrics *mempool.Metrics) TxPoolOption { return func(txmp *TxPool) { txmp.metrics = metrics } } -// Lock is a noop as ABCI calls are serialized -func (txmp *TxPool) Lock() {} +// Lock locks the mempool, no new transactions can be processed +func (txmp *TxPool) Lock() { + txmp.mtx.Lock() +} -// Unlock is a noop as ABCI calls are serialized -func (txmp *TxPool) Unlock() {} +// Unlock unlocks the mempool +func (txmp *TxPool) Unlock() { + txmp.mtx.Unlock() +} // Size returns the number of valid transactions in the mempool. It is // thread-safe. @@ -161,8 +165,8 @@ func (txmp *TxPool) TxsAvailable() <-chan struct{} { return txmp.txsAvailable } // Height returns the latest height that the mempool is at func (txmp *TxPool) Height() int64 { - txmp.updateMtx.Lock() - defer txmp.updateMtx.Unlock() + txmp.mtx.Lock() + defer txmp.mtx.Unlock() return txmp.height } @@ -203,8 +207,8 @@ func (txmp *TxPool) IsRejectedTx(txKey types.TxKey) bool { // the txpool looped through all transactions and if so, performs a purge of any transaction // that has expired according to the TTLDuration. This is thread safe. func (txmp *TxPool) CheckToPurgeExpiredTxs() { - txmp.updateMtx.Lock() - defer txmp.updateMtx.Unlock() + txmp.mtx.Lock() + defer txmp.mtx.Unlock() if txmp.config.TTLDuration > 0 && time.Since(txmp.lastPurgeTime) > txmp.config.TTLDuration { expirationAge := time.Now().Add(-txmp.config.TTLDuration) // A height of 0 means no transactions will be removed because of height @@ -330,6 +334,9 @@ func (txmp *TxPool) TryAddNewTx(tx types.Tx, key types.TxKey, txInfo mempool.TxI return nil, err } + txmp.mtx.Lock() + defer txmp.mtx.Unlock() + // Invoke an ABCI CheckTx for this transaction. rsp, err := txmp.proxyAppConn.CheckTxSync(abci.RequestCheckTx{Tx: tx}) if err != nil { @@ -345,7 +352,7 @@ func (txmp *TxPool) TryAddNewTx(tx types.Tx, key types.TxKey, txInfo mempool.TxI // Create wrapped tx wtx := newWrappedTx( - tx, key, txmp.Height(), rsp.GasWanted, rsp.Priority, rsp.Sender, + tx, key, txmp.height, rsp.GasWanted, rsp.Priority, rsp.Sender, ) // Perform the post check @@ -402,20 +409,6 @@ func (txmp *TxPool) PeerHasTx(peer uint16, txKey types.TxKey) { txmp.seenByPeersSet.Add(txKey, peer) } -// allEntriesSorted returns a slice of all the transactions currently in the -// mempool, sorted in nonincreasing order by priority with ties broken by -// increasing order of arrival time. -func (txmp *TxPool) allEntriesSorted() []*wrappedTx { - txs := txmp.store.getAllTxs() - sort.Slice(txs, func(i, j int) bool { - if txs[i].priority == txs[j].priority { - return txs[i].timestamp.Before(txs[j].timestamp) - } - return txs[i].priority > txs[j].priority // N.B. higher priorities first - }) - return txs -} - // ReapMaxBytesMaxGas returns a slice of valid transactions that fit within the // size and gas constraints. The results are ordered by nonincreasing priority, // with ties broken by increasing order of arrival. Reaping transactions does @@ -429,19 +422,20 @@ func (txmp *TxPool) allEntriesSorted() []*wrappedTx { func (txmp *TxPool) ReapMaxBytesMaxGas(maxBytes, maxGas int64) types.Txs { var totalGas, totalBytes int64 - var keep []types.Tx //nolint:prealloc - for _, w := range txmp.allEntriesSorted() { + var keep []types.Tx + txmp.store.iterateOrderedTxs(func(w *wrappedTx) bool { // N.B. When computing byte size, we need to include the overhead for // encoding as protobuf to send to the application. This actually overestimates it // as we add the proto overhead to each transaction txBytes := types.ComputeProtoSizeForTxs([]types.Tx{w.tx}) if (maxGas >= 0 && totalGas+w.gasWanted > maxGas) || (maxBytes >= 0 && totalBytes+txBytes > maxBytes) { - continue + return true } totalBytes += txBytes totalGas += w.gasWanted keep = append(keep, w.tx) - } + return true + }) return keep } @@ -454,14 +448,15 @@ func (txmp *TxPool) ReapMaxBytesMaxGas(maxBytes, maxGas int64) types.Txs { // The result may have fewer than max elements (possibly zero) if the mempool // does not have that many transactions available. func (txmp *TxPool) ReapMaxTxs(max int) types.Txs { - var keep []types.Tx //nolint:prealloc + var keep []types.Tx - for _, w := range txmp.allEntriesSorted() { + txmp.store.iterateOrderedTxs(func(w *wrappedTx) bool { if max >= 0 && len(keep) >= max { - break + return false } keep = append(keep, w.tx) - } + return true + }) return keep } @@ -490,7 +485,6 @@ func (txmp *TxPool) Update( } txmp.logger.Debug("updating mempool", "height", blockHeight, "txs", len(blockTxs)) - txmp.updateMtx.Lock() txmp.height = blockHeight txmp.notifiedTxsAvailable = false @@ -501,7 +495,6 @@ func (txmp *TxPool) Update( txmp.postCheckFn = newPostFn } txmp.lastPurgeTime = time.Now() - txmp.updateMtx.Unlock() txmp.metrics.SuccessfulTxs.Add(float64(len(blockTxs))) for _, tx := range blockTxs { @@ -665,19 +658,12 @@ func (txmp *TxPool) recheckTransactions() { txmp.logger.Debug( "executing re-CheckTx for all remaining transactions", "num_txs", txmp.Size(), - "height", txmp.Height(), + "height", txmp.height, ) - // Collect transactions currently in the mempool requiring recheck. - // TODO: we are iterating over a map, which may scramble the order of transactions - // such that they are not in order, dictated by nonce and then priority. This may - // cause transactions to needlessly be kicked out in RecheckTx - wtxs := txmp.store.getAllTxs() - // Issue CheckTx calls for each remaining transaction, and when all the // rechecks are complete signal watchers that transactions may be available. - for _, wtx := range wtxs { - wtx := wtx + txmp.store.iterateOrderedTxs(func(wtx *wrappedTx) bool { // The response for this CheckTx is handled by the default recheckTxCallback. rsp, err := txmp.proxyAppConn.CheckTxSync(abci.RequestCheckTx{ Tx: wtx.tx, @@ -689,7 +675,8 @@ func (txmp *TxPool) recheckTransactions() { } else { txmp.handleRecheckResult(wtx, rsp) } - } + return true + }) _ = txmp.proxyAppConn.FlushAsync() // When recheck is complete, trigger a notification for more transactions. @@ -766,8 +753,8 @@ func (txmp *TxPool) notifyTxsAvailable() { } func (txmp *TxPool) preCheck(tx types.Tx) error { - txmp.updateMtx.Lock() - defer txmp.updateMtx.Unlock() + txmp.mtx.Lock() + defer txmp.mtx.Unlock() if txmp.preCheckFn != nil { return txmp.preCheckFn(tx) } @@ -775,8 +762,6 @@ func (txmp *TxPool) preCheck(tx types.Tx) error { } func (txmp *TxPool) postCheck(tx types.Tx, res *abci.ResponseCheckTx) error { - txmp.updateMtx.Lock() - defer txmp.updateMtx.Unlock() if txmp.postCheckFn != nil { return txmp.postCheckFn(tx, res) } diff --git a/mempool/cat/store.go b/mempool/cat/store.go index 29e4914110..9840d52418 100644 --- a/mempool/cat/store.go +++ b/mempool/cat/store.go @@ -1,6 +1,8 @@ package cat import ( + "fmt" + "sort" "sync" "time" @@ -11,6 +13,7 @@ import ( type store struct { mtx sync.RWMutex bytes int64 + orderedTxs []*wrappedTx txs map[types.TxKey]*wrappedTx reservedTxs map[types.TxKey]struct{} } @@ -18,6 +21,7 @@ type store struct { func newStore() *store { return &store{ bytes: 0, + orderedTxs: make([]*wrappedTx, 0), txs: make(map[types.TxKey]*wrappedTx), reservedTxs: make(map[types.TxKey]struct{}), } @@ -31,6 +35,7 @@ func (s *store) set(wtx *wrappedTx) bool { defer s.mtx.Unlock() if _, exists := s.txs[wtx.key]; !exists { s.txs[wtx.key] = wtx + s.orderTx(wtx) s.bytes += wtx.size() return true } @@ -58,6 +63,9 @@ func (s *store) remove(txKey types.TxKey) bool { return false } s.bytes -= tx.size() + if err := s.deleteOrderedTx(tx); err != nil { + panic(err) + } delete(s.txs, txKey) return true } @@ -131,10 +139,13 @@ func (s *store) getTxsBelowPriority(priority int64) ([]*wrappedTx, int64) { defer s.mtx.RUnlock() txs := make([]*wrappedTx, 0, len(s.txs)) bytes := int64(0) - for _, tx := range s.txs { + for i := len(s.orderedTxs) - 1; i >= 0; i-- { + tx := s.orderedTxs[i] if tx.priority < priority { txs = append(txs, tx) bytes += tx.size() + } else { + break } } return txs, bytes @@ -165,4 +176,41 @@ func (s *store) reset() { defer s.mtx.Unlock() s.bytes = 0 s.txs = make(map[types.TxKey]*wrappedTx) + s.orderedTxs = make([]*wrappedTx, 0) +} + +func (s *store) orderTx(tx *wrappedTx) { + idx := s.getTxOrder(tx) + s.orderedTxs = append(s.orderedTxs[:idx], append([]*wrappedTx{tx}, s.orderedTxs[idx:]...)...) +} + +func (s *store) deleteOrderedTx(tx *wrappedTx) error { + if len(s.orderedTxs) == 0 { + return fmt.Errorf("ordered transactions list is empty") + } + idx := s.getTxOrder(tx) - 1 + if idx >= len(s.orderedTxs) || s.orderedTxs[idx] != tx { + return fmt.Errorf("transaction %X not found in ordered list", tx.key) + } + s.orderedTxs = append(s.orderedTxs[:idx], s.orderedTxs[idx+1:]...) + return nil +} + +func (s *store) getTxOrder(tx *wrappedTx) int { + return sort.Search(len(s.orderedTxs), func(i int) bool { + if s.orderedTxs[i].priority == tx.priority { + return tx.timestamp.Before(s.orderedTxs[i].timestamp) + } + return s.orderedTxs[i].priority < tx.priority + }) +} + +func (s *store) iterateOrderedTxs(fn func(tx *wrappedTx) bool) { + s.mtx.RLock() + defer s.mtx.RUnlock() + for _, tx := range s.orderedTxs { + if !fn(tx) { + break + } + } } diff --git a/mempool/cat/store_test.go b/mempool/cat/store_test.go index 4397f239ea..56bae0d750 100644 --- a/mempool/cat/store_test.go +++ b/mempool/cat/store_test.go @@ -40,6 +40,85 @@ func TestStoreSimple(t *testing.T) { require.Nil(t, store.get(key)) require.Zero(t, store.size()) require.Zero(t, store.totalBytes()) + require.Empty(t, store.orderedTxs) + require.Empty(t, store.txs) +} + +func TestStoreOrdering(t *testing.T) { + store := newStore() + + tx1 := types.Tx("tx1") + tx2 := types.Tx("tx2") + tx3 := types.Tx("tx3") + + // Create wrapped txs with different priorities + wtx1 := newWrappedTx(tx1, tx1.Key(), 1, 1, 1, "") + wtx2 := newWrappedTx(tx2, tx2.Key(), 2, 2, 2, "") + wtx3 := newWrappedTx(tx3, tx3.Key(), 3, 3, 3, "") + + // Add txs in reverse priority order + store.set(wtx1) + store.set(wtx2) + store.set(wtx3) + + // Check that iteration returns txs in correct priority order + var orderedTxs []*wrappedTx + store.iterateOrderedTxs(func(tx *wrappedTx) bool { + orderedTxs = append(orderedTxs, tx) + return true + }) + + require.Equal(t, 3, len(orderedTxs)) + require.Equal(t, wtx3, orderedTxs[0]) + require.Equal(t, wtx2, orderedTxs[1]) + require.Equal(t, wtx1, orderedTxs[2]) +} + +func TestStore(t *testing.T) { + t.Run("deleteOrderedTx", func(*testing.T) { + store := newStore() + + tx1 := types.Tx("tx1") + tx2 := types.Tx("tx2") + tx3 := types.Tx("tx3") + + // Create wrapped txs with different priorities + wtx1 := newWrappedTx(tx1, tx1.Key(), 1, 1, 1, "") + wtx2 := newWrappedTx(tx2, tx2.Key(), 2, 2, 2, "") + wtx3 := newWrappedTx(tx3, tx3.Key(), 3, 3, 3, "") + + // Add txs in reverse priority order + store.set(wtx1) + store.set(wtx2) + store.set(wtx3) + + orderedTxs := getOrderedTxs(store) + require.Equal(t, []*wrappedTx{wtx3, wtx2, wtx1}, orderedTxs) + + err := store.deleteOrderedTx(wtx2) + require.NoError(t, err) + require.Equal(t, []*wrappedTx{wtx3, wtx1}, getOrderedTxs(store)) + + err = store.deleteOrderedTx(wtx3) + require.NoError(t, err) + require.Equal(t, []*wrappedTx{wtx1}, getOrderedTxs(store)) + + err = store.deleteOrderedTx(wtx1) + require.NoError(t, err) + require.Equal(t, []*wrappedTx{}, getOrderedTxs(store)) + + err = store.deleteOrderedTx(wtx1) + require.ErrorContains(t, err, "ordered transactions list is empty") + }) +} + +func getOrderedTxs(store *store) []*wrappedTx { + orderedTxs := []*wrappedTx{} + store.iterateOrderedTxs(func(tx *wrappedTx) bool { + orderedTxs = append(orderedTxs, tx) + return true + }) + return orderedTxs } func TestStoreReservingTxs(t *testing.T) { diff --git a/mempool/v1/mempool.go b/mempool/v1/mempool.go index 6bcdff25d6..cc4136e3ce 100644 --- a/mempool/v1/mempool.go +++ b/mempool/v1/mempool.go @@ -46,8 +46,8 @@ type TxMempool struct { mtx *sync.RWMutex notifiedTxsAvailable bool txsAvailable chan struct{} // one value sent per height when mempool is not empty - preCheck mempool.PreCheckFunc - postCheck mempool.PostCheckFunc + preCheckFn mempool.PreCheckFunc + postCheckFn mempool.PostCheckFunc height int64 // the latest height passed to Update lastPurgeTime time.Time // the last time we attempted to purge transactions via the TTL @@ -98,14 +98,14 @@ func NewTxMempool( // returns an error. This is executed before CheckTx. It only applies to the // first created block. After that, Update() overwrites the existing value. func WithPreCheck(f mempool.PreCheckFunc) TxMempoolOption { - return func(txmp *TxMempool) { txmp.preCheck = f } + return func(txmp *TxMempool) { txmp.preCheckFn = f } } // WithPostCheck sets a filter for the mempool to reject a transaction if // f(tx, resp) returns an error. This is executed after CheckTx. It only applies // to the first created block. After that, Update overwrites the existing value. func WithPostCheck(f mempool.PostCheckFunc) TxMempoolOption { - return func(txmp *TxMempool) { txmp.postCheck = f } + return func(txmp *TxMempool) { txmp.postCheckFn = f } } // WithMetrics sets the mempool's metrics collector. @@ -185,50 +185,42 @@ func (txmp *TxMempool) TxsAvailable() <-chan struct{} { return txmp.txsAvailable // the size of tx, and adds tx instead. If no such transactions exist, tx is // discarded. func (txmp *TxMempool) CheckTx(tx types.Tx, cb func(*abci.Response), txInfo mempool.TxInfo) error { - // During the initial phase of CheckTx, we do not need to modify any state. - // A transaction will not actually be added to the mempool until it survives - // a call to the ABCI CheckTx method and size constraint checks. - height, err := func() (int64, error) { - txmp.mtx.RLock() - defer txmp.mtx.RUnlock() - - // Reject transactions in excess of the configured maximum transaction size. - if len(tx) > txmp.config.MaxTxBytes { - return 0, mempool.ErrTxTooLarge{Max: txmp.config.MaxTxBytes, Actual: len(tx)} - } - // If a precheck hook is defined, call it before invoking the application. - if txmp.preCheck != nil { - if err := txmp.preCheck(tx); err != nil { - txmp.metrics.FailedTxs.Add(1) - return 0, mempool.ErrPreCheck{Reason: err} - } - } + // Reject transactions in excess of the configured maximum transaction size. + if len(tx) > txmp.config.MaxTxBytes { + return mempool.ErrTxTooLarge{Max: txmp.config.MaxTxBytes, Actual: len(tx)} + } - // Early exit if the proxy connection has an error. - if err := txmp.proxyAppConn.Error(); err != nil { - return 0, err - } + // If a precheck hook is defined, call it before invoking the application. + if err := txmp.preCheck(tx); err != nil { + txmp.metrics.FailedTxs.Add(1) + return mempool.ErrPreCheck{Reason: err} + } + + // Early exit if the proxy connection has an error. + if err := txmp.proxyAppConn.Error(); err != nil { + return err + } - txKey := tx.Key() + txKey := tx.Key() - // Check for the transaction in the cache. - if !txmp.cache.Push(tx) { - // If the cached transaction is also in the pool, record its sender. - if elt, ok := txmp.txByKey[txKey]; ok { - txmp.metrics.AlreadySeenTxs.Add(1) - w := elt.Value.(*WrappedTx) - w.SetPeer(txInfo.SenderID) - } - return 0, mempool.ErrTxInCache + // Check for the transaction in the cache. + if !txmp.cache.Push(tx) { + // If the cached transaction is also in the pool, record its sender. + if elt, ok := txmp.txByKey[txKey]; ok { + txmp.metrics.AlreadySeenTxs.Add(1) + w := elt.Value.(*WrappedTx) + w.SetPeer(txInfo.SenderID) } - return txmp.height, nil - }() - if err != nil { - return err + return mempool.ErrTxInCache } + // At this point, we need to ensure that passing CheckTx and adding to + // the mempool is atomic. + txmp.Lock() + defer txmp.Unlock() + // Invoke an ABCI CheckTx for this transaction. rsp, err := txmp.proxyAppConn.CheckTxSync(abci.RequestCheckTx{Tx: tx}) if err != nil { @@ -239,9 +231,10 @@ func (txmp *TxMempool) CheckTx(tx types.Tx, cb func(*abci.Response), txInfo memp tx: tx, hash: tx.Key(), timestamp: time.Now().UTC(), - height: height, + height: txmp.height, } wtx.SetPeer(txInfo.SenderID) + // This won't add the transaction if the response code is non zero (i.e. there was an error) txmp.addNewTransaction(wtx, rsp) if cb != nil { cb(&abci.Response{Value: &abci.Response_CheckTx{CheckTx: rsp}}) @@ -426,10 +419,10 @@ func (txmp *TxMempool) Update( txmp.notifiedTxsAvailable = false if newPreFn != nil { - txmp.preCheck = newPreFn + txmp.preCheckFn = newPreFn } if newPostFn != nil { - txmp.postCheck = newPostFn + txmp.postCheckFn = newPostFn } txmp.metrics.SuccessfulTxs.Add(float64(len(blockTxs))) @@ -479,12 +472,9 @@ func (txmp *TxMempool) Update( // // Finally, the new transaction is added and size stats updated. func (txmp *TxMempool) addNewTransaction(wtx *WrappedTx, checkTxRes *abci.ResponseCheckTx) { - txmp.mtx.Lock() - defer txmp.mtx.Unlock() - var err error - if txmp.postCheck != nil { - err = txmp.postCheck(wtx.tx, checkTxRes) + if txmp.postCheckFn != nil { + err = txmp.postCheckFn(wtx.tx, checkTxRes) } if err != nil || checkTxRes.Code != abci.CodeTypeOK { @@ -659,8 +649,8 @@ func (txmp *TxMempool) handleRecheckResult(tx types.Tx, checkTxRes *abci.Respons // If a postcheck hook is defined, call it before checking the result. var err error - if txmp.postCheck != nil { - err = txmp.postCheck(tx, checkTxRes) + if txmp.postCheckFn != nil { + err = txmp.postCheckFn(tx, checkTxRes) } if checkTxRes.Code == abci.CodeTypeOK && err == nil { @@ -803,3 +793,12 @@ func (txmp *TxMempool) notifyTxsAvailable() { } } } + +func (txmp *TxMempool) preCheck(tx types.Tx) error { + txmp.mtx.Lock() + defer txmp.mtx.Unlock() + if txmp.preCheckFn != nil { + return txmp.preCheckFn(tx) + } + return nil +}