Skip to content

Commit

Permalink
chain: modify mempool config to fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yyforyongyu committed Sep 20, 2023
1 parent 0c87dca commit 8be54ce
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 18 deletions.
29 changes: 25 additions & 4 deletions chain/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,37 @@ type mempoolConfig struct {
// batchWaitInterval defines the default time to sleep between each
// batched calls.
batchWaitInterval time.Duration

// rawMempoolGetter mounts to `m.getRawMempool` and is only changed in
// unit tests.
//
// TODO(yy): interface rpcclient.FutureGetRawMempoolResult so we can
// remove this hack.
rawMempoolGetter func() ([]*chainhash.Hash, error)

// rawTxReceiver mounts to `m.getRawTxIgnoreErr` and is only changed in
// unit tests.
//
// TODO(yy): interface rpcclient.FutureGetRawTransactionResult so we
// can remove this hack.
rawTxReceiver func(getRawTxReceiver) *btcutil.Tx
}

// newMempool creates a new mempool object.
func newMempool(cfg *mempoolConfig) *mempool {
return &mempool{
m := &mempool{
cfg: cfg,
txs: make(map[chainhash.Hash]bool),
inputs: newCachedInputs(),
initFin: make(chan struct{}),
quit: make(chan struct{}),
}

// Mount the default methods.
m.cfg.rawMempoolGetter = m.getRawMempool
m.cfg.rawTxReceiver = m.getRawTxIgnoreErr

return m
}

// Shutdown signals the mempool to exit.
Expand Down Expand Up @@ -345,7 +365,7 @@ func (m *mempool) LoadMempool() error {
now := time.Now()

// Fetch the latest mempool.
txids, err := m.getRawMempool()
txids, err := m.cfg.rawMempoolGetter()
if err != nil {
log.Errorf("Unable to get raw mempool txs: %v", err)
return err
Expand All @@ -370,7 +390,7 @@ func (m *mempool) LoadMempool() error {
// that's new to its internal mempool.
func (m *mempool) UpdateMempoolTxes() []*wire.MsgTx {
// Fetch the latest mempool.
txids, err := m.getRawMempool()
txids, err := m.cfg.rawMempoolGetter()
if err != nil {
log.Errorf("Unable to get raw mempool txs: %v", err)
return nil
Expand Down Expand Up @@ -416,6 +436,7 @@ func (m *mempool) UpdateMempoolTxes() []*wire.MsgTx {
txesToNotify, err := m.batchGetRawTxes(newTxids, true)
if err != nil {
log.Error("Batch getrawtransaction got %v", err)

}

return txesToNotify
Expand Down Expand Up @@ -476,7 +497,7 @@ func (m *mempool) batchGetRawTxes(txids []*chainhash.Hash,

// Iterate the recievers and fetch the response.
for _, resp := range results {
tx := m.getRawTxIgnoreErr(resp)
tx := m.cfg.rawTxReceiver(resp)
if tx == nil {
continue
}
Expand Down
69 changes: 57 additions & 12 deletions chain/mempool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/wire"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -430,7 +431,7 @@ func TestUpdateMempoolTxes(t *testing.T) {
},
}
tx1Hash := tx1.TxHash()
btcTx1 := btcutil.NewTx(tx1)
btctx1 := btcutil.NewTx(tx1)

// Create another transaction.
op3 := wire.OutPoint{Hash: chainhash.Hash{3}}
Expand All @@ -446,10 +447,34 @@ func TestUpdateMempoolTxes(t *testing.T) {
// Create the current mempool state.
mempool1 := []*chainhash.Hash{&tx1Hash, &tx2Hash}

// Create mock receivers.
mockTx1Receiver := make(rpcclient.FutureGetRawTransactionResult)
mockTx2Receiver := make(rpcclient.FutureGetRawTransactionResult)
mockTx3Receiver := make(rpcclient.FutureGetRawTransactionResult)
mockTx4Receiver := make(rpcclient.FutureGetRawTransactionResult)

// Mock the client to return the txes.
mockRPC.On("GetRawMempool").Return(mempool1)
mockRPC.On("GetRawTransaction", &tx1Hash).Return(btcTx1, nil).Once()
mockRPC.On("GetRawTransaction", &tx2Hash).Return(btctx2, nil).Once()
mockRPC.On("GetRawTransactionAsync", &tx1Hash).Return(
mockTx1Receiver).Once()
mockRPC.On("GetRawTransactionAsync", &tx2Hash).Return(
mockTx2Receiver).Once()
mockRPC.On("Send").Return(nil).Once()

// Mock our rawMempoolGetter and rawTxReceiver.
m.cfg.rawMempoolGetter = func() ([]*chainhash.Hash, error) {
return mempool1, nil
}
m.cfg.rawTxReceiver = func(reciever getRawTxReceiver) *btcutil.Tx {
switch reciever {
case mockTx1Receiver:
return btctx1
case mockTx2Receiver:
return btctx2
}

require.Fail("unexpected receiver")
return nil
}

// Update our mempool using the above mempool state.
newTxes := m.UpdateMempoolTxes()
Expand Down Expand Up @@ -485,9 +510,27 @@ func TestUpdateMempoolTxes(t *testing.T) {
mempool2 := []*chainhash.Hash{&tx1Hash, &tx3Hash, &tx4Hash}

// Mock the client to return the txes.
mockRPC.On("GetRawMempool").Return(mempool2)
mockRPC.On("GetRawTransaction", &tx3Hash).Return(btctx3, nil).Once()
mockRPC.On("GetRawTransaction", &tx4Hash).Return(btctx4, nil).Once()
mockRPC.On("GetRawTransactionAsync",
&tx3Hash).Return(mockTx3Receiver).Once()
mockRPC.On("GetRawTransactionAsync",
&tx4Hash).Return(mockTx4Receiver).Once()
mockRPC.On("Send").Return(nil).Once()

// Mock our rawMempoolGetter and rawTxReceiver.
m.cfg.rawMempoolGetter = func() ([]*chainhash.Hash, error) {
return mempool2, nil
}
m.cfg.rawTxReceiver = func(reciever getRawTxReceiver) *btcutil.Tx {
switch reciever {
case mockTx3Receiver:
return btctx3
case mockTx4Receiver:
return btctx4
}

require.Fail("unexpected receiver")
return nil
}

// Update our mempool using the above mempool state.
newTxes = m.UpdateMempoolTxes()
Expand Down Expand Up @@ -548,14 +591,14 @@ func TestUpdateMempoolTxesOnShutdown(t *testing.T) {
},
}
tx1Hash := tx1.TxHash()
btcTx1 := btcutil.NewTx(tx1)

// Create the current mempool state.
mempool := []*chainhash.Hash{&tx1Hash}

// Mock the client to return the txes.
mockRPC.On("GetRawMempool").Return(mempool)
mockRPC.On("GetRawTransaction", &tx1Hash).Return(btcTx1, nil)
// Mock our rawMempoolGetter and rawTxReceiver.
m.cfg.rawMempoolGetter = func() ([]*chainhash.Hash, error) {
return mempool, nil
}

// Shutdown the mempool before updating the txes.
m.Shutdown()
Expand All @@ -567,7 +610,7 @@ func TestUpdateMempoolTxesOnShutdown(t *testing.T) {
require.Empty(newTxes)

// Assert GetRawTransaction is not called because mempool is quit.
mockRPC.AssertNotCalled(t, "GetRawTransaction")
mockRPC.AssertNotCalled(t, "GetRawTransactionAsync")
}

// TestGetRawTxIgnoreErr tests that the mempool's GetRawTxIgnoreErr method
Expand Down Expand Up @@ -611,3 +654,5 @@ func TestGetRawTxIgnoreErr(t *testing.T) {
mockRPC.AssertExpectations(t)
mockReceiver.AssertExpectations(t)
}

// TODO(yy): add tests for `batchGetRawTxes`
5 changes: 3 additions & 2 deletions chain/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ type mockRPCClient struct {
mock.Mock
}

// Compile time assertion that MockPeer implements lnpeer.Peer.
// Compile time assert the implementation.
var _ batchClient = (*mockRPCClient)(nil)

func (m *mockRPCClient) GetRawMempoolAsync() rpcclient.
Expand Down Expand Up @@ -193,6 +193,7 @@ func (m *mockRPCClient) Send() error {

// mockGetRawTxReceiver mocks the getRawTxReceiver interface.
type mockGetRawTxReceiver struct {
*rpcclient.FutureGetRawTransactionResult
mock.Mock
}

Expand All @@ -207,5 +208,5 @@ func (m *mockGetRawTxReceiver) Receive() (*btcutil.Tx, error) {
return args.Get(0).(*btcutil.Tx), args.Error(1)
}

// Compile time assertion that MockPeer implements lnpeer.Peer.
// Compile time assert the implementation.
var _ getRawTxReceiver = (*mockGetRawTxReceiver)(nil)

0 comments on commit 8be54ce

Please sign in to comment.