diff --git a/common/txmgr/address_state.go b/common/txmgr/address_state.go index a9e5ebf0aac..49743b12f88 100644 --- a/common/txmgr/address_state.go +++ b/common/txmgr/address_state.go @@ -171,30 +171,90 @@ func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) findTx filter func(*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) bool, txIDs ...int64, ) []txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] { - return nil + as.RLock() + defer as.RUnlock() + + // if txStates is empty then apply the filter to only the as.allTransactions map + if len(txStates) == 0 { + return as._findTxs(as.allTxs, filter, txIDs...) + } + + var txs []txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] + for _, txState := range txStates { + switch txState { + case TxUnstarted: + filter2 := func(tx *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) bool { + if tx.State != TxUnstarted { + return false + } + return filter(tx) + } + txs = append(txs, as._findTxs(as.allTxs, filter2, txIDs...)...) + case TxInProgress: + if as.inprogressTx != nil && filter(as.inprogressTx) { + txs = append(txs, *as.inprogressTx) + } + case TxUnconfirmed: + txs = append(txs, as._findTxs(as.unconfirmedTxs, filter, txIDs...)...) + case TxConfirmedMissingReceipt: + txs = append(txs, as._findTxs(as.confirmedMissingReceiptTxs, filter, txIDs...)...) + case TxConfirmed: + txs = append(txs, as._findTxs(as.confirmedTxs, filter, txIDs...)...) + case TxFatalError: + txs = append(txs, as._findTxs(as.fatalErroredTxs, filter, txIDs...)...) + default: + panic("findTxs: unknown transaction state") + } + } + + return txs } // pruneUnstartedTxQueue removes the transactions with the given IDs from the unstarted transaction queue. func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) pruneUnstartedTxQueue(ids []int64) { + as.Lock() + defer as.Unlock() + + txs := as.unstartedTxs.PruneByTxIDs(ids) + as._deleteTxs(txs...) } // deleteTxs removes the transactions with the given IDs from the address state. func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) deleteTxs(txs ...txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) { + as.Lock() + defer as.Unlock() + + as._deleteTxs(txs...) } // peekNextUnstartedTx returns the next unstarted transaction in the queue without removing it from the unstarted queue. -func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) peekNextUnstartedTx() (*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { - return nil, nil +// If there are no unstarted transactions, nil is returned. +func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) peekNextUnstartedTx() *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] { + as.RLock() + defer as.RUnlock() + + return as.unstartedTxs.PeekNextTx() } // peekInProgressTx returns the in-progress transaction without removing it from the in-progress state. -func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) peekInProgressTx() (*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { - return nil, nil +// If there is no in-progress transaction, nil is returned. +func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) peekInProgressTx() *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] { + as.RLock() + defer as.RUnlock() + + return as.inprogressTx } -// addTxToUnstarted adds the given transaction to the unstarted queue. -func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) addTxToUnstarted(tx *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) error { - return nil +// addTxToUnstartedQueue adds the given transaction to the unstarted queue. +func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) addTxToUnstartedQueue(tx *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) { + as.Lock() + defer as.Unlock() + + as.unstartedTxs.AddTx(tx) + as.allTxs[tx.ID] = tx + if tx.IdempotencyKey != nil { + as.idempotencyKeyToTx[*tx.IdempotencyKey] = tx + } } // moveUnstartedToInProgress moves the next unstarted transaction to the in-progress state. @@ -248,3 +308,50 @@ func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) moveIn func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) moveConfirmedToUnconfirmed(attempt txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) error { return nil } + +// This is not a concurrency-safe method and should only be called from within a lock +func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) _deleteTxs(txs ...txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) { + for _, tx := range txs { + if tx.IdempotencyKey != nil { + delete(as.idempotencyKeyToTx, *tx.IdempotencyKey) + } + txID := tx.ID + if as.inprogressTx != nil && as.inprogressTx.ID == txID { + as.inprogressTx = nil + } + delete(as.allTxs, txID) + delete(as.unconfirmedTxs, txID) + delete(as.confirmedMissingReceiptTxs, txID) + delete(as.confirmedTxs, txID) + delete(as.fatalErroredTxs, txID) + as.unstartedTxs.RemoveTxByID(txID) + } +} + +// This method is not concurrent safe and should only be called from within a lock +func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) _findTxs( + txIDsToTx map[int64]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], + filter func(*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) bool, + txIDs ...int64, +) []txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] { + var txs []txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] + // if txIDs is not empty then only apply the filter to those transactions + if len(txIDs) > 0 { + for _, txID := range txIDs { + tx := txIDsToTx[txID] + if tx != nil && filter(tx) { + txs = append(txs, *tx) + } + } + return txs + } + + // if txIDs is empty then apply the filter to all transactions + for _, tx := range txIDsToTx { + if filter(tx) { + txs = append(txs, *tx) + } + } + + return txs +} diff --git a/common/txmgr/inmemory_store.go b/common/txmgr/inmemory_store.go index bd4e9a2f3a6..78bf33587ce 100644 --- a/common/txmgr/inmemory_store.go +++ b/common/txmgr/inmemory_store.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/big" + "sync" "time" "github.com/google/uuid" @@ -47,7 +48,8 @@ type inMemoryStore[ keyStore txmgrtypes.KeyStore[ADDR, CHAIN_ID, SEQ] persistentTxStore txmgrtypes.TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE] - addressStates map[ADDR]*addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE] + addressStatesLock sync.RWMutex + addressStates map[ADDR]*addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE] } // NewInMemoryStore returns a new inMemoryStore @@ -108,7 +110,29 @@ func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Creat txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error, ) { - return txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]{}, nil + tx := txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]{} + if ms.chainID.String() != chainID.String() { + panic("invalid chain ID") + } + + ms.addressStatesLock.Lock() + as, ok := ms.addressStates[txRequest.FromAddress] + if !ok { + as = newAddressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE](ms.lggr, chainID, txRequest.FromAddress, ms.maxUnstarted, nil) + ms.addressStates[txRequest.FromAddress] = as + } + ms.addressStatesLock.Unlock() + + // Persist Transaction to persistent storage + tx, err := ms.persistentTxStore.CreateTransaction(ctx, txRequest, chainID) + if err != nil { + return tx, fmt.Errorf("create_transaction: %w", err) + } + + // Update in memory store + // Add the request to the Unstarted channel to be processed by the Broadcaster + as.addTxToUnstartedQueue(&tx) + return *ms.deepCopyTx(tx), nil } // FindTxWithIdempotencyKey returns a transaction with the given idempotency key @@ -153,7 +177,24 @@ func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Updat // GetTxInProgress returns the in_progress transaction for a given address. func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetTxInProgress(ctx context.Context, fromAddress ADDR) (*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { - return nil, nil + ms.addressStatesLock.RLock() + defer ms.addressStatesLock.RUnlock() + as, ok := ms.addressStates[fromAddress] + if !ok { + return nil, nil + } + + tx := as.peekInProgressTx() + if tx == nil { + return nil, nil + } + + if len(tx.TxAttempts) != 1 || tx.TxAttempts[0].State != txmgrtypes.TxAttemptInProgress { + return nil, fmt.Errorf("get_tx_in_progress: invariant violation: expected in_progress transaction %v to have exactly one unsent attempt. "+ + "Your database is in an inconsistent state and this node will not function correctly until the problem is resolved", tx.ID) + } + + return ms.deepCopyTx(*tx), nil } // UpdateTxAttemptInProgressToBroadcast updates a transaction attempt from in_progress to broadcast. @@ -168,8 +209,27 @@ func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Updat } // FindNextUnstartedTransactionFromAddress returns the next unstarted transaction for a given address. -func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindNextUnstartedTransactionFromAddress(_ context.Context, tx *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], fromAddress ADDR, chainID CHAIN_ID) error { - return nil +func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindNextUnstartedTransactionFromAddress(_ context.Context, fromAddress ADDR, chainID CHAIN_ID) ( + *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], + error, +) { + if ms.chainID.String() != chainID.String() { + panic("invalid chain ID") + } + ms.addressStatesLock.RLock() + defer ms.addressStatesLock.RUnlock() + as, ok := ms.addressStates[fromAddress] + if !ok { + return nil, fmt.Errorf("find_next_unstarted_transaction_from_address: %w: %q", ErrAddressNotFound, fromAddress) + } + + etx := as.peekNextUnstartedTx() + if etx == nil { + return nil, fmt.Errorf("find_next_unstarted_transaction_from_address: %w", ErrTxnNotFound) + } + tx := ms.deepCopyTx(*etx) + + return tx, nil } // SaveReplacementInProgressAttempt saves a replacement attempt for a transaction that is in_progress. @@ -260,7 +320,20 @@ func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindT } func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) PruneUnstartedTxQueue(ctx context.Context, queueSize uint32, subject uuid.UUID) ([]int64, error) { - return nil, nil + // Persist to persistent storage + ids, err := ms.persistentTxStore.PruneUnstartedTxQueue(ctx, queueSize, subject) + if err != nil { + return ids, err + } + + // Update in memory store + ms.addressStatesLock.RLock() + defer ms.addressStatesLock.RUnlock() + for _, as := range ms.addressStates { + as.pruneUnstartedTxQueue(ids) + } + + return ids, nil } func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) ReapTxHistory(ctx context.Context, minBlockNumberToKeep int64, timeThreshold time.Time, chainID CHAIN_ID) error { diff --git a/core/chains/evm/txmgr/evm_inmemory_store_test.go b/core/chains/evm/txmgr/evm_inmemory_store_test.go index a102ee1c996..d38b5a5cc7f 100644 --- a/core/chains/evm/txmgr/evm_inmemory_store_test.go +++ b/core/chains/evm/txmgr/evm_inmemory_store_test.go @@ -1,14 +1,334 @@ package txmgr_test import ( + "math/big" + "sort" "testing" + "github.com/ethereum/go-ethereum/common" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + commontxmgr "github.com/smartcontractkit/chainlink/v2/common/txmgr" + + txmgrtypes "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" + + evmassets "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" + evmgas "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" evmtxmgr "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" + evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/evmtest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" ) +func TestInMemoryStore_GetTxInProgress(t *testing.T) { + t.Parallel() + + db := pgtest.NewSqlxDB(t) + _, dbcfg, evmcfg := evmtxmgr.MakeTestConfigs(t) + persistentStore := cltest.NewTestTxStore(t, db) + kst := cltest.NewKeyStore(t, db, dbcfg) + _, fromAddress := cltest.MustInsertRandomKey(t, kst.Eth()) + _, otherAddress := cltest.MustInsertRandomKey(t, kst.Eth()) + + ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + lggr := logger.TestSugared(t) + chainID := ethClient.ConfiguredChainID() + ctx := testutils.Context(t) + + inMemoryStore, err := commontxmgr.NewInMemoryStore[ + *big.Int, + common.Address, common.Hash, common.Hash, + *evmtypes.Receipt, + evmtypes.Nonce, + evmgas.EvmFee, + ](ctx, lggr, chainID, kst.Eth(), persistentStore, evmcfg.Transactions()) + require.NoError(t, err) + + // insert the transaction into the persistent store + inTx := mustInsertInProgressEthTxWithAttempt(t, persistentStore, 123, fromAddress) + require.NotNil(t, inTx) + // insert the transaction into the in-memory store + require.NoError(t, inMemoryStore.XXXTestInsertTx(fromAddress, &inTx)) + + // insert non in-progress transaction for another address + otherTx := cltest.NewEthTx(otherAddress) + require.NoError(t, persistentStore.InsertTx(ctx, &otherTx)) + require.NoError(t, inMemoryStore.XXXTestInsertTx(otherAddress, &otherTx)) + + tcs := []struct { + name string + fromAddress common.Address + + hasErr bool + hasTx bool + }{ + {"finds the correct inprogress transaction", fromAddress, false, true}, + {"wrong fromAddress", common.Address{}, false, false}, + {"no inprogress transaction", otherAddress, false, false}, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + actTx, actErr := inMemoryStore.GetTxInProgress(ctx, tc.fromAddress) + expTx, expErr := persistentStore.GetTxInProgress(ctx, tc.fromAddress) + if tc.hasErr { + require.NotNil(t, actErr) + require.NotNil(t, expErr) + require.Equal(t, expErr, actErr) + } else { + require.Nil(t, actErr) + require.Nil(t, expErr) + } + if tc.hasTx { + require.NotNil(t, actTx) + require.NotNil(t, expTx) + assertTxEqual(t, *expTx, *actTx) + } else { + require.Nil(t, actTx) + require.Nil(t, expTx) + } + }) + } +} + +func TestInMemoryStore_FindNextUnstartedTransactionFromAddress(t *testing.T) { + t.Parallel() + + db := pgtest.NewSqlxDB(t) + _, dbcfg, evmcfg := evmtxmgr.MakeTestConfigs(t) + persistentStore := cltest.NewTestTxStore(t, db) + kst := cltest.NewKeyStore(t, db, dbcfg) + _, fromAddress := cltest.MustInsertRandomKey(t, kst.Eth()) + _, otherAddress := cltest.MustInsertRandomKey(t, kst.Eth()) + + ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + lggr := logger.TestSugared(t) + chainID := ethClient.ConfiguredChainID() + ctx := testutils.Context(t) + + inMemoryStore, err := commontxmgr.NewInMemoryStore[ + *big.Int, + common.Address, common.Hash, common.Hash, + *evmtypes.Receipt, + evmtypes.Nonce, + evmgas.EvmFee, + ](ctx, lggr, chainID, kst.Eth(), persistentStore, evmcfg.Transactions()) + require.NoError(t, err) + + // insert the transaction into the persistent store + inTx := mustCreateUnstartedGeneratedTx(t, persistentStore, fromAddress, chainID) + // insert the transaction into the in-memory store + require.NoError(t, inMemoryStore.XXXTestInsertTx(fromAddress, &inTx)) + + // insert non in-progress transaction for another address + otherTx := mustInsertInProgressEthTxWithAttempt(t, persistentStore, 13, otherAddress) + require.NoError(t, inMemoryStore.XXXTestInsertTx(otherAddress, &otherTx)) + + tcs := []struct { + name string + fromAddress common.Address + chainID *big.Int + + hasErr bool + hasTx bool + }{ + {"finds the correct inprogress transaction", fromAddress, chainID, false, true}, + {"no unstarted transaction", otherAddress, chainID, true, false}, + {"wrong chainID", fromAddress, big.NewInt(123), true, false}, + {"unknown address", common.Address{}, chainID, true, false}, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + actTx, actErr := inMemoryStore.FindNextUnstartedTransactionFromAddress(ctx, tc.fromAddress, tc.chainID) + expTx, expErr := persistentStore.FindNextUnstartedTransactionFromAddress(ctx, tc.fromAddress, tc.chainID) + if tc.hasErr { + require.NotNil(t, actErr) + require.NotNil(t, expErr) + } else { + require.Nil(t, actErr) + require.Nil(t, expErr) + } + if tc.hasTx { + require.NotNil(t, actTx) + require.NotNil(t, expTx) + assertTxEqual(t, *expTx, *actTx) + } else { + require.Nil(t, actTx) + require.Nil(t, expTx) + } + }) + } +} + +func TestInMemoryStore_CreateTransaction(t *testing.T) { + t.Parallel() + + db := pgtest.NewSqlxDB(t) + _, dbcfg, evmcfg := evmtxmgr.MakeTestConfigs(t) + persistentStore := cltest.NewTestTxStore(t, db) + kst := cltest.NewKeyStore(t, db, dbcfg) + _, fromAddress := cltest.MustInsertRandomKey(t, kst.Eth()) + + ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + lggr := logger.TestSugared(t) + chainID := ethClient.ConfiguredChainID() + ctx := testutils.Context(t) + + inMemoryStore, err := commontxmgr.NewInMemoryStore[ + *big.Int, + common.Address, common.Hash, common.Hash, + *evmtypes.Receipt, + evmtypes.Nonce, + evmgas.EvmFee, + ](ctx, lggr, chainID, kst.Eth(), persistentStore, evmcfg.Transactions()) + require.NoError(t, err) + + toAddress := testutils.NewAddress() + gasLimit := uint32(1000) + payload := []byte{1, 2, 3} + + t.Run("with queue under capacity inserts eth_tx", func(t *testing.T) { + subject := uuid.New() + strategy := newMockTxStrategy(t) + strategy.On("Subject").Return(uuid.NullUUID{UUID: subject, Valid: true}) + actTx, err := inMemoryStore.CreateTransaction(ctx, evmtxmgr.TxRequest{ + FromAddress: fromAddress, + ToAddress: toAddress, + EncodedPayload: payload, + FeeLimit: uint64(gasLimit), + Meta: nil, + Strategy: strategy, + }, chainID) + require.NoError(t, err) + + // check that the transaction was inserted into the persistent store + cltest.AssertCount(t, db, "evm.txes", 1) + + var dbEthTx evmtxmgr.DbEthTx + require.NoError(t, db.Get(&dbEthTx, `SELECT * FROM evm.txes ORDER BY id ASC LIMIT 1`)) + + assert.Equal(t, commontxmgr.TxUnstarted, dbEthTx.State) + assert.Equal(t, gasLimit, dbEthTx.GasLimit) + assert.Equal(t, fromAddress, dbEthTx.FromAddress) + assert.Equal(t, toAddress, dbEthTx.ToAddress) + assert.Equal(t, payload, dbEthTx.EncodedPayload) + assert.Equal(t, evmassets.NewEthValue(0), dbEthTx.Value) + assert.Equal(t, subject, dbEthTx.Subject.UUID) + + var expTx evmtxmgr.Tx + dbEthTx.ToTx(&expTx) + + // check that the in-memory store has the same transaction data as the persistent store + assertTxEqual(t, expTx, actTx) + }) +} + +func TestInMemoryStore_PruneUnstartedTxQueue(t *testing.T) { + t.Parallel() + + db := pgtest.NewSqlxDB(t) + _, dbcfg, evmcfg := evmtxmgr.MakeTestConfigs(t) + persistentStore := cltest.NewTestTxStore(t, db) + kst := cltest.NewKeyStore(t, db, dbcfg) + _, fromAddress := cltest.MustInsertRandomKey(t, kst.Eth()) + + ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + lggr := logger.TestSugared(t) + chainID := ethClient.ConfiguredChainID() + ctx := testutils.Context(t) + + inMemoryStore, err := commontxmgr.NewInMemoryStore[ + *big.Int, + common.Address, common.Hash, common.Hash, + *evmtypes.Receipt, + evmtypes.Nonce, + evmgas.EvmFee, + ](ctx, lggr, chainID, kst.Eth(), persistentStore, evmcfg.Transactions()) + require.NoError(t, err) + + t.Run("doesnt prune unstarted transactions if under maxQueueSize", func(t *testing.T) { + maxQueueSize := uint32(5) + nTxs := 3 + subject := uuid.NullUUID{UUID: uuid.New(), Valid: true} + strat := commontxmgr.NewDropOldestStrategy(subject.UUID, maxQueueSize, dbcfg.DefaultQueryTimeout()) + for i := 0; i < nTxs; i++ { + inTx := cltest.NewEthTx(fromAddress) + inTx.Subject = subject + // insert the transaction into the persistent store + require.NoError(t, persistentStore.InsertTx(ctx, &inTx)) + // insert the transaction into the in-memory store + require.NoError(t, inMemoryStore.XXXTestInsertTx(fromAddress, &inTx)) + } + + ids, err := strat.PruneQueue(ctx, inMemoryStore) + require.NoError(t, err) + assert.Equal(t, 0, len(ids)) + + AssertCountPerSubject(t, persistentStore, int64(nTxs), subject.UUID) + fn := func(tx *evmtxmgr.Tx) bool { return true } + states := []txmgrtypes.TxState{commontxmgr.TxUnstarted} + actTxs := inMemoryStore.XXXTestFindTxs(states, fn) + expTxs, err := persistentStore.FindTxesByFromAddressAndState(ctx, fromAddress, "unstarted") + require.NoError(t, err) + require.Equal(t, len(expTxs), len(actTxs)) + + // sort by ID to ensure the order is the same for comparison + sort.SliceStable(actTxs, func(i, j int) bool { + return actTxs[i].ID < actTxs[j].ID + }) + sort.SliceStable(expTxs, func(i, j int) bool { + return expTxs[i].ID < expTxs[j].ID + }) + for i := 0; i < len(expTxs); i++ { + assertTxEqual(t, *expTxs[i], actTxs[i]) + } + }) + t.Run("prunes unstarted transactions", func(t *testing.T) { + maxQueueSize := uint32(5) + nTxs := 5 + subject := uuid.NullUUID{UUID: uuid.New(), Valid: true} + strat := commontxmgr.NewDropOldestStrategy(subject.UUID, maxQueueSize, dbcfg.DefaultQueryTimeout()) + for i := 0; i < nTxs; i++ { + inTx := cltest.NewEthTx(fromAddress) + inTx.Subject = subject + // insert the transaction into the persistent store + require.NoError(t, persistentStore.InsertTx(ctx, &inTx)) + // insert the transaction into the in-memory store + require.NoError(t, inMemoryStore.XXXTestInsertTx(fromAddress, &inTx)) + } + + ids, err := strat.PruneQueue(ctx, inMemoryStore) + require.NoError(t, err) + assert.Equal(t, int(nTxs)-int(maxQueueSize-1), len(ids)) + + AssertCountPerSubject(t, persistentStore, int64(maxQueueSize-1), subject.UUID) + fn := func(tx *evmtxmgr.Tx) bool { return true } + states := []txmgrtypes.TxState{commontxmgr.TxUnstarted} + actTxs := inMemoryStore.XXXTestFindTxs(states, fn) + expTxs, err := persistentStore.FindTxesByFromAddressAndState(ctx, fromAddress, "unstarted") + require.NoError(t, err) + require.Equal(t, len(expTxs), len(actTxs)) + + // sort by ID to ensure the order is the same for comparison + sort.SliceStable(actTxs, func(i, j int) bool { + return actTxs[i].ID < actTxs[j].ID + }) + sort.SliceStable(expTxs, func(i, j int) bool { + return expTxs[i].ID < expTxs[j].ID + }) + for i := 0; i < len(expTxs); i++ { + assertTxEqual(t, *expTxs[i], actTxs[i]) + } + }) + +} + // assertTxEqual asserts that two transactions are equal func assertTxEqual(t *testing.T, exp, act evmtxmgr.Tx) { assert.Equal(t, exp.ID, act.ID) @@ -42,7 +362,6 @@ func assertTxEqual(t *testing.T, exp, act evmtxmgr.Tx) { func assertTxAttemptEqual(t *testing.T, exp, act evmtxmgr.TxAttempt) { assert.Equal(t, exp.ID, act.ID) assert.Equal(t, exp.TxID, act.TxID) - assert.Equal(t, exp.Tx, act.Tx) assert.Equal(t, exp.TxFee, act.TxFee) assert.Equal(t, exp.ChainSpecificFeeLimit, act.ChainSpecificFeeLimit) assert.Equal(t, exp.SignedRawTx, act.SignedRawTx)