diff --git a/common/client/node_lifecycle_test.go b/common/client/node_lifecycle_test.go index 2215365d7dd..97fca5037d7 100644 --- a/common/client/node_lifecycle_test.go +++ b/common/client/node_lifecycle_test.go @@ -27,9 +27,6 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node := newTestNode(t, opts) opts.rpc.On("Close").Return(nil).Once() - t.Cleanup(func() { - assert.NoError(t, node.close()) - }) node.setState(nodeStateDialed) return node } @@ -57,6 +54,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateUnreachable }) + assert.NoError(t, node.close()) }) t.Run("if remote RPC connection is closed transitions to unreachable", func(t *testing.T) { t.Parallel() @@ -82,6 +80,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node.declareAlive() tests.AssertLogEventually(t, observedLogs, "Subscription was terminated") assert.Equal(t, nodeStateUnreachable, node.State()) + assert.NoError(t, node.close()) }) newSubscribedNode := func(t *testing.T, opts testNodeOpts) testNode { @@ -92,6 +91,21 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { opts.rpc.On("SetAliveLoopSub", sub).Once() return newDialedNode(t, opts) } + t.Run("Stays alive and waits for signal", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{}, + rpc: rpc, + lggr: lggr, + }) + node.declareAlive() + tests.AssertLogEventually(t, observedLogs, "Head liveness checking disabled") + tests.AssertLogEventually(t, observedLogs, "Polling disabled") + assert.Equal(t, nodeStateAlive, node.State()) + assert.NoError(t, node.close()) + }) t.Run("stays alive while below pollFailureThreshold and resets counter on success", func(t *testing.T) { t.Parallel() rpc := newMockNodeClient[types.ID, Head](t) @@ -131,9 +145,10 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { tests.AssertLogCountEventually(t, observedLogs, fmt.Sprintf("Poll failure, RPC endpoint %s failed to respond properly", node.String()), pollFailureThreshold) tests.AssertLogCountEventually(t, observedLogs, "Version poll successful", 2) assert.True(t, ensuredAlive.Load(), "expected to ensure that node was alive") + assert.NoError(t, node.close()) }) - t.Run("becomes unreachable when exceeds pollFailureThreshold", func(t *testing.T) { + t.Run("with threshold poll failures, transitions to unreachable", func(t *testing.T) { t.Parallel() rpc := newMockNodeClient[types.ID, Head](t) lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) @@ -157,8 +172,9 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return nodeStateUnreachable == node.State() }) + assert.NoError(t, node.close()) }) - t.Run("stays alive even, when exceeds pollFailureThreshold because it's last node", func(t *testing.T) { + t.Run("with threshold poll failures, but we are the last node alive, forcibly keeps it alive", func(t *testing.T) { t.Parallel() rpc := newMockNodeClient[types.ID, Head](t) lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) @@ -179,8 +195,9 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node.declareAlive() tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("RPC endpoint failed to respond to %d consecutive polls", pollFailureThreshold)) assert.Equal(t, nodeStateAlive, node.State()) + assert.NoError(t, node.close()) }) - t.Run("outOfSync when falls behind", func(t *testing.T) { + t.Run("when behind more than SyncThreshold, transitions to out of sync", func(t *testing.T) { t.Parallel() rpc := newMockNodeClient[types.ID, Head](t) lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) @@ -209,8 +226,9 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() node.declareAlive() tests.AssertLogEventually(t, observedLogs, "Failed to dial out-of-sync RPC node") + assert.NoError(t, node.close()) }) - t.Run("stays alive even when falls behind", func(t *testing.T) { + t.Run("when behind more than SyncThreshold but we are the last live node, forcibly stays alive", func(t *testing.T) { t.Parallel() rpc := newMockNodeClient[types.ID, Head](t) lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) @@ -231,7 +249,32 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { rpc.On("ClientVersion", mock.Anything).Return("", nil) node.declareAlive() tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("RPC endpoint has fallen behind; %s %s", msgCannotDisable, msgDegradedState)) + assert.NoError(t, node.close()) }) + t.Run("when behind but SyncThreshold=0, stay alive", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{ + pollInterval: tests.TestInterval, + syncThreshold: 0, + selectionMode: NodeSelectionModeRoundRobin, + }, + rpc: rpc, + lggr: lggr, + }) + node.stateLatestBlockNumber = 20 + node.nLiveNodes = func() (count int, blockNumber int64, totalDifficulty *utils.Big) { + return 1, node.stateLatestBlockNumber + 100, utils.NewBigI(10) + } + rpc.On("ClientVersion", mock.Anything).Return("", nil) + node.declareAlive() + tests.AssertLogCountEventually(t, observedLogs, "Version poll successful", 2) + assert.Equal(t, nodeStateAlive, node.State()) + assert.NoError(t, node.close()) + }) + t.Run("when no new heads received for threshold, transitions to out of sync", func(t *testing.T) { t.Parallel() rpc := newMockNodeClient[types.ID, Head](t) @@ -254,8 +297,9 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { // we check that we were in out of sync state on first Dial call return node.State() == nodeStateUnreachable }) + assert.NoError(t, node.close()) }) - t.Run("when no new heads received for threshold and no nodes stay alive", func(t *testing.T) { + t.Run("when no new heads received for threshold but we are the last live node, forcibly stays alive", func(t *testing.T) { t.Parallel() rpc := newMockNodeClient[types.ID, Head](t) lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) @@ -271,6 +315,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node.declareAlive() tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("RPC endpoint detected out of sync; %s %s", msgCannotDisable, msgDegradedState)) assert.Equal(t, nodeStateAlive, node.State()) + assert.NoError(t, node.close()) }) t.Run("rpc closed head channel", func(t *testing.T) { @@ -298,6 +343,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node.declareAlive() tests.AssertLogEventually(t, observedLogs, "Subscription channel unexpectedly closed") assert.Equal(t, nodeStateUnreachable, node.State()) + assert.NoError(t, node.close()) }) t.Run("updates block number and difficulty on new head", func(t *testing.T) { @@ -332,10 +378,28 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { state, block, diff := node.StateAndLatest() return state == nodeStateAlive && block == expectedBlockNumber == diff.Equal(expectedDiff) }) - + assert.NoError(t, node.close()) }) } +func writeHeads(ctx context.Context, t *testing.T, ch chan<- Head, blockNumbers ...int64) { + defer func() { + if rerr := recover(); rerr != nil { + t.Error(rerr) + } + }() + for _, blockNumber := range blockNumbers { + h := newMockHead(t) + h.On("BlockNumber").Return(blockNumber) + h.On("BlockDifficulty").Return(utils.NewBigI(100)) + select { + case ch <- h: + case <-ctx.Done(): + return + } + } +} + func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { t.Parallel() @@ -344,10 +408,6 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { opts.rpc.On("Close").Return(nil).Once() // disconnects all on transfer to unreachable or outOfSync opts.rpc.On("DisconnectAll") - - t.Cleanup(func() { - assert.NoError(t, node.close()) - }) node.setState(nodeStateAlive) return node } @@ -363,6 +423,40 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { node.wg.Add(1) node.outOfSyncLoop(stubIsOutOfSync) }) + t.Run("on old blocks stays outOfSync and returns on close", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + + rpc.On("Dial", mock.Anything).Return(nil).Once() + // might be called multiple times + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + + outOfSyncSubscription := mocks.NewSubscription(t) + outOfSyncSubscription.On("Err").Return((<-chan error)(nil)) + outOfSyncSubscription.On("Unsubscribe").Once() + blockNumbers := []int64{7, 11, 13} + ctx, cancel := context.WithCancel(tests.Context(t)) + defer cancel() + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Run(func(args mock.Arguments) { + ch := args.Get(1).(chan<- Head) + go writeHeads(ctx, t, ch, blockNumbers...) + }).Return(outOfSyncSubscription, nil).Once() + rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() + + node.declareOutOfSync(func(num int64, td *utils.Big) bool { + return true + }) + tests.AssertLogCountEventually(t, observedLogs, msgReceivedBlock, len(blockNumbers)) + assert.Equal(t, nodeStateOutOfSync, node.State()) + assert.NoError(t, node.close()) + }) t.Run("if initial dial fails, transitions to unreachable", func(t *testing.T) { t.Parallel() rpc := newMockNodeClient[types.ID, Head](t) @@ -377,6 +471,7 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateUnreachable }) + assert.NoError(t, node.close()) }) t.Run("if fail to get chainID, transitions to invalidChainID", func(t *testing.T) { t.Parallel() @@ -393,6 +488,7 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateInvalidChainID }) + assert.NoError(t, node.close()) }) t.Run("if chainID does not match, transitions to invalidChainID", func(t *testing.T) { t.Parallel() @@ -411,6 +507,7 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateInvalidChainID }) + assert.NoError(t, node.close()) }) t.Run("if fails to subscribe, becomes unreachable", func(t *testing.T) { t.Parallel() @@ -431,6 +528,7 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateUnreachable }) + assert.NoError(t, node.close()) }) t.Run("on subscription termination becomes unreachable", func(t *testing.T) { t.Parallel() @@ -459,6 +557,7 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateUnreachable }) + assert.NoError(t, node.close()) }) t.Run("becomes unreachable if head channel is closed", func(t *testing.T) { t.Parallel() @@ -488,8 +587,9 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateUnreachable }) + assert.NoError(t, node.close()) }) - t.Run("becomes alive when receives block of sufficient height", func(t *testing.T) { + t.Run("becomes alive if it receives a newer head", func(t *testing.T) { t.Parallel() rpc := newMockNodeClient[types.ID, Head](t) nodeChainID := types.RandomID() @@ -510,24 +610,9 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { const highestBlock = 1000 ctx, cancel := context.WithCancel(tests.Context(t)) defer cancel() - writeHeads := func(ch chan<- Head) { - newHead := func(height int64) Head { - h := newMockHead(t) - h.On("BlockNumber").Return(height) - h.On("BlockDifficulty").Return(utils.NewBigI(100)) - return h - } - for _, head := range []Head{newHead(highestBlock - 1), newHead(highestBlock)} { - select { - case ch <- head: - case <-ctx.Done(): - return - } - } - } rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Run(func(args mock.Arguments) { ch := args.Get(1).(chan<- Head) - go writeHeads(ch) + go writeHeads(ctx, t, ch, highestBlock-1, highestBlock) }).Return(outOfSyncSubscription, nil).Once() rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() @@ -547,6 +632,7 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateAlive }) + assert.NoError(t, node.close()) }) t.Run("becomes alive if there is no other nodes", func(t *testing.T) { t.Parallel() @@ -586,6 +672,7 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateAlive }) + assert.NoError(t, node.close()) }) } @@ -598,9 +685,6 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { // disconnects all on transfer to unreachable opts.rpc.On("DisconnectAll") - t.Cleanup(func() { - assert.NoError(t, node.close()) - }) node.setState(nodeStateAlive) return node } @@ -626,6 +710,7 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")) node.declareUnreachable() tests.AssertLogCountEventually(t, observedLogs, "Failed to redial RPC node; still unreachable", 2) + assert.NoError(t, node.close()) }) t.Run("on failed chainID verification, keep trying", func(t *testing.T) { t.Parallel() @@ -644,6 +729,7 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { }).Return(nodeChainID, errors.New("failed to get chain id")) node.declareUnreachable() tests.AssertLogCountEventually(t, observedLogs, "Failed to redial RPC node; verify failed", 2) + assert.NoError(t, node.close()) }) t.Run("on chain ID mismatch transitions to invalidChainID", func(t *testing.T) { t.Parallel() @@ -661,6 +747,7 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateInvalidChainID }) + assert.NoError(t, node.close()) }) t.Run("on valid chain ID becomes alive", func(t *testing.T) { t.Parallel() @@ -686,6 +773,7 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateAlive }) + assert.NoError(t, node.close()) }) } @@ -696,9 +784,6 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { opts.rpc.On("Close").Return(nil).Once() opts.rpc.On("DisconnectAll") - t.Cleanup(func() { - assert.NoError(t, node.close()) - }) node.setState(nodeStateDialed) return node } @@ -729,6 +814,7 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateUnreachable }) + assert.NoError(t, node.close()) }) t.Run("on chainID mismatch keeps trying", func(t *testing.T) { t.Parallel() @@ -744,10 +830,11 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) node.declareInvalidChainID() - tests.AssertLogEventually(t, observedLogs, "Failed to verify RPC node; remote endpoint returned the wrong chain ID") + tests.AssertLogCountEventually(t, observedLogs, "Failed to verify RPC node; remote endpoint returned the wrong chain ID", 2) tests.AssertEventually(t, func() bool { return node.State() == nodeStateInvalidChainID }) + assert.NoError(t, node.close()) }) t.Run("on valid chainID becomes alive", func(t *testing.T) { t.Parallel() @@ -772,6 +859,7 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateAlive }) + assert.NoError(t, node.close()) }) } @@ -782,9 +870,6 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { node := newTestNode(t, opts) opts.rpc.On("Close").Return(nil).Once() - t.Cleanup(func() { - assert.NoError(t, node.close()) - }) return node } t.Run("if fails on initial dial, becomes unreachable", func(t *testing.T) { @@ -807,6 +892,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateUnreachable }) + assert.NoError(t, node.close()) }) t.Run("if chainID verification fails, becomes unreachable", func(t *testing.T) { t.Parallel() @@ -831,6 +917,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateUnreachable }) + assert.NoError(t, node.close()) }) t.Run("on chain ID mismatch transitions to invalidChainID", func(t *testing.T) { t.Parallel() @@ -851,6 +938,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateInvalidChainID }) + assert.NoError(t, node.close()) }) t.Run("on valid chain ID becomes alive", func(t *testing.T) { t.Parallel() @@ -877,6 +965,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { tests.AssertEventually(t, func() bool { return node.State() == nodeStateAlive }) + assert.NoError(t, node.close()) }) } diff --git a/common/client/send_only_node_test.go b/common/client/send_only_node_test.go index 60412f2ddca..bfe55153656 100644 --- a/common/client/send_only_node_test.go +++ b/common/client/send_only_node_test.go @@ -95,7 +95,7 @@ func TestStartSendOnlyNode(t *testing.T) { return s.State() == nodeStateAlive }) }) - t.Run("Can remover from chainID mismatch", func(t *testing.T) { + t.Run("Can recover from chainID mismatch", func(t *testing.T) { t.Parallel() lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) client := newMockSendOnlyClient[types.ID](t) @@ -118,4 +118,22 @@ func TestStartSendOnlyNode(t *testing.T) { return s.State() == nodeStateAlive }) }) + t.Run("Start with Random ChainID", func(t *testing.T) { + t.Parallel() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) + client := newMockSendOnlyClient[types.ID](t) + client.On("Close").Once() + client.On("DialHTTP").Return(nil).Once() + configuredChainID := types.RandomID() + client.On("ChainID", mock.Anything).Return(configuredChainID, nil) + s := NewSendOnlyNode(lggr, url.URL{}, t.Name(), configuredChainID, client) + + defer func() { assert.NoError(t, s.Close()) }() + err := s.Start(tests.Context(t)) + assert.NoError(t, err) + tests.AssertEventually(t, func() bool { + return s.State() == nodeStateAlive + }) + assert.Equal(t, 0, observedLogs.Len()) // No warnings expected + }) }