Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
DylanTinianov committed Jun 27, 2024
1 parent 511a7a2 commit 8886d0c
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .changeset/orange-feet-share.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
Implemented new EVM Multinode design. The Multinode is now called by chain clients to retrieve the best healthy RPC rather than performing RPC calls directly.
Multinode performs verious health checks on RPCs, and in turn increases reliability.
This new EVM Multinode design will also be implemented for non-EVMs chains in the future.
#updated #changed
#updated #changed #internal
13 changes: 6 additions & 7 deletions common/client/multi_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ func (c *MultiNode[CHAIN_ID, RPC_CLIENT]) DoAll(ctx context.Context, do func(ctx
if n.State() != NodeStateAlive {
continue
}
if do(ctx, n.RPC(), false) {
callsCompleted++
}
do(ctx, n.RPC(), false)
callsCompleted++
}
if callsCompleted == 0 {
return fmt.Errorf("no calls were completed")
return ErroringNodeError
}

for _, n := range c.sendOnlyNodes {
Expand All @@ -118,7 +117,7 @@ func (c *MultiNode[CHAIN_ID, RPC_CLIENT]) DoAll(ctx context.Context, do func(ctx
if n.State() != NodeStateAlive {
continue
}
do(ctx, n.RPC(), false)
do(ctx, n.RPC(), true)
}
return nil
}
Expand Down Expand Up @@ -167,11 +166,11 @@ func (c *MultiNode[CHAIN_ID, RPC_CLIENT]) HighestChainInfo() ChainInfo {
return ch
}

// Dial starts every node in the pool
// Start starts every node in the pool
//
// Nodes handle their own redialing and runloops, so this function does not
// return any error if the nodes aren't available
func (c *MultiNode[CHAIN_ID, RPC_CLIENT]) Dial(ctx context.Context) error {
func (c *MultiNode[CHAIN_ID, RPC_CLIENT]) Start(ctx context.Context) error {
return c.StartOnce("MultiNode", func() (merr error) {
if len(c.primaryNodes) == 0 {
return fmt.Errorf("no available nodes for chain %s", c.chainID.String())
Expand Down
24 changes: 12 additions & 12 deletions common/client/multi_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestMultiNode_Dial(t *testing.T) {
selectionMode: NodeSelectionModeRoundRobin,
chainID: types.RandomID(),
})
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, fmt.Sprintf("no available nodes for chain %s", mn.chainID.String()))
})
t.Run("Fails with wrong node's chainID", func(t *testing.T) {
Expand All @@ -89,7 +89,7 @@ func TestMultiNode_Dial(t *testing.T) {
chainID: multiNodeChainID,
nodes: []Node[types.ID, multiNodeRPCClient]{node},
})
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, fmt.Sprintf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", nodeName, nodeChainID, mn.chainID))
})
t.Run("Fails if node fails", func(t *testing.T) {
Expand All @@ -105,7 +105,7 @@ func TestMultiNode_Dial(t *testing.T) {
chainID: chainID,
nodes: []Node[types.ID, multiNodeRPCClient]{node},
})
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, expectedError.Error())
})

Expand All @@ -124,7 +124,7 @@ func TestMultiNode_Dial(t *testing.T) {
chainID: chainID,
nodes: []Node[types.ID, multiNodeRPCClient]{node1, node2},
})
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, expectedError.Error())
})
t.Run("Fails with wrong send only node's chainID", func(t *testing.T) {
Expand All @@ -143,7 +143,7 @@ func TestMultiNode_Dial(t *testing.T) {
nodes: []Node[types.ID, multiNodeRPCClient]{node},
sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{sendOnly},
})
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, fmt.Sprintf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", sendOnlyName, sendOnlyChainID, mn.chainID))
})

Expand All @@ -170,7 +170,7 @@ func TestMultiNode_Dial(t *testing.T) {
nodes: []Node[types.ID, multiNodeRPCClient]{node},
sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{sendOnly1, sendOnly2},
})
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, expectedError.Error())
})
t.Run("Starts successfully with healthy nodes", func(t *testing.T) {
Expand All @@ -184,7 +184,7 @@ func TestMultiNode_Dial(t *testing.T) {
sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{newHealthySendOnly(t, chainID)},
})
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
require.NoError(t, err)
selectedNode, err := mn.selectNode()
require.NoError(t, err)
Expand All @@ -208,7 +208,7 @@ func TestMultiNode_Report(t *testing.T) {
})
mn.reportInterval = tests.TestInterval
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
require.NoError(t, err)
tests.AssertLogCountEventually(t, observedLogs, "At least one primary node is dead: 1/2 nodes are alive", 2)
})
Expand All @@ -225,7 +225,7 @@ func TestMultiNode_Report(t *testing.T) {
})
mn.reportInterval = tests.TestInterval
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
require.NoError(t, err)
tests.AssertLogCountEventually(t, observedLogs, "no primary nodes available: 0/1 nodes are alive", 2)
err = mn.Healthy()
Expand All @@ -248,7 +248,7 @@ func TestMultiNode_CheckLease(t *testing.T) {
nodes: []Node[types.ID, multiNodeRPCClient]{node},
})
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
require.NoError(t, err)
tests.RequireLogMessage(t, observedLogs, "Best node switching is disabled")
})
Expand All @@ -265,7 +265,7 @@ func TestMultiNode_CheckLease(t *testing.T) {
leaseDuration: 0,
})
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
require.NoError(t, err)
tests.RequireLogMessage(t, observedLogs, "Best node switching is disabled")
})
Expand All @@ -287,7 +287,7 @@ func TestMultiNode_CheckLease(t *testing.T) {
})
defer func() { assert.NoError(t, mn.Close()) }()
mn.nodeSelector = nodeSelector
err := mn.Dial(tests.Context(t))
err := mn.Start(tests.Context(t))
require.NoError(t, err)
tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("Switching to best node from %q to %q", node.String(), bestNode.String()))
tests.AssertEventually(t, func() bool {
Expand Down
2 changes: 0 additions & 2 deletions common/client/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ func (n *node[CHAIN_ID, HEAD, RPC_CLIENT]) String() string {
}

func (n *node[CHAIN_ID, HEAD, RPC_CLIENT]) ConfiguredChainID() (chainID CHAIN_ID) {
n.stateMu.RLock()
defer n.stateMu.RUnlock()
return n.chainID
}

Expand Down
10 changes: 1 addition & 9 deletions common/client/node_lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) {

expectedError := errors.New("failed to subscribe to rpc")
rpc.On("SubscribeToHeads", mock.Anything).Return(nil, nil, expectedError).Once()
rpc.On("UnsubscribeAllExcept", nil, nil)
// might be called in unreachable loop
rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe()
node.declareAlive()
Expand All @@ -74,8 +73,6 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) {
sub.On("Err").Return((<-chan error)(errChan)).Once()
sub.On("Unsubscribe").Once()
rpc.On("SubscribeToHeads", mock.Anything).Return(nil, sub, nil).Once()
// disconnects all on transfer to unreachable
rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Once()
// might be called in unreachable loop
rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe()
node.declareAlive()
Expand Down Expand Up @@ -1143,6 +1140,7 @@ func TestUnit_NodeLifecycle_start(t *testing.T) {

newNode := func(t *testing.T, opts testNodeOpts) testNode {
node := newTestNode(t, opts)
opts.rpc.On("UnsubscribeAllExcept", nil, nil)
opts.rpc.On("Close").Return(nil).Once()

return node
Expand All @@ -1161,7 +1159,6 @@ func TestUnit_NodeLifecycle_start(t *testing.T) {

rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial"))
// disconnects all on transfer to unreachable
rpc.On("UnsubscribeAllExcept", nil, nil)
err := node.Start(tests.Context(t))
assert.NoError(t, err)
tests.AssertLogEventually(t, observedLogs, "Dial failed: Node is unreachable")
Expand All @@ -1186,7 +1183,6 @@ func TestUnit_NodeLifecycle_start(t *testing.T) {
assert.Equal(t, NodeStateDialed, node.State())
}).Return(nodeChainID, errors.New("failed to get chain id"))
// disconnects all on transfer to unreachable
rpc.On("UnsubscribeAllExcept", nil, nil)
err := node.Start(tests.Context(t))
assert.NoError(t, err)
tests.AssertLogEventually(t, observedLogs, "Failed to verify chain ID for node")
Expand All @@ -1208,7 +1204,6 @@ func TestUnit_NodeLifecycle_start(t *testing.T) {
rpc.On("Dial", mock.Anything).Return(nil)
rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil)
// disconnects all on transfer to unreachable
rpc.On("UnsubscribeAllExcept", nil, nil)
err := node.Start(tests.Context(t))
assert.NoError(t, err)
tests.AssertEventually(t, func() bool {
Expand All @@ -1234,7 +1229,6 @@ func TestUnit_NodeLifecycle_start(t *testing.T) {
}).Return(nodeChainID, nil).Once()
rpc.On("IsSyncing", mock.Anything).Return(false, errors.New("failed to check syncing status"))
// disconnects all on transfer to unreachable
rpc.On("UnsubscribeAllExcept", nil, nil)
// fail to redial to stay in unreachable state
rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial"))
err := node.Start(tests.Context(t))
Expand All @@ -1259,7 +1253,6 @@ func TestUnit_NodeLifecycle_start(t *testing.T) {
rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil)
rpc.On("IsSyncing", mock.Anything).Return(true, nil)
// disconnects all on transfer to unreachable
rpc.On("UnsubscribeAllExcept", nil, nil)
err := node.Start(tests.Context(t))
assert.NoError(t, err)
tests.AssertEventually(t, func() bool {
Expand Down Expand Up @@ -1459,7 +1452,6 @@ func TestUnit_NodeLifecycle_SyncingLoop(t *testing.T) {
opts.config.nodeIsSyncingEnabled = true
node := newTestNode(t, opts)
opts.rpc.On("Close").Return(nil).Once()
opts.rpc.On("UnsubscribeAllExcept", nil, nil)

node.setState(NodeStateDialed)
return node
Expand Down
19 changes: 2 additions & 17 deletions core/chains/evm/client/chain_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ type Client interface {
Close()
// ChainID locally stored for quick access
ConfiguredChainID() *big.Int
// ChainID RPC call
ChainID() (*big.Int, error)

// NodeStates returns a map of node Name->node state
// It might be nil or empty, e.g. for mock clients etc
Expand Down Expand Up @@ -231,16 +229,6 @@ func (c *chainClient) PendingCallContract(ctx context.Context, msg ethereum.Call
return rpc.PendingCallContract(ctx, msg)
}

// TODO-1663: change this to actual ChainID() call once client.go is deprecated.
func (c *chainClient) ChainID() (*big.Int, error) {
rpc, err := c.multiNode.SelectRPC()
if err != nil {
return nil, err
}
// TODO: Progagate context
return rpc.ChainID(context.Background())
}

func (c *chainClient) Close() {
_ = c.multiNode.Close()
}
Expand All @@ -258,7 +246,7 @@ func (c *chainClient) ConfiguredChainID() *big.Int {
}

func (c *chainClient) Dial(ctx context.Context) error {
return c.multiNode.Dial(ctx)
return c.multiNode.Start(ctx)
}

func (c *chainClient) EstimateGas(ctx context.Context, call ethereum.CallMsg) (uint64, error) {
Expand Down Expand Up @@ -390,10 +378,7 @@ func (c *chainClient) SubscribeNewHead(ctx context.Context) (<-chan *evmtypes.He
if err != nil {
return nil, nil, err
}
chainID, err := c.ChainID()
if err != nil {
return nil, nil, err
}
chainID := c.ConfiguredChainID()
forwardCh, csf := newChainIDSubForwarder(chainID, ch)
err = csf.start(sub, err)
if err != nil {
Expand Down
4 changes: 1 addition & 3 deletions core/chains/evm/client/chain_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,11 +828,9 @@ func TestEthClient_ErroringClient(t *testing.T) {
_, err = erroringClient.CallContract(ctx, ethereum.CallMsg{}, nil)
require.Equal(t, err, commonclient.ErroringNodeError)

// TODO-1663: test actual ChainID() call once client.go is deprecated.
id, err := erroringClient.ChainID()
id := erroringClient.ConfiguredChainID()
var expected *big.Int
require.Equal(t, id, expected)
require.Equal(t, err, commonclient.ErroringNodeError)

_, err = erroringClient.CodeAt(ctx, common.Address{}, nil)
require.Equal(t, err, commonclient.ErroringNodeError)
Expand Down

0 comments on commit 8886d0c

Please sign in to comment.