diff --git a/common/txmgr/address_state.go b/common/txmgr/address_state.go index 063086ac6e2..24ccf795eb0 100644 --- a/common/txmgr/address_state.go +++ b/common/txmgr/address_state.go @@ -381,24 +381,27 @@ func (as *addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) moveCo defer as.Unlock() if txAttempt.State != txmgrtypes.TxAttemptBroadcast { - return fmt.Errorf("move_confirmed_to_unconfirmed: attempt must be in broadcast state") + return fmt.Errorf("attempt must be in broadcast state") } tx, ok := as.confirmedTxs[txAttempt.TxID] if !ok || tx == nil { - return fmt.Errorf("move_confirmed_to_unconfirmed: no confirmed transaction with ID %d", txAttempt.TxID) + return fmt.Errorf("no confirmed transaction with ID %d", txAttempt.TxID) } if len(tx.TxAttempts) == 0 { - return fmt.Errorf("move_confirmed_to_unconfirmed: no attempts for transaction with ID %d", txAttempt.TxID) + return fmt.Errorf("no attempts for transaction with ID %d", txAttempt.TxID) } tx.State = TxUnconfirmed // Delete the receipt from the attempt - txAttempt.Receipts = nil - // Reset the broadcast information for the attempt - txAttempt.State = txmgrtypes.TxAttemptInProgress - txAttempt.BroadcastBeforeBlockNum = nil - tx.TxAttempts = append(tx.TxAttempts, txAttempt) + for i := 0; i < len(tx.TxAttempts); i++ { + if tx.TxAttempts[i].ID == txAttempt.ID { + tx.TxAttempts[i].Receipts = nil + tx.TxAttempts[i].State = txmgrtypes.TxAttemptInProgress + tx.TxAttempts[i].BroadcastBeforeBlockNum = nil + break + } + } as.unconfirmedTxs[tx.ID] = tx delete(as.confirmedTxs, tx.ID) diff --git a/common/txmgr/inmemory_store.go b/common/txmgr/inmemory_store.go index bd4e9a2f3a6..e5c99516b25 100644 --- a/common/txmgr/inmemory_store.go +++ b/common/txmgr/inmemory_store.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/big" + "sync" "time" "github.com/google/uuid" @@ -47,7 +48,8 @@ type inMemoryStore[ keyStore txmgrtypes.KeyStore[ADDR, CHAIN_ID, SEQ] persistentTxStore txmgrtypes.TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE] - addressStates map[ADDR]*addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE] + addressStatesLock sync.RWMutex + addressStates map[ADDR]*addressState[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE] } // NewInMemoryStore returns a new inMemoryStore @@ -331,7 +333,20 @@ func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) SaveS return nil } func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) UpdateTxForRebroadcast(ctx context.Context, etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], etxAttempt txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) error { - return nil + ms.addressStatesLock.RLock() + defer ms.addressStatesLock.RUnlock() + as, ok := ms.addressStates[etx.FromAddress] + if !ok { + return nil + } + + // Persist to persistent storage + if err := ms.persistentTxStore.UpdateTxForRebroadcast(ctx, etx, etxAttempt); err != nil { + return err + } + + // Update in memory store + return as.moveConfirmedToUnconfirmed(etxAttempt) } func (ms *inMemoryStore[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) IsTxFinalized(ctx context.Context, blockHeight int64, txID int64, chainID CHAIN_ID) (bool, error) { return false, nil diff --git a/core/chains/evm/txmgr/evm_inmemory_store_test.go b/core/chains/evm/txmgr/evm_inmemory_store_test.go index a102ee1c996..b0860b5b0ed 100644 --- a/core/chains/evm/txmgr/evm_inmemory_store_test.go +++ b/core/chains/evm/txmgr/evm_inmemory_store_test.go @@ -1,14 +1,129 @@ package txmgr_test import ( + "math/big" "testing" + "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + commontxmgr "github.com/smartcontractkit/chainlink/v2/common/txmgr" + txmgrtypes "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" + + evmgas "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" evmtxmgr "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" + evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/evmtest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" ) +func TestInMemoryStore_UpdateTxForRebroadcast(t *testing.T) { + t.Parallel() + + t.Run("delete all receipts for transaction", func(t *testing.T) { + db := pgtest.NewSqlxDB(t) + _, dbcfg, evmcfg := evmtxmgr.MakeTestConfigs(t) + persistentStore := cltest.NewTestTxStore(t, db) + kst := cltest.NewKeyStore(t, db, dbcfg) + _, fromAddress := cltest.MustInsertRandomKey(t, kst.Eth()) + + ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + lggr := logger.TestSugared(t) + chainID := ethClient.ConfiguredChainID() + ctx := testutils.Context(t) + + inMemoryStore, err := commontxmgr.NewInMemoryStore[ + *big.Int, + common.Address, common.Hash, common.Hash, + *evmtypes.Receipt, + evmtypes.Nonce, + evmgas.EvmFee, + ](ctx, lggr, chainID, kst.Eth(), persistentStore, evmcfg.Transactions()) + require.NoError(t, err) + + // Insert a transaction into persistent store + inTx := mustInsertConfirmedEthTxWithReceipt(t, persistentStore, fromAddress, 777, 1) + // Insert the transaction into the in-memory store + require.NoError(t, inMemoryStore.XXXTestInsertTx(fromAddress, &inTx)) + + txAttempt := inTx.TxAttempts[0] + err = inMemoryStore.UpdateTxForRebroadcast(ctx, inTx, txAttempt) + require.NoError(t, err) + + expTx, err := persistentStore.FindTxWithAttempts(ctx, inTx.ID) + require.NoError(t, err) + + fn := func(tx *evmtxmgr.Tx) bool { return true } + actTxs := inMemoryStore.XXXTestFindTxs(nil, fn, inTx.ID) + require.Equal(t, 1, len(actTxs)) + actTx := actTxs[0] + assertTxEqual(t, expTx, actTx) + assert.Equal(t, commontxmgr.TxUnconfirmed, actTx.State) + assert.Equal(t, txmgrtypes.TxAttemptInProgress, actTx.TxAttempts[0].State) + assert.Nil(t, actTx.TxAttempts[0].BroadcastBeforeBlockNum) + assert.Equal(t, 0, len(actTx.TxAttempts[0].Receipts)) + }) + + t.Run("error parity for in-memory vs persistent store", func(t *testing.T) { + db := pgtest.NewSqlxDB(t) + _, dbcfg, evmcfg := evmtxmgr.MakeTestConfigs(t) + persistentStore := cltest.NewTestTxStore(t, db) + kst := cltest.NewKeyStore(t, db, dbcfg) + _, fromAddress := cltest.MustInsertRandomKey(t, kst.Eth()) + + ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + lggr := logger.TestSugared(t) + chainID := ethClient.ConfiguredChainID() + ctx := testutils.Context(t) + + inMemoryStore, err := commontxmgr.NewInMemoryStore[ + *big.Int, + common.Address, common.Hash, common.Hash, + *evmtypes.Receipt, + evmtypes.Nonce, + evmgas.EvmFee, + ](ctx, lggr, chainID, kst.Eth(), persistentStore, evmcfg.Transactions()) + require.NoError(t, err) + + // Insert a transaction into persistent store + inTx := mustInsertConfirmedEthTxWithReceipt(t, persistentStore, fromAddress, 777, 1) + // Insert the transaction into the in-memory store + require.NoError(t, inMemoryStore.XXXTestInsertTx(fromAddress, &inTx)) + + txAttempt := inTx.TxAttempts[0] + + t.Run("error when attempt is not in Broadcast state", func(t *testing.T) { + txAttempt.State = txmgrtypes.TxAttemptInProgress + expErr := persistentStore.UpdateTxForRebroadcast(ctx, inTx, txAttempt) + actErr := inMemoryStore.UpdateTxForRebroadcast(ctx, inTx, txAttempt) + assert.Error(t, expErr) + assert.Error(t, actErr) + txAttempt.State = txmgrtypes.TxAttemptBroadcast + }) + t.Run("error when transaction is not in confirmed state", func(t *testing.T) { + inTx.State = commontxmgr.TxUnconfirmed + expErr := persistentStore.UpdateTxForRebroadcast(ctx, inTx, txAttempt) + actErr := inMemoryStore.UpdateTxForRebroadcast(ctx, inTx, txAttempt) + assert.Error(t, expErr) + assert.Error(t, actErr) + inTx.State = commontxmgr.TxConfirmed + }) + t.Run("wrong fromAddress has no error", func(t *testing.T) { + inTx.FromAddress = common.Address{} + expErr := persistentStore.UpdateTxForRebroadcast(ctx, inTx, txAttempt) + actErr := inMemoryStore.UpdateTxForRebroadcast(ctx, inTx, txAttempt) + assert.Equal(t, expErr, actErr) + assert.Nil(t, actErr) + inTx.FromAddress = fromAddress + }) + + }) +} + // assertTxEqual asserts that two transactions are equal func assertTxEqual(t *testing.T, exp, act evmtxmgr.Tx) { assert.Equal(t, exp.ID, act.ID) @@ -33,7 +148,7 @@ func assertTxEqual(t *testing.T, exp, act evmtxmgr.Tx) { assert.Equal(t, exp.SignalCallback, act.SignalCallback) assert.Equal(t, exp.CallbackCompleted, act.CallbackCompleted) - require.Len(t, exp.TxAttempts, len(act.TxAttempts)) + require.Equal(t, len(exp.TxAttempts), len(act.TxAttempts)) for i := 0; i < len(exp.TxAttempts); i++ { assertTxAttemptEqual(t, exp.TxAttempts[i], act.TxAttempts[i]) }