diff --git a/chain/mempool.go b/chain/mempool.go index cc0e30d0ad..b15ce5127c 100644 --- a/chain/mempool.go +++ b/chain/mempool.go @@ -137,17 +137,37 @@ type mempoolConfig struct { // batchWaitInterval defines the default time to sleep between each // batched calls. batchWaitInterval time.Duration + + // rawMempoolGetter mounts to `m.getRawMempool` and is only changed in + // unit tests. + // + // TODO(yy): interface rpcclient.FutureGetRawMempoolResult so we can + // remove this hack. + rawMempoolGetter func() ([]*chainhash.Hash, error) + + // rawTxReceiver mounts to `m.getRawTxIgnoreErr` and is only changed in + // unit tests. + // + // TODO(yy): interface rpcclient.FutureGetRawTransactionResult so we + // can remove this hack. + rawTxReceiver func(getRawTxReceiver) *btcutil.Tx } // newMempool creates a new mempool object. func newMempool(cfg *mempoolConfig) *mempool { - return &mempool{ + m := &mempool{ cfg: cfg, txs: make(map[chainhash.Hash]bool), inputs: newCachedInputs(), initFin: make(chan struct{}), quit: make(chan struct{}), } + + // Mount the default methods. + m.cfg.rawMempoolGetter = m.getRawMempool + m.cfg.rawTxReceiver = m.getRawTxIgnoreErr + + return m } // Shutdown signals the mempool to exit. @@ -345,7 +365,7 @@ func (m *mempool) LoadMempool() error { now := time.Now() // Fetch the latest mempool. - txids, err := m.getRawMempool() + txids, err := m.cfg.rawMempoolGetter() if err != nil { log.Errorf("Unable to get raw mempool txs: %v", err) return err @@ -370,7 +390,7 @@ func (m *mempool) LoadMempool() error { // that's new to its internal mempool. func (m *mempool) UpdateMempoolTxes() []*wire.MsgTx { // Fetch the latest mempool. - txids, err := m.getRawMempool() + txids, err := m.cfg.rawMempoolGetter() if err != nil { log.Errorf("Unable to get raw mempool txs: %v", err) return nil @@ -416,6 +436,7 @@ func (m *mempool) UpdateMempoolTxes() []*wire.MsgTx { txesToNotify, err := m.batchGetRawTxes(newTxids, true) if err != nil { log.Error("Batch getrawtransaction got %v", err) + } return txesToNotify @@ -476,7 +497,7 @@ func (m *mempool) batchGetRawTxes(txids []*chainhash.Hash, // Iterate the recievers and fetch the response. for _, resp := range results { - tx := m.getRawTxIgnoreErr(resp) + tx := m.cfg.rawTxReceiver(resp) if tx == nil { continue } diff --git a/chain/mempool_test.go b/chain/mempool_test.go index 351a742ef8..6a77cf19ca 100644 --- a/chain/mempool_test.go +++ b/chain/mempool_test.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" "github.com/stretchr/testify/require" ) @@ -430,7 +431,7 @@ func TestUpdateMempoolTxes(t *testing.T) { }, } tx1Hash := tx1.TxHash() - btcTx1 := btcutil.NewTx(tx1) + btctx1 := btcutil.NewTx(tx1) // Create another transaction. op3 := wire.OutPoint{Hash: chainhash.Hash{3}} @@ -446,10 +447,34 @@ func TestUpdateMempoolTxes(t *testing.T) { // Create the current mempool state. mempool1 := []*chainhash.Hash{&tx1Hash, &tx2Hash} + // Create mock receivers. + mockTx1Receiver := make(rpcclient.FutureGetRawTransactionResult) + mockTx2Receiver := make(rpcclient.FutureGetRawTransactionResult) + mockTx3Receiver := make(rpcclient.FutureGetRawTransactionResult) + mockTx4Receiver := make(rpcclient.FutureGetRawTransactionResult) + // Mock the client to return the txes. - mockRPC.On("GetRawMempool").Return(mempool1) - mockRPC.On("GetRawTransaction", &tx1Hash).Return(btcTx1, nil).Once() - mockRPC.On("GetRawTransaction", &tx2Hash).Return(btctx2, nil).Once() + mockRPC.On("GetRawTransactionAsync", &tx1Hash).Return( + mockTx1Receiver).Once() + mockRPC.On("GetRawTransactionAsync", &tx2Hash).Return( + mockTx2Receiver).Once() + mockRPC.On("Send").Return(nil).Once() + + // Mock our rawMempoolGetter and rawTxReceiver. + m.cfg.rawMempoolGetter = func() ([]*chainhash.Hash, error) { + return mempool1, nil + } + m.cfg.rawTxReceiver = func(reciever getRawTxReceiver) *btcutil.Tx { + switch reciever { + case mockTx1Receiver: + return btctx1 + case mockTx2Receiver: + return btctx2 + } + + require.Fail("unexpected receiver") + return nil + } // Update our mempool using the above mempool state. newTxes := m.UpdateMempoolTxes() @@ -485,9 +510,27 @@ func TestUpdateMempoolTxes(t *testing.T) { mempool2 := []*chainhash.Hash{&tx1Hash, &tx3Hash, &tx4Hash} // Mock the client to return the txes. - mockRPC.On("GetRawMempool").Return(mempool2) - mockRPC.On("GetRawTransaction", &tx3Hash).Return(btctx3, nil).Once() - mockRPC.On("GetRawTransaction", &tx4Hash).Return(btctx4, nil).Once() + mockRPC.On("GetRawTransactionAsync", + &tx3Hash).Return(mockTx3Receiver).Once() + mockRPC.On("GetRawTransactionAsync", + &tx4Hash).Return(mockTx4Receiver).Once() + mockRPC.On("Send").Return(nil).Once() + + // Mock our rawMempoolGetter and rawTxReceiver. + m.cfg.rawMempoolGetter = func() ([]*chainhash.Hash, error) { + return mempool2, nil + } + m.cfg.rawTxReceiver = func(reciever getRawTxReceiver) *btcutil.Tx { + switch reciever { + case mockTx3Receiver: + return btctx3 + case mockTx4Receiver: + return btctx4 + } + + require.Fail("unexpected receiver") + return nil + } // Update our mempool using the above mempool state. newTxes = m.UpdateMempoolTxes() @@ -548,14 +591,14 @@ func TestUpdateMempoolTxesOnShutdown(t *testing.T) { }, } tx1Hash := tx1.TxHash() - btcTx1 := btcutil.NewTx(tx1) // Create the current mempool state. mempool := []*chainhash.Hash{&tx1Hash} - // Mock the client to return the txes. - mockRPC.On("GetRawMempool").Return(mempool) - mockRPC.On("GetRawTransaction", &tx1Hash).Return(btcTx1, nil) + // Mock our rawMempoolGetter and rawTxReceiver. + m.cfg.rawMempoolGetter = func() ([]*chainhash.Hash, error) { + return mempool, nil + } // Shutdown the mempool before updating the txes. m.Shutdown() @@ -567,7 +610,7 @@ func TestUpdateMempoolTxesOnShutdown(t *testing.T) { require.Empty(newTxes) // Assert GetRawTransaction is not called because mempool is quit. - mockRPC.AssertNotCalled(t, "GetRawTransaction") + mockRPC.AssertNotCalled(t, "GetRawTransactionAsync") } // TestGetRawTxIgnoreErr tests that the mempool's GetRawTxIgnoreErr method @@ -611,3 +654,5 @@ func TestGetRawTxIgnoreErr(t *testing.T) { mockRPC.AssertExpectations(t) mockReceiver.AssertExpectations(t) } + +// TODO(yy): add tests for `batchGetRawTxes` diff --git a/chain/mocks_test.go b/chain/mocks_test.go index 43192e5d0f..f19dac72d6 100644 --- a/chain/mocks_test.go +++ b/chain/mocks_test.go @@ -163,7 +163,7 @@ type mockRPCClient struct { mock.Mock } -// Compile time assertion that MockPeer implements lnpeer.Peer. +// Compile time assert the implementation. var _ batchClient = (*mockRPCClient)(nil) func (m *mockRPCClient) GetRawMempoolAsync() rpcclient. @@ -193,6 +193,7 @@ func (m *mockRPCClient) Send() error { // mockGetRawTxReceiver mocks the getRawTxReceiver interface. type mockGetRawTxReceiver struct { + *rpcclient.FutureGetRawTransactionResult mock.Mock } @@ -207,5 +208,5 @@ func (m *mockGetRawTxReceiver) Receive() (*btcutil.Tx, error) { return args.Get(0).(*btcutil.Tx), args.Error(1) } -// Compile time assertion that MockPeer implements lnpeer.Peer. +// Compile time assert the implementation. var _ getRawTxReceiver = (*mockGetRawTxReceiver)(nil)