diff --git a/.changeset/orange-feet-share.md b/.changeset/orange-feet-share.md new file mode 100644 index 00000000000..1df7e85ca9e --- /dev/null +++ b/.changeset/orange-feet-share.md @@ -0,0 +1,8 @@ +--- +"chainlink": minor +--- + +Implemented new EVM Multinode design. The Multinode is now called by chain clients to retrieve the best healthy RPC rather than performing RPC calls directly. +Multinode performs verious health checks on RPCs, and in turn increases reliability. +This new EVM Multinode design will also be implemented for non-EVMs chains in the future. +#updated #changed #internal diff --git a/common/client/mock_node_client_test.go b/common/client/mock_node_client_test.go deleted file mode 100644 index a7c0e4dbdb8..00000000000 --- a/common/client/mock_node_client_test.go +++ /dev/null @@ -1,273 +0,0 @@ -// Code generated by mockery v2.43.2. DO NOT EDIT. - -package client - -import ( - context "context" - - types "github.com/smartcontractkit/chainlink/v2/common/types" - mock "github.com/stretchr/testify/mock" -) - -// mockNodeClient is an autogenerated mock type for the NodeClient type -type mockNodeClient[CHAIN_ID types.ID, HEAD Head] struct { - mock.Mock -} - -// ChainID provides a mock function with given fields: ctx -func (_m *mockNodeClient[CHAIN_ID, HEAD]) ChainID(ctx context.Context) (CHAIN_ID, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for ChainID") - } - - var r0 CHAIN_ID - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (CHAIN_ID, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) CHAIN_ID); ok { - r0 = rf(ctx) - } else { - r0 = ret.Get(0).(CHAIN_ID) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ClientVersion provides a mock function with given fields: _a0 -func (_m *mockNodeClient[CHAIN_ID, HEAD]) ClientVersion(_a0 context.Context) (string, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for ClientVersion") - } - - var r0 string - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (string, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) string); ok { - r0 = rf(_a0) - } else { - r0 = ret.Get(0).(string) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Close provides a mock function with given fields: -func (_m *mockNodeClient[CHAIN_ID, HEAD]) Close() { - _m.Called() -} - -// Dial provides a mock function with given fields: ctx -func (_m *mockNodeClient[CHAIN_ID, HEAD]) Dial(ctx context.Context) error { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Dial") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = rf(ctx) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// DialHTTP provides a mock function with given fields: -func (_m *mockNodeClient[CHAIN_ID, HEAD]) DialHTTP() error { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for DialHTTP") - } - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// DisconnectAll provides a mock function with given fields: -func (_m *mockNodeClient[CHAIN_ID, HEAD]) DisconnectAll() { - _m.Called() -} - -// GetInterceptedChainInfo provides a mock function with given fields: -func (_m *mockNodeClient[CHAIN_ID, HEAD]) GetInterceptedChainInfo() (ChainInfo, ChainInfo) { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for GetInterceptedChainInfo") - } - - var r0 ChainInfo - var r1 ChainInfo - if rf, ok := ret.Get(0).(func() (ChainInfo, ChainInfo)); ok { - return rf() - } - if rf, ok := ret.Get(0).(func() ChainInfo); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(ChainInfo) - } - - if rf, ok := ret.Get(1).(func() ChainInfo); ok { - r1 = rf() - } else { - r1 = ret.Get(1).(ChainInfo) - } - - return r0, r1 -} - -// IsSyncing provides a mock function with given fields: ctx -func (_m *mockNodeClient[CHAIN_ID, HEAD]) IsSyncing(ctx context.Context) (bool, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for IsSyncing") - } - - var r0 bool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (bool, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) bool); ok { - r0 = rf(ctx) - } else { - r0 = ret.Get(0).(bool) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// LatestFinalizedBlock provides a mock function with given fields: ctx -func (_m *mockNodeClient[CHAIN_ID, HEAD]) LatestFinalizedBlock(ctx context.Context) (HEAD, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for LatestFinalizedBlock") - } - - var r0 HEAD - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (HEAD, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) HEAD); ok { - r0 = rf(ctx) - } else { - r0 = ret.Get(0).(HEAD) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// SetAliveLoopSub provides a mock function with given fields: _a0 -func (_m *mockNodeClient[CHAIN_ID, HEAD]) SetAliveLoopSub(_a0 types.Subscription) { - _m.Called(_a0) -} - -// SubscribeNewHead provides a mock function with given fields: ctx, channel -func (_m *mockNodeClient[CHAIN_ID, HEAD]) SubscribeNewHead(ctx context.Context, channel chan<- HEAD) (types.Subscription, error) { - ret := _m.Called(ctx, channel) - - if len(ret) == 0 { - panic("no return value specified for SubscribeNewHead") - } - - var r0 types.Subscription - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, chan<- HEAD) (types.Subscription, error)); ok { - return rf(ctx, channel) - } - if rf, ok := ret.Get(0).(func(context.Context, chan<- HEAD) types.Subscription); ok { - r0 = rf(ctx, channel) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(types.Subscription) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, chan<- HEAD) error); ok { - r1 = rf(ctx, channel) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// SubscribersCount provides a mock function with given fields: -func (_m *mockNodeClient[CHAIN_ID, HEAD]) SubscribersCount() int32 { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for SubscribersCount") - } - - var r0 int32 - if rf, ok := ret.Get(0).(func() int32); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int32) - } - - return r0 -} - -// UnsubscribeAllExceptAliveLoop provides a mock function with given fields: -func (_m *mockNodeClient[CHAIN_ID, HEAD]) UnsubscribeAllExceptAliveLoop() { - _m.Called() -} - -// newMockNodeClient creates a new instance of mockNodeClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func newMockNodeClient[CHAIN_ID types.ID, HEAD Head](t interface { - mock.TestingT - Cleanup(func()) -}) *mockNodeClient[CHAIN_ID, HEAD] { - mock := &mockNodeClient[CHAIN_ID, HEAD]{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/common/client/mock_node_selector_test.go b/common/client/mock_node_selector_test.go index 783fc50c290..e9387779bab 100644 --- a/common/client/mock_node_selector_test.go +++ b/common/client/mock_node_selector_test.go @@ -8,12 +8,12 @@ import ( ) // mockNodeSelector is an autogenerated mock type for the NodeSelector type -type mockNodeSelector[CHAIN_ID types.ID, HEAD Head, RPC NodeClient[CHAIN_ID, HEAD]] struct { +type mockNodeSelector[CHAIN_ID types.ID, RPC interface{}] struct { mock.Mock } // Name provides a mock function with given fields: -func (_m *mockNodeSelector[CHAIN_ID, HEAD, RPC]) Name() string { +func (_m *mockNodeSelector[CHAIN_ID, RPC]) Name() string { ret := _m.Called() if len(ret) == 0 { @@ -31,19 +31,19 @@ func (_m *mockNodeSelector[CHAIN_ID, HEAD, RPC]) Name() string { } // Select provides a mock function with given fields: -func (_m *mockNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { +func (_m *mockNodeSelector[CHAIN_ID, RPC]) Select() Node[CHAIN_ID, RPC] { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for Select") } - var r0 Node[CHAIN_ID, HEAD, RPC] - if rf, ok := ret.Get(0).(func() Node[CHAIN_ID, HEAD, RPC]); ok { + var r0 Node[CHAIN_ID, RPC] + if rf, ok := ret.Get(0).(func() Node[CHAIN_ID, RPC]); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(Node[CHAIN_ID, HEAD, RPC]) + r0 = ret.Get(0).(Node[CHAIN_ID, RPC]) } } @@ -52,11 +52,11 @@ func (_m *mockNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, R // newMockNodeSelector creates a new instance of mockNodeSelector. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func newMockNodeSelector[CHAIN_ID types.ID, HEAD Head, RPC NodeClient[CHAIN_ID, HEAD]](t interface { +func newMockNodeSelector[CHAIN_ID types.ID, RPC interface{}](t interface { mock.TestingT Cleanup(func()) -}) *mockNodeSelector[CHAIN_ID, HEAD, RPC] { - mock := &mockNodeSelector[CHAIN_ID, HEAD, RPC]{} +}) *mockNodeSelector[CHAIN_ID, RPC] { + mock := &mockNodeSelector[CHAIN_ID, RPC]{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/common/client/mock_node_test.go b/common/client/mock_node_test.go index 5109eb6bb90..2a051c07ec1 100644 --- a/common/client/mock_node_test.go +++ b/common/client/mock_node_test.go @@ -10,12 +10,12 @@ import ( ) // mockNode is an autogenerated mock type for the Node type -type mockNode[CHAIN_ID types.ID, HEAD Head, RPC NodeClient[CHAIN_ID, HEAD]] struct { +type mockNode[CHAIN_ID types.ID, RPC interface{}] struct { mock.Mock } // Close provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Close() error { +func (_m *mockNode[CHAIN_ID, RPC]) Close() error { ret := _m.Called() if len(ret) == 0 { @@ -33,7 +33,7 @@ func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Close() error { } // ConfiguredChainID provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) ConfiguredChainID() CHAIN_ID { +func (_m *mockNode[CHAIN_ID, RPC]) ConfiguredChainID() CHAIN_ID { ret := _m.Called() if len(ret) == 0 { @@ -51,7 +51,7 @@ func (_m *mockNode[CHAIN_ID, HEAD, RPC]) ConfiguredChainID() CHAIN_ID { } // HighestUserObservations provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) HighestUserObservations() ChainInfo { +func (_m *mockNode[CHAIN_ID, RPC]) HighestUserObservations() ChainInfo { ret := _m.Called() if len(ret) == 0 { @@ -69,7 +69,7 @@ func (_m *mockNode[CHAIN_ID, HEAD, RPC]) HighestUserObservations() ChainInfo { } // Name provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Name() string { +func (_m *mockNode[CHAIN_ID, RPC]) Name() string { ret := _m.Called() if len(ret) == 0 { @@ -87,7 +87,7 @@ func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Name() string { } // Order provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Order() int32 { +func (_m *mockNode[CHAIN_ID, RPC]) Order() int32 { ret := _m.Called() if len(ret) == 0 { @@ -105,7 +105,7 @@ func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Order() int32 { } // RPC provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) RPC() RPC { +func (_m *mockNode[CHAIN_ID, RPC]) RPC() RPC { ret := _m.Called() if len(ret) == 0 { @@ -123,12 +123,12 @@ func (_m *mockNode[CHAIN_ID, HEAD, RPC]) RPC() RPC { } // SetPoolChainInfoProvider provides a mock function with given fields: _a0 -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) SetPoolChainInfoProvider(_a0 PoolChainInfoProvider) { +func (_m *mockNode[CHAIN_ID, RPC]) SetPoolChainInfoProvider(_a0 PoolChainInfoProvider) { _m.Called(_a0) } // Start provides a mock function with given fields: _a0 -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Start(_a0 context.Context) error { +func (_m *mockNode[CHAIN_ID, RPC]) Start(_a0 context.Context) error { ret := _m.Called(_a0) if len(ret) == 0 { @@ -146,40 +146,40 @@ func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Start(_a0 context.Context) error { } // State provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) State() nodeState { +func (_m *mockNode[CHAIN_ID, RPC]) State() NodeState { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for State") } - var r0 nodeState - if rf, ok := ret.Get(0).(func() nodeState); ok { + var r0 NodeState + if rf, ok := ret.Get(0).(func() NodeState); ok { r0 = rf() } else { - r0 = ret.Get(0).(nodeState) + r0 = ret.Get(0).(NodeState) } return r0 } // StateAndLatest provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) StateAndLatest() (nodeState, ChainInfo) { +func (_m *mockNode[CHAIN_ID, RPC]) StateAndLatest() (NodeState, ChainInfo) { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for StateAndLatest") } - var r0 nodeState + var r0 NodeState var r1 ChainInfo - if rf, ok := ret.Get(0).(func() (nodeState, ChainInfo)); ok { + if rf, ok := ret.Get(0).(func() (NodeState, ChainInfo)); ok { return rf() } - if rf, ok := ret.Get(0).(func() nodeState); ok { + if rf, ok := ret.Get(0).(func() NodeState); ok { r0 = rf() } else { - r0 = ret.Get(0).(nodeState) + r0 = ret.Get(0).(NodeState) } if rf, ok := ret.Get(1).(func() ChainInfo); ok { @@ -192,7 +192,7 @@ func (_m *mockNode[CHAIN_ID, HEAD, RPC]) StateAndLatest() (nodeState, ChainInfo) } // String provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) String() string { +func (_m *mockNode[CHAIN_ID, RPC]) String() string { ret := _m.Called() if len(ret) == 0 { @@ -209,36 +209,18 @@ func (_m *mockNode[CHAIN_ID, HEAD, RPC]) String() string { return r0 } -// SubscribersCount provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) SubscribersCount() int32 { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for SubscribersCount") - } - - var r0 int32 - if rf, ok := ret.Get(0).(func() int32); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int32) - } - - return r0 -} - // UnsubscribeAllExceptAliveLoop provides a mock function with given fields: -func (_m *mockNode[CHAIN_ID, HEAD, RPC]) UnsubscribeAllExceptAliveLoop() { +func (_m *mockNode[CHAIN_ID, RPC]) UnsubscribeAllExceptAliveLoop() { _m.Called() } // newMockNode creates a new instance of mockNode. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func newMockNode[CHAIN_ID types.ID, HEAD Head, RPC NodeClient[CHAIN_ID, HEAD]](t interface { +func newMockNode[CHAIN_ID types.ID, RPC interface{}](t interface { mock.TestingT Cleanup(func()) -}) *mockNode[CHAIN_ID, HEAD, RPC] { - mock := &mockNode[CHAIN_ID, HEAD, RPC]{} +}) *mockNode[CHAIN_ID, RPC] { + mock := &mockNode[CHAIN_ID, RPC]{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/common/client/mock_rpc_client_test.go b/common/client/mock_rpc_client_test.go new file mode 100644 index 00000000000..c1204ca5914 --- /dev/null +++ b/common/client/mock_rpc_client_test.go @@ -0,0 +1,243 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package client + +import ( + context "context" + + types "github.com/smartcontractkit/chainlink/v2/common/types" + mock "github.com/stretchr/testify/mock" +) + +// mockRPCClient is an autogenerated mock type for the RPCClient type +type mockRPCClient[CHAIN_ID types.ID, HEAD Head] struct { + mock.Mock +} + +// ChainID provides a mock function with given fields: ctx +func (_m *mockRPCClient[CHAIN_ID, HEAD]) ChainID(ctx context.Context) (CHAIN_ID, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ChainID") + } + + var r0 CHAIN_ID + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (CHAIN_ID, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) CHAIN_ID); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(CHAIN_ID) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Close provides a mock function with given fields: +func (_m *mockRPCClient[CHAIN_ID, HEAD]) Close() { + _m.Called() +} + +// Dial provides a mock function with given fields: ctx +func (_m *mockRPCClient[CHAIN_ID, HEAD]) Dial(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Dial") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetInterceptedChainInfo provides a mock function with given fields: +func (_m *mockRPCClient[CHAIN_ID, HEAD]) GetInterceptedChainInfo() (ChainInfo, ChainInfo) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetInterceptedChainInfo") + } + + var r0 ChainInfo + var r1 ChainInfo + if rf, ok := ret.Get(0).(func() (ChainInfo, ChainInfo)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() ChainInfo); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(ChainInfo) + } + + if rf, ok := ret.Get(1).(func() ChainInfo); ok { + r1 = rf() + } else { + r1 = ret.Get(1).(ChainInfo) + } + + return r0, r1 +} + +// IsSyncing provides a mock function with given fields: ctx +func (_m *mockRPCClient[CHAIN_ID, HEAD]) IsSyncing(ctx context.Context) (bool, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for IsSyncing") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (bool, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) bool); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Ping provides a mock function with given fields: _a0 +func (_m *mockRPCClient[CHAIN_ID, HEAD]) Ping(_a0 context.Context) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Ping") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SubscribeToFinalizedHeads provides a mock function with given fields: ctx +func (_m *mockRPCClient[CHAIN_ID, HEAD]) SubscribeToFinalizedHeads(ctx context.Context) (<-chan HEAD, types.Subscription, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for SubscribeToFinalizedHeads") + } + + var r0 <-chan HEAD + var r1 types.Subscription + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) (<-chan HEAD, types.Subscription, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) <-chan HEAD); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan HEAD) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) types.Subscription); ok { + r1 = rf(ctx) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(types.Subscription) + } + } + + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// SubscribeToHeads provides a mock function with given fields: ctx +func (_m *mockRPCClient[CHAIN_ID, HEAD]) SubscribeToHeads(ctx context.Context) (<-chan HEAD, types.Subscription, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for SubscribeToHeads") + } + + var r0 <-chan HEAD + var r1 types.Subscription + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) (<-chan HEAD, types.Subscription, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) <-chan HEAD); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan HEAD) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) types.Subscription); ok { + r1 = rf(ctx) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(types.Subscription) + } + } + + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// UnsubscribeAllExcept provides a mock function with given fields: subs +func (_m *mockRPCClient[CHAIN_ID, HEAD]) UnsubscribeAllExcept(subs ...types.Subscription) { + _va := make([]interface{}, len(subs)) + for _i := range subs { + _va[_i] = subs[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + _m.Called(_ca...) +} + +// newMockRPCClient creates a new instance of mockRPCClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockRPCClient[CHAIN_ID types.ID, HEAD Head](t interface { + mock.TestingT + Cleanup(func()) +}) *mockRPCClient[CHAIN_ID, HEAD] { + mock := &mockRPCClient[CHAIN_ID, HEAD]{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/common/client/mock_rpc_test.go b/common/client/mock_rpc_test.go index 81bac04547d..b2e65998785 100644 --- a/common/client/mock_rpc_test.go +++ b/common/client/mock_rpc_test.go @@ -665,34 +665,43 @@ func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS return r0 } -// SubscribeNewHead provides a mock function with given fields: ctx, channel -func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, BATCH_ELEM]) SubscribeNewHead(ctx context.Context, channel chan<- HEAD) (types.Subscription, error) { - ret := _m.Called(ctx, channel) +// SubscribeNewHead provides a mock function with given fields: ctx +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, BATCH_ELEM]) SubscribeNewHead(ctx context.Context) (<-chan HEAD, types.Subscription, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for SubscribeNewHead") } - var r0 types.Subscription - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, chan<- HEAD) (types.Subscription, error)); ok { - return rf(ctx, channel) + var r0 <-chan HEAD + var r1 types.Subscription + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) (<-chan HEAD, types.Subscription, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func(context.Context, chan<- HEAD) types.Subscription); ok { - r0 = rf(ctx, channel) + if rf, ok := ret.Get(0).(func(context.Context) <-chan HEAD); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.Subscription) + r0 = ret.Get(0).(<-chan HEAD) } } - if rf, ok := ret.Get(1).(func(context.Context, chan<- HEAD) error); ok { - r1 = rf(ctx, channel) + if rf, ok := ret.Get(1).(func(context.Context) types.Subscription); ok { + r1 = rf(ctx) } else { - r1 = ret.Error(1) + if ret.Get(1) != nil { + r1 = ret.Get(1).(types.Subscription) + } } - return r0, r1 + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 } // SubscribersCount provides a mock function with given fields: diff --git a/common/client/mock_send_only_client_test.go b/common/client/mock_send_only_client_test.go index 606506e657a..fbe37f7ca08 100644 --- a/common/client/mock_send_only_client_test.go +++ b/common/client/mock_send_only_client_test.go @@ -47,17 +47,17 @@ func (_m *mockSendOnlyClient[CHAIN_ID]) Close() { _m.Called() } -// DialHTTP provides a mock function with given fields: -func (_m *mockSendOnlyClient[CHAIN_ID]) DialHTTP() error { - ret := _m.Called() +// Dial provides a mock function with given fields: ctx +func (_m *mockSendOnlyClient[CHAIN_ID]) Dial(ctx context.Context) error { + ret := _m.Called(ctx) if len(ret) == 0 { - panic("no return value specified for DialHTTP") + panic("no return value specified for Dial") } var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) } else { r0 = ret.Error(0) } diff --git a/common/client/mock_send_only_node_test.go b/common/client/mock_send_only_node_test.go index a39df992aef..4fc51fe5c64 100644 --- a/common/client/mock_send_only_node_test.go +++ b/common/client/mock_send_only_node_test.go @@ -10,7 +10,7 @@ import ( ) // mockSendOnlyNode is an autogenerated mock type for the SendOnlyNode type -type mockSendOnlyNode[CHAIN_ID types.ID, RPC sendOnlyClient[CHAIN_ID]] struct { +type mockSendOnlyNode[CHAIN_ID types.ID, RPC interface{}] struct { mock.Mock } @@ -105,18 +105,18 @@ func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) Start(_a0 context.Context) error { } // State provides a mock function with given fields: -func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) State() nodeState { +func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) State() NodeState { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for State") } - var r0 nodeState - if rf, ok := ret.Get(0).(func() nodeState); ok { + var r0 NodeState + if rf, ok := ret.Get(0).(func() NodeState); ok { r0 = rf() } else { - r0 = ret.Get(0).(nodeState) + r0 = ret.Get(0).(NodeState) } return r0 @@ -142,7 +142,7 @@ func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) String() string { // newMockSendOnlyNode creates a new instance of mockSendOnlyNode. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func newMockSendOnlyNode[CHAIN_ID types.ID, RPC sendOnlyClient[CHAIN_ID]](t interface { +func newMockSendOnlyNode[CHAIN_ID types.ID, RPC interface{}](t interface { mock.TestingT Cleanup(func()) }) *mockSendOnlyNode[CHAIN_ID, RPC] { diff --git a/common/client/models.go b/common/client/models.go index 8b616137669..6a6afe431e3 100644 --- a/common/client/models.go +++ b/common/client/models.go @@ -22,12 +22,6 @@ const ( sendTxReturnCodeLen // tracks the number of errors. Must always be last ) -// sendTxSevereErrors - error codes which signal that transaction would never be accepted in its current form by the node -var sendTxSevereErrors = []SendTxReturnCode{Fatal, Underpriced, Unsupported, ExceedsMaxFee, FeeOutOfValidRange, Unknown} - -// sendTxSuccessfulCodes - error codes which signal that transaction was accepted by the node -var sendTxSuccessfulCodes = []SendTxReturnCode{Successful, TransactionAlreadyKnown} - func (c SendTxReturnCode) String() string { switch c { case Successful: diff --git a/common/client/multi_node.go b/common/client/multi_node.go index 4d4ea925fe8..3bfa74c4393 100644 --- a/common/client/multi_node.go +++ b/common/client/multi_node.go @@ -3,19 +3,15 @@ package client import ( "context" "fmt" - "math" "math/big" - "slices" "sync" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" - feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" "github.com/smartcontractkit/chainlink/v2/common/types" ) @@ -25,11 +21,6 @@ var ( Name: "multi_node_states", Help: "The number of RPC nodes currently in the given state for the given chain", }, []string{"network", "chainId", "state"}) - // PromMultiNodeInvariantViolations reports violation of our assumptions - PromMultiNodeInvariantViolations = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "multi_node_invariant_violations", - Help: "The number of invariant violations", - }, []string{"network", "chainId", "invariant"}) ErroringNodeError = fmt.Errorf("no live nodes available") ) @@ -37,129 +28,57 @@ var ( // It also handles multiple node RPC connections simultaneously. type MultiNode[ CHAIN_ID types.ID, - SEQ types.Sequence, - ADDR types.Hashable, - BLOCK_HASH types.Hashable, - TX any, - TX_HASH types.Hashable, - EVENT any, - EVENT_OPS any, - TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], - FEE feetypes.Fee, - HEAD types.Head[BLOCK_HASH], - RPC_CLIENT RPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, BATCH_ELEM], - BATCH_ELEM any, -] interface { - clientAPI[ - CHAIN_ID, - SEQ, - ADDR, - BLOCK_HASH, - TX, - TX_HASH, - EVENT, - EVENT_OPS, - TX_RECEIPT, - FEE, - HEAD, - BATCH_ELEM, - ] - Close() error - NodeStates() map[string]string - SelectNodeRPC() (RPC_CLIENT, error) - - BatchCallContextAll(ctx context.Context, b []BATCH_ELEM) error - ConfiguredChainID() CHAIN_ID -} - -type multiNode[ - CHAIN_ID types.ID, - SEQ types.Sequence, - ADDR types.Hashable, - BLOCK_HASH types.Hashable, - TX any, - TX_HASH types.Hashable, - EVENT any, - EVENT_OPS any, - TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], - FEE feetypes.Fee, - HEAD types.Head[BLOCK_HASH], - RPC_CLIENT RPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, BATCH_ELEM], - BATCH_ELEM any, + RPC any, ] struct { services.StateMachine - nodes []Node[CHAIN_ID, HEAD, RPC_CLIENT] - sendonlys []SendOnlyNode[CHAIN_ID, RPC_CLIENT] + primaryNodes []Node[CHAIN_ID, RPC] + sendOnlyNodes []SendOnlyNode[CHAIN_ID, RPC] chainID CHAIN_ID lggr logger.SugaredLogger selectionMode string - noNewHeadsThreshold time.Duration - nodeSelector NodeSelector[CHAIN_ID, HEAD, RPC_CLIENT] + nodeSelector NodeSelector[CHAIN_ID, RPC] leaseDuration time.Duration leaseTicker *time.Ticker chainFamily string reportInterval time.Duration deathDeclarationDelay time.Duration - sendTxSoftTimeout time.Duration // defines max waiting time from first response til responses evaluation activeMu sync.RWMutex - activeNode Node[CHAIN_ID, HEAD, RPC_CLIENT] + activeNode Node[CHAIN_ID, RPC] chStop services.StopChan wg sync.WaitGroup - - classifySendTxError func(tx TX, err error) SendTxReturnCode } func NewMultiNode[ CHAIN_ID types.ID, - SEQ types.Sequence, - ADDR types.Hashable, - BLOCK_HASH types.Hashable, - TX any, - TX_HASH types.Hashable, - EVENT any, - EVENT_OPS any, - TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], - FEE feetypes.Fee, - HEAD types.Head[BLOCK_HASH], - RPC_CLIENT RPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, BATCH_ELEM], - BATCH_ELEM any, + RPC any, ]( lggr logger.Logger, - selectionMode string, - leaseDuration time.Duration, - noNewHeadsThreshold time.Duration, - nodes []Node[CHAIN_ID, HEAD, RPC_CLIENT], - sendonlys []SendOnlyNode[CHAIN_ID, RPC_CLIENT], - chainID CHAIN_ID, - chainFamily string, - classifySendTxError func(tx TX, err error) SendTxReturnCode, - sendTxSoftTimeout time.Duration, + selectionMode string, // type of the "best" RPC selector (e.g HighestHead, RoundRobin, etc.) + leaseDuration time.Duration, // defines interval on which new "best" RPC should be selected + primaryNodes []Node[CHAIN_ID, RPC], + sendOnlyNodes []SendOnlyNode[CHAIN_ID, RPC], + chainID CHAIN_ID, // configured chain ID (used to verify that passed primaryNodes belong to the same chain) + chainFamily string, // name of the chain family - used in the metrics deathDeclarationDelay time.Duration, -) MultiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM] { - nodeSelector := newNodeSelector(selectionMode, nodes) +) *MultiNode[CHAIN_ID, RPC] { + nodeSelector := newNodeSelector(selectionMode, primaryNodes) // Prometheus' default interval is 15s, set this to under 7.5s to avoid // aliasing (see: https://en.wikipedia.org/wiki/Nyquist_frequency) const reportInterval = 6500 * time.Millisecond - if sendTxSoftTimeout == 0 { - sendTxSoftTimeout = QueryTimeout / 2 - } - c := &multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]{ - nodes: nodes, - sendonlys: sendonlys, + c := &MultiNode[CHAIN_ID, RPC]{ + primaryNodes: primaryNodes, + sendOnlyNodes: sendOnlyNodes, chainID: chainID, lggr: logger.Sugared(lggr).Named("MultiNode").With("chainID", chainID.String()), selectionMode: selectionMode, - noNewHeadsThreshold: noNewHeadsThreshold, nodeSelector: nodeSelector, chStop: make(services.StopChan), leaseDuration: leaseDuration, chainFamily: chainFamily, - classifySendTxError: classifySendTxError, reportInterval: reportInterval, deathDeclarationDelay: deathDeclarationDelay, - sendTxSoftTimeout: sendTxSoftTimeout, } c.lggr.Debugf("The MultiNode is configured to use NodeSelectionMode: %s", selectionMode) @@ -167,17 +86,60 @@ func NewMultiNode[ return c } -// Dial starts every node in the pool +func (c *MultiNode[CHAIN_ID, RPC]) ChainID() CHAIN_ID { + return c.chainID +} + +func (c *MultiNode[CHAIN_ID, RPC]) DoAll(ctx context.Context, do func(ctx context.Context, rpc RPC, isSendOnly bool)) error { + callsCompleted := 0 + for _, n := range c.primaryNodes { + if ctx.Err() != nil { + return ctx.Err() + } + if n.State() != NodeStateAlive { + continue + } + do(ctx, n.RPC(), false) + callsCompleted++ + } + if callsCompleted == 0 { + return ErroringNodeError + } + + for _, n := range c.sendOnlyNodes { + if ctx.Err() != nil { + return ctx.Err() + } + if n.State() != NodeStateAlive { + continue + } + do(ctx, n.RPC(), true) + } + return nil +} + +func (c *MultiNode[CHAIN_ID, RPC]) NodeStates() map[string]NodeState { + states := map[string]NodeState{} + for _, n := range c.primaryNodes { + states[n.String()] = n.State() + } + for _, n := range c.sendOnlyNodes { + states[n.String()] = n.State() + } + return states +} + +// Start starts every node in the pool // // Nodes handle their own redialing and runloops, so this function does not // return any error if the nodes aren't available -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) Dial(ctx context.Context) error { +func (c *MultiNode[CHAIN_ID, RPC]) Start(ctx context.Context) error { return c.StartOnce("MultiNode", func() (merr error) { - if len(c.nodes) == 0 { + if len(c.primaryNodes) == 0 { return fmt.Errorf("no available nodes for chain %s", c.chainID.String()) } var ms services.MultiStart - for _, n := range c.nodes { + for _, n := range c.primaryNodes { if n.ConfiguredChainID().String() != c.chainID.String() { return ms.CloseBecause(fmt.Errorf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", n.String(), n.ConfiguredChainID().String(), c.chainID.String())) } @@ -187,7 +149,7 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP return err } } - for _, s := range c.sendonlys { + for _, s := range c.sendOnlyNodes { if s.ConfiguredChainID().String() != c.chainID.String() { return ms.CloseBecause(fmt.Errorf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", s.String(), s.ConfiguredChainID().String(), c.chainID.String())) } @@ -211,18 +173,18 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP } // Close tears down the MultiNode and closes all nodes -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) Close() error { +func (c *MultiNode[CHAIN_ID, RPC]) Close() error { return c.StopOnce("MultiNode", func() error { close(c.chStop) c.wg.Wait() - return services.CloseAll(services.MultiCloser(c.nodes), services.MultiCloser(c.sendonlys)) + return services.CloseAll(services.MultiCloser(c.primaryNodes), services.MultiCloser(c.sendOnlyNodes)) }) } -// SelectNodeRPC returns an RPC of an active node. If there are no active nodes it returns an error. +// SelectRPC returns an RPC of an active node. If there are no active nodes it returns an error. // Call this method from your chain-specific client implementation to access any chain-specific rpc calls. -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) SelectNodeRPC() (rpc RPC_CLIENT, err error) { +func (c *MultiNode[CHAIN_ID, RPC]) SelectRPC() (rpc RPC, err error) { n, err := c.selectNode() if err != nil { return rpc, err @@ -230,12 +192,12 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP return n.RPC(), nil } -// selectNode returns the active Node, if it is still nodeStateAlive, otherwise it selects a new one from the NodeSelector. -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) selectNode() (node Node[CHAIN_ID, HEAD, RPC_CLIENT], err error) { +// selectNode returns the active Node, if it is still NodeStateAlive, otherwise it selects a new one from the NodeSelector. +func (c *MultiNode[CHAIN_ID, RPC]) selectNode() (node Node[CHAIN_ID, RPC], err error) { c.activeMu.RLock() node = c.activeNode c.activeMu.RUnlock() - if node != nil && node.State() == nodeStateAlive { + if node != nil && node.State() == NodeStateAlive { return // still alive } @@ -243,7 +205,7 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP c.activeMu.Lock() defer c.activeMu.Unlock() node = c.activeNode - if node != nil && node.State() == nodeStateAlive { + if node != nil && node.State() == NodeStateAlive { return // another goroutine beat us here } @@ -265,13 +227,13 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP // LatestChainInfo - returns number of live nodes available in the pool, so we can prevent the last alive node in a pool from being marked as out-of-sync. // Return highest ChainInfo most recently received by the alive nodes. // E.g. If Node A's the most recent block is 10 and highest 15 and for Node B it's - 12 and 14. This method will return 12. -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) LatestChainInfo() (int, ChainInfo) { +func (c *MultiNode[CHAIN_ID, RPC]) LatestChainInfo() (int, ChainInfo) { var nLiveNodes int ch := ChainInfo{ TotalDifficulty: big.NewInt(0), } - for _, n := range c.nodes { - if s, nodeChainInfo := n.StateAndLatest(); s == nodeStateAlive { + for _, n := range c.primaryNodes { + if s, nodeChainInfo := n.StateAndLatest(); s == NodeStateAlive { nLiveNodes++ ch.BlockNumber = max(ch.BlockNumber, nodeChainInfo.BlockNumber) ch.FinalizedBlockNumber = max(ch.FinalizedBlockNumber, nodeChainInfo.FinalizedBlockNumber) @@ -282,11 +244,11 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP } // HighestUserObservations - returns highest ChainInfo ever observed by any user of the MultiNode -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) HighestUserObservations() ChainInfo { +func (c *MultiNode[CHAIN_ID, RPC]) HighestUserObservations() ChainInfo { ch := ChainInfo{ TotalDifficulty: big.NewInt(0), } - for _, n := range c.nodes { + for _, n := range c.primaryNodes { nodeChainInfo := n.HighestUserObservations() ch.BlockNumber = max(ch.BlockNumber, nodeChainInfo.BlockNumber) ch.FinalizedBlockNumber = max(ch.FinalizedBlockNumber, nodeChainInfo.FinalizedBlockNumber) @@ -295,12 +257,12 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP return ch } -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) checkLease() { +func (c *MultiNode[CHAIN_ID, RPC]) checkLease() { bestNode := c.nodeSelector.Select() - for _, n := range c.nodes { + for _, n := range c.primaryNodes { // Terminate client subscriptions. Services are responsible for reconnecting, which will be routed to the new // best node. Only terminate connections with more than 1 subscription to account for the aliveLoop subscription - if n.State() == nodeStateAlive && n != bestNode && n.SubscribersCount() > 1 { + if n.State() == NodeStateAlive && n != bestNode { c.lggr.Infof("Switching to best node from %q to %q", n.String(), bestNode.String()) n.UnsubscribeAllExceptAliveLoop() } @@ -316,7 +278,7 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP } } -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) checkLeaseLoop() { +func (c *MultiNode[CHAIN_ID, RPC]) checkLeaseLoop() { defer c.wg.Done() c.leaseTicker = time.NewTicker(c.leaseDuration) defer c.leaseTicker.Stop() @@ -331,11 +293,11 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP } } -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) runLoop() { +func (c *MultiNode[CHAIN_ID, RPC]) runLoop() { defer c.wg.Done() - nodeStates := make([]nodeWithState, len(c.nodes)) - for i, n := range c.nodes { + nodeStates := make([]nodeWithState, len(c.primaryNodes)) + for i, n := range c.primaryNodes { nodeStates[i] = nodeWithState{ Node: n.String(), State: n.State().String(), @@ -364,15 +326,15 @@ type nodeWithState struct { DeadSince *time.Time } -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) report(nodesStateInfo []nodeWithState) { +func (c *MultiNode[CHAIN_ID, RPC]) report(nodesStateInfo []nodeWithState) { start := time.Now() var dead int - counts := make(map[nodeState]int) - for i, n := range c.nodes { + counts := make(map[NodeState]int) + for i, n := range c.primaryNodes { state := n.State() counts[state]++ nodesStateInfo[i].State = state.String() - if state == nodeStateAlive { + if state == NodeStateAlive { nodesStateInfo[i].DeadSince = nil continue } @@ -390,7 +352,7 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP PromMultiNodeRPCNodeStates.WithLabelValues(c.chainFamily, c.chainID.String(), state.String()).Set(float64(count)) } - total := len(c.nodes) + total := len(c.primaryNodes) live := total - dead c.lggr.Tracew(fmt.Sprintf("MultiNode state: %d/%d nodes are alive", live, total), "nodeStates", nodesStateInfo) if total == dead { @@ -401,455 +363,3 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP c.lggr.Errorw(fmt.Sprintf("At least one primary node is dead: %d/%d nodes are alive", live, total), "nodeStates", nodesStateInfo) } } - -// ClientAPI methods -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) BalanceAt(ctx context.Context, account ADDR, blockNumber *big.Int) (*big.Int, error) { - n, err := c.selectNode() - if err != nil { - return nil, err - } - return n.RPC().BalanceAt(ctx, account, blockNumber) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) BatchCallContext(ctx context.Context, b []BATCH_ELEM) error { - n, err := c.selectNode() - if err != nil { - return err - } - return n.RPC().BatchCallContext(ctx, b) -} - -// BatchCallContextAll calls BatchCallContext for every single node including -// sendonlys. -// CAUTION: This should only be used for mass re-transmitting transactions, it -// might have unexpected effects to use it for anything else. -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) BatchCallContextAll(ctx context.Context, b []BATCH_ELEM) error { - var wg sync.WaitGroup - defer wg.Wait() - - main, selectionErr := c.selectNode() - var all []SendOnlyNode[CHAIN_ID, RPC_CLIENT] - for _, n := range c.nodes { - all = append(all, n) - } - all = append(all, c.sendonlys...) - for _, n := range all { - if n == main { - // main node is used at the end for the return value - continue - } - - if n.State() != nodeStateAlive { - continue - } - // Parallel call made to all other nodes with ignored return value - wg.Add(1) - go func(n SendOnlyNode[CHAIN_ID, RPC_CLIENT]) { - defer wg.Done() - err := n.RPC().BatchCallContext(ctx, b) - if err != nil { - c.lggr.Debugw("Secondary node BatchCallContext failed", "err", err) - } else { - c.lggr.Trace("Secondary node BatchCallContext success") - } - }(n) - } - - if selectionErr != nil { - return selectionErr - } - return main.RPC().BatchCallContext(ctx, b) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) BlockByHash(ctx context.Context, hash BLOCK_HASH) (h HEAD, err error) { - n, err := c.selectNode() - if err != nil { - return h, err - } - return n.RPC().BlockByHash(ctx, hash) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) BlockByNumber(ctx context.Context, number *big.Int) (h HEAD, err error) { - n, err := c.selectNode() - if err != nil { - return h, err - } - return n.RPC().BlockByNumber(ctx, number) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { - n, err := c.selectNode() - if err != nil { - return err - } - return n.RPC().CallContext(ctx, result, method, args...) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) CallContract( - ctx context.Context, - attempt interface{}, - blockNumber *big.Int, -) (rpcErr []byte, extractErr error) { - n, err := c.selectNode() - if err != nil { - return rpcErr, err - } - return n.RPC().CallContract(ctx, attempt, blockNumber) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) PendingCallContract( - ctx context.Context, - attempt interface{}, -) (rpcErr []byte, extractErr error) { - n, err := c.selectNode() - if err != nil { - return rpcErr, err - } - return n.RPC().PendingCallContract(ctx, attempt) -} - -// ChainID makes a direct RPC call. In most cases it should be better to use the configured chain id instead by -// calling ConfiguredChainID. -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) ChainID(ctx context.Context) (id CHAIN_ID, err error) { - n, err := c.selectNode() - if err != nil { - return id, err - } - return n.RPC().ChainID(ctx) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) CodeAt(ctx context.Context, account ADDR, blockNumber *big.Int) (code []byte, err error) { - n, err := c.selectNode() - if err != nil { - return code, err - } - return n.RPC().CodeAt(ctx, account, blockNumber) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) ConfiguredChainID() CHAIN_ID { - return c.chainID -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) EstimateGas(ctx context.Context, call any) (gas uint64, err error) { - n, err := c.selectNode() - if err != nil { - return gas, err - } - return n.RPC().EstimateGas(ctx, call) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) FilterEvents(ctx context.Context, query EVENT_OPS) (e []EVENT, err error) { - n, err := c.selectNode() - if err != nil { - return e, err - } - return n.RPC().FilterEvents(ctx, query) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) LatestBlockHeight(ctx context.Context) (h *big.Int, err error) { - n, err := c.selectNode() - if err != nil { - return h, err - } - return n.RPC().LatestBlockHeight(ctx) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) LINKBalance(ctx context.Context, accountAddress ADDR, linkAddress ADDR) (b *assets.Link, err error) { - n, err := c.selectNode() - if err != nil { - return b, err - } - return n.RPC().LINKBalance(ctx, accountAddress, linkAddress) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) NodeStates() (states map[string]string) { - states = make(map[string]string) - for _, n := range c.nodes { - states[n.Name()] = n.State().String() - } - for _, s := range c.sendonlys { - states[s.Name()] = s.State().String() - } - return -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) PendingSequenceAt(ctx context.Context, addr ADDR) (s SEQ, err error) { - n, err := c.selectNode() - if err != nil { - return s, err - } - return n.RPC().PendingSequenceAt(ctx, addr) -} - -type sendTxErrors map[SendTxReturnCode][]error - -// String - returns string representation of the errors map. Required by logger to properly represent the value -func (errs sendTxErrors) String() string { - return fmt.Sprint(map[SendTxReturnCode][]error(errs)) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) SendEmptyTransaction( - ctx context.Context, - newTxAttempt func(seq SEQ, feeLimit uint32, fee FEE, fromAddress ADDR) (attempt any, err error), - seq SEQ, - gasLimit uint32, - fee FEE, - fromAddress ADDR, -) (txhash string, err error) { - n, err := c.selectNode() - if err != nil { - return txhash, err - } - return n.RPC().SendEmptyTransaction(ctx, newTxAttempt, seq, gasLimit, fee, fromAddress) -} - -type sendTxResult struct { - Err error - ResultCode SendTxReturnCode -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) broadcastTxAsync(ctx context.Context, - n SendOnlyNode[CHAIN_ID, RPC_CLIENT], tx TX) sendTxResult { - txErr := n.RPC().SendTransaction(ctx, tx) - c.lggr.Debugw("Node sent transaction", "name", n.String(), "tx", tx, "err", txErr) - resultCode := c.classifySendTxError(tx, txErr) - if !slices.Contains(sendTxSuccessfulCodes, resultCode) { - c.lggr.Warnw("RPC returned error", "name", n.String(), "tx", tx, "err", txErr) - } - - return sendTxResult{Err: txErr, ResultCode: resultCode} -} - -// collectTxResults - refer to SendTransaction comment for implementation details, -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) collectTxResults(ctx context.Context, tx TX, healthyNodesNum int, txResults <-chan sendTxResult) error { - if healthyNodesNum == 0 { - return ErroringNodeError - } - // combine context and stop channel to ensure we stop, when signal received - ctx, cancel := c.chStop.Ctx(ctx) - defer cancel() - requiredResults := int(math.Ceil(float64(healthyNodesNum) * sendTxQuorum)) - errorsByCode := sendTxErrors{} - var softTimeoutChan <-chan time.Time - var resultsCount int -loop: - for { - select { - case <-ctx.Done(): - c.lggr.Debugw("Failed to collect of the results before context was done", "tx", tx, "errorsByCode", errorsByCode) - return ctx.Err() - case result := <-txResults: - errorsByCode[result.ResultCode] = append(errorsByCode[result.ResultCode], result.Err) - resultsCount++ - if slices.Contains(sendTxSuccessfulCodes, result.ResultCode) || resultsCount >= requiredResults { - break loop - } - case <-softTimeoutChan: - c.lggr.Debugw("Send Tx soft timeout expired - returning responses we've collected so far", "tx", tx, "resultsCount", resultsCount, "requiredResults", requiredResults) - break loop - } - - if softTimeoutChan == nil { - tm := time.NewTimer(c.sendTxSoftTimeout) - softTimeoutChan = tm.C - // we are fine with stopping timer at the end of function - //nolint - defer tm.Stop() - } - } - - // ignore critical error as it's reported in reportSendTxAnomalies - result, _ := aggregateTxResults(errorsByCode) - return result -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) reportSendTxAnomalies(tx TX, txResults <-chan sendTxResult) { - defer c.wg.Done() - resultsByCode := sendTxErrors{} - // txResults eventually will be closed - for txResult := range txResults { - resultsByCode[txResult.ResultCode] = append(resultsByCode[txResult.ResultCode], txResult.Err) - } - - _, criticalErr := aggregateTxResults(resultsByCode) - if criticalErr != nil { - c.lggr.Criticalw("observed invariant violation on SendTransaction", "tx", tx, "resultsByCode", resultsByCode, "err", criticalErr) - c.SvcErrBuffer.Append(criticalErr) - PromMultiNodeInvariantViolations.WithLabelValues(c.chainFamily, c.chainID.String(), criticalErr.Error()).Inc() - } -} - -func aggregateTxResults(resultsByCode sendTxErrors) (txResult error, err error) { - severeErrors, hasSevereErrors := findFirstIn(resultsByCode, sendTxSevereErrors) - successResults, hasSuccess := findFirstIn(resultsByCode, sendTxSuccessfulCodes) - if hasSuccess { - // We assume that primary node would never report false positive txResult for a transaction. - // Thus, if such case occurs it's probably due to misconfiguration or a bug and requires manual intervention. - if hasSevereErrors { - const errMsg = "found contradictions in nodes replies on SendTransaction: got success and severe error" - // return success, since at least 1 node has accepted our broadcasted Tx, and thus it can now be included onchain - return successResults[0], fmt.Errorf(errMsg) - } - - // other errors are temporary - we are safe to return success - return successResults[0], nil - } - - if hasSevereErrors { - return severeErrors[0], nil - } - - // return temporary error - for _, result := range resultsByCode { - return result[0], nil - } - - err = fmt.Errorf("expected at least one response on SendTransaction") - return err, err -} - -const sendTxQuorum = 0.7 - -// SendTransaction - broadcasts transaction to all the send-only and primary nodes regardless of their health. -// A returned nil or error does not guarantee that the transaction will or won't be included. Additional checks must be -// performed to determine the final state. -// -// Send-only nodes' results are ignored as they tend to return false-positive responses. Broadcast to them is necessary -// to speed up the propagation of TX in the network. -// -// Handling of primary nodes' results consists of collection and aggregation. -// In the collection step, we gather as many results as possible while minimizing waiting time. This operation succeeds -// on one of the following conditions: -// * Received at least one success -// * Received at least one result and `sendTxSoftTimeout` expired -// * Received results from the sufficient number of nodes defined by sendTxQuorum. -// The aggregation is based on the following conditions: -// * If there is at least one success - returns success -// * If there is at least one terminal error - returns terminal error -// * If there is both success and terminal error - returns success and reports invariant violation -// * Otherwise, returns any (effectively random) of the errors. -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) SendTransaction(ctx context.Context, tx TX) error { - if len(c.nodes) == 0 { - return ErroringNodeError - } - - healthyNodesNum := 0 - txResults := make(chan sendTxResult, len(c.nodes)) - // Must wrap inside IfNotStopped to avoid waitgroup racing with Close - ok := c.IfNotStopped(func() { - // fire-n-forget, as sendOnlyNodes can not be trusted with result reporting - for _, n := range c.sendonlys { - if n.State() != nodeStateAlive { - continue - } - c.wg.Add(1) - go func(n SendOnlyNode[CHAIN_ID, RPC_CLIENT]) { - defer c.wg.Done() - c.broadcastTxAsync(ctx, n, tx) - }(n) - } - - var primaryBroadcastWg sync.WaitGroup - txResultsToReport := make(chan sendTxResult, len(c.nodes)) - for _, n := range c.nodes { - if n.State() != nodeStateAlive { - continue - } - - healthyNodesNum++ - primaryBroadcastWg.Add(1) - go func(n SendOnlyNode[CHAIN_ID, RPC_CLIENT]) { - defer primaryBroadcastWg.Done() - result := c.broadcastTxAsync(ctx, n, tx) - // both channels are sufficiently buffered, so we won't be locked - txResultsToReport <- result - txResults <- result - }(n) - } - - c.wg.Add(1) - go func() { - // wait for primary nodes to finish the broadcast before closing the channel - primaryBroadcastWg.Wait() - close(txResultsToReport) - close(txResults) - c.wg.Done() - }() - - c.wg.Add(1) - go c.reportSendTxAnomalies(tx, txResultsToReport) - }) - if !ok { - return fmt.Errorf("aborted while broadcasting tx - multiNode is stopped: %w", context.Canceled) - } - - return c.collectTxResults(ctx, tx, healthyNodesNum, txResults) -} - -// findFirstIn - returns first existing value for the slice of keys -func findFirstIn[K comparable, V any](set map[K]V, keys []K) (V, bool) { - for _, k := range keys { - if v, ok := set[k]; ok { - return v, true - } - } - var v V - return v, false -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) SequenceAt(ctx context.Context, account ADDR, blockNumber *big.Int) (s SEQ, err error) { - n, err := c.selectNode() - if err != nil { - return s, err - } - return n.RPC().SequenceAt(ctx, account, blockNumber) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) SimulateTransaction(ctx context.Context, tx TX) error { - n, err := c.selectNode() - if err != nil { - return err - } - return n.RPC().SimulateTransaction(ctx, tx) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) SubscribeNewHead(ctx context.Context, channel chan<- HEAD) (s types.Subscription, err error) { - n, err := c.selectNode() - if err != nil { - return s, err - } - return n.RPC().SubscribeNewHead(ctx, channel) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) TokenBalance(ctx context.Context, account ADDR, tokenAddr ADDR) (b *big.Int, err error) { - n, err := c.selectNode() - if err != nil { - return b, err - } - return n.RPC().TokenBalance(ctx, account, tokenAddr) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) TransactionByHash(ctx context.Context, txHash TX_HASH) (tx TX, err error) { - n, err := c.selectNode() - if err != nil { - return tx, err - } - return n.RPC().TransactionByHash(ctx, txHash) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) TransactionReceipt(ctx context.Context, txHash TX_HASH) (txr TX_RECEIPT, err error) { - n, err := c.selectNode() - if err != nil { - return txr, err - } - return n.RPC().TransactionReceipt(ctx, txHash) -} - -func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT, BATCH_ELEM]) LatestFinalizedBlock(ctx context.Context) (head HEAD, err error) { - n, err := c.selectNode() - if err != nil { - return head, err - } - - return n.RPC().LatestFinalizedBlock(ctx) -} diff --git a/common/client/multi_node_test.go b/common/client/multi_node_test.go index ffef0c29d56..dc8140d52ce 100644 --- a/common/client/multi_node_test.go +++ b/common/client/multi_node_test.go @@ -1,44 +1,38 @@ package client import ( - "context" - "errors" "fmt" "math/big" "math/rand" "testing" "time" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap" - "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink/v2/common/types" ) -type multiNodeRPCClient RPC[types.ID, *big.Int, Hashable, Hashable, any, Hashable, any, any, - types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable], any] +type multiNodeRPCClient RPCClient[types.ID, types.Head[Hashable]] type testMultiNode struct { - *multiNode[types.ID, *big.Int, Hashable, Hashable, any, Hashable, any, any, - types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable], multiNodeRPCClient, any] + *MultiNode[types.ID, multiNodeRPCClient] } type multiNodeOpts struct { logger logger.Logger selectionMode string leaseDuration time.Duration - noNewHeadsThreshold time.Duration - nodes []Node[types.ID, types.Head[Hashable], multiNodeRPCClient] + nodes []Node[types.ID, multiNodeRPCClient] sendonlys []SendOnlyNode[types.ID, multiNodeRPCClient] chainID types.ID chainFamily string - classifySendTxError func(tx any, err error) SendTxReturnCode - sendTxSoftTimeout time.Duration deathDeclarationDelay time.Duration } @@ -47,46 +41,32 @@ func newTestMultiNode(t *testing.T, opts multiNodeOpts) testMultiNode { opts.logger = logger.Test(t) } - result := NewMultiNode[types.ID, *big.Int, Hashable, Hashable, any, Hashable, any, any, - types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable], multiNodeRPCClient, any](opts.logger, - opts.selectionMode, opts.leaseDuration, opts.noNewHeadsThreshold, opts.nodes, opts.sendonlys, - opts.chainID, opts.chainFamily, opts.classifySendTxError, opts.sendTxSoftTimeout, opts.deathDeclarationDelay) + result := NewMultiNode[types.ID, multiNodeRPCClient]( + opts.logger, opts.selectionMode, opts.leaseDuration, opts.nodes, opts.sendonlys, opts.chainID, opts.chainFamily, opts.deathDeclarationDelay) return testMultiNode{ - result.(*multiNode[types.ID, *big.Int, Hashable, Hashable, any, Hashable, any, any, - types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable], multiNodeRPCClient, any]), + result, } } -func newMultiNodeRPCClient(t *testing.T) *mockRPC[types.ID, *big.Int, Hashable, Hashable, any, Hashable, any, any, - types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable], any] { - return newMockRPC[types.ID, *big.Int, Hashable, Hashable, any, Hashable, any, any, - types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable], any](t) -} - -func newHealthyNode(t *testing.T, chainID types.ID) *mockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] { - return newNodeWithState(t, chainID, nodeStateAlive) -} - -func newNodeWithState(t *testing.T, chainID types.ID, state nodeState) *mockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] { - node := newDialableNode(t, chainID) - node.On("State").Return(state).Maybe() - return node +func newHealthyNode(t *testing.T, chainID types.ID) *mockNode[types.ID, multiNodeRPCClient] { + return newNodeWithState(t, chainID, NodeStateAlive) } -func newDialableNode(t *testing.T, chainID types.ID) *mockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] { - node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) +func newNodeWithState(t *testing.T, chainID types.ID, state NodeState) *mockNode[types.ID, multiNodeRPCClient] { + node := newMockNode[types.ID, multiNodeRPCClient](t) node.On("ConfiguredChainID").Return(chainID).Once() node.On("Start", mock.Anything).Return(nil).Once() node.On("Close").Return(nil).Once() node.On("String").Return(fmt.Sprintf("healthy_node_%d", rand.Int())).Maybe() node.On("SetPoolChainInfoProvider", mock.Anything).Once() + node.On("State").Return(state).Maybe() return node } func TestMultiNode_Dial(t *testing.T) { t.Parallel() - newMockNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] + newMockNode := newMockNode[types.ID, multiNodeRPCClient] newMockSendOnlyNode := newMockSendOnlyNode[types.ID, multiNodeRPCClient] t.Run("Fails without nodes", func(t *testing.T) { @@ -95,7 +75,7 @@ func TestMultiNode_Dial(t *testing.T) { selectionMode: NodeSelectionModeRoundRobin, chainID: types.RandomID(), }) - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) assert.EqualError(t, err, fmt.Sprintf("no available nodes for chain %s", mn.chainID.String())) }) t.Run("Fails with wrong node's chainID", func(t *testing.T) { @@ -109,9 +89,9 @@ func TestMultiNode_Dial(t *testing.T) { mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: multiNodeChainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + nodes: []Node[types.ID, multiNodeRPCClient]{node}, }) - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) assert.EqualError(t, err, fmt.Sprintf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", nodeName, nodeChainID, mn.chainID)) }) t.Run("Fails if node fails", func(t *testing.T) { @@ -119,15 +99,15 @@ func TestMultiNode_Dial(t *testing.T) { node := newMockNode(t) chainID := types.RandomID() node.On("ConfiguredChainID").Return(chainID).Once() - node.On("SetPoolChainInfoProvider", mock.Anything).Once() expectedError := errors.New("failed to start node") node.On("Start", mock.Anything).Return(expectedError).Once() + node.On("SetPoolChainInfoProvider", mock.Anything).Once() mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: chainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + nodes: []Node[types.ID, multiNodeRPCClient]{node}, }) - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) assert.EqualError(t, err, expectedError.Error()) }) @@ -137,16 +117,16 @@ func TestMultiNode_Dial(t *testing.T) { node1 := newHealthyNode(t, chainID) node2 := newMockNode(t) node2.On("ConfiguredChainID").Return(chainID).Once() - node2.On("SetPoolChainInfoProvider", mock.Anything).Once() expectedError := errors.New("failed to start node") node2.On("Start", mock.Anything).Return(expectedError).Once() + node2.On("SetPoolChainInfoProvider", mock.Anything).Once() mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: chainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node1, node2}, + nodes: []Node[types.ID, multiNodeRPCClient]{node1, node2}, }) - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) assert.EqualError(t, err, expectedError.Error()) }) t.Run("Fails with wrong send only node's chainID", func(t *testing.T) { @@ -162,10 +142,10 @@ func TestMultiNode_Dial(t *testing.T) { mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: multiNodeChainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + nodes: []Node[types.ID, multiNodeRPCClient]{node}, sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{sendOnly}, }) - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) assert.EqualError(t, err, fmt.Sprintf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", sendOnlyName, sendOnlyChainID, mn.chainID)) }) @@ -189,10 +169,10 @@ func TestMultiNode_Dial(t *testing.T) { mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: chainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + nodes: []Node[types.ID, multiNodeRPCClient]{node}, sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{sendOnly1, sendOnly2}, }) - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) assert.EqualError(t, err, expectedError.Error()) }) t.Run("Starts successfully with healthy nodes", func(t *testing.T) { @@ -202,11 +182,11 @@ func TestMultiNode_Dial(t *testing.T) { mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: chainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + nodes: []Node[types.ID, multiNodeRPCClient]{node}, sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{newHealthySendOnly(t, chainID)}, }) defer func() { assert.NoError(t, mn.Close()) }() - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) require.NoError(t, err) selectedNode, err := mn.selectNode() require.NoError(t, err) @@ -220,36 +200,36 @@ func TestMultiNode_Report(t *testing.T) { t.Parallel() chainID := types.RandomID() node1 := newHealthyNode(t, chainID) - node2 := newNodeWithState(t, chainID, nodeStateOutOfSync) + node2 := newNodeWithState(t, chainID, NodeStateOutOfSync) lggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: chainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node1, node2}, + nodes: []Node[types.ID, multiNodeRPCClient]{node1, node2}, logger: lggr, }) mn.reportInterval = tests.TestInterval mn.deathDeclarationDelay = tests.TestInterval defer func() { assert.NoError(t, mn.Close()) }() - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) require.NoError(t, err) tests.AssertLogCountEventually(t, observedLogs, "At least one primary node is dead: 1/2 nodes are alive", 2) }) t.Run("Report critical error on all node failure", func(t *testing.T) { t.Parallel() chainID := types.RandomID() - node := newNodeWithState(t, chainID, nodeStateOutOfSync) + node := newNodeWithState(t, chainID, NodeStateOutOfSync) lggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: chainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + nodes: []Node[types.ID, multiNodeRPCClient]{node}, logger: lggr, }) mn.reportInterval = tests.TestInterval mn.deathDeclarationDelay = tests.TestInterval defer func() { assert.NoError(t, mn.Close()) }() - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) require.NoError(t, err) tests.AssertLogCountEventually(t, observedLogs, "no primary nodes available: 0/1 nodes are alive", 2) err = mn.Healthy() @@ -269,10 +249,10 @@ func TestMultiNode_CheckLease(t *testing.T) { selectionMode: NodeSelectionModeRoundRobin, chainID: chainID, logger: lggr, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + nodes: []Node[types.ID, multiNodeRPCClient]{node}, }) defer func() { assert.NoError(t, mn.Close()) }() - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) require.NoError(t, err) tests.RequireLogMessage(t, observedLogs, "Best node switching is disabled") }) @@ -285,11 +265,11 @@ func TestMultiNode_CheckLease(t *testing.T) { selectionMode: NodeSelectionModeHighestHead, chainID: chainID, logger: lggr, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + nodes: []Node[types.ID, multiNodeRPCClient]{node}, leaseDuration: 0, }) defer func() { assert.NoError(t, mn.Close()) }() - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) require.NoError(t, err) tests.RequireLogMessage(t, observedLogs, "Best node switching is disabled") }) @@ -297,22 +277,21 @@ func TestMultiNode_CheckLease(t *testing.T) { t.Parallel() chainID := types.RandomID() node := newHealthyNode(t, chainID) - node.On("SubscribersCount").Return(int32(2)) node.On("UnsubscribeAllExceptAliveLoop") bestNode := newHealthyNode(t, chainID) - nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector := newMockNodeSelector[types.ID, multiNodeRPCClient](t) nodeSelector.On("Select").Return(bestNode) lggr, observedLogs := logger.TestObserved(t, zap.InfoLevel) mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeHighestHead, chainID: chainID, logger: lggr, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node, bestNode}, + nodes: []Node[types.ID, multiNodeRPCClient]{node, bestNode}, leaseDuration: tests.TestInterval, }) defer func() { assert.NoError(t, mn.Close()) }() mn.nodeSelector = nodeSelector - err := mn.Dial(tests.Context(t)) + err := mn.Start(tests.Context(t)) require.NoError(t, err) tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("Switching to best node from %q to %q", node.String(), bestNode.String())) tests.AssertEventually(t, func() bool { @@ -325,10 +304,10 @@ func TestMultiNode_CheckLease(t *testing.T) { t.Run("NodeStates returns proper states", func(t *testing.T) { t.Parallel() chainID := types.NewIDFromInt(10) - nodes := map[string]nodeState{ - "node_1": nodeStateAlive, - "node_2": nodeStateUnreachable, - "node_3": nodeStateDialed, + nodes := map[string]NodeState{ + "node_1": NodeStateAlive, + "node_2": NodeStateUnreachable, + "node_3": NodeStateDialed, } opts := multiNodeOpts{ @@ -336,21 +315,21 @@ func TestMultiNode_CheckLease(t *testing.T) { chainID: chainID, } - expectedResult := map[string]string{} + expectedResult := map[string]NodeState{} for name, state := range nodes { - node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - node.On("Name").Return(name).Once() + node := newMockNode[types.ID, multiNodeRPCClient](t) node.On("State").Return(state).Once() + node.On("String").Return(name).Once() opts.nodes = append(opts.nodes, node) sendOnly := newMockSendOnlyNode[types.ID, multiNodeRPCClient](t) sendOnlyName := "send_only_" + name - sendOnly.On("Name").Return(sendOnlyName).Once() sendOnly.On("State").Return(state).Once() + sendOnly.On("String").Return(sendOnlyName).Once() opts.sendonlys = append(opts.sendonlys, sendOnly) - expectedResult[name] = state.String() - expectedResult[sendOnlyName] = state.String() + expectedResult[name] = state + expectedResult[sendOnlyName] = state } mn := newTestMultiNode(t, opts) @@ -364,17 +343,17 @@ func TestMultiNode_selectNode(t *testing.T) { t.Run("Returns same node, if it's still healthy", func(t *testing.T) { t.Parallel() chainID := types.RandomID() - node1 := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - node1.On("State").Return(nodeStateAlive).Once() + node1 := newMockNode[types.ID, multiNodeRPCClient](t) + node1.On("State").Return(NodeStateAlive).Once() node1.On("String").Return("node1").Maybe() - node2 := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + node2 := newMockNode[types.ID, multiNodeRPCClient](t) node2.On("String").Return("node2").Maybe() mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: chainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node1, node2}, + nodes: []Node[types.ID, multiNodeRPCClient]{node1, node2}, }) - nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector := newMockNodeSelector[types.ID, multiNodeRPCClient](t) nodeSelector.On("Select").Return(node1).Once() mn.nodeSelector = nodeSelector prevActiveNode, err := mn.selectNode() @@ -387,24 +366,24 @@ func TestMultiNode_selectNode(t *testing.T) { t.Run("Updates node if active is not healthy", func(t *testing.T) { t.Parallel() chainID := types.RandomID() - oldBest := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + oldBest := newMockNode[types.ID, multiNodeRPCClient](t) oldBest.On("String").Return("oldBest").Maybe() - oldBest.On("UnsubscribeAllExceptAliveLoop").Once() - newBest := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + oldBest.On("UnsubscribeAllExceptAliveLoop") + newBest := newMockNode[types.ID, multiNodeRPCClient](t) newBest.On("String").Return("newBest").Maybe() mn := newTestMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: chainID, - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{oldBest, newBest}, + nodes: []Node[types.ID, multiNodeRPCClient]{oldBest, newBest}, }) - nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector := newMockNodeSelector[types.ID, multiNodeRPCClient](t) nodeSelector.On("Select").Return(oldBest).Once() mn.nodeSelector = nodeSelector activeNode, err := mn.selectNode() require.NoError(t, err) require.Equal(t, oldBest.String(), activeNode.String()) // old best died, so we should replace it - oldBest.On("State").Return(nodeStateOutOfSync).Twice() + oldBest.On("State").Return(NodeStateOutOfSync).Twice() nodeSelector.On("Select").Return(newBest).Once() newActiveNode, err := mn.selectNode() require.NoError(t, err) @@ -419,7 +398,7 @@ func TestMultiNode_selectNode(t *testing.T) { chainID: chainID, logger: lggr, }) - nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector := newMockNodeSelector[types.ID, multiNodeRPCClient](t) nodeSelector.On("Select").Return(nil).Once() nodeSelector.On("Name").Return("MockedNodeSelector").Once() mn.nodeSelector = nodeSelector @@ -435,7 +414,7 @@ func TestMultiNode_ChainInfo(t *testing.T) { type nodeParams struct { LatestChainInfo ChainInfo HighestUserObservations ChainInfo - State nodeState + State NodeState } testCases := []struct { Name string @@ -468,7 +447,7 @@ func TestMultiNode_ChainInfo(t *testing.T) { }, NodeParams: []nodeParams{ { - State: nodeStateOutOfSync, + State: NodeStateOutOfSync, LatestChainInfo: ChainInfo{ BlockNumber: 1000, FinalizedBlockNumber: 990, @@ -481,7 +460,7 @@ func TestMultiNode_ChainInfo(t *testing.T) { }, }, { - State: nodeStateAlive, + State: NodeStateAlive, LatestChainInfo: ChainInfo{ BlockNumber: 20, FinalizedBlockNumber: 10, @@ -494,7 +473,7 @@ func TestMultiNode_ChainInfo(t *testing.T) { }, }, { - State: nodeStateAlive, + State: NodeStateAlive, LatestChainInfo: ChainInfo{ BlockNumber: 19, FinalizedBlockNumber: 9, @@ -507,7 +486,7 @@ func TestMultiNode_ChainInfo(t *testing.T) { }, }, { - State: nodeStateAlive, + State: NodeStateAlive, LatestChainInfo: ChainInfo{ BlockNumber: 11, FinalizedBlockNumber: 1, @@ -532,10 +511,10 @@ func TestMultiNode_ChainInfo(t *testing.T) { tc := testCases[i] t.Run(tc.Name, func(t *testing.T) { for _, params := range tc.NodeParams { - node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + node := newMockNode[types.ID, multiNodeRPCClient](t) + mn.primaryNodes = append(mn.primaryNodes, node) node.On("StateAndLatest").Return(params.State, params.LatestChainInfo) node.On("HighestUserObservations").Return(params.HighestUserObservations) - mn.nodes = append(mn.nodes, node) } nNodes, latestChainInfo := mn.LatestChainInfo() @@ -548,103 +527,9 @@ func TestMultiNode_ChainInfo(t *testing.T) { } } -func TestMultiNode_BatchCallContextAll(t *testing.T) { - t.Parallel() - t.Run("Fails if failed to select active node", func(t *testing.T) { - chainID := types.RandomID() - mn := newTestMultiNode(t, multiNodeOpts{ - selectionMode: NodeSelectionModeRoundRobin, - chainID: chainID, - }) - nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - nodeSelector.On("Select").Return(nil).Once() - nodeSelector.On("Name").Return("MockedNodeSelector").Once() - mn.nodeSelector = nodeSelector - err := mn.BatchCallContextAll(tests.Context(t), nil) - require.EqualError(t, err, ErroringNodeError.Error()) - }) - t.Run("Returns error if RPC call fails for active node", func(t *testing.T) { - chainID := types.RandomID() - rpc := newMultiNodeRPCClient(t) - expectedError := errors.New("rpc failed to do the batch call") - rpc.On("BatchCallContext", mock.Anything, mock.Anything).Return(expectedError).Once() - node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - node.On("RPC").Return(rpc) - nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - nodeSelector.On("Select").Return(node).Once() - mn := newTestMultiNode(t, multiNodeOpts{ - selectionMode: NodeSelectionModeRoundRobin, - chainID: chainID, - }) - mn.nodeSelector = nodeSelector - err := mn.BatchCallContextAll(tests.Context(t), nil) - require.EqualError(t, err, expectedError.Error()) - }) - t.Run("Waits for all nodes to complete the call and logs results", func(t *testing.T) { - // setup RPCs - failedRPC := newMultiNodeRPCClient(t) - failedRPC.On("BatchCallContext", mock.Anything, mock.Anything). - Return(errors.New("rpc failed to do the batch call")).Once() - okRPC := newMultiNodeRPCClient(t) - okRPC.On("BatchCallContext", mock.Anything, mock.Anything).Return(nil).Twice() - - // setup ok and failed auxiliary nodes - okNode := newMockSendOnlyNode[types.ID, multiNodeRPCClient](t) - okNode.On("RPC").Return(okRPC).Once() - okNode.On("State").Return(nodeStateAlive) - failedNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - failedNode.On("RPC").Return(failedRPC).Once() - failedNode.On("State").Return(nodeStateAlive) - - // setup main node - mainNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - mainNode.On("RPC").Return(okRPC) - nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - nodeSelector.On("Select").Return(mainNode).Once() - lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) - mn := newTestMultiNode(t, multiNodeOpts{ - selectionMode: NodeSelectionModeRoundRobin, - chainID: types.RandomID(), - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{failedNode, mainNode}, - sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{okNode}, - logger: lggr, - }) - mn.nodeSelector = nodeSelector - - err := mn.BatchCallContextAll(tests.Context(t), nil) - require.NoError(t, err) - tests.RequireLogMessage(t, observedLogs, "Secondary node BatchCallContext failed") - }) - t.Run("Does not call BatchCallContext for unhealthy nodes", func(t *testing.T) { - // setup RPCs - okRPC := newMultiNodeRPCClient(t) - okRPC.On("BatchCallContext", mock.Anything, mock.Anything).Return(nil).Twice() - - // setup ok and failed auxiliary nodes - healthyNode := newMockSendOnlyNode[types.ID, multiNodeRPCClient](t) - healthyNode.On("RPC").Return(okRPC).Once() - healthyNode.On("State").Return(nodeStateAlive) - deadNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - deadNode.On("State").Return(nodeStateUnreachable) - - // setup main node - mainNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - mainNode.On("RPC").Return(okRPC) - nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) - nodeSelector.On("Select").Return(mainNode).Once() - mn := newTestMultiNode(t, multiNodeOpts{ - selectionMode: NodeSelectionModeRoundRobin, - chainID: types.RandomID(), - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{deadNode, mainNode}, - sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{healthyNode, deadNode}, - }) - mn.nodeSelector = nodeSelector - - err := mn.BatchCallContextAll(tests.Context(t), nil) - require.NoError(t, err) - }) -} +/* TODO: Add test covereage for DoAll() +/* TODO: Implement TransactionSender func TestMultiNode_SendTransaction(t *testing.T) { t.Parallel() classifySendTxError := func(tx any, err error) SendTxReturnCode { @@ -654,7 +539,7 @@ func TestMultiNode_SendTransaction(t *testing.T) { return Successful } - newNodeWithState := func(t *testing.T, state nodeState, txErr error, sendTxRun func(args mock.Arguments)) *mockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] { + newNodeWithState := func(t *testing.T, state NodeState, txErr error, sendTxRun func(args mock.Arguments)) *mockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] { rpc := newMultiNodeRPCClient(t) rpc.On("SendTransaction", mock.Anything, mock.Anything).Return(txErr).Run(sendTxRun).Maybe() node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) @@ -666,7 +551,7 @@ func TestMultiNode_SendTransaction(t *testing.T) { } newNode := func(t *testing.T, txErr error, sendTxRun func(args mock.Arguments)) *mockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] { - return newNodeWithState(t, nodeStateAlive, txErr, sendTxRun) + return newNodeWithState(t, NodeStateAlive, txErr, sendTxRun) } newStartedMultiNode := func(t *testing.T, opts multiNodeOpts) testMultiNode { mn := newTestMultiNode(t, opts) @@ -803,14 +688,14 @@ func TestMultiNode_SendTransaction(t *testing.T) { require.NoError(t, err) require.NoError(t, mn.Close()) err = mn.SendTransaction(tests.Context(t), nil) - require.EqualError(t, err, "aborted while broadcasting tx - multiNode is stopped: context canceled") + require.EqualError(t, err, "aborted while broadcasting tx - MultiNode is stopped: context canceled") }) t.Run("Returns error if there is no healthy primary nodes", func(t *testing.T) { mn := newStartedMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, chainID: types.RandomID(), - nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{newNodeWithState(t, nodeStateUnreachable, nil, nil)}, - sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{newNodeWithState(t, nodeStateUnreachable, nil, nil)}, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{newNodeWithState(t, NodeStateUnreachable, nil, nil)}, + sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{newNodeWithState(t, NodeStateUnreachable, nil, nil)}, classifySendTxError: classifySendTxError, }) err := mn.SendTransaction(tests.Context(t), nil) @@ -822,8 +707,8 @@ func TestMultiNode_SendTransaction(t *testing.T) { unexpectedCall := func(args mock.Arguments) { panic("SendTx must not be called for unhealthy node") } - unhealthyNode := newNodeWithState(t, nodeStateUnreachable, nil, unexpectedCall) - unhealthySendOnlyNode := newNodeWithState(t, nodeStateUnreachable, nil, unexpectedCall) + unhealthyNode := newNodeWithState(t, NodeStateUnreachable, nil, unexpectedCall) + unhealthySendOnlyNode := newNodeWithState(t, NodeStateUnreachable, nil, unexpectedCall) lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) mn := newStartedMultiNode(t, multiNodeOpts{ selectionMode: NodeSelectionModeRoundRobin, @@ -944,3 +829,4 @@ func TestMultiNode_SendTransaction_aggregateTxResults(t *testing.T) { } assert.Empty(t, codesToCover, "all of the SendTxReturnCode must be covered by this test") } +*/ diff --git a/common/client/node.go b/common/client/node.go index 7871c622eb4..7ef0460e538 100644 --- a/common/client/node.go +++ b/common/client/node.go @@ -57,27 +57,28 @@ type ChainConfig interface { //go:generate mockery --quiet --name Node --structname mockNode --filename "mock_node_test.go" --inpackage --case=underscore type Node[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], + RPC any, ] interface { // State returns most accurate state of the Node on the moment of call. // While some of the checks may be performed in the background and State may return cached value, critical, like // `FinalizedBlockOutOfSync`, must be executed upon every call. - State() nodeState + State() NodeState // StateAndLatest returns nodeState with the latest ChainInfo observed by Node during current lifecycle. - StateAndLatest() (nodeState, ChainInfo) + StateAndLatest() (NodeState, ChainInfo) // HighestUserObservations - returns highest ChainInfo ever observed by underlying RPC excluding results of health check requests HighestUserObservations() ChainInfo SetPoolChainInfoProvider(PoolChainInfoProvider) // Name is a unique identifier for this node. Name() string + // String - returns string representation of the node, useful for debugging (name + URLS used to connect to the RPC) String() string RPC() RPC - SubscribersCount() int32 // UnsubscribeAllExceptAliveLoop - closes all subscriptions except the aliveLoop subscription UnsubscribeAllExceptAliveLoop() ConfiguredChainID() CHAIN_ID + // Order - returns priority order configured for the RPC Order() int32 + // Start - starts health checks Start(context.Context) error Close() error } @@ -85,7 +86,7 @@ type Node[ type node[ CHAIN_ID types.ID, HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], + RPC RPCClient[CHAIN_ID, HEAD], ] struct { services.StateMachine lfcLog logger.Logger @@ -103,19 +104,22 @@ type node[ rpc RPC stateMu sync.RWMutex // protects state* fields - state nodeState + state NodeState poolInfoProvider PoolChainInfoProvider stopCh services.StopChan // wg waits for subsidiary goroutines wg sync.WaitGroup + + aliveLoopSub types.Subscription + finalizedBlockSub types.Subscription } func NewNode[ CHAIN_ID types.ID, HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], + RPC RPCClient[CHAIN_ID, HEAD], ]( nodeCfg NodeConfig, chainCfg ChainConfig, @@ -128,7 +132,7 @@ func NewNode[ nodeOrder int32, rpc RPC, chainFamily string, -) Node[CHAIN_ID, HEAD, RPC] { +) Node[CHAIN_ID, RPC] { n := new(node[CHAIN_ID, HEAD, RPC]) n.name = name n.id = id @@ -175,12 +179,18 @@ func (n *node[CHAIN_ID, HEAD, RPC]) RPC() RPC { return n.rpc } -func (n *node[CHAIN_ID, HEAD, RPC]) SubscribersCount() int32 { - return n.rpc.SubscribersCount() +// unsubscribeAllExceptAliveLoop is not thread-safe; it should only be called +// while holding the stateMu lock. +func (n *node[CHAIN_ID, HEAD, RPC]) unsubscribeAllExceptAliveLoop() { + aliveLoopSub := n.aliveLoopSub + finalizedBlockSub := n.finalizedBlockSub + n.rpc.UnsubscribeAllExcept(aliveLoopSub, finalizedBlockSub) } func (n *node[CHAIN_ID, HEAD, RPC]) UnsubscribeAllExceptAliveLoop() { - n.rpc.UnsubscribeAllExceptAliveLoop() + n.stateMu.Lock() + defer n.stateMu.Unlock() + n.unsubscribeAllExceptAliveLoop() } func (n *node[CHAIN_ID, HEAD, RPC]) Close() error { @@ -197,7 +207,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) close() error { defer n.stateMu.Unlock() close(n.stopCh) - n.state = nodeStateClosed + n.state = NodeStateClosed return nil } @@ -218,7 +228,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) Start(startCtx context.Context) error { // Node lifecycle is synchronous: only one goroutine should be running at a // time. func (n *node[CHAIN_ID, HEAD, RPC]) start(startCtx context.Context) { - if n.state != nodeStateUndialed { + if n.state != NodeStateUndialed { panic(fmt.Sprintf("cannot dial node with state %v", n.state)) } @@ -227,7 +237,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) start(startCtx context.Context) { n.declareUnreachable() return } - n.setState(nodeStateDialed) + n.setState(NodeStateDialed) state := n.verifyConn(startCtx, n.lfcLog) n.declareState(state) @@ -236,7 +246,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) start(startCtx context.Context) { // verifyChainID checks that connection to the node matches the given chain ID // Not thread-safe // Pure verifyChainID: does not mutate node "state" field. -func (n *node[CHAIN_ID, HEAD, RPC]) verifyChainID(callerCtx context.Context, lggr logger.Logger) nodeState { +func (n *node[CHAIN_ID, HEAD, RPC]) verifyChainID(callerCtx context.Context, lggr logger.Logger) NodeState { promPoolRPCNodeVerifies.WithLabelValues(n.chainFamily, n.chainID.String(), n.name).Inc() promFailed := func() { promPoolRPCNodeVerifiesFailed.WithLabelValues(n.chainFamily, n.chainID.String(), n.name).Inc() @@ -244,11 +254,11 @@ func (n *node[CHAIN_ID, HEAD, RPC]) verifyChainID(callerCtx context.Context, lgg st := n.getCachedState() switch st { - case nodeStateClosed: + case NodeStateClosed: // The node is already closed, and any subsequent transition is invalid. // To make spotting such transitions a bit easier, return the invalid node state. - return nodeStateLen - case nodeStateDialed, nodeStateOutOfSync, nodeStateInvalidChainID, nodeStateSyncing: + return NodeStateLen + case NodeStateDialed, NodeStateOutOfSync, NodeStateInvalidChainID, NodeStateSyncing: default: panic(fmt.Sprintf("cannot verify node in state %v", st)) } @@ -258,7 +268,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) verifyChainID(callerCtx context.Context, lgg if chainID, err = n.rpc.ChainID(callerCtx); err != nil { promFailed() lggr.Errorw("Failed to verify chain ID for node", "err", err, "nodeState", n.getCachedState()) - return nodeStateUnreachable + return NodeStateUnreachable } else if chainID.String() != n.chainID.String() { promFailed() err = fmt.Errorf( @@ -269,30 +279,30 @@ func (n *node[CHAIN_ID, HEAD, RPC]) verifyChainID(callerCtx context.Context, lgg errInvalidChainID, ) lggr.Errorw("Failed to verify RPC node; remote endpoint returned the wrong chain ID", "err", err, "nodeState", n.getCachedState()) - return nodeStateInvalidChainID + return NodeStateInvalidChainID } promPoolRPCNodeVerifiesSuccess.WithLabelValues(n.chainFamily, n.chainID.String(), n.name).Inc() - return nodeStateAlive + return NodeStateAlive } // createVerifiedConn - establishes new connection with the RPC and verifies that it's valid: chainID matches, and it's not syncing. -// Returns desired state if one of the verifications fails. Otherwise, returns nodeStateAlive. -func (n *node[CHAIN_ID, HEAD, RPC]) createVerifiedConn(ctx context.Context, lggr logger.Logger) nodeState { +// Returns desired state if one of the verifications fails. Otherwise, returns NodeStateAlive. +func (n *node[CHAIN_ID, HEAD, RPC]) createVerifiedConn(ctx context.Context, lggr logger.Logger) NodeState { if err := n.rpc.Dial(ctx); err != nil { n.lfcLog.Errorw("Dial failed: Node is unreachable", "err", err, "nodeState", n.getCachedState()) - return nodeStateUnreachable + return NodeStateUnreachable } return n.verifyConn(ctx, lggr) } // verifyConn - verifies that current connection is valid: chainID matches, and it's not syncing. -// Returns desired state if one of the verifications fails. Otherwise, returns nodeStateAlive. -func (n *node[CHAIN_ID, HEAD, RPC]) verifyConn(ctx context.Context, lggr logger.Logger) nodeState { +// Returns desired state if one of the verifications fails. Otherwise, returns NodeStateAlive. +func (n *node[CHAIN_ID, HEAD, RPC]) verifyConn(ctx context.Context, lggr logger.Logger) NodeState { state := n.verifyChainID(ctx, lggr) - if state != nodeStateAlive { + if state != NodeStateAlive { return state } @@ -300,23 +310,16 @@ func (n *node[CHAIN_ID, HEAD, RPC]) verifyConn(ctx context.Context, lggr logger. isSyncing, err := n.rpc.IsSyncing(ctx) if err != nil { lggr.Errorw("Unexpected error while verifying RPC node synchronization status", "err", err, "nodeState", n.getCachedState()) - return nodeStateUnreachable + return NodeStateUnreachable } if isSyncing { lggr.Errorw("Verification failed: Node is syncing", "nodeState", n.getCachedState()) - return nodeStateSyncing + return NodeStateSyncing } } - return nodeStateAlive -} - -// disconnectAll disconnects all clients connected to the node -// WARNING: NOT THREAD-SAFE -// This must be called from within the n.stateMu lock -func (n *node[CHAIN_ID, HEAD, RPC]) disconnectAll() { - n.rpc.DisconnectAll() + return NodeStateAlive } func (n *node[CHAIN_ID, HEAD, RPC]) Order() int32 { diff --git a/common/client/node_fsm.go b/common/client/node_fsm.go index 5a5e2554431..734a15e4be7 100644 --- a/common/client/node_fsm.go +++ b/common/client/node_fsm.go @@ -11,106 +11,106 @@ import ( var ( promPoolRPCNodeTransitionsToAlive = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "pool_rpc_node_num_transitions_to_alive", - Help: transitionString(nodeStateAlive), + Help: transitionString(NodeStateAlive), }, []string{"chainID", "nodeName"}) promPoolRPCNodeTransitionsToInSync = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "pool_rpc_node_num_transitions_to_in_sync", - Help: fmt.Sprintf("%s to %s", transitionString(nodeStateOutOfSync), nodeStateAlive), + Help: fmt.Sprintf("%s to %s", transitionString(NodeStateOutOfSync), NodeStateAlive), }, []string{"chainID", "nodeName"}) promPoolRPCNodeTransitionsToOutOfSync = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "pool_rpc_node_num_transitions_to_out_of_sync", - Help: transitionString(nodeStateOutOfSync), + Help: transitionString(NodeStateOutOfSync), }, []string{"chainID", "nodeName"}) promPoolRPCNodeTransitionsToUnreachable = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "pool_rpc_node_num_transitions_to_unreachable", - Help: transitionString(nodeStateUnreachable), + Help: transitionString(NodeStateUnreachable), }, []string{"chainID", "nodeName"}) promPoolRPCNodeTransitionsToInvalidChainID = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "pool_rpc_node_num_transitions_to_invalid_chain_id", - Help: transitionString(nodeStateInvalidChainID), + Help: transitionString(NodeStateInvalidChainID), }, []string{"chainID", "nodeName"}) promPoolRPCNodeTransitionsToUnusable = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "pool_rpc_node_num_transitions_to_unusable", - Help: transitionString(nodeStateUnusable), + Help: transitionString(NodeStateUnusable), }, []string{"chainID", "nodeName"}) promPoolRPCNodeTransitionsToSyncing = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "pool_rpc_node_num_transitions_to_syncing", - Help: transitionString(nodeStateSyncing), + Help: transitionString(NodeStateSyncing), }, []string{"chainID", "nodeName"}) ) -// nodeState represents the current state of the node +// NodeState represents the current state of the node // Node is a FSM (finite state machine) -type nodeState int +type NodeState int -func (n nodeState) String() string { +func (n NodeState) String() string { switch n { - case nodeStateUndialed: + case NodeStateUndialed: return "Undialed" - case nodeStateDialed: + case NodeStateDialed: return "Dialed" - case nodeStateInvalidChainID: + case NodeStateInvalidChainID: return "InvalidChainID" - case nodeStateAlive: + case NodeStateAlive: return "Alive" - case nodeStateUnreachable: + case NodeStateUnreachable: return "Unreachable" - case nodeStateUnusable: + case NodeStateUnusable: return "Unusable" - case nodeStateOutOfSync: + case NodeStateOutOfSync: return "OutOfSync" - case nodeStateClosed: + case NodeStateClosed: return "Closed" - case nodeStateSyncing: + case NodeStateSyncing: return "Syncing" - case nodeStateFinalizedBlockOutOfSync: + case NodeStateFinalizedBlockOutOfSync: return "FinalizedBlockOutOfSync" default: - return fmt.Sprintf("nodeState(%d)", n) + return fmt.Sprintf("NodeState(%d)", n) } } // GoString prints a prettier state -func (n nodeState) GoString() string { - return fmt.Sprintf("nodeState%s(%d)", n.String(), n) +func (n NodeState) GoString() string { + return fmt.Sprintf("NodeState%s(%d)", n.String(), n) } const ( - // nodeStateUndialed is the first state of a virgin node - nodeStateUndialed = nodeState(iota) - // nodeStateDialed is after a node has successfully dialed but before it has verified the correct chain ID - nodeStateDialed - // nodeStateInvalidChainID is after chain ID verification failed - nodeStateInvalidChainID - // nodeStateAlive is a healthy node after chain ID verification succeeded - nodeStateAlive - // nodeStateUnreachable is a node that cannot be dialed or has disconnected - nodeStateUnreachable - // nodeStateOutOfSync is a node that is accepting connections but exceeded + // NodeStateUndialed is the first state of a virgin node + NodeStateUndialed = NodeState(iota) + // NodeStateDialed is after a node has successfully dialed but before it has verified the correct chain ID + NodeStateDialed + // NodeStateInvalidChainID is after chain ID verification failed + NodeStateInvalidChainID + // NodeStateAlive is a healthy node after chain ID verification succeeded + NodeStateAlive + // NodeStateUnreachable is a node that cannot be dialed or has disconnected + NodeStateUnreachable + // NodeStateOutOfSync is a node that is accepting connections but exceeded // the failure threshold without sending any new heads. It will be // disconnected, then put into a revive loop and re-awakened after redial // if a new head arrives - nodeStateOutOfSync - // nodeStateUnusable is a sendonly node that has an invalid URL that can never be reached - nodeStateUnusable - // nodeStateClosed is after the connection has been closed and the node is at the end of its lifecycle - nodeStateClosed - // nodeStateSyncing is a node that is actively back-filling blockchain. Usually, it's a newly set up node that is - // still syncing the chain. The main difference from `nodeStateOutOfSync` is that it represents state relative - // to other primary nodes configured in the MultiNode. In contrast, `nodeStateSyncing` represents the internal state of + NodeStateOutOfSync + // NodeStateUnusable is a sendonly node that has an invalid URL that can never be reached + NodeStateUnusable + // NodeStateClosed is after the connection has been closed and the node is at the end of its lifecycle + NodeStateClosed + // NodeStateSyncing is a node that is actively back-filling blockchain. Usually, it's a newly set up node that is + // still syncing the chain. The main difference from `NodeStateOutOfSync` is that it represents state relative + // to other primary nodes configured in the MultiNode. In contrast, `NodeStateSyncing` represents the internal state of // the node (RPC). - nodeStateSyncing + NodeStateSyncing // nodeStateFinalizedBlockOutOfSync - node is lagging behind on latest finalized block - nodeStateFinalizedBlockOutOfSync + NodeStateFinalizedBlockOutOfSync // nodeStateLen tracks the number of states - nodeStateLen + NodeStateLen ) // allNodeStates represents all possible states a node can be in -var allNodeStates []nodeState +var allNodeStates []NodeState func init() { - for s := nodeState(0); s < nodeStateLen; s++ { + for s := NodeState(0); s < NodeStateLen; s++ { allNodeStates = append(allNodeStates, s) } } @@ -118,29 +118,29 @@ func init() { // FSM methods // State allows reading the current state of the node. -func (n *node[CHAIN_ID, HEAD, RPC]) State() nodeState { +func (n *node[CHAIN_ID, HEAD, RPC]) State() NodeState { n.stateMu.RLock() defer n.stateMu.RUnlock() return n.recalculateState() } -func (n *node[CHAIN_ID, HEAD, RPC]) getCachedState() nodeState { +func (n *node[CHAIN_ID, HEAD, RPC]) getCachedState() NodeState { n.stateMu.RLock() defer n.stateMu.RUnlock() return n.state } -func (n *node[CHAIN_ID, HEAD, RPC]) recalculateState() nodeState { - if n.state != nodeStateAlive { +func (n *node[CHAIN_ID, HEAD, RPC]) recalculateState() NodeState { + if n.state != NodeStateAlive { return n.state } // double check that node is not lagging on finalized block if n.nodePoolCfg.EnforceRepeatableRead() && n.isFinalizedBlockOutOfSync() { - return nodeStateFinalizedBlockOutOfSync + return NodeStateFinalizedBlockOutOfSync } - return nodeStateAlive + return NodeStateAlive } func (n *node[CHAIN_ID, HEAD, RPC]) isFinalizedBlockOutOfSync() bool { @@ -158,7 +158,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) isFinalizedBlockOutOfSync() bool { } // StateAndLatest returns nodeState with the latest ChainInfo observed by Node during current lifecycle. -func (n *node[CHAIN_ID, HEAD, RPC]) StateAndLatest() (nodeState, ChainInfo) { +func (n *node[CHAIN_ID, HEAD, RPC]) StateAndLatest() (NodeState, ChainInfo) { n.stateMu.RLock() defer n.stateMu.RUnlock() latest, _ := n.rpc.GetInterceptedChainInfo() @@ -178,7 +178,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) SetPoolChainInfoProvider(poolInfoProvider Po // This is low-level; care should be taken by the caller to ensure the new state is a valid transition. // State changes should always be synchronous: only one goroutine at a time should change state. // n.stateMu should not be locked for long periods of time because external clients expect a timely response from n.State() -func (n *node[CHAIN_ID, HEAD, RPC]) setState(s nodeState) { +func (n *node[CHAIN_ID, HEAD, RPC]) setState(s NodeState) { n.stateMu.Lock() defer n.stateMu.Unlock() n.state = s @@ -199,14 +199,14 @@ func (n *node[CHAIN_ID, HEAD, RPC]) transitionToAlive(fn func()) { promPoolRPCNodeTransitionsToAlive.WithLabelValues(n.chainID.String(), n.name).Inc() n.stateMu.Lock() defer n.stateMu.Unlock() - if n.state == nodeStateClosed { + if n.state == NodeStateClosed { return } switch n.state { - case nodeStateDialed, nodeStateInvalidChainID, nodeStateSyncing: - n.state = nodeStateAlive + case NodeStateDialed, NodeStateInvalidChainID, NodeStateSyncing: + n.state = NodeStateAlive default: - panic(transitionFail(n.state, nodeStateAlive)) + panic(transitionFail(n.state, NodeStateAlive)) } fn() } @@ -226,14 +226,14 @@ func (n *node[CHAIN_ID, HEAD, RPC]) transitionToInSync(fn func()) { promPoolRPCNodeTransitionsToInSync.WithLabelValues(n.chainID.String(), n.name).Inc() n.stateMu.Lock() defer n.stateMu.Unlock() - if n.state == nodeStateClosed { + if n.state == NodeStateClosed { return } switch n.state { - case nodeStateOutOfSync, nodeStateSyncing: - n.state = nodeStateAlive + case NodeStateOutOfSync, NodeStateSyncing: + n.state = NodeStateAlive default: - panic(transitionFail(n.state, nodeStateAlive)) + panic(transitionFail(n.state, NodeStateAlive)) } fn() } @@ -252,15 +252,15 @@ func (n *node[CHAIN_ID, HEAD, RPC]) transitionToOutOfSync(fn func()) { promPoolRPCNodeTransitionsToOutOfSync.WithLabelValues(n.chainID.String(), n.name).Inc() n.stateMu.Lock() defer n.stateMu.Unlock() - if n.state == nodeStateClosed { + if n.state == NodeStateClosed { return } switch n.state { - case nodeStateAlive: - n.disconnectAll() - n.state = nodeStateOutOfSync + case NodeStateAlive: + n.unsubscribeAllExceptAliveLoop() + n.state = NodeStateOutOfSync default: - panic(transitionFail(n.state, nodeStateOutOfSync)) + panic(transitionFail(n.state, NodeStateOutOfSync)) } fn() } @@ -277,31 +277,31 @@ func (n *node[CHAIN_ID, HEAD, RPC]) transitionToUnreachable(fn func()) { promPoolRPCNodeTransitionsToUnreachable.WithLabelValues(n.chainID.String(), n.name).Inc() n.stateMu.Lock() defer n.stateMu.Unlock() - if n.state == nodeStateClosed { + if n.state == NodeStateClosed { return } switch n.state { - case nodeStateUndialed, nodeStateDialed, nodeStateAlive, nodeStateOutOfSync, nodeStateInvalidChainID, nodeStateSyncing: - n.disconnectAll() - n.state = nodeStateUnreachable + case NodeStateUndialed, NodeStateDialed, NodeStateAlive, NodeStateOutOfSync, NodeStateInvalidChainID, NodeStateSyncing: + n.unsubscribeAllExceptAliveLoop() + n.state = NodeStateUnreachable default: - panic(transitionFail(n.state, nodeStateUnreachable)) + panic(transitionFail(n.state, NodeStateUnreachable)) } fn() } -func (n *node[CHAIN_ID, HEAD, RPC]) declareState(state nodeState) { - if n.getCachedState() == nodeStateClosed { +func (n *node[CHAIN_ID, HEAD, RPC]) declareState(state NodeState) { + if n.getCachedState() == NodeStateClosed { return } switch state { - case nodeStateInvalidChainID: + case NodeStateInvalidChainID: n.declareInvalidChainID() - case nodeStateUnreachable: + case NodeStateUnreachable: n.declareUnreachable() - case nodeStateSyncing: + case NodeStateSyncing: n.declareSyncing() - case nodeStateAlive: + case NodeStateAlive: n.declareAlive() default: panic(fmt.Sprintf("%#v state declaration is not implemented", state)) @@ -320,15 +320,15 @@ func (n *node[CHAIN_ID, HEAD, RPC]) transitionToInvalidChainID(fn func()) { promPoolRPCNodeTransitionsToInvalidChainID.WithLabelValues(n.chainID.String(), n.name).Inc() n.stateMu.Lock() defer n.stateMu.Unlock() - if n.state == nodeStateClosed { + if n.state == NodeStateClosed { return } switch n.state { - case nodeStateDialed, nodeStateOutOfSync, nodeStateSyncing: - n.disconnectAll() - n.state = nodeStateInvalidChainID + case NodeStateDialed, NodeStateOutOfSync, NodeStateSyncing: + n.unsubscribeAllExceptAliveLoop() + n.state = NodeStateInvalidChainID default: - panic(transitionFail(n.state, nodeStateInvalidChainID)) + panic(transitionFail(n.state, NodeStateInvalidChainID)) } fn() } @@ -345,27 +345,27 @@ func (n *node[CHAIN_ID, HEAD, RPC]) transitionToSyncing(fn func()) { promPoolRPCNodeTransitionsToSyncing.WithLabelValues(n.chainID.String(), n.name).Inc() n.stateMu.Lock() defer n.stateMu.Unlock() - if n.state == nodeStateClosed { + if n.state == NodeStateClosed { return } switch n.state { - case nodeStateDialed, nodeStateOutOfSync, nodeStateInvalidChainID: - n.disconnectAll() - n.state = nodeStateSyncing + case NodeStateDialed, NodeStateOutOfSync, NodeStateInvalidChainID: + n.unsubscribeAllExceptAliveLoop() + n.state = NodeStateSyncing default: - panic(transitionFail(n.state, nodeStateSyncing)) + panic(transitionFail(n.state, NodeStateSyncing)) } if !n.nodePoolCfg.NodeIsSyncingEnabled() { - panic("unexpected transition to nodeStateSyncing, while it's disabled") + panic("unexpected transition to NodeStateSyncing, while it's disabled") } fn() } -func transitionString(state nodeState) string { +func transitionString(state NodeState) string { return fmt.Sprintf("Total number of times node has transitioned to %s", state) } -func transitionFail(from nodeState, to nodeState) string { +func transitionFail(from NodeState, to NodeState) string { return fmt.Sprintf("cannot transition from %#v to %#v", from, to) } diff --git a/common/client/node_fsm_test.go b/common/client/node_fsm_test.go index dc0ca0e7de8..62a5264b32e 100644 --- a/common/client/node_fsm_test.go +++ b/common/client/node_fsm_test.go @@ -5,6 +5,8 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/assert" "github.com/smartcontractkit/chainlink/v2/common/types" @@ -29,68 +31,68 @@ func TestUnit_Node_StateTransitions(t *testing.T) { t.Run("setState", func(t *testing.T) { n := newTestNode(t, testNodeOpts{rpc: nil, config: testNodeConfig{nodeIsSyncingEnabled: true}}) - assert.Equal(t, nodeStateUndialed, n.State()) - n.setState(nodeStateAlive) - assert.Equal(t, nodeStateAlive, n.State()) - n.setState(nodeStateUndialed) - assert.Equal(t, nodeStateUndialed, n.State()) + assert.Equal(t, NodeStateUndialed, n.State()) + n.setState(NodeStateAlive) + assert.Equal(t, NodeStateAlive, n.State()) + n.setState(NodeStateUndialed) + assert.Equal(t, NodeStateUndialed, n.State()) }) t.Run("transitionToAlive", func(t *testing.T) { - const destinationState = nodeStateAlive - allowedStates := []nodeState{nodeStateDialed, nodeStateInvalidChainID, nodeStateSyncing} - rpc := newMockNodeClient[types.ID, Head](t) + const destinationState = NodeStateAlive + allowedStates := []NodeState{NodeStateDialed, NodeStateInvalidChainID, NodeStateSyncing} + rpc := newMockRPCClient[types.ID, Head](t) testTransition(t, rpc, testNode.transitionToAlive, destinationState, allowedStates...) }) t.Run("transitionToInSync", func(t *testing.T) { - const destinationState = nodeStateAlive - allowedStates := []nodeState{nodeStateOutOfSync, nodeStateSyncing} - rpc := newMockNodeClient[types.ID, Head](t) + const destinationState = NodeStateAlive + allowedStates := []NodeState{NodeStateOutOfSync, NodeStateSyncing} + rpc := newMockRPCClient[types.ID, Head](t) testTransition(t, rpc, testNode.transitionToInSync, destinationState, allowedStates...) }) t.Run("transitionToOutOfSync", func(t *testing.T) { - const destinationState = nodeStateOutOfSync - allowedStates := []nodeState{nodeStateAlive} - rpc := newMockNodeClient[types.ID, Head](t) - rpc.On("DisconnectAll").Once() + const destinationState = NodeStateOutOfSync + allowedStates := []NodeState{NodeStateAlive} + rpc := newMockRPCClient[types.ID, Head](t) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) testTransition(t, rpc, testNode.transitionToOutOfSync, destinationState, allowedStates...) }) t.Run("transitionToUnreachable", func(t *testing.T) { - const destinationState = nodeStateUnreachable - allowedStates := []nodeState{nodeStateUndialed, nodeStateDialed, nodeStateAlive, nodeStateOutOfSync, nodeStateInvalidChainID, nodeStateSyncing} - rpc := newMockNodeClient[types.ID, Head](t) - rpc.On("DisconnectAll").Times(len(allowedStates)) + const destinationState = NodeStateUnreachable + allowedStates := []NodeState{NodeStateUndialed, NodeStateDialed, NodeStateAlive, NodeStateOutOfSync, NodeStateInvalidChainID, NodeStateSyncing} + rpc := newMockRPCClient[types.ID, Head](t) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) testTransition(t, rpc, testNode.transitionToUnreachable, destinationState, allowedStates...) }) t.Run("transitionToInvalidChain", func(t *testing.T) { - const destinationState = nodeStateInvalidChainID - allowedStates := []nodeState{nodeStateDialed, nodeStateOutOfSync, nodeStateSyncing} - rpc := newMockNodeClient[types.ID, Head](t) - rpc.On("DisconnectAll").Times(len(allowedStates)) + const destinationState = NodeStateInvalidChainID + allowedStates := []NodeState{NodeStateDialed, NodeStateOutOfSync, NodeStateSyncing} + rpc := newMockRPCClient[types.ID, Head](t) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) testTransition(t, rpc, testNode.transitionToInvalidChainID, destinationState, allowedStates...) }) t.Run("transitionToSyncing", func(t *testing.T) { - const destinationState = nodeStateSyncing - allowedStates := []nodeState{nodeStateDialed, nodeStateOutOfSync, nodeStateInvalidChainID} - rpc := newMockNodeClient[types.ID, Head](t) - rpc.On("DisconnectAll").Times(len(allowedStates)) + const destinationState = NodeStateSyncing + allowedStates := []NodeState{NodeStateDialed, NodeStateOutOfSync, NodeStateInvalidChainID} + rpc := newMockRPCClient[types.ID, Head](t) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) testTransition(t, rpc, testNode.transitionToSyncing, destinationState, allowedStates...) }) t.Run("transitionToSyncing panics if nodeIsSyncing is disabled", func(t *testing.T) { - rpc := newMockNodeClient[types.ID, Head](t) - rpc.On("DisconnectAll").Once() + rpc := newMockRPCClient[types.ID, Head](t) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) node := newTestNode(t, testNodeOpts{rpc: rpc}) - node.setState(nodeStateDialed) + node.setState(NodeStateDialed) fn := new(fnMock) defer fn.AssertNotCalled(t) - assert.PanicsWithValue(t, "unexpected transition to nodeStateSyncing, while it's disabled", func() { + assert.PanicsWithValue(t, "unexpected transition to NodeStateSyncing, while it's disabled", func() { node.transitionToSyncing(fn.Fn) }) }) } -func testTransition(t *testing.T, rpc *mockNodeClient[types.ID, Head], transition func(node testNode, fn func()), destinationState nodeState, allowedStates ...nodeState) { +func testTransition(t *testing.T, rpc *mockRPCClient[types.ID, Head], transition func(node testNode, fn func()), destinationState NodeState, allowedStates ...NodeState) { node := newTestNode(t, testNodeOpts{rpc: rpc, config: testNodeConfig{nodeIsSyncingEnabled: true}}) for _, allowedState := range allowedStates { m := new(fnMock) @@ -101,13 +103,13 @@ func testTransition(t *testing.T, rpc *mockNodeClient[types.ID, Head], transitio } // noop on attempt to transition from Closed state m := new(fnMock) - node.setState(nodeStateClosed) + node.setState(NodeStateClosed) transition(node, m.Fn) m.AssertNotCalled(t) - assert.Equal(t, nodeStateClosed, node.State(), "Expected node to remain in closed state on transition attempt") + assert.Equal(t, NodeStateClosed, node.State(), "Expected node to remain in closed state on transition attempt") for _, nodeState := range allNodeStates { - if slices.Contains(allowedStates, nodeState) || nodeState == nodeStateClosed { + if slices.Contains(allowedStates, nodeState) || nodeState == NodeStateClosed { continue } @@ -124,7 +126,7 @@ func testTransition(t *testing.T, rpc *mockNodeClient[types.ID, Head], transitio func TestNodeState_String(t *testing.T) { t.Run("Ensure all states are meaningful when converted to string", func(t *testing.T) { for _, ns := range allNodeStates { - // ensure that string representation is not nodeState(%d) + // ensure that string representation is not NodeState(%d) assert.NotContains(t, ns.String(), strconv.FormatInt(int64(ns), 10), "Expected node state to have readable name") } }) diff --git a/common/client/node_lifecycle.go b/common/client/node_lifecycle.go index 39e17bb4972..26307a4f32a 100644 --- a/common/client/node_lifecycle.go +++ b/common/client/node_lifecycle.go @@ -7,6 +7,8 @@ import ( "math/big" "time" + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -77,8 +79,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { // sanity check state := n.getCachedState() switch state { - case nodeStateAlive: - case nodeStateClosed: + case NodeStateAlive: + case NodeStateClosed: return default: panic(fmt.Sprintf("aliveLoop can only run for node in Alive state, got: %s", state)) @@ -92,17 +94,22 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { lggr := logger.Sugared(n.lfcLog).Named("Alive").With("noNewHeadsTimeoutThreshold", noNewHeadsTimeoutThreshold, "pollInterval", pollInterval, "pollFailureThreshold", pollFailureThreshold) lggr.Tracew("Alive loop starting", "nodeState", n.getCachedState()) - headsC := make(chan HEAD) - sub, err := n.rpc.SubscribeNewHead(ctx, headsC) + headsC, sub, err := n.rpc.SubscribeToHeads(ctx) if err != nil { lggr.Errorw("Initial subscribe for heads failed", "nodeState", n.getCachedState()) n.declareUnreachable() return } - // TODO: nit fix. If multinode switches primary node before we set sub as AliveSub, sub will be closed and we'll - // falsely transition this node to unreachable state - n.rpc.SetAliveLoopSub(sub) - defer sub.Unsubscribe() + + n.stateMu.Lock() + n.aliveLoopSub = sub + n.stateMu.Unlock() + defer func() { + defer sub.Unsubscribe() + n.stateMu.Lock() + n.aliveLoopSub = nil + n.stateMu.Unlock() + }() var outOfSyncT *time.Ticker var outOfSyncTC <-chan time.Time @@ -131,12 +138,26 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { lggr.Debug("Polling disabled") } - var pollFinalizedHeadCh <-chan time.Time - if n.chainCfg.FinalityTagEnabled() && n.nodePoolCfg.FinalizedBlockPollInterval() > 0 { + var finalizedHeadCh <-chan HEAD + if n.chainCfg.FinalityTagEnabled() { + var finalizedHeadSub types.Subscription lggr.Debugw("Finalized block polling enabled") - pollT := time.NewTicker(n.nodePoolCfg.FinalizedBlockPollInterval()) - defer pollT.Stop() - pollFinalizedHeadCh = pollT.C + finalizedHeadCh, finalizedHeadSub, err = n.rpc.SubscribeToFinalizedHeads(ctx) + if err != nil { + lggr.Errorw("Failed to subscribe to finalized heads", "err", err) + n.declareUnreachable() + return + } + + n.stateMu.Lock() + n.finalizedBlockSub = finalizedHeadSub + n.stateMu.Unlock() + defer func() { + finalizedHeadSub.Unsubscribe() + n.stateMu.Lock() + n.finalizedBlockSub = nil + n.stateMu.Unlock() + }() } localHighestChainInfo, _ := n.rpc.GetInterceptedChainInfo() @@ -148,12 +169,10 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { return case <-pollCh: promPoolRPCNodePolls.WithLabelValues(n.chainID.String(), n.name).Inc() - lggr.Tracew("Polling for version", "nodeState", n.getCachedState(), "pollFailures", pollFailures) - version, err := func(ctx context.Context) (string, error) { - ctx, cancel := context.WithTimeout(ctx, pollInterval) - defer cancel() - return n.RPC().ClientVersion(ctx) - }(ctx) + lggr.Tracew("Pinging RPC", "nodeState", n.State(), "pollFailures", pollFailures) + pollCtx, cancel := context.WithTimeout(ctx, pollInterval) + err := n.RPC().Ping(pollCtx) + cancel() if err != nil { // prevent overflow if pollFailures < math.MaxUint32 { @@ -162,7 +181,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { } lggr.Warnw(fmt.Sprintf("Poll failure, RPC endpoint %s failed to respond properly", n.String()), "err", err, "pollFailures", pollFailures, "nodeState", n.getCachedState()) } else { - lggr.Debugw("Version poll successful", "nodeState", n.getCachedState(), "clientVersion", version) + lggr.Debugw("Ping successful", "nodeState", n.State()) promPoolRPCNodePollsSuccess.WithLabelValues(n.chainID.String(), n.name).Inc() pollFailures = 0 } @@ -232,17 +251,12 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { } n.declareOutOfSync(func(num int64, td *big.Int) bool { return num < localHighestChainInfo.BlockNumber }) return - case <-pollFinalizedHeadCh: - latestFinalized, err := func(ctx context.Context) (HEAD, error) { - ctx, cancel := context.WithTimeout(ctx, n.nodePoolCfg.FinalizedBlockPollInterval()) - defer cancel() - return n.RPC().LatestFinalizedBlock(ctx) - }(ctx) - if err != nil { - lggr.Warnw("Failed to fetch latest finalized block", "err", err) - continue + case latestFinalized, open := <-finalizedHeadCh: + if !open { + lggr.Errorw("Subscription channel unexpectedly closed", "nodeState", n.State()) + n.declareUnreachable() + return } - if !latestFinalized.IsValid() { lggr.Warn("Latest finalized block is not valid") continue @@ -302,8 +316,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td // sanity check state := n.getCachedState() switch state { - case nodeStateOutOfSync: - case nodeStateClosed: + case NodeStateOutOfSync: + case NodeStateClosed: return default: panic(fmt.Sprintf("outOfSyncLoop can only run for node in OutOfSync state, got: %s", state)) @@ -317,15 +331,14 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td // Need to redial since out-of-sync nodes are automatically disconnected state := n.createVerifiedConn(ctx, lggr) - if state != nodeStateAlive { + if state != NodeStateAlive { n.declareState(state) return } lggr.Tracew("Successfully subscribed to heads feed on out-of-sync RPC node", "nodeState", n.getCachedState()) - ch := make(chan HEAD) - sub, err := n.rpc.SubscribeNewHead(ctx, ch) + ch, sub, err := n.rpc.SubscribeToHeads(ctx) if err != nil { lggr.Errorw("Failed to subscribe heads on out-of-sync RPC node", "nodeState", n.getCachedState(), "err", err) n.declareUnreachable() @@ -375,8 +388,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { // sanity check state := n.getCachedState() switch state { - case nodeStateUnreachable: - case nodeStateClosed: + case NodeStateUnreachable: + case NodeStateClosed: return default: panic(fmt.Sprintf("unreachableLoop can only run for node in Unreachable state, got: %s", state)) @@ -403,14 +416,14 @@ func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { continue } - n.setState(nodeStateDialed) + n.setState(NodeStateDialed) state := n.verifyConn(ctx, lggr) switch state { - case nodeStateUnreachable: - n.setState(nodeStateUnreachable) + case NodeStateUnreachable: + n.setState(NodeStateUnreachable) continue - case nodeStateAlive: + case NodeStateAlive: lggr.Infow(fmt.Sprintf("Successfully redialled and verified RPC node %s. Node was offline for %s", n.String(), time.Since(unreachableAt)), "nodeState", n.getCachedState()) fallthrough default: @@ -430,8 +443,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { // sanity check state := n.getCachedState() switch state { - case nodeStateInvalidChainID: - case nodeStateClosed: + case NodeStateInvalidChainID: + case NodeStateClosed: return default: panic(fmt.Sprintf("invalidChainIDLoop can only run for node in InvalidChainID state, got: %s", state)) @@ -444,7 +457,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { // Need to redial since invalid chain ID nodes are automatically disconnected state := n.createVerifiedConn(ctx, lggr) - if state != nodeStateInvalidChainID { + if state != NodeStateInvalidChainID { n.declareState(state) return } @@ -460,9 +473,9 @@ func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { case <-time.After(chainIDRecheckBackoff.Duration()): state := n.verifyConn(ctx, lggr) switch state { - case nodeStateInvalidChainID: + case NodeStateInvalidChainID: continue - case nodeStateAlive: + case NodeStateAlive: lggr.Infow(fmt.Sprintf("Successfully verified RPC node. Node was offline for %s", time.Since(invalidAt)), "nodeState", n.getCachedState()) fallthrough default: @@ -482,11 +495,11 @@ func (n *node[CHAIN_ID, HEAD, RPC]) syncingLoop() { // sanity check state := n.getCachedState() switch state { - case nodeStateSyncing: - case nodeStateClosed: + case NodeStateSyncing: + case NodeStateClosed: return default: - panic(fmt.Sprintf("syncingLoop can only run for node in nodeStateSyncing state, got: %s", state)) + panic(fmt.Sprintf("syncingLoop can only run for node in NodeStateSyncing state, got: %s", state)) } } @@ -496,7 +509,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) syncingLoop() { lggr.Debugw(fmt.Sprintf("Periodically re-checking RPC node %s with syncing status", n.String()), "nodeState", n.getCachedState()) // Need to redial since syncing nodes are automatically disconnected state := n.createVerifiedConn(ctx, lggr) - if state != nodeStateSyncing { + if state != NodeStateSyncing { n.declareState(state) return } diff --git a/common/client/node_lifecycle_test.go b/common/client/node_lifecycle_test.go index 863a15a1fad..081c8374090 100644 --- a/common/client/node_lifecycle_test.go +++ b/common/client/node_lifecycle_test.go @@ -29,37 +29,37 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node := newTestNode(t, opts) opts.rpc.On("Close").Return(nil).Once() - node.setState(nodeStateDialed) + node.setState(NodeStateDialed) return node } t.Run("returns on closed", func(t *testing.T) { node := newTestNode(t, testNodeOpts{}) - node.setState(nodeStateClosed) + node.setState(NodeStateClosed) node.wg.Add(1) node.aliveLoop() }) t.Run("if initial subscribe fails, transitions to unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) node := newDialedNode(t, testNodeOpts{ rpc: rpc, }) defer func() { assert.NoError(t, node.close()) }() expectedError := errors.New("failed to subscribe to rpc") - rpc.On("DisconnectAll").Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, expectedError).Once() + rpc.On("SubscribeToHeads", mock.Anything).Return(nil, nil, expectedError).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) // might be called in unreachable loop rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() node.declareAlive() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("if remote RPC connection is closed transitions to unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) lggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) node := newDialedNode(t, testNodeOpts{ @@ -74,28 +74,25 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { close(errChan) sub.On("Err").Return((<-chan error)(errChan)).Once() sub.On("Unsubscribe").Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(sub, nil).Once() - rpc.On("SetAliveLoopSub", sub).Once() - // disconnects all on transfer to unreachable - rpc.On("DisconnectAll").Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) + rpc.On("SubscribeToHeads", mock.Anything).Return(nil, sub, nil).Once() // might be called in unreachable loop rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() node.declareAlive() tests.AssertLogEventually(t, observedLogs, "Subscription was terminated") - assert.Equal(t, nodeStateUnreachable, node.State()) + assert.Equal(t, NodeStateUnreachable, node.State()) }) newSubscribedNode := func(t *testing.T, opts testNodeOpts) testNode { sub := mocks.NewSubscription(t) - sub.On("Err").Return((<-chan error)(nil)) + sub.On("Err").Return(nil) sub.On("Unsubscribe").Once() - opts.rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(sub, nil).Once() - opts.rpc.On("SetAliveLoopSub", sub).Once() + opts.rpc.On("SubscribeToHeads", mock.Anything).Return(make(<-chan Head), sub, nil).Once() return newDialedNode(t, opts) } t.Run("Stays alive and waits for signal", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newSubscribedNode(t, testNodeOpts{ @@ -107,11 +104,11 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node.declareAlive() tests.AssertLogEventually(t, observedLogs, "Head liveness checking disabled") tests.AssertLogEventually(t, observedLogs, "Polling disabled") - assert.Equal(t, nodeStateAlive, node.State()) + assert.Equal(t, NodeStateAlive, node.State()) }) t.Run("stays alive while below pollFailureThreshold and resets counter on success", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}) lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) const pollFailureThreshold = 3 @@ -127,33 +124,33 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { pollError := errors.New("failed to get ClientVersion") // 1. Return error several times, but below threshold - rpc.On("ClientVersion", mock.Anything).Return("", pollError).Run(func(_ mock.Arguments) { + rpc.On("Ping", mock.Anything).Return(pollError).Run(func(_ mock.Arguments) { // stays healthy while below threshold - assert.Equal(t, nodeStateAlive, node.State()) + assert.Equal(t, NodeStateAlive, node.State()) }).Times(pollFailureThreshold - 1) // 2. Successful call that is expected to reset counter - rpc.On("ClientVersion", mock.Anything).Return("client_version", nil).Once() + rpc.On("Ping", mock.Anything).Return(nil).Once() // 3. Return error. If we have not reset the timer, we'll transition to nonAliveState - rpc.On("ClientVersion", mock.Anything).Return("", pollError).Once() + rpc.On("Ping", mock.Anything).Return(pollError).Once() // 4. Once during the call, check if node is alive var ensuredAlive atomic.Bool - rpc.On("ClientVersion", mock.Anything).Return("client_version", nil).Run(func(_ mock.Arguments) { + rpc.On("Ping", mock.Anything).Return(nil).Run(func(_ mock.Arguments) { if ensuredAlive.Load() { return } ensuredAlive.Store(true) - assert.Equal(t, nodeStateAlive, node.State()) + assert.Equal(t, NodeStateAlive, node.State()) }).Once() // redundant call to stay in alive state - rpc.On("ClientVersion", mock.Anything).Return("client_version", nil) + rpc.On("Ping", mock.Anything).Return(nil) node.declareAlive() tests.AssertLogCountEventually(t, observedLogs, fmt.Sprintf("Poll failure, RPC endpoint %s failed to respond properly", node.String()), pollFailureThreshold) - tests.AssertLogCountEventually(t, observedLogs, "Version poll successful", 2) + tests.AssertLogCountEventually(t, observedLogs, "Ping successful", 2) assert.True(t, ensuredAlive.Load(), "expected to ensure that node was alive") }) t.Run("with threshold poll failures, transitions to unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}) lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) const pollFailureThreshold = 3 @@ -167,20 +164,20 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() pollError := errors.New("failed to get ClientVersion") - rpc.On("ClientVersion", mock.Anything).Return("", pollError) + rpc.On("Ping", mock.Anything).Return(pollError) // disconnects all on transfer to unreachable - rpc.On("DisconnectAll").Once() // might be called in unreachable loop rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) node.declareAlive() tests.AssertLogCountEventually(t, observedLogs, fmt.Sprintf("Poll failure, RPC endpoint %s failed to respond properly", node.String()), pollFailureThreshold) tests.AssertEventually(t, func() bool { - return nodeStateUnreachable == node.State() + return NodeStateUnreachable == node.State() }) }) t.Run("with threshold poll failures, but we are the last node alive, forcibly keeps it alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) const pollFailureThreshold = 3 node := newSubscribedNode(t, testNodeOpts{ @@ -199,14 +196,14 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node.SetPoolChainInfoProvider(poolInfo) rpc.On("GetInterceptedChainInfo").Return(ChainInfo{BlockNumber: 20}, ChainInfo{BlockNumber: 20}) pollError := errors.New("failed to get ClientVersion") - rpc.On("ClientVersion", mock.Anything).Return("", pollError) + rpc.On("Ping", mock.Anything).Return(pollError) node.declareAlive() tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("RPC endpoint failed to respond to %d consecutive polls", pollFailureThreshold)) - assert.Equal(t, nodeStateAlive, node.State()) + assert.Equal(t, NodeStateAlive, node.State()) }) t.Run("when behind more than SyncThreshold, transitions to out of sync", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) const syncThreshold = 10 node := newSubscribedNode(t, testNodeOpts{ @@ -219,6 +216,8 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { lggr: lggr, }) defer func() { assert.NoError(t, node.close()) }() + rpc.On("Ping", mock.Anything).Return(nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) const mostRecentBlock = 20 rpc.On("GetInterceptedChainInfo").Return(ChainInfo{BlockNumber: mostRecentBlock}, ChainInfo{BlockNumber: 30}) poolInfo := newMockPoolChainInfoProvider(t) @@ -227,23 +226,22 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { TotalDifficulty: big.NewInt(10), }).Once() node.SetPoolChainInfoProvider(poolInfo) - rpc.On("ClientVersion", mock.Anything).Return("", nil) // tries to redial in outOfSync rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Run(func(_ mock.Arguments) { - assert.Equal(t, nodeStateOutOfSync, node.State()) + assert.Equal(t, NodeStateOutOfSync, node.State()) }).Once() // disconnects all on transfer to unreachable or outOfSync - rpc.On("DisconnectAll").Maybe() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Maybe() // might be called in unreachable loop rpc.On("Dial", mock.Anything).Run(func(_ mock.Arguments) { - require.Equal(t, nodeStateOutOfSync, node.State()) + require.Equal(t, NodeStateOutOfSync, node.State()) }).Return(errors.New("failed to dial")).Maybe() node.declareAlive() tests.AssertLogEventually(t, observedLogs, "Dial failed: Node is unreachable") }) t.Run("when behind more than SyncThreshold but we are the last live node, forcibly stays alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) const syncThreshold = 10 node := newSubscribedNode(t, testNodeOpts{ @@ -256,6 +254,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { lggr: lggr, }) defer func() { assert.NoError(t, node.close()) }() + rpc.On("Ping", mock.Anything).Return(nil) const mostRecentBlock = 20 rpc.On("GetInterceptedChainInfo").Return(ChainInfo{BlockNumber: mostRecentBlock}, ChainInfo{BlockNumber: 30}) poolInfo := newMockPoolChainInfoProvider(t) @@ -264,13 +263,12 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { TotalDifficulty: big.NewInt(10), }).Once() node.SetPoolChainInfoProvider(poolInfo) - rpc.On("ClientVersion", mock.Anything).Return("", nil) node.declareAlive() tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("RPC endpoint has fallen behind; %s %s", msgCannotDisable, msgDegradedState)) }) t.Run("when behind but SyncThreshold=0, stay alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newSubscribedNode(t, testNodeOpts{ config: testNodeConfig{ @@ -282,17 +280,18 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { lggr: lggr, }) defer func() { assert.NoError(t, node.close()) }() + rpc.On("Ping", mock.Anything).Return(nil) const mostRecentBlock = 20 rpc.On("GetInterceptedChainInfo").Return(ChainInfo{BlockNumber: mostRecentBlock}, ChainInfo{BlockNumber: 30}) - rpc.On("ClientVersion", mock.Anything).Return("", nil) node.declareAlive() - tests.AssertLogCountEventually(t, observedLogs, "Version poll successful", 2) - assert.Equal(t, nodeStateAlive, node.State()) + tests.AssertLogCountEventually(t, observedLogs, "Ping successful", 2) + assert.Equal(t, NodeStateAlive, node.State()) }) t.Run("when no new heads received for threshold, transitions to out of sync", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) node := newSubscribedNode(t, testNodeOpts{ config: testNodeConfig{}, chainConfig: clientMocks.ChainConfig{ @@ -303,22 +302,22 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() // tries to redial in outOfSync rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Run(func(_ mock.Arguments) { - assert.Equal(t, nodeStateOutOfSync, node.State()) + assert.Equal(t, NodeStateOutOfSync, node.State()) }).Once() // disconnects all on transfer to unreachable or outOfSync - rpc.On("DisconnectAll").Maybe() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Maybe() // might be called in unreachable loop rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() node.declareAlive() tests.AssertEventually(t, func() bool { // right after outOfSync we'll transfer to unreachable due to returned error on Dial // we check that we were in out of sync state on first Dial call - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("when no new heads received for threshold but we are the last live node, forcibly stays alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newSubscribedNode(t, testNodeOpts{ @@ -338,20 +337,20 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node.SetPoolChainInfoProvider(poolInfo) node.declareAlive() tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("RPC endpoint detected out of sync; %s %s", msgCannotDisable, msgDegradedState)) - assert.Equal(t, nodeStateAlive, node.State()) + assert.Equal(t, NodeStateAlive, node.State()) }) t.Run("rpc closed head channel", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) - rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() + rpc := newMockRPCClient[types.ID, Head](t) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) sub := mocks.NewSubscription(t) sub.On("Err").Return((<-chan error)(nil)) sub.On("Unsubscribe").Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - ch := args.Get(1).(chan<- Head) + ch := make(chan Head) + rpc.On("SubscribeToHeads", mock.Anything).Run(func(args mock.Arguments) { + rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() close(ch) - }).Return(sub, nil).Once() - rpc.On("SetAliveLoopSub", sub).Once() + }).Return((<-chan Head)(ch), sub, nil).Once() lggr, observedLogs := logger.TestObserved(t, zap.ErrorLevel) node := newDialedNode(t, testNodeOpts{ lggr: lggr, @@ -363,28 +362,26 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() // disconnects all on transfer to unreachable or outOfSync - rpc.On("DisconnectAll").Once() // might be called in unreachable loop rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() node.declareAlive() tests.AssertLogEventually(t, observedLogs, "Subscription channel unexpectedly closed") - assert.Equal(t, nodeStateUnreachable, node.State()) + assert.Equal(t, NodeStateUnreachable, node.State()) }) t.Run("If finality tag is not enabled updates finalized block metric using finality depth and latest head", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) sub := mocks.NewSubscription(t) - sub.On("Err").Return((<-chan error)(nil)) + sub.On("Err").Return(nil) sub.On("Unsubscribe").Once() const blockNumber = 1000 const finalityDepth = 10 const expectedBlock = 990 - rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - ch := args.Get(1).(chan<- Head) + ch := make(chan Head) + rpc.On("SubscribeToHeads", mock.Anything).Run(func(args mock.Arguments) { + rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() go writeHeads(t, ch, head{BlockNumber: blockNumber - 1}, head{BlockNumber: blockNumber}, head{BlockNumber: blockNumber - 1}) - }).Return(sub, nil).Once() - rpc.On("SetAliveLoopSub", sub).Once() + }).Return((<-chan Head)(ch), sub, nil).Once() name := "node-" + rand.Str(5) node := newDialedNode(t, testNodeOpts{ config: testNodeConfig{}, @@ -403,16 +400,16 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { return float64(expectedBlock) == m.Gauge.GetValue() }) }) - t.Run("Logs warning if failed to get finalized block", func(t *testing.T) { + t.Run("Logs warning if failed to subscrive to latest finalized blocks", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) - rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() - rpc.On("LatestFinalizedBlock", mock.Anything).Return(newMockHead(t), errors.New("failed to get finalized block")) + rpc := newMockRPCClient[types.ID, Head](t) sub := mocks.NewSubscription(t) - sub.On("Err").Return((<-chan error)(nil)) - sub.On("Unsubscribe").Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(sub, nil).Once() - rpc.On("SetAliveLoopSub", sub).Once() + sub.On("Err").Return(nil).Maybe() + sub.On("Unsubscribe") + rpc.On("SubscribeToHeads", mock.Anything).Return(make(<-chan Head), sub, nil).Once() + expectedError := errors.New("failed to subscribe to finalized heads") + rpc.On("SubscribeToFinalizedHeads", mock.Anything).Return(nil, sub, expectedError).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Maybe() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newDialedNode(t, testNodeOpts{ config: testNodeConfig{ @@ -426,20 +423,23 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() node.declareAlive() - tests.AssertLogEventually(t, observedLogs, "Failed to fetch latest finalized block") + tests.AssertLogEventually(t, observedLogs, "Failed to subscribe to finalized heads") }) t.Run("Logs warning if latest finalized block is not valid", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) + sub := mocks.NewSubscription(t) + sub.On("Err").Return(nil) + sub.On("Unsubscribe") + rpc.On("SubscribeToHeads", mock.Anything).Return(make(<-chan Head), sub, nil).Once() + ch := make(chan Head, 1) head := newMockHead(t) head.On("IsValid").Return(false) - rpc.On("LatestFinalizedBlock", mock.Anything).Return(head, nil) + rpc.On("SubscribeToFinalizedHeads", mock.Anything).Run(func(args mock.Arguments) { + ch <- head + }).Return((<-chan Head)(ch), sub, nil).Once() + rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() - sub := mocks.NewSubscription(t) - sub.On("Err").Return((<-chan error)(nil)) - sub.On("Unsubscribe").Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(sub, nil).Once() - rpc.On("SetAliveLoopSub", sub).Once() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newDialedNode(t, testNodeOpts{ config: testNodeConfig{ @@ -455,24 +455,20 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { node.declareAlive() tests.AssertLogEventually(t, observedLogs, "Latest finalized block is not valid") }) - t.Run("If finality tag and finalized block polling are enabled updates latest finalized block metric", func(t *testing.T) { + t.Run("If finality tag is enabled updates latest finalized block metric", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) const expectedBlock = 1101 const finalityDepth = 10 - rpc.On("LatestFinalizedBlock", mock.Anything).Return(head{BlockNumber: expectedBlock - 1}.ToMockHead(t), nil).Once() - rpc.On("LatestFinalizedBlock", mock.Anything).Return(head{BlockNumber: expectedBlock}.ToMockHead(t), nil) sub := mocks.NewSubscription(t) - sub.On("Err").Return((<-chan error)(nil)) - sub.On("Unsubscribe").Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - ch := args.Get(1).(chan<- Head) - // ensure that "calculated" finalized head is larger than actual, to ensure we are correctly setting - // the metric - go writeHeads(t, ch, head{BlockNumber: expectedBlock*2 + finalityDepth}) - }).Return(sub, nil).Once() + sub.On("Err").Return(nil) + sub.On("Unsubscribe") + ch := make(chan Head, 1) + rpc.On("SubscribeToHeads", mock.Anything).Return(make(<-chan Head), sub, nil).Once() + rpc.On("SubscribeToFinalizedHeads", mock.Anything).Run(func(args mock.Arguments) { + go writeHeads(t, ch, head{BlockNumber: expectedBlock - 1}, head{BlockNumber: expectedBlock}) + }).Return((<-chan Head)(ch), sub, nil).Once() rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Once() - rpc.On("SetAliveLoopSub", sub).Once() name := "node-" + rand.Str(5) node := newDialedNode(t, testNodeOpts{ config: testNodeConfig{ @@ -522,12 +518,13 @@ func writeHeads(t *testing.T, ch chan<- Head, heads ...head) { } } -func setupRPCForAliveLoop(t *testing.T, rpc *mockNodeClient[types.ID, Head]) { +func setupRPCForAliveLoop(t *testing.T, rpc *mockRPCClient[types.ID, Head]) { rpc.On("Dial", mock.Anything).Return(nil).Maybe() aliveSubscription := mocks.NewSubscription(t) - aliveSubscription.On("Err").Return((<-chan error)(nil)).Maybe() + aliveSubscription.On("Err").Return(nil).Maybe() aliveSubscription.On("Unsubscribe").Maybe() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(aliveSubscription, nil).Maybe() + rpc.On("SubscribeToHeads", mock.Anything).Return(make(<-chan Head), aliveSubscription, nil).Maybe() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Maybe() rpc.On("SetAliveLoopSub", mock.Anything).Maybe() rpc.On("GetInterceptedChainInfo").Return(ChainInfo{}, ChainInfo{}).Maybe() } @@ -539,8 +536,7 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { node := newTestNode(t, opts) opts.rpc.On("Close").Return(nil).Once() // disconnects all on transfer to unreachable or outOfSync - opts.rpc.On("DisconnectAll") - node.setState(nodeStateAlive) + node.setState(NodeStateAlive) return node } @@ -551,13 +547,13 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { t.Run("returns on closed", func(t *testing.T) { t.Parallel() node := newTestNode(t, testNodeOpts{}) - node.setState(nodeStateClosed) + node.setState(NodeStateClosed) node.wg.Add(1) node.outOfSyncLoop(stubIsOutOfSync) }) t.Run("on old blocks stays outOfSync and returns on close", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newAliveNode(t, testNodeOpts{ @@ -569,26 +565,28 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) outOfSyncSubscription := mocks.NewSubscription(t) outOfSyncSubscription.On("Err").Return((<-chan error)(nil)) outOfSyncSubscription.On("Unsubscribe").Once() heads := []head{{BlockNumber: 7}, {BlockNumber: 11}, {BlockNumber: 13}} - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - ch := args.Get(1).(chan<- Head) + ch := make(chan Head) + rpc.On("SubscribeToHeads", mock.Anything).Run(func(args mock.Arguments) { go writeHeads(t, ch, heads...) - }).Return(outOfSyncSubscription, nil).Once() + }).Return((<-chan Head)(ch), outOfSyncSubscription, nil).Once() + rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() node.declareOutOfSync(func(num int64, td *big.Int) bool { return true }) tests.AssertLogCountEventually(t, observedLogs, msgReceivedBlock, len(heads)) - assert.Equal(t, nodeStateOutOfSync, node.State()) + assert.Equal(t, NodeStateOutOfSync, node.State()) }) t.Run("if initial dial fails, transitions to unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) node := newAliveNode(t, testNodeOpts{ rpc: rpc, }) @@ -597,14 +595,15 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { expectedError := errors.New("failed to dial rpc") // might be called again in unreachable loop, so no need to set once rpc.On("Dial", mock.Anything).Return(expectedError) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) node.declareOutOfSync(stubIsOutOfSync) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("if fail to get chainID, transitions to unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) node := newAliveNode(t, testNodeOpts{ rpc: rpc, }) @@ -614,17 +613,18 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil).Once() // for unreachable rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) expectedError := errors.New("failed to get chain ID") // might be called multiple times rpc.On("ChainID", mock.Anything).Return(types.NewIDFromInt(0), expectedError) node.declareOutOfSync(stubIsOutOfSync) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("if chainID does not match, transitions to invalidChainID", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) rpcChainID := types.NewIDFromInt(11) node := newAliveNode(t, testNodeOpts{ @@ -635,16 +635,17 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { // one for out-of-sync & one for invalid chainID rpc.On("Dial", mock.Anything).Return(nil).Twice() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) // might be called multiple times rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) node.declareOutOfSync(stubIsOutOfSync) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateInvalidChainID + return node.State() == NodeStateInvalidChainID }) }) t.Run("if syncing, transitions to syncing", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) node := newAliveNode(t, testNodeOpts{ rpc: rpc, @@ -655,16 +656,17 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil) rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) // might be called multiple times rpc.On("IsSyncing", mock.Anything).Return(true, nil) node.declareOutOfSync(stubIsOutOfSync) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateSyncing + return node.State() == NodeStateSyncing }) }) t.Run("if fails to fetch syncing status, transitions to unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) node := newAliveNode(t, testNodeOpts{ rpc: rpc, @@ -675,6 +677,7 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { // one for out-of-sync rpc.On("Dial", mock.Anything).Return(nil).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) // for unreachable rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() @@ -682,12 +685,12 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { rpc.On("IsSyncing", mock.Anything).Return(false, errors.New("failed to check syncing")) node.declareOutOfSync(stubIsOutOfSync) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("if fails to subscribe, becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() node := newAliveNode(t, testNodeOpts{ rpc: rpc, @@ -698,16 +701,17 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() expectedError := errors.New("failed to subscribe") - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, expectedError) + rpc.On("SubscribeToHeads", mock.Anything).Return(nil, nil, expectedError).Once() rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) node.declareOutOfSync(stubIsOutOfSync) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("on subscription termination becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.ErrorLevel) node := newAliveNode(t, testNodeOpts{ @@ -719,23 +723,23 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() - + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) sub := mocks.NewSubscription(t) errChan := make(chan error, 1) errChan <- errors.New("subscription was terminate") sub.On("Err").Return((<-chan error)(errChan)) sub.On("Unsubscribe").Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(sub, nil).Once() + rpc.On("SubscribeToHeads", mock.Anything).Return(make(<-chan Head), sub, nil).Once() rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() node.declareOutOfSync(stubIsOutOfSync) tests.AssertLogEventually(t, observedLogs, "Subscription was terminated") tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("becomes unreachable if head channel is closed", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.ErrorLevel) node := newAliveNode(t, testNodeOpts{ @@ -747,25 +751,26 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) sub := mocks.NewSubscription(t) sub.On("Err").Return((<-chan error)(nil)) sub.On("Unsubscribe").Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - ch := args.Get(1).(chan<- Head) + ch := make(chan Head) + rpc.On("SubscribeToHeads", mock.Anything).Run(func(args mock.Arguments) { close(ch) - }).Return(sub, nil).Once() + }).Return((<-chan Head)(ch), sub, nil).Once() rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() node.declareOutOfSync(stubIsOutOfSync) tests.AssertLogEventually(t, observedLogs, "Subscription channel unexpectedly closed") tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("becomes alive if it receives a newer head", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newAliveNode(t, testNodeOpts{ @@ -777,17 +782,17 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) outOfSyncSubscription := mocks.NewSubscription(t) outOfSyncSubscription.On("Err").Return((<-chan error)(nil)) outOfSyncSubscription.On("Unsubscribe").Once() const highestBlock = 1000 - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - ch := args.Get(1).(chan<- Head) + ch := make(chan Head) + rpc.On("SubscribeToHeads", mock.Anything).Run(func(args mock.Arguments) { go writeHeads(t, ch, head{BlockNumber: highestBlock - 1}, head{BlockNumber: highestBlock}) - }).Return(outOfSyncSubscription, nil).Once() + }).Return((<-chan Head)(ch), outOfSyncSubscription, nil).Once() rpc.On("GetInterceptedChainInfo").Return(ChainInfo{BlockNumber: highestBlock}, ChainInfo{BlockNumber: highestBlock}) - setupRPCForAliveLoop(t, rpc) node.declareOutOfSync(func(num int64, td *big.Int) bool { @@ -796,12 +801,12 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { tests.AssertLogEventually(t, observedLogs, msgReceivedBlock) tests.AssertLogEventually(t, observedLogs, msgInSync) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateAlive + return node.State() == NodeStateAlive }) }) t.Run("becomes alive if there is no other nodes", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newAliveNode(t, testNodeOpts{ @@ -823,18 +828,18 @@ func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) outOfSyncSubscription := mocks.NewSubscription(t) outOfSyncSubscription.On("Err").Return((<-chan error)(nil)) outOfSyncSubscription.On("Unsubscribe").Once() - rpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(outOfSyncSubscription, nil).Once() - + rpc.On("SubscribeToHeads", mock.Anything).Return(make(<-chan Head), outOfSyncSubscription, nil).Once() setupRPCForAliveLoop(t, rpc) node.declareOutOfSync(stubIsOutOfSync) tests.AssertLogEventually(t, observedLogs, "RPC endpoint is still out of sync, but there are no other available nodes. This RPC node will be forcibly moved back into the live pool in a degraded state") tests.AssertEventually(t, func() bool { - return node.State() == nodeStateAlive + return node.State() == NodeStateAlive }) }) } @@ -846,21 +851,20 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { node := newTestNode(t, opts) opts.rpc.On("Close").Return(nil).Once() // disconnects all on transfer to unreachable - opts.rpc.On("DisconnectAll") - node.setState(nodeStateAlive) + node.setState(NodeStateAlive) return node } t.Run("returns on closed", func(t *testing.T) { t.Parallel() node := newTestNode(t, testNodeOpts{}) - node.setState(nodeStateClosed) + node.setState(NodeStateClosed) node.wg.Add(1) node.unreachableLoop() }) t.Run("on failed redial, keeps trying", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newAliveNode(t, testNodeOpts{ @@ -871,12 +875,13 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) node.declareUnreachable() tests.AssertLogCountEventually(t, observedLogs, "Failed to redial RPC node; still unreachable", 2) }) t.Run("on failed chainID verification, keep trying", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newAliveNode(t, testNodeOpts{ @@ -887,15 +892,16 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Once() rpc.On("ChainID", mock.Anything).Run(func(_ mock.Arguments) { - assert.Equal(t, nodeStateDialed, node.State()) + assert.Equal(t, NodeStateDialed, node.State()) }).Return(nodeChainID, errors.New("failed to get chain id")) node.declareUnreachable() tests.AssertLogCountEventually(t, observedLogs, "Failed to verify chain ID for node", 2) }) t.Run("on chain ID mismatch transitions to invalidChainID", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) rpcChainID := types.NewIDFromInt(11) node := newAliveNode(t, testNodeOpts{ @@ -906,14 +912,16 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil) rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) + node.declareUnreachable() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateInvalidChainID + return node.State() == NodeStateInvalidChainID }) }) t.Run("on syncing status check failure, keeps trying", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newAliveNode(t, testNodeOpts{ @@ -925,8 +933,9 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Once() rpc.On("ChainID", mock.Anything).Run(func(_ mock.Arguments) { - assert.Equal(t, nodeStateDialed, node.State()) + assert.Equal(t, NodeStateDialed, node.State()) }).Return(nodeChainID, nil) rpc.On("IsSyncing", mock.Anything).Return(false, errors.New("failed to check syncing status")) node.declareUnreachable() @@ -934,7 +943,7 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { }) t.Run("on syncing, transitions to syncing state", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() node := newAliveNode(t, testNodeOpts{ rpc: rpc, @@ -946,17 +955,18 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil) rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) rpc.On("IsSyncing", mock.Anything).Return(true, nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) setupRPCForAliveLoop(t, rpc) node.declareUnreachable() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateSyncing + return node.State() == NodeStateSyncing }) }) t.Run("on successful verification becomes alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() node := newAliveNode(t, testNodeOpts{ rpc: rpc, @@ -965,20 +975,18 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() - rpc.On("Dial", mock.Anything).Return(nil) rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) rpc.On("IsSyncing", mock.Anything).Return(false, nil) - setupRPCForAliveLoop(t, rpc) node.declareUnreachable() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateAlive + return node.State() == NodeStateAlive }) }) t.Run("on successful verification without isSyncing becomes alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() node := newAliveNode(t, testNodeOpts{ rpc: rpc, @@ -988,12 +996,13 @@ func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil) rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Once() setupRPCForAliveLoop(t, rpc) node.declareUnreachable() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateAlive + return node.State() == NodeStateAlive }) }) } @@ -1003,21 +1012,20 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { newDialedNode := func(t *testing.T, opts testNodeOpts) testNode { node := newTestNode(t, opts) opts.rpc.On("Close").Return(nil).Once() - opts.rpc.On("DisconnectAll") - node.setState(nodeStateDialed) + node.setState(NodeStateDialed) return node } t.Run("returns on closed", func(t *testing.T) { t.Parallel() node := newTestNode(t, testNodeOpts{}) - node.setState(nodeStateClosed) + node.setState(NodeStateClosed) node.wg.Add(1) node.invalidChainIDLoop() }) t.Run("on invalid dial becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() node := newDialedNode(t, testNodeOpts{ rpc: rpc, @@ -1026,14 +1034,16 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) + node.declareInvalidChainID() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("on failed chainID call becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newDialedNode(t, testNodeOpts{ @@ -1047,15 +1057,17 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { // once for chainID and maybe another one for unreachable rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) + node.declareInvalidChainID() tests.AssertLogEventually(t, observedLogs, "Failed to verify chain ID for node") tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("on chainID mismatch keeps trying", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) rpcChainID := types.NewIDFromInt(11) lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) @@ -1068,15 +1080,17 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) + node.declareInvalidChainID() tests.AssertLogCountEventually(t, observedLogs, "Failed to verify RPC node; remote endpoint returned the wrong chain ID", 2) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateInvalidChainID + return node.State() == NodeStateInvalidChainID }) }) t.Run("on successful verification without isSyncing becomes alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) rpcChainID := types.NewIDFromInt(11) node := newDialedNode(t, testNodeOpts{ @@ -1085,20 +1099,18 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() - rpc.On("Dial", mock.Anything).Return(nil).Once() + setupRPCForAliveLoop(t, rpc) rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil).Once() rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() - setupRPCForAliveLoop(t, rpc) - node.declareInvalidChainID() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateAlive + return node.State() == NodeStateAlive }) }) t.Run("on successful verification becomes alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) rpcChainID := types.NewIDFromInt(11) node := newDialedNode(t, testNodeOpts{ @@ -1108,7 +1120,6 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() - rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil).Once() rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() rpc.On("IsSyncing", mock.Anything).Return(false, nil).Once() @@ -1117,7 +1128,7 @@ func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { node.declareInvalidChainID() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateAlive + return node.State() == NodeStateAlive }) }) } @@ -1127,13 +1138,14 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { newNode := func(t *testing.T, opts testNodeOpts) testNode { node := newTestNode(t, opts) + opts.rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Maybe() opts.rpc.On("Close").Return(nil).Once() return node } t.Run("if fails on initial dial, becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newNode(t, testNodeOpts{ @@ -1144,18 +1156,16 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")) - // disconnects all on transfer to unreachable - rpc.On("DisconnectAll") err := node.Start(tests.Context(t)) assert.NoError(t, err) tests.AssertLogEventually(t, observedLogs, "Dial failed: Node is unreachable") tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("if chainID verification fails, becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newNode(t, testNodeOpts{ @@ -1166,21 +1176,21 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) rpc.On("ChainID", mock.Anything).Run(func(_ mock.Arguments) { - assert.Equal(t, nodeStateDialed, node.State()) + assert.Equal(t, NodeStateDialed, node.State()) }).Return(nodeChainID, errors.New("failed to get chain id")) // disconnects all on transfer to unreachable - rpc.On("DisconnectAll") err := node.Start(tests.Context(t)) assert.NoError(t, err) tests.AssertLogEventually(t, observedLogs, "Failed to verify chain ID for node") tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("on chain ID mismatch transitions to invalidChainID", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) rpcChainID := types.NewIDFromInt(11) node := newNode(t, testNodeOpts{ @@ -1190,18 +1200,18 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) // disconnects all on transfer to unreachable - rpc.On("DisconnectAll") err := node.Start(tests.Context(t)) assert.NoError(t, err) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateInvalidChainID + return node.State() == NodeStateInvalidChainID }) }) t.Run("if syncing verification fails, becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newNode(t, testNodeOpts{ @@ -1213,24 +1223,24 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(nil).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) rpc.On("ChainID", mock.Anything).Run(func(_ mock.Arguments) { - assert.Equal(t, nodeStateDialed, node.State()) + assert.Equal(t, NodeStateDialed, node.State()) }).Return(nodeChainID, nil).Once() rpc.On("IsSyncing", mock.Anything).Return(false, errors.New("failed to check syncing status")) // disconnects all on transfer to unreachable - rpc.On("DisconnectAll") // fail to redial to stay in unreachable state rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")) err := node.Start(tests.Context(t)) assert.NoError(t, err) tests.AssertLogEventually(t, observedLogs, "Unexpected error while verifying RPC node synchronization status") tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("on isSyncing transitions to syncing", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) node := newNode(t, testNodeOpts{ rpc: rpc, @@ -1240,19 +1250,19 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) rpc.On("IsSyncing", mock.Anything).Return(true, nil) // disconnects all on transfer to unreachable - rpc.On("DisconnectAll") err := node.Start(tests.Context(t)) assert.NoError(t, err) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateSyncing + return node.State() == NodeStateSyncing }) }) t.Run("on successful verification becomes alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() node := newNode(t, testNodeOpts{ rpc: rpc, @@ -1261,21 +1271,19 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() - rpc.On("Dial", mock.Anything).Return(nil) rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) rpc.On("IsSyncing", mock.Anything).Return(false, nil) - setupRPCForAliveLoop(t, rpc) err := node.Start(tests.Context(t)) assert.NoError(t, err) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateAlive + return node.State() == NodeStateAlive }) }) t.Run("on successful verification without isSyncing becomes alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() node := newNode(t, testNodeOpts{ rpc: rpc, @@ -1283,15 +1291,13 @@ func TestUnit_NodeLifecycle_start(t *testing.T) { }) defer func() { assert.NoError(t, node.close()) }() - rpc.On("Dial", mock.Anything).Return(nil) rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) - setupRPCForAliveLoop(t, rpc) err := node.Start(tests.Context(t)) assert.NoError(t, err) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateAlive + return node.State() == NodeStateAlive }) }) } @@ -1442,21 +1448,21 @@ func TestUnit_NodeLifecycle_SyncingLoop(t *testing.T) { opts.config.nodeIsSyncingEnabled = true node := newTestNode(t, opts) opts.rpc.On("Close").Return(nil).Once() - opts.rpc.On("DisconnectAll") + opts.rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything).Maybe() - node.setState(nodeStateDialed) + node.setState(NodeStateDialed) return node } t.Run("returns on closed", func(t *testing.T) { t.Parallel() node := newTestNode(t, testNodeOpts{}) - node.setState(nodeStateClosed) + node.setState(NodeStateClosed) node.wg.Add(1) node.syncingLoop() }) t.Run("on invalid dial becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() node := newDialedNode(t, testNodeOpts{ rpc: rpc, @@ -1465,14 +1471,15 @@ func TestUnit_NodeLifecycle_SyncingLoop(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) node.declareSyncing() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("on failed chainID call becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newDialedNode(t, testNodeOpts{ @@ -1483,18 +1490,19 @@ func TestUnit_NodeLifecycle_SyncingLoop(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("ChainID", mock.Anything).Return(nodeChainID, errors.New("failed to get chain id")) + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) // once for syncing and maybe another one for unreachable rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() node.declareSyncing() tests.AssertLogEventually(t, observedLogs, "Failed to verify chain ID for node") tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("on chainID mismatch transitions to invalidChainID", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.NewIDFromInt(10) rpcChainID := types.NewIDFromInt(11) lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) @@ -1506,16 +1514,17 @@ func TestUnit_NodeLifecycle_SyncingLoop(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(nil).Twice() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) node.declareSyncing() tests.AssertLogCountEventually(t, observedLogs, "Failed to verify RPC node; remote endpoint returned the wrong chain ID", 2) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateInvalidChainID + return node.State() == NodeStateInvalidChainID }) }) t.Run("on failed Syncing check - becomes unreachable", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newDialedNode(t, testNodeOpts{ @@ -1531,15 +1540,16 @@ func TestUnit_NodeLifecycle_SyncingLoop(t *testing.T) { rpc.On("IsSyncing", mock.Anything).Return(false, errors.New("failed to check if syncing")).Once() rpc.On("Dial", mock.Anything).Return(nil).Once() rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) node.declareSyncing() tests.AssertLogEventually(t, observedLogs, "Unexpected error while verifying RPC node synchronization status") tests.AssertEventually(t, func() bool { - return node.State() == nodeStateUnreachable + return node.State() == NodeStateUnreachable }) }) t.Run("on IsSyncing - keeps trying", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() lggr, observedLogs := logger.TestObserved(t, zap.DebugLevel) node := newDialedNode(t, testNodeOpts{ @@ -1552,15 +1562,16 @@ func TestUnit_NodeLifecycle_SyncingLoop(t *testing.T) { rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() rpc.On("IsSyncing", mock.Anything).Return(true, nil) rpc.On("Dial", mock.Anything).Return(nil).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) node.declareSyncing() tests.AssertLogCountEventually(t, observedLogs, "Verification failed: Node is syncing", 2) tests.AssertEventually(t, func() bool { - return node.State() == nodeStateSyncing + return node.State() == NodeStateSyncing }) }) t.Run("on successful verification becomes alive", func(t *testing.T) { t.Parallel() - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) nodeChainID := types.RandomID() node := newDialedNode(t, testNodeOpts{ rpc: rpc, @@ -1569,23 +1580,28 @@ func TestUnit_NodeLifecycle_SyncingLoop(t *testing.T) { defer func() { assert.NoError(t, node.close()) }() rpc.On("Dial", mock.Anything).Return(nil).Once() + rpc.On("UnsubscribeAllExcept", mock.Anything, mock.Anything) rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() rpc.On("IsSyncing", mock.Anything).Return(true, nil).Once() rpc.On("IsSyncing", mock.Anything).Return(false, nil).Once() + sub := mocks.NewSubscription(t) + sub.On("Err").Return(nil) + sub.On("Unsubscribe").Once() + rpc.On("SubscribeToHeads", mock.Anything).Return(make(<-chan Head), sub, nil).Once() setupRPCForAliveLoop(t, rpc) node.declareSyncing() tests.AssertEventually(t, func() bool { - return node.State() == nodeStateAlive + return node.State() == NodeStateAlive }) }) } func TestNode_State(t *testing.T) { t.Run("If not Alive, returns as is", func(t *testing.T) { - for state := nodeState(0); state < nodeStateLen; state++ { - if state == nodeStateAlive { + for state := NodeState(0); state < NodeStateLen; state++ { + if state == NodeStateAlive { continue } @@ -1596,8 +1612,8 @@ func TestNode_State(t *testing.T) { }) t.Run("If repeatable read is not enforced, returns alive", func(t *testing.T) { node := newTestNode(t, testNodeOpts{}) - node.setState(nodeStateAlive) - assert.Equal(t, nodeStateAlive, node.State()) + node.setState(NodeStateAlive) + assert.Equal(t, NodeStateAlive, node.State()) }) testCases := []struct { Name string @@ -1605,7 +1621,7 @@ func TestNode_State(t *testing.T) { IsFinalityTagEnabled bool PoolChainInfo ChainInfo NodeChainInfo ChainInfo - ExpectedState nodeState + ExpectedState NodeState }{ { Name: "If finality lag does not exceeds offset, returns alive (FinalityDepth)", @@ -1616,7 +1632,7 @@ func TestNode_State(t *testing.T) { NodeChainInfo: ChainInfo{ BlockNumber: 5, }, - ExpectedState: nodeStateAlive, + ExpectedState: NodeStateAlive, }, { Name: "If finality lag does not exceeds offset, returns alive (FinalityTag)", @@ -1628,7 +1644,7 @@ func TestNode_State(t *testing.T) { NodeChainInfo: ChainInfo{ FinalizedBlockNumber: 5, }, - ExpectedState: nodeStateAlive, + ExpectedState: NodeStateAlive, }, { Name: "If finality lag exceeds offset, returns nodeStateFinalizedBlockOutOfSync (FinalityDepth)", @@ -1639,7 +1655,7 @@ func TestNode_State(t *testing.T) { NodeChainInfo: ChainInfo{ BlockNumber: 4, }, - ExpectedState: nodeStateFinalizedBlockOutOfSync, + ExpectedState: NodeStateFinalizedBlockOutOfSync, }, { Name: "If finality lag exceeds offset, returns nodeStateFinalizedBlockOutOfSync (FinalityTag)", @@ -1651,12 +1667,12 @@ func TestNode_State(t *testing.T) { NodeChainInfo: ChainInfo{ FinalizedBlockNumber: 4, }, - ExpectedState: nodeStateFinalizedBlockOutOfSync, + ExpectedState: NodeStateFinalizedBlockOutOfSync, }, } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { - rpc := newMockNodeClient[types.ID, Head](t) + rpc := newMockRPCClient[types.ID, Head](t) rpc.On("GetInterceptedChainInfo").Return(tc.NodeChainInfo, tc.PoolChainInfo).Once() node := newTestNode(t, testNodeOpts{ config: testNodeConfig{ @@ -1671,7 +1687,7 @@ func TestNode_State(t *testing.T) { poolInfo := newMockPoolChainInfoProvider(t) poolInfo.On("HighestUserObservations").Return(tc.PoolChainInfo).Once() node.SetPoolChainInfoProvider(poolInfo) - node.setState(nodeStateAlive) + node.setState(NodeStateAlive) assert.Equal(t, tc.ExpectedState, node.State()) }) } diff --git a/common/client/node_selector.go b/common/client/node_selector.go index 45604ebe8d9..d1bb58c6273 100644 --- a/common/client/node_selector.go +++ b/common/client/node_selector.go @@ -16,30 +16,28 @@ const ( //go:generate mockery --quiet --name NodeSelector --structname mockNodeSelector --filename "mock_node_selector_test.go" --inpackage --case=underscore type NodeSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], + RPC any, ] interface { // Select returns a Node, or nil if none can be selected. // Implementation must be thread-safe. - Select() Node[CHAIN_ID, HEAD, RPC] + Select() Node[CHAIN_ID, RPC] // Name returns the strategy name, e.g. "HighestHead" or "RoundRobin" Name() string } func newNodeSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -](selectionMode string, nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { + RPC any, +](selectionMode string, nodes []Node[CHAIN_ID, RPC]) NodeSelector[CHAIN_ID, RPC] { switch selectionMode { case NodeSelectionModeHighestHead: - return NewHighestHeadNodeSelector[CHAIN_ID, HEAD, RPC](nodes) + return NewHighestHeadNodeSelector[CHAIN_ID, RPC](nodes) case NodeSelectionModeRoundRobin: - return NewRoundRobinSelector[CHAIN_ID, HEAD, RPC](nodes) + return NewRoundRobinSelector[CHAIN_ID, RPC](nodes) case NodeSelectionModeTotalDifficulty: - return NewTotalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC](nodes) + return NewTotalDifficultyNodeSelector[CHAIN_ID, RPC](nodes) case NodeSelectionModePriorityLevel: - return NewPriorityLevelNodeSelector[CHAIN_ID, HEAD, RPC](nodes) + return NewPriorityLevelNodeSelector[CHAIN_ID, RPC](nodes) default: panic(fmt.Sprintf("unsupported NodeSelectionMode: %s", selectionMode)) } diff --git a/common/client/node_selector_highest_head.go b/common/client/node_selector_highest_head.go index 25a931fc01b..dbf402c7062 100644 --- a/common/client/node_selector_highest_head.go +++ b/common/client/node_selector_highest_head.go @@ -8,25 +8,23 @@ import ( type highestHeadNodeSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -] []Node[CHAIN_ID, HEAD, RPC] + RPC any, +] []Node[CHAIN_ID, RPC] func NewHighestHeadNodeSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -](nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { - return highestHeadNodeSelector[CHAIN_ID, HEAD, RPC](nodes) + RPC any, +](nodes []Node[CHAIN_ID, RPC]) NodeSelector[CHAIN_ID, RPC] { + return highestHeadNodeSelector[CHAIN_ID, RPC](nodes) } -func (s highestHeadNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { +func (s highestHeadNodeSelector[CHAIN_ID, RPC]) Select() Node[CHAIN_ID, RPC] { var highestHeadNumber int64 = math.MinInt64 - var highestHeadNodes []Node[CHAIN_ID, HEAD, RPC] + var highestHeadNodes []Node[CHAIN_ID, RPC] for _, n := range s { state, currentChainInfo := n.StateAndLatest() currentHeadNumber := currentChainInfo.BlockNumber - if state == nodeStateAlive && currentHeadNumber >= highestHeadNumber { + if state == NodeStateAlive && currentHeadNumber >= highestHeadNumber { if highestHeadNumber < currentHeadNumber { highestHeadNumber = currentHeadNumber highestHeadNodes = nil @@ -37,6 +35,6 @@ func (s highestHeadNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HE return firstOrHighestPriority(highestHeadNodes) } -func (s highestHeadNodeSelector[CHAIN_ID, HEAD, RPC]) Name() string { +func (s highestHeadNodeSelector[CHAIN_ID, RPC]) Name() string { return NodeSelectionModeHighestHead } diff --git a/common/client/node_selector_highest_head_test.go b/common/client/node_selector_highest_head_test.go index e245924589c..6519e7f0bf2 100644 --- a/common/client/node_selector_highest_head_test.go +++ b/common/client/node_selector_highest_head_test.go @@ -9,40 +9,40 @@ import ( ) func TestHighestHeadNodeSelectorName(t *testing.T) { - selector := newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]](NodeSelectionModeHighestHead, nil) + selector := newNodeSelector[types.ID, RPCClient[types.ID, Head]](NodeSelectionModeHighestHead, nil) assert.Equal(t, selector.Name(), NodeSelectionModeHighestHead) } func TestHighestHeadNodeSelector(t *testing.T) { t.Parallel() - type nodeClient NodeClient[types.ID, Head] + type nodeClient RPCClient[types.ID, Head] - var nodes []Node[types.ID, Head, nodeClient] + var nodes []Node[types.ID, nodeClient] for i := 0; i < 3; i++ { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) if i == 0 { // first node is out of sync - node.On("StateAndLatest").Return(nodeStateOutOfSync, ChainInfo{BlockNumber: int64(-1)}) + node.On("StateAndLatest").Return(NodeStateOutOfSync, ChainInfo{BlockNumber: int64(-1)}) } else if i == 1 { // second node is alive, LatestReceivedBlockNumber = 1 - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(1)}) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(1)}) } else { // third node is alive, LatestReceivedBlockNumber = 2 (best node) - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(2)}) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(2)}) } node.On("Order").Maybe().Return(int32(1)) nodes = append(nodes, node) } - selector := newNodeSelector[types.ID, Head, nodeClient](NodeSelectionModeHighestHead, nodes) + selector := newNodeSelector[types.ID, nodeClient](NodeSelectionModeHighestHead, nodes) assert.Same(t, nodes[2], selector.Select()) t.Run("stick to the same node", func(t *testing.T) { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) // fourth node is alive, LatestReceivedBlockNumber = 2 (same as 3rd) - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(2)}) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(2)}) node.On("Order").Return(int32(1)) nodes = append(nodes, node) @@ -51,9 +51,9 @@ func TestHighestHeadNodeSelector(t *testing.T) { }) t.Run("another best node", func(t *testing.T) { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) // fifth node is alive, LatestReceivedBlockNumber = 3 (better than 3rd and 4th) - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(3)}) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(3)}) node.On("Order").Return(int32(1)) nodes = append(nodes, node) @@ -62,13 +62,13 @@ func TestHighestHeadNodeSelector(t *testing.T) { }) t.Run("nodes never update latest block number", func(t *testing.T) { - node1 := newMockNode[types.ID, Head, nodeClient](t) - node1.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(-1)}) + node1 := newMockNode[types.ID, nodeClient](t) + node1.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(-1)}) node1.On("Order").Return(int32(1)) - node2 := newMockNode[types.ID, Head, nodeClient](t) - node2.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(-1)}) + node2 := newMockNode[types.ID, nodeClient](t) + node2.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(-1)}) node2.On("Order").Return(int32(1)) - selector := newNodeSelector(NodeSelectionModeHighestHead, []Node[types.ID, Head, nodeClient]{node1, node2}) + selector := newNodeSelector(NodeSelectionModeHighestHead, []Node[types.ID, nodeClient]{node1, node2}) assert.Same(t, node1, selector.Select()) }) } @@ -76,17 +76,17 @@ func TestHighestHeadNodeSelector(t *testing.T) { func TestHighestHeadNodeSelector_None(t *testing.T) { t.Parallel() - type nodeClient NodeClient[types.ID, Head] - var nodes []Node[types.ID, Head, nodeClient] + type nodeClient RPCClient[types.ID, Head] + var nodes []Node[types.ID, nodeClient] for i := 0; i < 3; i++ { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) if i == 0 { // first node is out of sync - node.On("StateAndLatest").Return(nodeStateOutOfSync, ChainInfo{BlockNumber: int64(-1)}) + node.On("StateAndLatest").Return(NodeStateOutOfSync, ChainInfo{BlockNumber: int64(-1)}) } else { // others are unreachable - node.On("StateAndLatest").Return(nodeStateUnreachable, ChainInfo{BlockNumber: int64(-1)}) + node.On("StateAndLatest").Return(NodeStateUnreachable, ChainInfo{BlockNumber: int64(-1)}) } nodes = append(nodes, node) } @@ -98,13 +98,13 @@ func TestHighestHeadNodeSelector_None(t *testing.T) { func TestHighestHeadNodeSelectorWithOrder(t *testing.T) { t.Parallel() - type nodeClient NodeClient[types.ID, Head] - var nodes []Node[types.ID, Head, nodeClient] + type nodeClient RPCClient[types.ID, Head] + var nodes []Node[types.ID, nodeClient] t.Run("same head and order", func(t *testing.T) { for i := 0; i < 3; i++ { - node := newMockNode[types.ID, Head, nodeClient](t) - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(1)}) + node := newMockNode[types.ID, nodeClient](t) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(1)}) node.On("Order").Return(int32(2)) nodes = append(nodes, node) } @@ -114,61 +114,61 @@ func TestHighestHeadNodeSelectorWithOrder(t *testing.T) { }) t.Run("same head but different order", func(t *testing.T) { - node1 := newMockNode[types.ID, Head, nodeClient](t) - node1.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(3)}) + node1 := newMockNode[types.ID, nodeClient](t) + node1.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(3)}) node1.On("Order").Return(int32(3)) - node2 := newMockNode[types.ID, Head, nodeClient](t) - node2.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(3)}) + node2 := newMockNode[types.ID, nodeClient](t) + node2.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(3)}) node2.On("Order").Return(int32(1)) - node3 := newMockNode[types.ID, Head, nodeClient](t) - node3.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(3)}) + node3 := newMockNode[types.ID, nodeClient](t) + node3.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(3)}) node3.On("Order").Return(int32(2)) - nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3} + nodes := []Node[types.ID, nodeClient]{node1, node2, node3} selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) //Should select the second node as it has the highest priority assert.Same(t, nodes[1], selector.Select()) }) t.Run("different head but same order", func(t *testing.T) { - node1 := newMockNode[types.ID, Head, nodeClient](t) - node1.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(1)}) + node1 := newMockNode[types.ID, nodeClient](t) + node1.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(1)}) node1.On("Order").Maybe().Return(int32(3)) - node2 := newMockNode[types.ID, Head, nodeClient](t) - node2.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(2)}) + node2 := newMockNode[types.ID, nodeClient](t) + node2.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(2)}) node2.On("Order").Maybe().Return(int32(3)) - node3 := newMockNode[types.ID, Head, nodeClient](t) - node3.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(3)}) + node3 := newMockNode[types.ID, nodeClient](t) + node3.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(3)}) node3.On("Order").Return(int32(3)) - nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3} + nodes := []Node[types.ID, nodeClient]{node1, node2, node3} selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) //Should select the third node as it has the highest head assert.Same(t, nodes[2], selector.Select()) }) t.Run("different head and different order", func(t *testing.T) { - node1 := newMockNode[types.ID, Head, nodeClient](t) - node1.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(10)}) + node1 := newMockNode[types.ID, nodeClient](t) + node1.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(10)}) node1.On("Order").Maybe().Return(int32(3)) - node2 := newMockNode[types.ID, Head, nodeClient](t) - node2.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(11)}) + node2 := newMockNode[types.ID, nodeClient](t) + node2.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(11)}) node2.On("Order").Maybe().Return(int32(4)) - node3 := newMockNode[types.ID, Head, nodeClient](t) - node3.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(11)}) + node3 := newMockNode[types.ID, nodeClient](t) + node3.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(11)}) node3.On("Order").Maybe().Return(int32(3)) - node4 := newMockNode[types.ID, Head, nodeClient](t) - node4.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: int64(10)}) + node4 := newMockNode[types.ID, nodeClient](t) + node4.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: int64(10)}) node4.On("Order").Maybe().Return(int32(1)) - nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3, node4} + nodes := []Node[types.ID, nodeClient]{node1, node2, node3, node4} selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) //Should select the third node as it has the highest head and will win the priority tie-breaker assert.Same(t, nodes[2], selector.Select()) diff --git a/common/client/node_selector_priority_level.go b/common/client/node_selector_priority_level.go index 45cc62de077..d9a45c2d5de 100644 --- a/common/client/node_selector_priority_level.go +++ b/common/client/node_selector_priority_level.go @@ -10,34 +10,31 @@ import ( type priorityLevelNodeSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], + RPC any, ] struct { - nodes []Node[CHAIN_ID, HEAD, RPC] + nodes []Node[CHAIN_ID, RPC] roundRobinCount []atomic.Uint32 } type nodeWithPriority[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], + RPC any, ] struct { - node Node[CHAIN_ID, HEAD, RPC] + node Node[CHAIN_ID, RPC] priority int32 } func NewPriorityLevelNodeSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -](nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { - return &priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]{ + RPC any, +](nodes []Node[CHAIN_ID, RPC]) NodeSelector[CHAIN_ID, RPC] { + return &priorityLevelNodeSelector[CHAIN_ID, RPC]{ nodes: nodes, roundRobinCount: make([]atomic.Uint32, nrOfPriorityTiers(nodes)), } } -func (s priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { +func (s priorityLevelNodeSelector[CHAIN_ID, RPC]) Select() Node[CHAIN_ID, RPC] { nodes := s.getHighestPriorityAliveTier() if len(nodes) == 0 { @@ -52,17 +49,17 @@ func (s priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, return nodes[idx].node } -func (s priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]) Name() string { +func (s priorityLevelNodeSelector[CHAIN_ID, RPC]) Name() string { return NodeSelectionModePriorityLevel } -// getHighestPriorityAliveTier filters nodes that are not in state nodeStateAlive and +// getHighestPriorityAliveTier filters nodes that are not in state NodeStateAlive and // returns only the highest tier of alive nodes -func (s priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]) getHighestPriorityAliveTier() []nodeWithPriority[CHAIN_ID, HEAD, RPC] { - var nodes []nodeWithPriority[CHAIN_ID, HEAD, RPC] +func (s priorityLevelNodeSelector[CHAIN_ID, RPC]) getHighestPriorityAliveTier() []nodeWithPriority[CHAIN_ID, RPC] { + var nodes []nodeWithPriority[CHAIN_ID, RPC] for _, n := range s.nodes { - if n.State() == nodeStateAlive { - nodes = append(nodes, nodeWithPriority[CHAIN_ID, HEAD, RPC]{n, n.Order()}) + if n.State() == NodeStateAlive { + nodes = append(nodes, nodeWithPriority[CHAIN_ID, RPC]{n, n.Order()}) } } @@ -76,14 +73,13 @@ func (s priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]) getHighestPriorityAliveT // removeLowerTiers take a slice of nodeWithPriority[CHAIN_ID, BLOCK_HASH, HEAD, RPC] and keeps only the highest tier func removeLowerTiers[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -](nodes []nodeWithPriority[CHAIN_ID, HEAD, RPC]) []nodeWithPriority[CHAIN_ID, HEAD, RPC] { + RPC any, +](nodes []nodeWithPriority[CHAIN_ID, RPC]) []nodeWithPriority[CHAIN_ID, RPC] { sort.SliceStable(nodes, func(i, j int) bool { return nodes[i].priority > nodes[j].priority }) - var nodes2 []nodeWithPriority[CHAIN_ID, HEAD, RPC] + var nodes2 []nodeWithPriority[CHAIN_ID, RPC] currentPriority := nodes[len(nodes)-1].priority for _, n := range nodes { @@ -98,9 +94,8 @@ func removeLowerTiers[ // nrOfPriorityTiers calculates the total number of priority tiers func nrOfPriorityTiers[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -](nodes []Node[CHAIN_ID, HEAD, RPC]) int32 { + RPC any, +](nodes []Node[CHAIN_ID, RPC]) int32 { highestPriority := int32(0) for _, n := range nodes { priority := n.Order() @@ -114,11 +109,10 @@ func nrOfPriorityTiers[ // firstOrHighestPriority takes a list of nodes and returns the first one with the highest priority func firstOrHighestPriority[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -](nodes []Node[CHAIN_ID, HEAD, RPC]) Node[CHAIN_ID, HEAD, RPC] { + RPC any, +](nodes []Node[CHAIN_ID, RPC]) Node[CHAIN_ID, RPC] { hp := int32(math.MaxInt32) - var node Node[CHAIN_ID, HEAD, RPC] + var node Node[CHAIN_ID, RPC] for _, n := range nodes { if n.Order() < hp { hp = n.Order() diff --git a/common/client/node_selector_priority_level_test.go b/common/client/node_selector_priority_level_test.go index 15a7a7ac60b..b85a6209a3b 100644 --- a/common/client/node_selector_priority_level_test.go +++ b/common/client/node_selector_priority_level_test.go @@ -9,17 +9,17 @@ import ( ) func TestPriorityLevelNodeSelectorName(t *testing.T) { - selector := newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]](NodeSelectionModePriorityLevel, nil) + selector := newNodeSelector[types.ID, RPCClient[types.ID, Head]](NodeSelectionModePriorityLevel, nil) assert.Equal(t, selector.Name(), NodeSelectionModePriorityLevel) } func TestPriorityLevelNodeSelector(t *testing.T) { t.Parallel() - type nodeClient NodeClient[types.ID, Head] + type nodeClient RPCClient[types.ID, Head] type testNode struct { order int32 - state nodeState + state NodeState } type testCase struct { name string @@ -31,34 +31,34 @@ func TestPriorityLevelNodeSelector(t *testing.T) { { name: "TwoNodesSameOrder: Highest Allowed Order", nodes: []testNode{ - {order: 1, state: nodeStateAlive}, - {order: 1, state: nodeStateAlive}, + {order: 1, state: NodeStateAlive}, + {order: 1, state: NodeStateAlive}, }, expect: []int{0, 1, 0, 1, 0, 1}, }, { name: "TwoNodesSameOrder: Lowest Allowed Order", nodes: []testNode{ - {order: 100, state: nodeStateAlive}, - {order: 100, state: nodeStateAlive}, + {order: 100, state: NodeStateAlive}, + {order: 100, state: NodeStateAlive}, }, expect: []int{0, 1, 0, 1, 0, 1}, }, { name: "NoneAvailable", nodes: []testNode{ - {order: 1, state: nodeStateOutOfSync}, - {order: 1, state: nodeStateUnreachable}, - {order: 1, state: nodeStateUnreachable}, + {order: 1, state: NodeStateOutOfSync}, + {order: 1, state: NodeStateUnreachable}, + {order: 1, state: NodeStateUnreachable}, }, expect: []int{}, // no nodes should be selected }, { name: "DifferentOrder", nodes: []testNode{ - {order: 1, state: nodeStateAlive}, - {order: 2, state: nodeStateAlive}, - {order: 3, state: nodeStateAlive}, + {order: 1, state: NodeStateAlive}, + {order: 2, state: NodeStateAlive}, + {order: 3, state: NodeStateAlive}, }, expect: []int{0, 0}, // only the highest order node should be selected }, @@ -66,9 +66,9 @@ func TestPriorityLevelNodeSelector(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var nodes []Node[types.ID, Head, nodeClient] + var nodes []Node[types.ID, nodeClient] for _, tn := range tc.nodes { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) node.On("State").Return(tn.state) node.On("Order").Return(tn.order) nodes = append(nodes, node) diff --git a/common/client/node_selector_round_robin.go b/common/client/node_selector_round_robin.go index 5cdad7f52ee..50b648594e6 100644 --- a/common/client/node_selector_round_robin.go +++ b/common/client/node_selector_round_robin.go @@ -8,27 +8,25 @@ import ( type roundRobinSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], + RPC any, ] struct { - nodes []Node[CHAIN_ID, HEAD, RPC] + nodes []Node[CHAIN_ID, RPC] roundRobinCount atomic.Uint32 } func NewRoundRobinSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -](nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { - return &roundRobinSelector[CHAIN_ID, HEAD, RPC]{ + RPC any, +](nodes []Node[CHAIN_ID, RPC]) NodeSelector[CHAIN_ID, RPC] { + return &roundRobinSelector[CHAIN_ID, RPC]{ nodes: nodes, } } -func (s *roundRobinSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { - var liveNodes []Node[CHAIN_ID, HEAD, RPC] +func (s *roundRobinSelector[CHAIN_ID, RPC]) Select() Node[CHAIN_ID, RPC] { + var liveNodes []Node[CHAIN_ID, RPC] for _, n := range s.nodes { - if n.State() == nodeStateAlive { + if n.State() == NodeStateAlive { liveNodes = append(liveNodes, n) } } @@ -45,6 +43,6 @@ func (s *roundRobinSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, return liveNodes[idx] } -func (s *roundRobinSelector[CHAIN_ID, HEAD, RPC]) Name() string { +func (s *roundRobinSelector[CHAIN_ID, RPC]) Name() string { return NodeSelectionModeRoundRobin } diff --git a/common/client/node_selector_round_robin_test.go b/common/client/node_selector_round_robin_test.go index e5078d858f1..6b59e299248 100644 --- a/common/client/node_selector_round_robin_test.go +++ b/common/client/node_selector_round_robin_test.go @@ -9,24 +9,24 @@ import ( ) func TestRoundRobinNodeSelectorName(t *testing.T) { - selector := newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]](NodeSelectionModeRoundRobin, nil) + selector := newNodeSelector[types.ID, RPCClient[types.ID, Head]](NodeSelectionModeRoundRobin, nil) assert.Equal(t, selector.Name(), NodeSelectionModeRoundRobin) } func TestRoundRobinNodeSelector(t *testing.T) { t.Parallel() - type nodeClient NodeClient[types.ID, Head] - var nodes []Node[types.ID, Head, nodeClient] + type nodeClient RPCClient[types.ID, Head] + var nodes []Node[types.ID, nodeClient] for i := 0; i < 3; i++ { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) if i == 0 { // first node is out of sync - node.On("State").Return(nodeStateOutOfSync) + node.On("State").Return(NodeStateOutOfSync) } else { // second & third nodes are alive - node.On("State").Return(nodeStateAlive) + node.On("State").Return(NodeStateAlive) } nodes = append(nodes, node) } @@ -41,17 +41,17 @@ func TestRoundRobinNodeSelector(t *testing.T) { func TestRoundRobinNodeSelector_None(t *testing.T) { t.Parallel() - type nodeClient NodeClient[types.ID, Head] - var nodes []Node[types.ID, Head, nodeClient] + type nodeClient RPCClient[types.ID, Head] + var nodes []Node[types.ID, nodeClient] for i := 0; i < 3; i++ { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) if i == 0 { // first node is out of sync - node.On("State").Return(nodeStateOutOfSync) + node.On("State").Return(NodeStateOutOfSync) } else { // others are unreachable - node.On("State").Return(nodeStateUnreachable) + node.On("State").Return(NodeStateUnreachable) } nodes = append(nodes, node) } diff --git a/common/client/node_selector_test.go b/common/client/node_selector_test.go index 226cb67168d..f652bfc50ad 100644 --- a/common/client/node_selector_test.go +++ b/common/client/node_selector_test.go @@ -12,7 +12,7 @@ func TestNodeSelector(t *testing.T) { // rest of the tests are located in specific node selectors tests t.Run("panics on unknown type", func(t *testing.T) { assert.Panics(t, func() { - _ = newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]]("unknown", nil) + _ = newNodeSelector[types.ID, RPCClient[types.ID, Head]]("unknown", nil) }) }) } diff --git a/common/client/node_selector_total_difficulty.go b/common/client/node_selector_total_difficulty.go index 6b45e75528b..cc69a00e2ff 100644 --- a/common/client/node_selector_total_difficulty.go +++ b/common/client/node_selector_total_difficulty.go @@ -8,27 +8,25 @@ import ( type totalDifficultyNodeSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -] []Node[CHAIN_ID, HEAD, RPC] + RPC any, +] []Node[CHAIN_ID, RPC] func NewTotalDifficultyNodeSelector[ CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -](nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { - return totalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC](nodes) + RPC any, +](nodes []Node[CHAIN_ID, RPC]) NodeSelector[CHAIN_ID, RPC] { + return totalDifficultyNodeSelector[CHAIN_ID, RPC](nodes) } -func (s totalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { +func (s totalDifficultyNodeSelector[CHAIN_ID, RPC]) Select() Node[CHAIN_ID, RPC] { // NodeNoNewHeadsThreshold may not be enabled, in this case all nodes have td == nil var highestTD *big.Int - var nodes []Node[CHAIN_ID, HEAD, RPC] - var aliveNodes []Node[CHAIN_ID, HEAD, RPC] + var nodes []Node[CHAIN_ID, RPC] + var aliveNodes []Node[CHAIN_ID, RPC] for _, n := range s { state, currentChainInfo := n.StateAndLatest() - if state != nodeStateAlive { + if state != NodeStateAlive { continue } @@ -50,6 +48,6 @@ func (s totalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID return firstOrHighestPriority(nodes) } -func (s totalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC]) Name() string { +func (s totalDifficultyNodeSelector[CHAIN_ID, RPC]) Name() string { return NodeSelectionModeTotalDifficulty } diff --git a/common/client/node_selector_total_difficulty_test.go b/common/client/node_selector_total_difficulty_test.go index 0bc214918d7..7fce0f96042 100644 --- a/common/client/node_selector_total_difficulty_test.go +++ b/common/client/node_selector_total_difficulty_test.go @@ -10,27 +10,27 @@ import ( ) func TestTotalDifficultyNodeSelectorName(t *testing.T) { - selector := newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]](NodeSelectionModeTotalDifficulty, nil) + selector := newNodeSelector[types.ID, RPCClient[types.ID, Head]](NodeSelectionModeTotalDifficulty, nil) assert.Equal(t, selector.Name(), NodeSelectionModeTotalDifficulty) } func TestTotalDifficultyNodeSelector(t *testing.T) { t.Parallel() - type nodeClient NodeClient[types.ID, Head] - var nodes []Node[types.ID, Head, nodeClient] + type nodeClient RPCClient[types.ID, Head] + var nodes []Node[types.ID, nodeClient] for i := 0; i < 3; i++ { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) if i == 0 { // first node is out of sync - node.On("StateAndLatest").Return(nodeStateOutOfSync, ChainInfo{BlockNumber: -1}) + node.On("StateAndLatest").Return(NodeStateOutOfSync, ChainInfo{BlockNumber: -1}) } else if i == 1 { // second node is alive - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(7)}) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(7)}) } else { // third node is alive and best - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 2, TotalDifficulty: big.NewInt(8)}) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 2, TotalDifficulty: big.NewInt(8)}) } node.On("Order").Maybe().Return(int32(1)) nodes = append(nodes, node) @@ -40,9 +40,9 @@ func TestTotalDifficultyNodeSelector(t *testing.T) { assert.Same(t, nodes[2], selector.Select()) t.Run("stick to the same node", func(t *testing.T) { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) // fourth node is alive (same as 3rd) - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 2, TotalDifficulty: big.NewInt(8)}) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 2, TotalDifficulty: big.NewInt(8)}) node.On("Order").Maybe().Return(int32(1)) nodes = append(nodes, node) @@ -51,9 +51,9 @@ func TestTotalDifficultyNodeSelector(t *testing.T) { }) t.Run("another best node", func(t *testing.T) { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) // fifth node is alive (better than 3rd and 4th) - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 3, TotalDifficulty: big.NewInt(11)}) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 3, TotalDifficulty: big.NewInt(11)}) node.On("Order").Maybe().Return(int32(1)) nodes = append(nodes, node) @@ -62,13 +62,13 @@ func TestTotalDifficultyNodeSelector(t *testing.T) { }) t.Run("nodes never update latest block number", func(t *testing.T) { - node1 := newMockNode[types.ID, Head, nodeClient](t) - node1.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: -1, TotalDifficulty: nil}) + node1 := newMockNode[types.ID, nodeClient](t) + node1.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: -1, TotalDifficulty: nil}) node1.On("Order").Maybe().Return(int32(1)) - node2 := newMockNode[types.ID, Head, nodeClient](t) - node2.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: -1, TotalDifficulty: nil}) + node2 := newMockNode[types.ID, nodeClient](t) + node2.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: -1, TotalDifficulty: nil}) node2.On("Order").Maybe().Return(int32(1)) - nodes := []Node[types.ID, Head, nodeClient]{node1, node2} + nodes := []Node[types.ID, nodeClient]{node1, node2} selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) assert.Same(t, node1, selector.Select()) @@ -78,17 +78,17 @@ func TestTotalDifficultyNodeSelector(t *testing.T) { func TestTotalDifficultyNodeSelector_None(t *testing.T) { t.Parallel() - type nodeClient NodeClient[types.ID, Head] - var nodes []Node[types.ID, Head, nodeClient] + type nodeClient RPCClient[types.ID, Head] + var nodes []Node[types.ID, nodeClient] for i := 0; i < 3; i++ { - node := newMockNode[types.ID, Head, nodeClient](t) + node := newMockNode[types.ID, nodeClient](t) if i == 0 { // first node is out of sync - node.On("StateAndLatest").Return(nodeStateOutOfSync, ChainInfo{BlockNumber: -1, TotalDifficulty: nil}) + node.On("StateAndLatest").Return(NodeStateOutOfSync, ChainInfo{BlockNumber: -1, TotalDifficulty: nil}) } else { // others are unreachable - node.On("StateAndLatest").Return(nodeStateUnreachable, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(7)}) + node.On("StateAndLatest").Return(NodeStateUnreachable, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(7)}) } nodes = append(nodes, node) } @@ -100,13 +100,13 @@ func TestTotalDifficultyNodeSelector_None(t *testing.T) { func TestTotalDifficultyNodeSelectorWithOrder(t *testing.T) { t.Parallel() - type nodeClient NodeClient[types.ID, Head] - var nodes []Node[types.ID, Head, nodeClient] + type nodeClient RPCClient[types.ID, Head] + var nodes []Node[types.ID, nodeClient] t.Run("same td and order", func(t *testing.T) { for i := 0; i < 3; i++ { - node := newMockNode[types.ID, Head, nodeClient](t) - node.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(10)}) + node := newMockNode[types.ID, nodeClient](t) + node.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(10)}) node.On("Order").Return(int32(2)) nodes = append(nodes, node) } @@ -116,61 +116,61 @@ func TestTotalDifficultyNodeSelectorWithOrder(t *testing.T) { }) t.Run("same td but different order", func(t *testing.T) { - node1 := newMockNode[types.ID, Head, nodeClient](t) - node1.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 3, TotalDifficulty: big.NewInt(10)}) + node1 := newMockNode[types.ID, nodeClient](t) + node1.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 3, TotalDifficulty: big.NewInt(10)}) node1.On("Order").Return(int32(3)) - node2 := newMockNode[types.ID, Head, nodeClient](t) - node2.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 3, TotalDifficulty: big.NewInt(10)}) + node2 := newMockNode[types.ID, nodeClient](t) + node2.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 3, TotalDifficulty: big.NewInt(10)}) node2.On("Order").Return(int32(1)) - node3 := newMockNode[types.ID, Head, nodeClient](t) - node3.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 3, TotalDifficulty: big.NewInt(10)}) + node3 := newMockNode[types.ID, nodeClient](t) + node3.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 3, TotalDifficulty: big.NewInt(10)}) node3.On("Order").Return(int32(2)) - nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3} + nodes := []Node[types.ID, nodeClient]{node1, node2, node3} selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) //Should select the second node as it has the highest priority assert.Same(t, nodes[1], selector.Select()) }) t.Run("different td but same order", func(t *testing.T) { - node1 := newMockNode[types.ID, Head, nodeClient](t) - node1.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(10)}) + node1 := newMockNode[types.ID, nodeClient](t) + node1.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(10)}) node1.On("Order").Maybe().Return(int32(3)) - node2 := newMockNode[types.ID, Head, nodeClient](t) - node2.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(11)}) + node2 := newMockNode[types.ID, nodeClient](t) + node2.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(11)}) node2.On("Order").Maybe().Return(int32(3)) - node3 := newMockNode[types.ID, Head, nodeClient](t) - node3.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(12)}) + node3 := newMockNode[types.ID, nodeClient](t) + node3.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(12)}) node3.On("Order").Return(int32(3)) - nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3} + nodes := []Node[types.ID, nodeClient]{node1, node2, node3} selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) //Should select the third node as it has the highest td assert.Same(t, nodes[2], selector.Select()) }) t.Run("different head and different order", func(t *testing.T) { - node1 := newMockNode[types.ID, Head, nodeClient](t) - node1.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(100)}) + node1 := newMockNode[types.ID, nodeClient](t) + node1.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(100)}) node1.On("Order").Maybe().Return(int32(4)) - node2 := newMockNode[types.ID, Head, nodeClient](t) - node2.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(110)}) + node2 := newMockNode[types.ID, nodeClient](t) + node2.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(110)}) node2.On("Order").Maybe().Return(int32(5)) - node3 := newMockNode[types.ID, Head, nodeClient](t) - node3.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(110)}) + node3 := newMockNode[types.ID, nodeClient](t) + node3.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(110)}) node3.On("Order").Maybe().Return(int32(1)) - node4 := newMockNode[types.ID, Head, nodeClient](t) - node4.On("StateAndLatest").Return(nodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(105)}) + node4 := newMockNode[types.ID, nodeClient](t) + node4.On("StateAndLatest").Return(NodeStateAlive, ChainInfo{BlockNumber: 1, TotalDifficulty: big.NewInt(105)}) node4.On("Order").Maybe().Return(int32(2)) - nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3, node4} + nodes := []Node[types.ID, nodeClient]{node1, node2, node3, node4} selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) //Should select the third node as it has the highest td and will win the priority tie-breaker assert.Same(t, nodes[2], selector.Select()) diff --git a/common/client/node_test.go b/common/client/node_test.go index 3b971e84902..87f3b589e12 100644 --- a/common/client/node_test.go +++ b/common/client/node_test.go @@ -55,7 +55,7 @@ func (n testNodeConfig) DeathDeclarationDelay() time.Duration { } type testNode struct { - *node[types.ID, Head, NodeClient[types.ID, Head]] + *node[types.ID, Head, RPCClient[types.ID, Head]] } type testNodeOpts struct { @@ -68,7 +68,7 @@ type testNodeOpts struct { id int32 chainID types.ID nodeOrder int32 - rpc *mockNodeClient[types.ID, Head] + rpc *mockRPCClient[types.ID, Head] chainFamily string } @@ -93,10 +93,10 @@ func newTestNode(t *testing.T, opts testNodeOpts) testNode { opts.id = 42 } - nodeI := NewNode[types.ID, Head, NodeClient[types.ID, Head]](opts.config, opts.chainConfig, opts.lggr, + nodeI := NewNode[types.ID, Head, RPCClient[types.ID, Head]](opts.config, opts.chainConfig, opts.lggr, opts.wsuri, opts.httpuri, opts.name, opts.id, opts.chainID, opts.nodeOrder, opts.rpc, opts.chainFamily) return testNode{ - nodeI.(*node[types.ID, Head, NodeClient[types.ID, Head]]), + nodeI.(*node[types.ID, Head, RPCClient[types.ID, Head]]), } } diff --git a/common/client/send_only_node.go b/common/client/send_only_node.go index b63e93b703d..ba70ec32461 100644 --- a/common/client/send_only_node.go +++ b/common/client/send_only_node.go @@ -18,7 +18,7 @@ type sendOnlyClient[ ] interface { Close() ChainID(context.Context) (CHAIN_ID, error) - DialHTTP() error + Dial(ctx context.Context) error } // SendOnlyNode represents one node used as a sendonly @@ -26,7 +26,7 @@ type sendOnlyClient[ //go:generate mockery --quiet --name SendOnlyNode --structname mockSendOnlyNode --filename "mock_send_only_node_test.go" --inpackage --case=underscore type SendOnlyNode[ CHAIN_ID types.ID, - RPC sendOnlyClient[CHAIN_ID], + RPC any, ] interface { // Start may attempt to connect to the node, but should only return error for misconfiguration - never for temporary errors. Start(context.Context) error @@ -36,8 +36,8 @@ type SendOnlyNode[ RPC() RPC String() string - // State returns nodeState - State() nodeState + // State returns NodeState + State() NodeState // Name is a unique identifier for this node. Name() string } @@ -51,7 +51,7 @@ type sendOnlyNode[ services.StateMachine stateMu sync.RWMutex // protects state* fields - state nodeState + state NodeState rpc RPC uri url.URL @@ -96,18 +96,18 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) Start(ctx context.Context) error { // Start setups up and verifies the sendonly node // Should only be called once in a node's lifecycle func (s *sendOnlyNode[CHAIN_ID, RPC]) start(startCtx context.Context) { - if s.State() != nodeStateUndialed { + if s.State() != NodeStateUndialed { panic(fmt.Sprintf("cannot dial node with state %v", s.state)) } - err := s.rpc.DialHTTP() + err := s.rpc.Dial(startCtx) if err != nil { promPoolRPCNodeTransitionsToUnusable.WithLabelValues(s.chainID.String(), s.name).Inc() s.log.Errorw("Dial failed: SendOnly Node is unusable", "err", err) - s.setState(nodeStateUnusable) + s.setState(NodeStateUnusable) return } - s.setState(nodeStateDialed) + s.setState(NodeStateDialed) if s.chainID.String() == "0" { // Skip verification if chainID is zero @@ -119,7 +119,7 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) start(startCtx context.Context) { if err != nil { promPoolRPCNodeTransitionsToUnreachable.WithLabelValues(s.chainID.String(), s.name).Inc() s.log.Errorw(fmt.Sprintf("Verify failed: %v", err), "err", err) - s.setState(nodeStateUnreachable) + s.setState(NodeStateUnreachable) } else { promPoolRPCNodeTransitionsToInvalidChainID.WithLabelValues(s.chainID.String(), s.name).Inc() s.log.Errorf( @@ -128,7 +128,7 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) start(startCtx context.Context) { s.chainID.String(), s.name, ) - s.setState(nodeStateInvalidChainID) + s.setState(NodeStateInvalidChainID) } // Since it has failed, spin up the verifyLoop that will keep // retrying until success @@ -139,8 +139,8 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) start(startCtx context.Context) { } promPoolRPCNodeTransitionsToAlive.WithLabelValues(s.chainID.String(), s.name).Inc() - s.setState(nodeStateAlive) - s.log.Infow("Sendonly RPC Node is online", "nodeState", s.state) + s.setState(NodeStateAlive) + s.log.Infow("Sendonly RPC Node is online", "NodeState", s.state) } func (s *sendOnlyNode[CHAIN_ID, RPC]) Close() error { @@ -148,7 +148,7 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) Close() error { s.rpc.Close() close(s.chStop) s.wg.Wait() - s.setState(nodeStateClosed) + s.setState(NodeStateClosed) return nil }) } @@ -165,7 +165,7 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) String() string { return fmt.Sprintf("(%s)%s:%s", Secondary.String(), s.name, s.uri.Redacted()) } -func (s *sendOnlyNode[CHAIN_ID, RPC]) setState(state nodeState) (changed bool) { +func (s *sendOnlyNode[CHAIN_ID, RPC]) setState(state NodeState) (changed bool) { s.stateMu.Lock() defer s.stateMu.Unlock() if s.state == state { @@ -175,7 +175,7 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) setState(state nodeState) (changed bool) { return true } -func (s *sendOnlyNode[CHAIN_ID, RPC]) State() nodeState { +func (s *sendOnlyNode[CHAIN_ID, RPC]) State() NodeState { s.stateMu.RLock() defer s.stateMu.RUnlock() return s.state diff --git a/common/client/send_only_node_lifecycle.go b/common/client/send_only_node_lifecycle.go index c66d267ed42..a6ac112488b 100644 --- a/common/client/send_only_node_lifecycle.go +++ b/common/client/send_only_node_lifecycle.go @@ -26,7 +26,7 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) verifyLoop() { chainID, err := s.rpc.ChainID(ctx) if err != nil { ok := s.IfStarted(func() { - if changed := s.setState(nodeStateUnreachable); changed { + if changed := s.setState(NodeStateUnreachable); changed { promPoolRPCNodeTransitionsToUnreachable.WithLabelValues(s.chainID.String(), s.name).Inc() } }) @@ -37,7 +37,7 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) verifyLoop() { continue } else if chainID.String() != s.chainID.String() { ok := s.IfStarted(func() { - if changed := s.setState(nodeStateInvalidChainID); changed { + if changed := s.setState(NodeStateInvalidChainID); changed { promPoolRPCNodeTransitionsToInvalidChainID.WithLabelValues(s.chainID.String(), s.name).Inc() } }) @@ -54,14 +54,14 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) verifyLoop() { continue } ok := s.IfStarted(func() { - if changed := s.setState(nodeStateAlive); changed { + if changed := s.setState(NodeStateAlive); changed { promPoolRPCNodeTransitionsToAlive.WithLabelValues(s.chainID.String(), s.name).Inc() } }) if !ok { return } - s.log.Infow("Sendonly RPC Node is online", "nodeState", s.state) + s.log.Infow("Sendonly RPC Node is online", "NodeState", s.state) return } } diff --git a/common/client/send_only_node_test.go b/common/client/send_only_node_test.go index 79f4bfd60e3..352fb5b92ea 100644 --- a/common/client/send_only_node_test.go +++ b/common/client/send_only_node_test.go @@ -46,14 +46,14 @@ func TestStartSendOnlyNode(t *testing.T) { client := newMockSendOnlyClient[types.ID](t) client.On("Close").Once() expectedError := errors.New("some http error") - client.On("DialHTTP").Return(expectedError).Once() + client.On("Dial", mock.Anything).Return(expectedError).Once() s := NewSendOnlyNode(lggr, url.URL{}, t.Name(), types.RandomID(), client) defer func() { assert.NoError(t, s.Close()) }() err := s.Start(tests.Context(t)) require.NoError(t, err) - assert.Equal(t, nodeStateUnusable, s.State()) + assert.Equal(t, NodeStateUnusable, s.State()) tests.RequireLogMessage(t, observedLogs, "Dial failed: SendOnly Node is unusable") }) t.Run("Default ChainID(0) produces warn and skips checks", func(t *testing.T) { @@ -61,14 +61,14 @@ func TestStartSendOnlyNode(t *testing.T) { lggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) client := newMockSendOnlyClient[types.ID](t) client.On("Close").Once() - client.On("DialHTTP").Return(nil).Once() + client.On("Dial", mock.Anything).Return(nil).Once() s := NewSendOnlyNode(lggr, url.URL{}, t.Name(), types.NewIDFromInt(0), client) defer func() { assert.NoError(t, s.Close()) }() err := s.Start(tests.Context(t)) require.NoError(t, err) - assert.Equal(t, nodeStateAlive, s.State()) + assert.Equal(t, NodeStateAlive, s.State()) tests.RequireLogMessage(t, observedLogs, "sendonly rpc ChainID verification skipped") }) t.Run("Can recover from chainID verification failure", func(t *testing.T) { @@ -76,7 +76,7 @@ func TestStartSendOnlyNode(t *testing.T) { lggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) client := newMockSendOnlyClient[types.ID](t) client.On("Close").Once() - client.On("DialHTTP").Return(nil) + client.On("Dial", mock.Anything).Return(nil) expectedError := errors.New("failed to get chain ID") chainID := types.RandomID() const failuresCount = 2 @@ -89,10 +89,10 @@ func TestStartSendOnlyNode(t *testing.T) { err := s.Start(tests.Context(t)) require.NoError(t, err) - assert.Equal(t, nodeStateUnreachable, s.State()) + assert.Equal(t, NodeStateUnreachable, s.State()) tests.AssertLogCountEventually(t, observedLogs, fmt.Sprintf("Verify failed: %v", expectedError), failuresCount) tests.AssertEventually(t, func() bool { - return s.State() == nodeStateAlive + return s.State() == NodeStateAlive }) }) t.Run("Can recover from chainID mismatch", func(t *testing.T) { @@ -100,7 +100,7 @@ func TestStartSendOnlyNode(t *testing.T) { lggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) client := newMockSendOnlyClient[types.ID](t) client.On("Close").Once() - client.On("DialHTTP").Return(nil).Once() + client.On("Dial", mock.Anything).Return(nil).Once() configuredChainID := types.NewIDFromInt(11) rpcChainID := types.NewIDFromInt(20) const failuresCount = 2 @@ -112,10 +112,10 @@ func TestStartSendOnlyNode(t *testing.T) { err := s.Start(tests.Context(t)) require.NoError(t, err) - assert.Equal(t, nodeStateInvalidChainID, s.State()) + assert.Equal(t, NodeStateInvalidChainID, s.State()) tests.AssertLogCountEventually(t, observedLogs, "sendonly rpc ChainID doesn't match local chain ID", failuresCount) tests.AssertEventually(t, func() bool { - return s.State() == nodeStateAlive + return s.State() == NodeStateAlive }) }) t.Run("Start with Random ChainID", func(t *testing.T) { @@ -123,7 +123,7 @@ func TestStartSendOnlyNode(t *testing.T) { lggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) client := newMockSendOnlyClient[types.ID](t) client.On("Close").Once() - client.On("DialHTTP").Return(nil).Once() + client.On("Dial", mock.Anything).Return(nil).Once() configuredChainID := types.RandomID() client.On("ChainID", mock.Anything).Return(configuredChainID, nil) s := NewSendOnlyNode(lggr, url.URL{}, t.Name(), configuredChainID, client) @@ -132,7 +132,7 @@ func TestStartSendOnlyNode(t *testing.T) { err := s.Start(tests.Context(t)) assert.NoError(t, err) tests.AssertEventually(t, func() bool { - return s.State() == nodeStateAlive + return s.State() == NodeStateAlive }) assert.Equal(t, 0, observedLogs.Len()) // No warnings expected }) diff --git a/common/client/types.go b/common/client/types.go index d2706e4fec1..fd0d4e85397 100644 --- a/common/client/types.go +++ b/common/client/types.go @@ -2,7 +2,6 @@ package client import ( "context" - evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" "math/big" "github.com/smartcontractkit/chainlink-common/pkg/assets" @@ -13,10 +12,10 @@ import ( // RPCClient includes all the necessary generalized RPC methods along with any additional chain-specific methods. // -//go:generate mockery --quiet --name RPCClient --output ./mocks --case=underscore +//go:generate mockery --quiet --name RPCClient --structname mockRPCClient --filename "mock_rpc_client_test.go" --inpackage --case=underscore type RPCClient[ CHAIN_ID types.ID, - HEAD *evmtypes.Head, + HEAD Head, ] interface { // ChainID - fetches ChainID from the RPC to verify that it matches config ChainID(ctx context.Context) (CHAIN_ID, error) @@ -34,6 +33,17 @@ type RPCClient[ UnsubscribeAllExcept(subs ...types.Subscription) // Close - closes all subscriptions and aborts all RPC calls Close() + // GetInterceptedChainInfo - returns latest and highest observed by application layer ChainInfo. + // latest ChainInfo is the most recent value received within a NodeClient's current lifecycle between Dial and DisconnectAll. + // highestUserObservations ChainInfo is the highest ChainInfo observed excluding health checks calls. + // Its values must not be reset. + // The results of corresponding calls, to get the most recent head and the latest finalized head, must be + // intercepted and reflected in ChainInfo before being returned to a caller. Otherwise, MultiNode is not able to + // provide repeatable read guarantee. + // DisconnectAll must reset latest ChainInfo to default value. + // Ensure implementation does not have a race condition when values are reset before request completion and as + // a result latest ChainInfo contains information from the previous cycle. + GetInterceptedChainInfo() (latest, highestUserObservations ChainInfo) } // RPC includes all the necessary methods for a multi-node client to interact directly with any RPC endpoint. @@ -83,8 +93,6 @@ type Head interface { } // NodeClient includes all the necessary RPC methods required by a node. -// -//go:generate mockery --quiet --name NodeClient --structname mockNodeClient --filename "mock_node_client_test.go" --inpackage --case=underscore type NodeClient[ CHAIN_ID types.ID, HEAD Head, @@ -183,7 +191,7 @@ type connection[ ] interface { ChainID(ctx context.Context) (CHAIN_ID, error) Dial(ctx context.Context) error - SubscribeNewHead(ctx context.Context, channel chan<- HEAD) (types.Subscription, error) + SubscribeNewHead(ctx context.Context) (<-chan HEAD, types.Subscription, error) } // PoolChainInfoProvider - provides aggregation of nodes pool ChainInfo diff --git a/common/headtracker/head_listener.go b/common/headtracker/head_listener.go index 25715b35280..cc0afd0f2b4 100644 --- a/common/headtracker/head_listener.go +++ b/common/headtracker/head_listener.go @@ -58,7 +58,7 @@ type headListener[ client htrktypes.Client[HTH, S, ID, BLOCK_HASH] logger logger.Logger chStop services.StopChan - chHeaders chan HTH + chHeaders <-chan HTH headSubscription types.Subscription connected atomic.Bool receivingHeads atomic.Bool @@ -218,12 +218,9 @@ func (hl *headListener[HTH, S, ID, BLOCK_HASH]) subscribe(ctx context.Context) b } func (hl *headListener[HTH, S, ID, BLOCK_HASH]) subscribeToHead(ctx context.Context) error { - hl.chHeaders = make(chan HTH) - var err error - hl.headSubscription, err = hl.client.SubscribeNewHead(ctx, hl.chHeaders) + hl.chHeaders, hl.headSubscription, err = hl.client.SubscribeNewHead(ctx) if err != nil { - close(hl.chHeaders) return fmt.Errorf("Client#SubscribeNewHead: %w", err) } diff --git a/common/headtracker/types/client.go b/common/headtracker/types/client.go index a1e419809b5..b697c336f58 100644 --- a/common/headtracker/types/client.go +++ b/common/headtracker/types/client.go @@ -14,7 +14,7 @@ type Client[H types.Head[BLOCK_HASH], S types.Subscription, ID types.ID, BLOCK_H ConfiguredChainID() (id ID) // SubscribeNewHead is the method in which the client receives new Head. // It can be implemented differently for each chain i.e websocket, polling, etc - SubscribeNewHead(ctx context.Context, ch chan<- H) (S, error) + SubscribeNewHead(ctx context.Context) (<-chan H, S, error) // LatestFinalizedBlock - returns the latest block that was marked as finalized LatestFinalizedBlock(ctx context.Context) (head H, err error) } diff --git a/common/types/subscription.go b/common/types/subscription.go index e0cd0a1660d..b341fb42c44 100644 --- a/common/types/subscription.go +++ b/common/types/subscription.go @@ -7,7 +7,8 @@ package types // This is a generic interface for Subscription to represent used by clients. type Subscription interface { // Unsubscribe cancels the sending of events to the data channel - // and closes the error channel. + // and closes the error channel. Unsubscribe should be callable multiple + // times without causing an error. Unsubscribe() // Err returns the subscription error channel. The error channel receives // a value if there is an issue with the subscription (e.g. the network connection diff --git a/core/chains/evm/client/chain_client.go b/core/chains/evm/client/chain_client.go index f18c900b038..3062d6dddea 100644 --- a/core/chains/evm/client/chain_client.go +++ b/core/chains/evm/client/chain_client.go @@ -3,18 +3,17 @@ package client import ( "context" "math/big" + "sync" "time" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/rpc" + ethrpc "github.com/ethereum/go-ethereum/rpc" commonassets "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/logger" - commonclient "github.com/smartcontractkit/chainlink/v2/common/client" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/chaintype" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -33,12 +32,10 @@ type Client interface { Close() // ChainID locally stored for quick access ConfiguredChainID() *big.Int - // ChainID RPC call - ChainID() (*big.Int, error) // NodeStates returns a map of node Name->node state // It might be nil or empty, e.g. for mock clients etc - NodeStates() map[string]string + NodeStates() map[string]commonclient.NodeState TokenBalance(ctx context.Context, address common.Address, contractAddress common.Address) (*big.Int, error) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) @@ -46,12 +43,12 @@ type Client interface { // Wrapped RPC methods CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error - BatchCallContext(ctx context.Context, b []rpc.BatchElem) error + BatchCallContext(ctx context.Context, b []ethrpc.BatchElem) error // BatchCallContextAll calls BatchCallContext for every single node including // sendonlys. // CAUTION: This should only be used for mass re-transmitting transactions, it // might have unexpected effects to use it for anything else. - BatchCallContextAll(ctx context.Context, b []rpc.BatchElem) error + BatchCallContextAll(ctx context.Context, b []ethrpc.BatchElem) error // HeadByNumber and HeadByHash is a reimplemented version due to a // difference in how block header hashes are calculated by Parity nodes @@ -59,7 +56,7 @@ type Client interface { // correct hash from the RPC response. HeadByNumber(ctx context.Context, n *big.Int) (*evmtypes.Head, error) HeadByHash(ctx context.Context, n common.Hash) (*evmtypes.Head, error) - SubscribeNewHead(ctx context.Context, ch chan<- *evmtypes.Head) (ethereum.Subscription, error) + SubscribeNewHead(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) // LatestFinalizedBlock - returns the latest finalized block as it's returned from an RPC. // CAUTION: Using this method might cause local finality violations. It's highly recommended // to use HeadTracker to get latest finalized block. @@ -105,20 +102,9 @@ func ContextWithDefaultTimeout() (ctx context.Context, cancel context.CancelFunc } type chainClient struct { - multiNode commonclient.MultiNode[ + multiNode *commonclient.MultiNode[ *big.Int, - evmtypes.Nonce, - common.Address, - common.Hash, - *types.Transaction, - common.Hash, - types.Log, - ethereum.FilterQuery, - *evmtypes.Receipt, - *assets.Wei, - *evmtypes.Head, - RPCClient, - rpc.BatchElem, + *RpcClient, ] logger logger.SugaredLogger chainType chaintype.ChainType @@ -129,11 +115,9 @@ func NewChainClient( lggr logger.Logger, selectionMode string, leaseDuration time.Duration, - noNewHeadsThreshold time.Duration, - nodes []commonclient.Node[*big.Int, *evmtypes.Head, RPCClient], - sendonlys []commonclient.SendOnlyNode[*big.Int, RPCClient], + nodes []commonclient.Node[*big.Int, *RpcClient], + sendonlys []commonclient.SendOnlyNode[*big.Int, *RpcClient], chainID *big.Int, - chainType chaintype.ChainType, clientErrors evmconfig.ClientErrors, deathDeclarationDelay time.Duration, ) Client { @@ -141,17 +125,13 @@ func NewChainClient( lggr, selectionMode, leaseDuration, - noNewHeadsThreshold, nodes, sendonlys, chainID, "EVM", - func(tx *types.Transaction, err error) commonclient.SendTxReturnCode { - return ClassifySendError(err, clientErrors, logger.Sugared(logger.Nop()), tx, common.Address{}, chainType.IsL2()) - }, - 0, // use the default value provided by the implementation deathDeclarationDelay, ) + return &chainClient{ multiNode: multiNode, logger: logger.Sugared(lggr), @@ -160,24 +140,61 @@ func NewChainClient( } func (c *chainClient) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { - return c.multiNode.BalanceAt(ctx, account, blockNumber) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.BalanceAt(ctx, account, blockNumber) } // Request specific errors for batch calls are returned to the individual BatchElem. // Ensure the same BatchElem slice provided by the caller is passed through the call stack // to ensure the caller has access to the errors. -func (c *chainClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error { - return c.multiNode.BatchCallContext(ctx, b) +func (c *chainClient) BatchCallContext(ctx context.Context, b []ethrpc.BatchElem) error { + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return err + } + return rpc.BatchCallContext(ctx, b) } // Similar to BatchCallContext, ensure the provided BatchElem slice is passed through -func (c *chainClient) BatchCallContextAll(ctx context.Context, b []rpc.BatchElem) error { - return c.multiNode.BatchCallContextAll(ctx, b) +func (c *chainClient) BatchCallContextAll(ctx context.Context, b []ethrpc.BatchElem) error { + var wg sync.WaitGroup + defer wg.Wait() + + // Select main RPC to use for return value + main, selectionErr := c.multiNode.SelectRPC() + if selectionErr != nil { + return selectionErr + } + + doFunc := func(ctx context.Context, rpc *RpcClient, isSendOnly bool) { + if rpc == main { + return + } + // Parallel call made to all other nodes with ignored return value + wg.Add(1) + go func(rpc *RpcClient) { + defer wg.Done() + err := rpc.BatchCallContext(ctx, b) + if err != nil { + c.logger.Debugw("Secondary node BatchCallContext failed", "err", err) + } else { + c.logger.Debug("Secondary node BatchCallContext success") + } + }(rpc) + } + + if err := c.multiNode.DoAll(ctx, doFunc); err != nil { + return err + } + return main.BatchCallContext(ctx, b) } // TODO-1663: return custom Block type instead of geth's once client.go is deprecated. func (c *chainClient) BlockByHash(ctx context.Context, hash common.Hash) (b *types.Block, err error) { - rpc, err := c.multiNode.SelectNodeRPC() + rpc, err := c.multiNode.SelectRPC() if err != nil { return b, err } @@ -186,7 +203,7 @@ func (c *chainClient) BlockByHash(ctx context.Context, hash common.Hash) (b *typ // TODO-1663: return custom Block type instead of geth's once client.go is deprecated. func (c *chainClient) BlockByNumber(ctx context.Context, number *big.Int) (b *types.Block, err error) { - rpc, err := c.multiNode.SelectNodeRPC() + rpc, err := c.multiNode.SelectRPC() if err != nil { return b, err } @@ -194,48 +211,66 @@ func (c *chainClient) BlockByNumber(ctx context.Context, number *big.Int) (b *ty } func (c *chainClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { - return c.multiNode.CallContext(ctx, result, method, args...) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return err + } + return rpc.CallContext(ctx, result, method, args...) } func (c *chainClient) CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { - return c.multiNode.CallContract(ctx, msg, blockNumber) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.CallContract(ctx, msg, blockNumber) } func (c *chainClient) PendingCallContract(ctx context.Context, msg ethereum.CallMsg) ([]byte, error) { - return c.multiNode.PendingCallContract(ctx, msg) -} - -// TODO-1663: change this to actual ChainID() call once client.go is deprecated. -func (c *chainClient) ChainID() (*big.Int, error) { - //return c.multiNode.ChainID(ctx), nil - return c.multiNode.ConfiguredChainID(), nil + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.PendingCallContract(ctx, msg) } func (c *chainClient) Close() { - c.multiNode.Close() + _ = c.multiNode.Close() } func (c *chainClient) CodeAt(ctx context.Context, account common.Address, blockNumber *big.Int) ([]byte, error) { - return c.multiNode.CodeAt(ctx, account, blockNumber) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.CodeAt(ctx, account, blockNumber) } func (c *chainClient) ConfiguredChainID() *big.Int { - return c.multiNode.ConfiguredChainID() + return c.multiNode.ChainID() } func (c *chainClient) Dial(ctx context.Context) error { - return c.multiNode.Dial(ctx) + return c.multiNode.Start(ctx) } func (c *chainClient) EstimateGas(ctx context.Context, call ethereum.CallMsg) (uint64, error) { - return c.multiNode.EstimateGas(ctx, call) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return 0, err + } + return rpc.EstimateGas(ctx, call) } func (c *chainClient) FilterLogs(ctx context.Context, q ethereum.FilterQuery) ([]types.Log, error) { - return c.multiNode.FilterEvents(ctx, q) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.FilterEvents(ctx, q) } func (c *chainClient) HeaderByHash(ctx context.Context, h common.Hash) (head *types.Header, err error) { - rpc, err := c.multiNode.SelectNodeRPC() + rpc, err := c.multiNode.SelectRPC() if err != nil { return head, err } @@ -243,7 +278,7 @@ func (c *chainClient) HeaderByHash(ctx context.Context, h common.Hash) (head *ty } func (c *chainClient) HeaderByNumber(ctx context.Context, n *big.Int) (head *types.Header, err error) { - rpc, err := c.multiNode.SelectNodeRPC() + rpc, err := c.multiNode.SelectRPC() if err != nil { return head, err } @@ -251,11 +286,19 @@ func (c *chainClient) HeaderByNumber(ctx context.Context, n *big.Int) (head *typ } func (c *chainClient) HeadByHash(ctx context.Context, h common.Hash) (*evmtypes.Head, error) { - return c.multiNode.BlockByHash(ctx, h) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.BlockByHash(ctx, h) } func (c *chainClient) HeadByNumber(ctx context.Context, n *big.Int) (*evmtypes.Head, error) { - return c.multiNode.BlockByNumber(ctx, n) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.BlockByNumber(ctx, n) } func (c *chainClient) IsL2() bool { @@ -263,19 +306,27 @@ func (c *chainClient) IsL2() bool { } func (c *chainClient) LINKBalance(ctx context.Context, address common.Address, linkAddress common.Address) (*commonassets.Link, error) { - return c.multiNode.LINKBalance(ctx, address, linkAddress) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.LINKBalance(ctx, address, linkAddress) } func (c *chainClient) LatestBlockHeight(ctx context.Context) (*big.Int, error) { - return c.multiNode.LatestBlockHeight(ctx) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.LatestBlockHeight(ctx) } -func (c *chainClient) NodeStates() map[string]string { +func (c *chainClient) NodeStates() map[string]commonclient.NodeState { return c.multiNode.NodeStates() } func (c *chainClient) PendingCodeAt(ctx context.Context, account common.Address) (b []byte, err error) { - rpc, err := c.multiNode.SelectNodeRPC() + rpc, err := c.multiNode.SelectRPC() if err != nil { return b, err } @@ -284,12 +335,20 @@ func (c *chainClient) PendingCodeAt(ctx context.Context, account common.Address) // TODO-1663: change this to evmtypes.Nonce(int64) once client.go is deprecated. func (c *chainClient) PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) { - n, err := c.multiNode.PendingSequenceAt(ctx, account) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return 0, err + } + n, err := rpc.PendingSequenceAt(ctx, account) return uint64(n), err } func (c *chainClient) SendTransaction(ctx context.Context, tx *types.Transaction) error { - return c.multiNode.SendTransaction(ctx, tx) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return err + } + return rpc.SendTransaction(ctx, tx) } func (c *chainClient) SendTransactionReturnCode(ctx context.Context, tx *types.Transaction, fromAddress common.Address) (commonclient.SendTxReturnCode, error) { @@ -299,23 +358,37 @@ func (c *chainClient) SendTransactionReturnCode(ctx context.Context, tx *types.T } func (c *chainClient) SequenceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (evmtypes.Nonce, error) { - return c.multiNode.SequenceAt(ctx, account, blockNumber) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return 0, err + } + return rpc.SequenceAt(ctx, account, blockNumber) } func (c *chainClient) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (s ethereum.Subscription, err error) { - rpc, err := c.multiNode.SelectNodeRPC() + rpc, err := c.multiNode.SelectRPC() if err != nil { return s, err } return rpc.SubscribeFilterLogs(ctx, q, ch) } -func (c *chainClient) SubscribeNewHead(ctx context.Context, ch chan<- *evmtypes.Head) (ethereum.Subscription, error) { - return c.multiNode.SubscribeNewHead(ctx, ch) +func (c *chainClient) SubscribeNewHead(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, nil, err + } + + ch, sub, err := rpc.SubscribeToHeads(ctx) + if err != nil { + return nil, nil, err + } + + return ch, sub, nil } func (c *chainClient) SuggestGasPrice(ctx context.Context) (p *big.Int, err error) { - rpc, err := c.multiNode.SelectNodeRPC() + rpc, err := c.multiNode.SelectRPC() if err != nil { return p, err } @@ -323,7 +396,7 @@ func (c *chainClient) SuggestGasPrice(ctx context.Context) (p *big.Int, err erro } func (c *chainClient) SuggestGasTipCap(ctx context.Context) (t *big.Int, err error) { - rpc, err := c.multiNode.SelectNodeRPC() + rpc, err := c.multiNode.SelectRPC() if err != nil { return t, err } @@ -331,16 +404,24 @@ func (c *chainClient) SuggestGasTipCap(ctx context.Context) (t *big.Int, err err } func (c *chainClient) TokenBalance(ctx context.Context, address common.Address, contractAddress common.Address) (*big.Int, error) { - return c.multiNode.TokenBalance(ctx, address, contractAddress) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.TokenBalance(ctx, address, contractAddress) } func (c *chainClient) TransactionByHash(ctx context.Context, txHash common.Hash) (*types.Transaction, error) { - return c.multiNode.TransactionByHash(ctx, txHash) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.TransactionByHash(ctx, txHash) } // TODO-1663: return custom Receipt type instead of geth's once client.go is deprecated. func (c *chainClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (r *types.Receipt, err error) { - rpc, err := c.multiNode.SelectNodeRPC() + rpc, err := c.multiNode.SelectRPC() if err != nil { return r, err } @@ -349,7 +430,11 @@ func (c *chainClient) TransactionReceipt(ctx context.Context, txHash common.Hash } func (c *chainClient) LatestFinalizedBlock(ctx context.Context) (*evmtypes.Head, error) { - return c.multiNode.LatestFinalizedBlock(ctx) + rpc, err := c.multiNode.SelectRPC() + if err != nil { + return nil, err + } + return rpc.LatestFinalizedBlock(ctx) } func (c *chainClient) CheckTxValidity(ctx context.Context, from common.Address, to common.Address, data []byte) *SendError { diff --git a/core/chains/evm/client/chain_client_test.go b/core/chains/evm/client/chain_client_test.go index 33955c16451..09f402389eb 100644 --- a/core/chains/evm/client/chain_client_test.go +++ b/core/chains/evm/client/chain_client_test.go @@ -3,14 +3,11 @@ package client_test import ( "context" "encoding/json" - "errors" "fmt" "math/big" - "net/http/httptest" "net/url" "os" "strings" - "sync/atomic" "testing" "time" @@ -18,10 +15,8 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/rpc" pkgerrors "github.com/pkg/errors" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -30,7 +25,6 @@ import ( commonclient "github.com/smartcontractkit/chainlink/v2/common/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client/mocks" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/testutils" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" @@ -408,6 +402,7 @@ func TestEthClient_SendTransaction_NoSecondaryURL(t *testing.T) { assert.NoError(t, err) } +/* TODO: Implement tx sender func TestEthClient_SendTransaction_WithSecondaryURLs(t *testing.T) { t.Parallel() @@ -450,6 +445,7 @@ func TestEthClient_SendTransaction_WithSecondaryURLs(t *testing.T) { // synchronization. We have to rely on timing instead. require.Eventually(t, func() bool { return service.sentCount.Load() == int32(2) }, tests.WaitTimeout(t), 500*time.Millisecond) } +*/ func TestEthClient_SendTransactionReturnCode(t *testing.T) { t.Parallel() @@ -691,6 +687,7 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { }) } +/* type sendTxService struct { chainID *big.Int sentCount atomic.Int32 @@ -704,6 +701,7 @@ func (x *sendTxService) SendRawTransaction(ctx context.Context, signRawTx hexuti x.sentCount.Add(1) return nil } +*/ func TestEthClient_SubscribeNewHead(t *testing.T) { t.Parallel() @@ -729,8 +727,7 @@ func TestEthClient_SubscribeNewHead(t *testing.T) { err := ethClient.Dial(tests.Context(t)) require.NoError(t, err) - headCh := make(chan *evmtypes.Head) - sub, err := ethClient.SubscribeNewHead(ctx, headCh) + headCh, sub, err := ethClient.SubscribeNewHead(ctx) require.NoError(t, err) select { @@ -739,65 +736,13 @@ func TestEthClient_SubscribeNewHead(t *testing.T) { case <-ctx.Done(): t.Fatal(ctx.Err()) case h := <-headCh: + fmt.Println("Received head", h) require.NotNil(t, h.EVMChainID) require.Zero(t, chainId.Cmp(h.EVMChainID.ToInt())) } sub.Unsubscribe() } -func newMockRpc(t *testing.T) *mocks.RPCClient { - mockRpc := mocks.NewRPCClient(t) - mockRpc.On("Dial", mock.Anything).Return(nil).Once() - mockRpc.On("Close").Return(nil).Once() - mockRpc.On("ChainID", mock.Anything).Return(testutils.FixtureChainID, nil).Once() - // node does not always manage to fully setup aliveLoop, so we have to make calls optional to avoid flakes - mockRpc.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(client.NewMockSubscription(), nil).Maybe() - mockRpc.On("SetAliveLoopSub", mock.Anything).Return().Maybe() - return mockRpc -} - -func TestChainClient_BatchCallContext(t *testing.T) { - t.Parallel() - - t.Run("batch requests return errors", func(t *testing.T) { - ctx := tests.Context(t) - rpcError := errors.New("something went wrong") - blockNumResp := "" - blockNum := hexutil.EncodeBig(big.NewInt(42)) - b := []rpc.BatchElem{ - { - Method: "eth_getBlockByNumber", - Args: []interface{}{blockNum, true}, - Result: &types.Block{}, - }, - { - Method: "eth_blockNumber", - Result: &blockNumResp, - }, - } - - mockRpc := newMockRpc(t) - mockRpc.On("GetInterceptedChainInfo").Return(commonclient.ChainInfo{}, commonclient.ChainInfo{}).Maybe() - mockRpc.On("BatchCallContext", mock.Anything, b).Run(func(args mock.Arguments) { - reqs := args.Get(1).([]rpc.BatchElem) - for i := 0; i < len(reqs); i++ { - elem := &reqs[i] - elem.Error = rpcError - } - }).Return(nil).Once() - - client := client.NewChainClientWithMockedRpc(t, commonclient.NodeSelectionModeRoundRobin, time.Second*0, time.Second*0, testutils.FixtureChainID, mockRpc) - err := client.Dial(ctx) - require.NoError(t, err) - - err = client.BatchCallContext(ctx, b) - require.NoError(t, err) - for _, elem := range b { - require.ErrorIs(t, rpcError, elem.Error) - } - }) -} - func TestEthClient_ErroringClient(t *testing.T) { t.Parallel() ctx := tests.Context(t) @@ -826,11 +771,8 @@ func TestEthClient_ErroringClient(t *testing.T) { _, err = erroringClient.CallContract(ctx, ethereum.CallMsg{}, nil) require.Equal(t, err, commonclient.ErroringNodeError) - // TODO-1663: test actual ChainID() call once client.go is deprecated. - id, err := erroringClient.ChainID() - require.Equal(t, id, testutils.FixtureChainID) - //require.Equal(t, err, commonclient.ErroringNodeError) - require.Equal(t, err, nil) + id := erroringClient.ConfiguredChainID() + require.Equal(t, id, big.NewInt(0)) _, err = erroringClient.CodeAt(ctx, common.Address{}, nil) require.Equal(t, err, commonclient.ErroringNodeError) @@ -884,7 +826,7 @@ func TestEthClient_ErroringClient(t *testing.T) { _, err = erroringClient.SubscribeFilterLogs(ctx, ethereum.FilterQuery{}, nil) require.Equal(t, err, commonclient.ErroringNodeError) - _, err = erroringClient.SubscribeNewHead(ctx, nil) + _, _, err = erroringClient.SubscribeNewHead(ctx) require.Equal(t, err, commonclient.ErroringNodeError) _, err = erroringClient.SuggestGasPrice(ctx) diff --git a/core/chains/evm/client/evm_client.go b/core/chains/evm/client/evm_client.go index 7d121f2eca1..59c6e576a71 100644 --- a/core/chains/evm/client/evm_client.go +++ b/core/chains/evm/client/evm_client.go @@ -10,13 +10,12 @@ import ( evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/chaintype" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" - evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) func NewEvmClient(cfg evmconfig.NodePool, chainCfg commonclient.ChainConfig, clientErrors evmconfig.ClientErrors, lggr logger.Logger, chainID *big.Int, nodes []*toml.Node, chainType chaintype.ChainType) Client { var empty url.URL - var primaries []commonclient.Node[*big.Int, *evmtypes.Head, RPCClient] - var sendonlys []commonclient.SendOnlyNode[*big.Int, RPCClient] + var primaries []commonclient.Node[*big.Int, *RpcClient] + var sendonlys []commonclient.SendOnlyNode[*big.Int, *RpcClient] for i, node := range nodes { if node.SendOnly != nil && *node.SendOnly { rpc := NewRPCClient(cfg, lggr, empty, (*url.URL)(node.HTTPURL), *node.Name, int32(i), chainID, @@ -34,6 +33,6 @@ func NewEvmClient(cfg evmconfig.NodePool, chainCfg commonclient.ChainConfig, cli } } - return NewChainClient(lggr, cfg.SelectionMode(), cfg.LeaseDuration(), chainCfg.NodeNoNewHeadsThreshold(), - primaries, sendonlys, chainID, chainType, clientErrors, cfg.DeathDeclarationDelay()) + return NewChainClient(lggr, cfg.SelectionMode(), cfg.LeaseDuration(), + primaries, sendonlys, chainID, clientErrors, cfg.DeathDeclarationDelay()) } diff --git a/core/chains/evm/client/helpers_test.go b/core/chains/evm/client/helpers_test.go index e1017a5564f..8bc6f303acf 100644 --- a/core/chains/evm/client/helpers_test.go +++ b/core/chains/evm/client/helpers_test.go @@ -14,7 +14,6 @@ import ( commonclient "github.com/smartcontractkit/chainlink/v2/common/client" clientMocks "github.com/smartcontractkit/chainlink/v2/common/client/mocks" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/chaintype" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) @@ -140,27 +139,29 @@ func NewChainClientWithTestNode( } lggr := logger.Test(t) - rpc := NewRPCClient(lggr, *parsed, rpcHTTPURL, "eth-primary-rpc-0", id, chainID, commonclient.Primary) + nodePoolCfg := TestNodePoolConfig{ + NodeFinalizedBlockPollInterval: 1 * time.Second, + } + rpc := NewRPCClient(nodePoolCfg, lggr, *parsed, rpcHTTPURL, "eth-primary-rpc-0", id, chainID, commonclient.Primary) - n := commonclient.NewNode[*big.Int, *evmtypes.Head, RPCClient]( + n := commonclient.NewNode[*big.Int, *evmtypes.Head, *RpcClient]( nodeCfg, clientMocks.ChainConfig{NoNewHeadsThresholdVal: noNewHeadsThreshold}, lggr, *parsed, rpcHTTPURL, "eth-primary-node-0", id, chainID, 1, rpc, "EVM") - primaries := []commonclient.Node[*big.Int, *evmtypes.Head, RPCClient]{n} + primaries := []commonclient.Node[*big.Int, *RpcClient]{n} - var sendonlys []commonclient.SendOnlyNode[*big.Int, RPCClient] + var sendonlys []commonclient.SendOnlyNode[*big.Int, *RpcClient] for i, u := range sendonlyRPCURLs { if u.Scheme != "http" && u.Scheme != "https" { return nil, pkgerrors.Errorf("sendonly ethereum rpc url scheme must be http(s): %s", u.String()) } var empty url.URL - rpc := NewRPCClient(lggr, empty, &sendonlyRPCURLs[i], fmt.Sprintf("eth-sendonly-rpc-%d", i), id, chainID, commonclient.Secondary) - s := commonclient.NewSendOnlyNode[*big.Int, RPCClient]( + rpc := NewRPCClient(nodePoolCfg, lggr, empty, &sendonlyRPCURLs[i], fmt.Sprintf("eth-sendonly-rpc-%d", i), id, chainID, commonclient.Secondary) + s := commonclient.NewSendOnlyNode[*big.Int, *RpcClient]( lggr, u, fmt.Sprintf("eth-sendonly-%d", i), chainID, rpc) sendonlys = append(sendonlys, s) } - var chainType chaintype.ChainType clientErrors := NewTestClientErrors() - c := NewChainClient(lggr, nodeCfg.SelectionMode(), leaseDuration, noNewHeadsThreshold, primaries, sendonlys, chainID, chainType, &clientErrors, 0) + c := NewChainClient(lggr, nodeCfg.SelectionMode(), leaseDuration, primaries, sendonlys, chainID, &clientErrors, 0) t.Cleanup(c.Close) return c, nil } @@ -174,8 +175,7 @@ func NewChainClientWithEmptyNode( ) Client { lggr := logger.Test(t) - var chainType chaintype.ChainType - c := NewChainClient(lggr, selectionMode, leaseDuration, noNewHeadsThreshold, nil, nil, chainID, chainType, nil, 0) + c := NewChainClient(lggr, selectionMode, leaseDuration, nil, nil, chainID, nil, 0) t.Cleanup(c.Close) return c } @@ -186,22 +186,20 @@ func NewChainClientWithMockedRpc( leaseDuration time.Duration, noNewHeadsThreshold time.Duration, chainID *big.Int, - rpc RPCClient, + rpc *RpcClient, ) Client { lggr := logger.Test(t) - var chainType chaintype.ChainType - cfg := TestNodePoolConfig{ NodeSelectionMode: commonclient.NodeSelectionModeRoundRobin, } parsed, _ := url.ParseRequestURI("ws://test") - n := commonclient.NewNode[*big.Int, *evmtypes.Head, RPCClient]( + n := commonclient.NewNode[*big.Int, *evmtypes.Head, *RpcClient]( cfg, clientMocks.ChainConfig{NoNewHeadsThresholdVal: noNewHeadsThreshold}, lggr, *parsed, nil, "eth-primary-node-0", 1, chainID, 1, rpc, "EVM") - primaries := []commonclient.Node[*big.Int, *evmtypes.Head, RPCClient]{n} + primaries := []commonclient.Node[*big.Int, *RpcClient]{n} clientErrors := NewTestClientErrors() - c := NewChainClient(lggr, selectionMode, leaseDuration, noNewHeadsThreshold, primaries, nil, chainID, chainType, &clientErrors, 0) + c := NewChainClient(lggr, selectionMode, leaseDuration, primaries, nil, chainID, &clientErrors, 0) t.Cleanup(c.Close) return c } diff --git a/core/chains/evm/client/mocks/client.go b/core/chains/evm/client/mocks/client.go index 8a5b29cf4cb..1bcb64e5e6d 100644 --- a/core/chains/evm/client/mocks/client.go +++ b/core/chains/evm/client/mocks/client.go @@ -208,36 +208,6 @@ func (_m *Client) CallContract(ctx context.Context, msg ethereum.CallMsg, blockN return r0, r1 } -// ChainID provides a mock function with given fields: -func (_m *Client) ChainID() (*big.Int, error) { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for ChainID") - } - - var r0 *big.Int - var r1 error - if rf, ok := ret.Get(0).(func() (*big.Int, error)); ok { - return rf() - } - if rf, ok := ret.Get(0).(func() *big.Int); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*big.Int) - } - } - - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // CheckTxValidity provides a mock function with given fields: ctx, from, to, data func (_m *Client) CheckTxValidity(ctx context.Context, from common.Address, to common.Address, data []byte) *client.SendError { ret := _m.Called(ctx, from, to, data) @@ -618,19 +588,19 @@ func (_m *Client) LatestFinalizedBlock(ctx context.Context) (*evmtypes.Head, err } // NodeStates provides a mock function with given fields: -func (_m *Client) NodeStates() map[string]string { +func (_m *Client) NodeStates() map[string]commonclient.NodeState { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for NodeStates") } - var r0 map[string]string - if rf, ok := ret.Get(0).(func() map[string]string); ok { + var r0 map[string]commonclient.NodeState + if rf, ok := ret.Get(0).(func() map[string]commonclient.NodeState); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(map[string]string) + r0 = ret.Get(0).(map[string]commonclient.NodeState) } } @@ -829,34 +799,43 @@ func (_m *Client) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuer return r0, r1 } -// SubscribeNewHead provides a mock function with given fields: ctx, ch -func (_m *Client) SubscribeNewHead(ctx context.Context, ch chan<- *evmtypes.Head) (ethereum.Subscription, error) { - ret := _m.Called(ctx, ch) +// SubscribeNewHead provides a mock function with given fields: ctx +func (_m *Client) SubscribeNewHead(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for SubscribeNewHead") } - var r0 ethereum.Subscription - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, chan<- *evmtypes.Head) (ethereum.Subscription, error)); ok { - return rf(ctx, ch) + var r0 <-chan *evmtypes.Head + var r1 ethereum.Subscription + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func(context.Context, chan<- *evmtypes.Head) ethereum.Subscription); ok { - r0 = rf(ctx, ch) + if rf, ok := ret.Get(0).(func(context.Context) <-chan *evmtypes.Head); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(ethereum.Subscription) + r0 = ret.Get(0).(<-chan *evmtypes.Head) } } - if rf, ok := ret.Get(1).(func(context.Context, chan<- *evmtypes.Head) error); ok { - r1 = rf(ctx, ch) + if rf, ok := ret.Get(1).(func(context.Context) ethereum.Subscription); ok { + r1 = rf(ctx) } else { - r1 = ret.Error(1) + if ret.Get(1) != nil { + r1 = ret.Get(1).(ethereum.Subscription) + } } - return r0, r1 + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 } // SuggestGasPrice provides a mock function with given fields: ctx diff --git a/core/chains/evm/client/null_client.go b/core/chains/evm/client/null_client.go index 3129bcff9b0..52d418c1405 100644 --- a/core/chains/evm/client/null_client.go +++ b/core/chains/evm/client/null_client.go @@ -90,9 +90,9 @@ func (nc *NullClient) SubscribeFilterLogs(ctx context.Context, q ethereum.Filter return newNullSubscription(nc.lggr), nil } -func (nc *NullClient) SubscribeNewHead(ctx context.Context, ch chan<- *evmtypes.Head) (ethereum.Subscription, error) { +func (nc *NullClient) SubscribeNewHead(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { nc.lggr.Debug("SubscribeNewHead") - return newNullSubscription(nc.lggr), nil + return nil, newNullSubscription(nc.lggr), nil } // @@ -221,7 +221,7 @@ func (nc *NullClient) SuggestGasTipCap(ctx context.Context) (tipCap *big.Int, er } // NodeStates implements evmclient.Client -func (nc *NullClient) NodeStates() map[string]string { return nil } +func (nc *NullClient) NodeStates() map[string]commonclient.NodeState { return nil } func (nc *NullClient) IsL2() bool { nc.lggr.Debug("IsL2") diff --git a/core/chains/evm/client/null_client_test.go b/core/chains/evm/client/null_client_test.go index bc6c166030f..75880fb1769 100644 --- a/core/chains/evm/client/null_client_test.go +++ b/core/chains/evm/client/null_client_test.go @@ -15,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" - evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) func TestNullClient(t *testing.T) { @@ -62,8 +61,7 @@ func TestNullClient(t *testing.T) { require.Nil(t, h) require.Equal(t, 1, logs.FilterMessage("HeadByNumber").Len()) - chHeads := make(chan *evmtypes.Head) - sub, err := nc.SubscribeNewHead(ctx, chHeads) + _, sub, err := nc.SubscribeNewHead(ctx) require.NoError(t, err) require.Equal(t, 1, logs.FilterMessage("SubscribeNewHead").Len()) require.Nil(t, sub.Err()) diff --git a/core/chains/evm/client/rpc_client.go b/core/chains/evm/client/rpc_client.go index d31bb432500..9301e92d71b 100644 --- a/core/chains/evm/client/rpc_client.go +++ b/core/chains/evm/client/rpc_client.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "fmt" "math/big" "net/url" @@ -75,35 +76,7 @@ var ( }, []string{"evmChainID", "nodeName", "rpcHost", "isSendOnly", "success", "rpcCallName"}) ) -// RPCClient includes all the necessary generalized RPC methods along with any additional chain-specific methods. -// -//go:generate mockery --quiet --name RPCClient --output ./mocks --case=underscore -type RPCClient interface { - commonclient.RPC[ - *big.Int, - evmtypes.Nonce, - common.Address, - common.Hash, - *types.Transaction, - common.Hash, - types.Log, - ethereum.FilterQuery, - *evmtypes.Receipt, - *assets.Wei, - *evmtypes.Head, - rpc.BatchElem, - ] - BlockByHashGeth(ctx context.Context, hash common.Hash) (b *types.Block, err error) - BlockByNumberGeth(ctx context.Context, number *big.Int) (b *types.Block, err error) - HeaderByHash(ctx context.Context, h common.Hash) (head *types.Header, err error) - HeaderByNumber(ctx context.Context, n *big.Int) (head *types.Header, err error) - PendingCodeAt(ctx context.Context, account common.Address) (b []byte, err error) - SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (s ethereum.Subscription, err error) - SuggestGasPrice(ctx context.Context) (p *big.Int, err error) - SuggestGasTipCap(ctx context.Context) (t *big.Int, err error) - TransactionReceiptGeth(ctx context.Context, txHash common.Hash) (r *types.Receipt, err error) - GetInterceptedChainInfo() (latest, highestUserObservations commonclient.ChainInfo) -} +const rpcSubscriptionMethodNewHeads = "newHeads" type rawclient struct { rpc *rpc.Client @@ -128,9 +101,6 @@ type RpcClient struct { // close the underlying subscription subs []ethereum.Subscription - // Need to track the aliveLoop subscription, so we do not cancel it when checking lease on the MultiNode - aliveLoopSub ethereum.Subscription - // chStopInFlight can be closed to immediately cancel all in-flight requests on // this RpcClient. Closing and replacing should be serialized through // stateMu since it can happen on state transitions as well as RpcClient Close. @@ -177,14 +147,46 @@ func NewRPCClient( return r } -func (r *RpcClient) SubscribeToHeads(ctx context.Context) (<-chan *evmtypes.Head, commontypes.Subscription, error) { +func (r *RpcClient) SubscribeToHeads(ctx context.Context) (ch <-chan *evmtypes.Head, sub commontypes.Subscription, err error) { + ctx, cancel, chStopInFlight, ws, _ := r.acquireQueryCtx(ctx) + defer cancel() + + args := []interface{}{rpcSubscriptionMethodNewHeads} + start := time.Now() + lggr := r.newRqLggr().With("args", args) + + lggr.Debug("RPC call: evmclient.Client#EthSubscribe") + defer func() { + duration := time.Since(start) + r.logResult(lggr, err, duration, r.getRPCDomain(), "EthSubscribe") + err = r.wrapWS(err) + }() + channel := make(chan *evmtypes.Head) - sub, err := r.subscribe(ctx, channel) - return channel, sub, err + forwarder := newSubForwarder(channel, func(head *evmtypes.Head) *evmtypes.Head { + head.EVMChainID = ubig.New(r.chainID) + r.onNewHead(ctx, chStopInFlight, head) + return head + }, r.wrapRPCClientError) + + err = forwarder.start(ws.rpc.EthSubscribe(ctx, forwarder.srcCh, args...)) + if err != nil { + return nil, nil, err + } + + err = r.registerSub(forwarder, chStopInFlight) + if err != nil { + return nil, nil, err + } + + return channel, forwarder, err } func (r *RpcClient) SubscribeToFinalizedHeads(_ context.Context) (<-chan *evmtypes.Head, commontypes.Subscription, error) { interval := r.cfg.FinalizedBlockPollInterval() + if interval == 0 { + return nil, nil, errors.New("FinalizedBlockPollInterval is 0") + } timeout := interval poller, channel := commonclient.NewPoller[*evmtypes.Head](interval, r.LatestFinalizedBlock, timeout, r.rpcLog) if err := poller.Start(); err != nil { @@ -194,26 +196,29 @@ func (r *RpcClient) SubscribeToFinalizedHeads(_ context.Context) (<-chan *evmtyp } func (r *RpcClient) Ping(ctx context.Context) error { - _, err := r.ClientVersion(ctx) + version, err := r.ClientVersion(ctx) if err != nil { return fmt.Errorf("ping failed: %v", err) } + r.rpcLog.Debugf("ping client version: %s", version) return err } func (r *RpcClient) UnsubscribeAllExcept(subs ...commontypes.Subscription) { + r.stateMu.Lock() + defer r.stateMu.Unlock() + + keepSubs := map[commontypes.Subscription]struct{}{} + for _, sub := range subs { + keepSubs[sub] = struct{}{} + } + for _, sub := range r.subs { - var keep bool - for _, s := range subs { - if sub == s { - keep = true - break - } - } - if !keep { + if _, keep := keepSubs[sub]; !keep { sub.Unsubscribe() } } + r.latestChainInfo = commonclient.ChainInfo{} } // Not thread-safe, pure dial. @@ -277,16 +282,13 @@ func (r *RpcClient) Close() { r.ws.rpc.Close() } }() - - r.stateMu.Lock() - defer r.stateMu.Unlock() r.cancelInflightRequests() } // cancelInflightRequests closes and replaces the chStopInFlight -// WARNING: NOT THREAD-SAFE -// This must be called from within the r.stateMu lock func (r *RpcClient) cancelInflightRequests() { + r.stateMu.Lock() + defer r.stateMu.Unlock() close(r.chStopInFlight) r.chStopInFlight = make(chan struct{}) } @@ -355,34 +357,6 @@ func (r *RpcClient) registerSub(sub ethereum.Subscription, stopInFLightCh chan s return nil } -// DisconnectAll disconnects all clients connected to the rpcClient -func (r *RpcClient) DisconnectAll() { - r.stateMu.Lock() - defer r.stateMu.Unlock() - if r.ws.rpc != nil { - r.ws.rpc.Close() - } - r.cancelInflightRequests() - r.unsubscribeAll() - r.latestChainInfo = commonclient.ChainInfo{} -} - -// unsubscribeAll unsubscribes all subscriptions -// WARNING: NOT THREAD-SAFE -// This must be called from within the r.stateMu lock -func (r *RpcClient) unsubscribeAll() { - for _, sub := range r.subs { - sub.Unsubscribe() - } - r.subs = nil -} -func (r *RpcClient) SetAliveLoopSub(sub commontypes.Subscription) { - r.stateMu.Lock() - defer r.stateMu.Unlock() - - r.aliveLoopSub = sub -} - // SubscribersCount returns the number of client subscribed to the node func (r *RpcClient) SubscribersCount() int32 { r.stateMu.RLock() @@ -390,19 +364,6 @@ func (r *RpcClient) SubscribersCount() int32 { return int32(len(r.subs)) } -// UnsubscribeAllExceptAliveLoop disconnects all subscriptions to the node except the alive loop subscription -// while holding the n.stateMu lock -func (r *RpcClient) UnsubscribeAllExceptAliveLoop() { - r.stateMu.Lock() - defer r.stateMu.Unlock() - - for _, s := range r.subs { - if s != r.aliveLoopSub { - s.Unsubscribe() - } - } -} - // RPC wrappers // CallContext implementation @@ -449,60 +410,6 @@ func (r *RpcClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) err return err } -func (r *RpcClient) subscribe(ctx context.Context, channel chan<- *evmtypes.Head, args ...interface{}) (commontypes.Subscription, error) { - ctx, cancel, ws, _ := r.makeLiveQueryCtxAndSafeGetClients(ctx) - defer cancel() - lggr := r.newRqLggr().With("args", args) - - lggr.Debug("RPC call: evmclient.Client#EthSubscribe") - start := time.Now() - var sub commontypes.Subscription - sub, err := ws.rpc.EthSubscribe(ctx, channel, args...) - if err == nil { - err = r.registerSub(sub, r.chStopInFlight) - if err != nil { - sub.Unsubscribe() - return nil, err - } - } - duration := time.Since(start) - - r.logResult(lggr, err, duration, r.getRPCDomain(), "EthSubscribe") - - return sub, r.wrapWS(err) -} - -func (r *RpcClient) SubscribeNewHead(ctx context.Context, channel chan<- *evmtypes.Head) (_ commontypes.Subscription, err error) { - ctx, cancel, chStopInFlight, ws, _ := r.acquireQueryCtx(ctx) - defer cancel() - args := []interface{}{"newHeads"} - lggr := r.newRqLggr().With("args", args) - - lggr.Debug("RPC call: evmclient.Client#EthSubscribe") - start := time.Now() - defer func() { - duration := time.Since(start) - r.logResult(lggr, err, duration, r.getRPCDomain(), "EthSubscribe") - err = r.wrapWS(err) - }() - subForwarder := newSubForwarder(channel, func(head *evmtypes.Head) *evmtypes.Head { - head.EVMChainID = ubig.New(r.chainID) - r.onNewHead(ctx, chStopInFlight, head) - return head - }, r.wrapRPCClientError) - err = subForwarder.start(ws.rpc.EthSubscribe(ctx, subForwarder.srcCh, args...)) - if err != nil { - return - } - - err = r.registerSub(subForwarder, chStopInFlight) - if err != nil { - return - } - - return subForwarder, nil -} - // GethClient wrappers func (r *RpcClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (receipt *evmtypes.Receipt, err error) { @@ -609,7 +516,11 @@ func (r *RpcClient) HeaderByHash(ctx context.Context, hash common.Hash) (header } func (r *RpcClient) LatestFinalizedBlock(ctx context.Context) (*evmtypes.Head, error) { - return r.blockByNumber(ctx, rpc.FinalizedBlockNumber.String()) + head, err := r.blockByNumber(ctx, rpc.FinalizedBlockNumber.String()) + if err != nil { + return nil, err + } + return head, nil } func (r *RpcClient) BlockByNumber(ctx context.Context, number *big.Int) (head *evmtypes.Head, err error) { diff --git a/core/chains/evm/client/rpc_client_test.go b/core/chains/evm/client/rpc_client_test.go index 682c4352457..91651c94210 100644 --- a/core/chains/evm/client/rpc_client_test.go +++ b/core/chains/evm/client/rpc_client_test.go @@ -7,6 +7,7 @@ import ( "math/big" "net/url" "testing" + "time" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/core/types" @@ -15,10 +16,8 @@ import ( "github.com/tidwall/gjson" "go.uber.org/zap" - "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" commonclient "github.com/smartcontractkit/chainlink/v2/common/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/testutils" @@ -41,6 +40,10 @@ func TestRPCClient_SubscribeNewHead(t *testing.T) { chainId := big.NewInt(123456) lggr := logger.Test(t) + nodePoolCfg := client.TestNodePoolConfig{ + NodeFinalizedBlockPollInterval: 1 * time.Second, + } + serverCallBack := func(method string, params gjson.Result) (resp testutils.JSONRPCResponse) { if method == "eth_unsubscribe" { resp.Result = "true" @@ -56,7 +59,7 @@ func TestRPCClient_SubscribeNewHead(t *testing.T) { server := testutils.NewWSServer(t, chainId, serverCallBack) wsURL := server.WSURL() - rpc := client.NewRPCClient(lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) + rpc := client.NewRPCClient(nodePoolCfg, lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) defer rpc.Close() require.NoError(t, rpc.Dial(ctx)) // set to default values @@ -68,8 +71,7 @@ func TestRPCClient_SubscribeNewHead(t *testing.T) { assert.Equal(t, int64(0), highestUserObservations.FinalizedBlockNumber) assert.Nil(t, highestUserObservations.TotalDifficulty) - ch := make(chan *evmtypes.Head) - sub, err := rpc.SubscribeNewHead(tests.Context(t), ch) + ch, sub, err := rpc.SubscribeToHeads(tests.Context(t)) require.NoError(t, err) defer sub.Unsubscribe() go server.MustWriteBinaryMessageSync(t, makeNewHeadWSMessage(&evmtypes.Head{Number: 256, TotalDifficulty: big.NewInt(1000)})) @@ -92,8 +94,8 @@ func TestRPCClient_SubscribeNewHead(t *testing.T) { assertHighestUserObservations(highestUserObservations) - // DisconnectAll resets latest - rpc.DisconnectAll() + // UnsubscribeAllExcept resets latest + rpc.UnsubscribeAllExcept() latest, highestUserObservations = rpc.GetInterceptedChainInfo() assert.Equal(t, int64(0), latest.BlockNumber) @@ -106,11 +108,10 @@ func TestRPCClient_SubscribeNewHead(t *testing.T) { server := testutils.NewWSServer(t, chainId, serverCallBack) wsURL := server.WSURL() - rpc := client.NewRPCClient(lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) + rpc := client.NewRPCClient(nodePoolCfg, lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) defer rpc.Close() require.NoError(t, rpc.Dial(ctx)) - ch := make(chan *evmtypes.Head) - sub, err := rpc.SubscribeNewHead(commonclient.CtxAddHealthCheckFlag(tests.Context(t)), ch) + ch, sub, err := rpc.SubscribeToHeads(commonclient.CtxAddHealthCheckFlag(tests.Context(t))) require.NoError(t, err) defer sub.Unsubscribe() go server.MustWriteBinaryMessageSync(t, makeNewHeadWSMessage(&evmtypes.Head{Number: 256, TotalDifficulty: big.NewInt(1000)})) @@ -129,37 +130,36 @@ func TestRPCClient_SubscribeNewHead(t *testing.T) { t.Run("Block's chain ID matched configured", func(t *testing.T) { server := testutils.NewWSServer(t, chainId, serverCallBack) wsURL := server.WSURL() - rpc := client.NewRPCClient(lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) + rpc := client.NewRPCClient(nodePoolCfg, lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) defer rpc.Close() require.NoError(t, rpc.Dial(ctx)) - ch := make(chan *evmtypes.Head) - sub, err := rpc.SubscribeNewHead(tests.Context(t), ch) + ch, sub, err := rpc.SubscribeToHeads(tests.Context(t)) require.NoError(t, err) defer sub.Unsubscribe() go server.MustWriteBinaryMessageSync(t, makeNewHeadWSMessage(&evmtypes.Head{Number: 256})) head := <-ch require.Equal(t, chainId, head.ChainID()) }) - t.Run("Failed SubscribeNewHead returns and logs proper error", func(t *testing.T) { + t.Run("Failed SubscribeToHeads returns and logs proper error", func(t *testing.T) { server := testutils.NewWSServer(t, chainId, func(reqMethod string, reqParams gjson.Result) (resp testutils.JSONRPCResponse) { return resp }) wsURL := server.WSURL() observedLggr, observed := logger.TestObserved(t, zap.DebugLevel) - rpc := client.NewRPCClient(observedLggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) + rpc := client.NewRPCClient(nodePoolCfg, observedLggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) require.NoError(t, rpc.Dial(ctx)) server.Close() - _, err := rpc.SubscribeNewHead(ctx, make(chan *evmtypes.Head)) + _, _, err := rpc.SubscribeToHeads(ctx) require.ErrorContains(t, err, "RPCClient returned error (rpc)") tests.AssertLogEventually(t, observed, "evmclient.Client#EthSubscribe RPC call failure") }) t.Run("Subscription error is properly wrapper", func(t *testing.T) { server := testutils.NewWSServer(t, chainId, serverCallBack) wsURL := server.WSURL() - rpc := client.NewRPCClient(lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) + rpc := client.NewRPCClient(nodePoolCfg, lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) defer rpc.Close() require.NoError(t, rpc.Dial(ctx)) - sub, err := rpc.SubscribeNewHead(ctx, make(chan *evmtypes.Head)) + _, sub, err := rpc.SubscribeToHeads(ctx) require.NoError(t, err) go server.MustWriteBinaryMessageSync(t, "invalid msg") select { @@ -184,7 +184,7 @@ func TestRPCClient_SubscribeFilterLogs(t *testing.T) { }) wsURL := server.WSURL() observedLggr, observed := logger.TestObserved(t, zap.DebugLevel) - rpc := client.NewRPCClient(observedLggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) + rpc := client.NewRPCClient(client.TestNodePoolConfig{}, observedLggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) require.NoError(t, rpc.Dial(ctx)) server.Close() _, err := rpc.SubscribeFilterLogs(ctx, ethereum.FilterQuery{}, make(chan types.Log)) @@ -201,7 +201,7 @@ func TestRPCClient_SubscribeFilterLogs(t *testing.T) { return resp }) wsURL := server.WSURL() - rpc := client.NewRPCClient(lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) + rpc := client.NewRPCClient(client.TestNodePoolConfig{}, lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary) defer rpc.Close() require.NoError(t, rpc.Dial(ctx)) sub, err := rpc.SubscribeFilterLogs(ctx, ethereum.FilterQuery{}, make(chan types.Log)) @@ -250,7 +250,7 @@ func TestRPCClient_LatestFinalizedBlock(t *testing.T) { } server := createRPCServer() - rpc := client.NewRPCClient(lggr, *server.URL, nil, "rpc", 1, chainId, commonclient.Primary) + rpc := client.NewRPCClient(client.TestNodePoolConfig{}, lggr, *server.URL, nil, "rpc", 1, chainId, commonclient.Primary) require.NoError(t, rpc.Dial(ctx)) defer rpc.Close() server.Head = &evmtypes.Head{Number: 128} @@ -290,7 +290,7 @@ func TestRPCClient_LatestFinalizedBlock(t *testing.T) { assert.Equal(t, int64(256), latest.FinalizedBlockNumber) // DisconnectAll resets latest ChainInfo - rpc.DisconnectAll() + rpc.UnsubscribeAllExcept() latest, highestUserObservations = rpc.GetInterceptedChainInfo() assert.Equal(t, int64(0), highestUserObservations.BlockNumber) assert.Equal(t, int64(128), highestUserObservations.FinalizedBlockNumber) diff --git a/core/chains/evm/client/simulated_backend_client.go b/core/chains/evm/client/simulated_backend_client.go index 6bcc1f36960..b8cf07b39c8 100644 --- a/core/chains/evm/client/simulated_backend_client.go +++ b/core/chains/evm/client/simulated_backend_client.go @@ -298,15 +298,15 @@ func (h *headSubscription) Err() <-chan error { return h.subscription.Err() } // to convert those into evmtypes.Head. func (c *SimulatedBackendClient) SubscribeNewHead( ctx context.Context, - channel chan<- *evmtypes.Head, -) (ethereum.Subscription, error) { +) (<-chan *evmtypes.Head, ethereum.Subscription, error) { subscription := &headSubscription{unSub: make(chan chan struct{})} ch := make(chan *types.Header) + channel := make(chan *evmtypes.Head) var err error subscription.subscription, err = c.b.SubscribeNewHead(ctx, ch) if err != nil { - return nil, fmt.Errorf("%w: could not subscribe to new heads on "+ + return nil, nil, fmt.Errorf("%w: could not subscribe to new heads on "+ "simulated backend", err) } go func() { @@ -334,7 +334,7 @@ func (c *SimulatedBackendClient) SubscribeNewHead( } } }() - return subscription, err + return channel, subscription, err } // HeaderByNumber returns the geth header type. @@ -505,7 +505,7 @@ func (c *SimulatedBackendClient) Backend() *backends.SimulatedBackend { } // NodeStates implements evmclient.Client -func (c *SimulatedBackendClient) NodeStates() map[string]string { return nil } +func (c *SimulatedBackendClient) NodeStates() map[string]commonclient.NodeState { return nil } // Commit imports all the pending transactions as a single block and starts a // fresh new state. diff --git a/core/chains/evm/client/sub_forwarder.go b/core/chains/evm/client/sub_forwarder.go index 93e9b106b4a..a007cb68eaa 100644 --- a/core/chains/evm/client/sub_forwarder.go +++ b/core/chains/evm/client/sub_forwarder.go @@ -36,7 +36,7 @@ func newSubForwarder[T any](destCh chan<- T, interceptResult func(T) T, intercep // start spawns the forwarding loop for sub. func (c *subForwarder[T]) start(sub ethereum.Subscription, err error) error { if err != nil { - close(c.srcCh) + close(c.destCh) return err } c.srcSub = sub diff --git a/core/chains/evm/headtracker/head_broadcaster_test.go b/core/chains/evm/headtracker/head_broadcaster_test.go index 7ac61ab34b0..4ab4928cdc0 100644 --- a/core/chains/evm/headtracker/head_broadcaster_test.go +++ b/core/chains/evm/headtracker/head_broadcaster_test.go @@ -55,11 +55,12 @@ func TestHeadBroadcaster_Subscribe(t *testing.T) { ethClient := testutils.NewEthClientMockWithDefaultChain(t) chchHeaders := make(chan chan<- *evmtypes.Head, 1) - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + chHead := make(chan *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything). Run(func(args mock.Arguments) { - chchHeaders <- args.Get(1).(chan<- *evmtypes.Head) + chchHeaders <- chHead }). - Return(sub, nil) + Return((<-chan *evmtypes.Head)(chHead), sub, nil) ethClient.On("HeadByNumber", mock.Anything, mock.Anything).Return(testutils.Head(1), nil) sub.On("Unsubscribe").Return() diff --git a/core/chains/evm/headtracker/head_listener_test.go b/core/chains/evm/headtracker/head_listener_test.go index 29b090bbffe..5da9178084d 100644 --- a/core/chains/evm/headtracker/head_listener_test.go +++ b/core/chains/evm/headtracker/head_listener_test.go @@ -51,12 +51,11 @@ func Test_HeadListener_HappyPath(t *testing.T) { subscribeAwaiter := testutils.NewAwaiter() unsubscribeAwaiter := testutils.NewAwaiter() - var chHeads chan<- *evmtypes.Head + chHeads := make(chan *evmtypes.Head) var chErr = make(chan error) var chSubErr <-chan error = chErr sub := commonmocks.NewSubscription(t) - ethClient.On("SubscribeNewHead", mock.Anything, mock.AnythingOfType("chan<- *types.Head")).Return(sub, nil).Once().Run(func(args mock.Arguments) { - chHeads = args.Get(1).(chan<- *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything).Return((<-chan *evmtypes.Head)(chHeads), sub, nil).Once().Run(func(args mock.Arguments) { subscribeAwaiter.ItHappened() }) sub.On("Err").Return(chSubErr) @@ -110,13 +109,12 @@ func Test_HeadListener_NotReceivingHeads(t *testing.T) { return nil } + chHeads := make(chan *evmtypes.Head) subscribeAwaiter := testutils.NewAwaiter() - var chHeads chan<- *evmtypes.Head var chErr = make(chan error) var chSubErr <-chan error = chErr sub := commonmocks.NewSubscription(t) - ethClient.On("SubscribeNewHead", mock.Anything, mock.AnythingOfType("chan<- *types.Head")).Return(sub, nil).Once().Run(func(args mock.Arguments) { - chHeads = args.Get(1).(chan<- *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything).Return((<-chan *evmtypes.Head)(chHeads), sub, nil).Once().Run(func(args mock.Arguments) { subscribeAwaiter.ItHappened() }) sub.On("Err").Return(chSubErr) @@ -182,11 +180,10 @@ func Test_HeadListener_SubscriptionErr(t *testing.T) { // initially and once again after exactly one head has been received sub.On("Err").Return(chSubErr).Twice() + headsCh := make(chan *evmtypes.Head) subscribeAwaiter := testutils.NewAwaiter() - var headsCh chan<- *evmtypes.Head // Initial subscribe - ethClient.On("SubscribeNewHead", mock.Anything, mock.AnythingOfType("chan<- *types.Head")).Return(sub, nil).Once().Run(func(args mock.Arguments) { - headsCh = args.Get(1).(chan<- *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything).Return((<-chan *evmtypes.Head)(headsCh), sub, nil).Once().Run(func(args mock.Arguments) { subscribeAwaiter.ItHappened() }) go func() { @@ -216,9 +213,8 @@ func Test_HeadListener_SubscriptionErr(t *testing.T) { sub2.On("Err").Return(chSubErr2) subscribeAwaiter2 := testutils.NewAwaiter() - var headsCh2 chan<- *evmtypes.Head - ethClient.On("SubscribeNewHead", mock.Anything, mock.AnythingOfType("chan<- *types.Head")).Return(sub2, nil).Once().Run(func(args mock.Arguments) { - headsCh2 = args.Get(1).(chan<- *evmtypes.Head) + headsCh2 := make(chan *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything).Return((<-chan *evmtypes.Head)(headsCh2), sub2, nil).Once().Run(func(args mock.Arguments) { subscribeAwaiter2.ItHappened() }) diff --git a/core/chains/evm/headtracker/head_tracker_test.go b/core/chains/evm/headtracker/head_tracker_test.go index 21ff1b1a929..cc205453ccb 100644 --- a/core/chains/evm/headtracker/head_tracker_test.go +++ b/core/chains/evm/headtracker/head_tracker_test.go @@ -68,7 +68,7 @@ func TestHeadTracker_New(t *testing.T) { } ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). Maybe(). - Return(mockEth.NewSub(t), nil) + Return(nil, mockEth.NewSub(t), nil) orm := headtracker.NewORM(*testutils.FixtureChainID, db) assert.Nil(t, orm.IdempotentInsertHead(tests.Context(t), testutils.Head(1))) @@ -148,14 +148,13 @@ func TestHeadTracker_Get(t *testing.T) { mockEth := &testutils.MockEth{ EthClient: ethClient, } - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + ethClient.On("SubscribeNewHead", mock.Anything). Maybe(). Return( - func(ctx context.Context, ch chan<- *evmtypes.Head) ethereum.Subscription { + func(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { defer close(chStarted) - return mockEth.NewSub(t) + return make(<-chan *evmtypes.Head), mockEth.NewSub(t), nil }, - func(ctx context.Context, ch chan<- *evmtypes.Head) error { return nil }, ) ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(testutils.Head(0), nil).Maybe() @@ -201,11 +200,12 @@ func TestHeadTracker_Start_NewHeads(t *testing.T) { ethClient.On("HeadByNumber", mock.Anything, mock.Anything).Return(testutils.Head(0), nil).Once() // for backfill ethClient.On("HeadByNumber", mock.Anything, mock.Anything).Return(testutils.Head(0), nil).Maybe() - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + ch := make(chan *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything). Run(func(mock.Arguments) { close(chStarted) }). - Return(sub, nil) + Return((<-chan *evmtypes.Head)(ch), sub, nil) ht := createHeadTracker(t, ethClient, config.EVM(), config.EVM().HeadTracker(), orm) ht.Start(t) @@ -243,7 +243,7 @@ func TestHeadTracker_Start(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) mockEth := &testutils.MockEth{EthClient: ethClient} sub := mockEth.NewSub(t) - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(sub, nil).Maybe() + ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, sub, nil).Maybe() return createHeadTracker(t, ethClient, config.EVM(), config.EVM().HeadTracker(), orm) } t.Run("Starts even if failed to get initialHead", func(t *testing.T) { @@ -271,7 +271,7 @@ func TestHeadTracker_Start(t *testing.T) { head := testutils.Head(1000) ht.ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(head, nil).Once() ht.ethClient.On("LatestFinalizedBlock", mock.Anything).Return(nil, nil).Once() - ht.ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, errors.New("failed to connect")).Maybe() + ht.ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, nil, errors.New("failed to connect")).Maybe() ht.Start(t) tests.AssertLogEventually(t, ht.observer, "Error handling initial head") }) @@ -286,7 +286,7 @@ func TestHeadTracker_Start(t *testing.T) { ht.ethClient.On("LatestFinalizedBlock", mock.Anything).Return(finalizedHead, nil).Once() // on backfill ht.ethClient.On("LatestFinalizedBlock", mock.Anything).Return(nil, errors.New("backfill call to finalized failed")).Maybe() - ht.ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, errors.New("failed to connect")).Maybe() + ht.ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, nil, errors.New("failed to connect")).Maybe() ht.Start(t) tests.AssertLogEventually(t, ht.observer, "Loaded chain from DB") }) @@ -300,7 +300,7 @@ func TestHeadTracker_Start(t *testing.T) { require.NoError(t, ht.orm.IdempotentInsertHead(ctx, testutils.Head(finalizedHead.Number-1))) // on backfill ht.ethClient.On("HeadByNumber", mock.Anything, mock.Anything).Return(nil, errors.New("backfill call to finalized failed")).Maybe() - ht.ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, errors.New("failed to connect")).Maybe() + ht.ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, nil, errors.New("failed to connect")).Maybe() ht.Start(t) tests.AssertLogEventually(t, ht.observer, "Loaded chain from DB") } @@ -339,14 +339,14 @@ func TestHeadTracker_CallsHeadTrackableCallbacks(t *testing.T) { chchHeaders := make(chan testutils.RawSub[*evmtypes.Head], 1) mockEth := &testutils.MockEth{EthClient: ethClient} - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + chHead := make(chan *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything). Return( - func(ctx context.Context, ch chan<- *evmtypes.Head) ethereum.Subscription { + func(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { sub := mockEth.NewSub(t) - chchHeaders <- testutils.NewRawSub(ch, sub.Err()) - return sub + chchHeaders <- testutils.NewRawSub(chHead, sub.Err()) + return chHead, sub, nil }, - func(ctx context.Context, ch chan<- *evmtypes.Head) error { return nil }, ) ethClient.On("HeadByNumber", mock.Anything, mock.Anything).Return(testutils.Head(0), nil) ethClient.On("HeadByHash", mock.Anything, mock.Anything).Return(testutils.Head(0), nil).Maybe() @@ -375,16 +375,19 @@ func TestHeadTracker_ReconnectOnError(t *testing.T) { ethClient := testutils.NewEthClientMockWithDefaultChain(t) mockEth := &testutils.MockEth{EthClient: ethClient} - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + chHead := make(chan *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything). Return( - func(ctx context.Context, ch chan<- *evmtypes.Head) ethereum.Subscription { return mockEth.NewSub(t) }, - func(ctx context.Context, ch chan<- *evmtypes.Head) error { return nil }, + func(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { + return chHead, mockEth.NewSub(t), nil + }, ) - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, errors.New("cannot reconnect")) - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + ethClient.On("SubscribeNewHead", mock.Anything).Return((<-chan *evmtypes.Head)(chHead), nil, errors.New("cannot reconnect")) + ethClient.On("SubscribeNewHead", mock.Anything). Return( - func(ctx context.Context, ch chan<- *evmtypes.Head) ethereum.Subscription { return mockEth.NewSub(t) }, - func(ctx context.Context, ch chan<- *evmtypes.Head) error { return nil }, + func(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { + return chHead, mockEth.NewSub(t), nil + }, ) ethClient.On("HeadByNumber", mock.Anything, mock.Anything).Return(testutils.Head(0), nil) checker := &mocks.MockHeadTrackable{} @@ -409,16 +412,16 @@ func TestHeadTracker_ResubscribeOnSubscriptionError(t *testing.T) { ethClient := testutils.NewEthClientMockWithDefaultChain(t) + ch := make(chan *evmtypes.Head) chchHeaders := make(chan testutils.RawSub[*evmtypes.Head], 1) mockEth := &testutils.MockEth{EthClient: ethClient} - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + ethClient.On("SubscribeNewHead", mock.Anything). Return( - func(ctx context.Context, ch chan<- *evmtypes.Head) ethereum.Subscription { + func(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { sub := mockEth.NewSub(t) chchHeaders <- testutils.NewRawSub(ch, sub.Err()) - return sub + return ch, sub, nil }, - func(ctx context.Context, ch chan<- *evmtypes.Head) error { return nil }, ) ethClient.On("HeadByNumber", mock.Anything, mock.Anything).Return(testutils.Head(0), nil) ethClient.On("HeadByHash", mock.Anything, mock.Anything).Return(testutils.Head(0), nil).Maybe() @@ -474,14 +477,14 @@ func TestHeadTracker_Start_LoadsLatestChain(t *testing.T) { chchHeaders := make(chan testutils.RawSub[*evmtypes.Head], 1) mockEth := &testutils.MockEth{EthClient: ethClient} - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + ch := make(chan *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything). Return( - func(ctx context.Context, ch chan<- *evmtypes.Head) ethereum.Subscription { + func(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { sub := mockEth.NewSub(t) chchHeaders <- testutils.NewRawSub(ch, sub.Err()) - return sub + return ch, sub, nil }, - func(ctx context.Context, ch chan<- *evmtypes.Head) error { return nil }, ) orm := headtracker.NewORM(*testutils.FixtureChainID, db) @@ -531,14 +534,14 @@ func TestHeadTracker_SwitchesToLongestChainWithHeadSamplingEnabled(t *testing.T) chchHeaders := make(chan testutils.RawSub[*evmtypes.Head], 1) mockEth := &testutils.MockEth{EthClient: ethClient} - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + chHead := make(chan *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything). Return( - func(ctx context.Context, ch chan<- *evmtypes.Head) ethereum.Subscription { + func(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { sub := mockEth.NewSub(t) - chchHeaders <- testutils.NewRawSub(ch, sub.Err()) - return sub + chchHeaders <- testutils.NewRawSub(chHead, sub.Err()) + return chHead, sub, nil }, - func(ctx context.Context, ch chan<- *evmtypes.Head) error { return nil }, ) // --------------------- @@ -659,14 +662,14 @@ func TestHeadTracker_SwitchesToLongestChainWithHeadSamplingDisabled(t *testing.T chchHeaders := make(chan testutils.RawSub[*evmtypes.Head], 1) mockEth := &testutils.MockEth{EthClient: ethClient} - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + chHead := make(chan *evmtypes.Head) + ethClient.On("SubscribeNewHead", mock.Anything). Return( - func(ctx context.Context, ch chan<- *evmtypes.Head) ethereum.Subscription { + func(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { sub := mockEth.NewSub(t) - chchHeaders <- testutils.NewRawSub(ch, sub.Err()) - return sub + chchHeaders <- testutils.NewRawSub(chHead, sub.Err()) + return chHead, sub, nil }, - func(ctx context.Context, ch chan<- *evmtypes.Head) error { return nil }, ) // --------------------- diff --git a/core/chains/legacyevm/chain.go b/core/chains/legacyevm/chain.go index b38cd2c4508..07822cd3461 100644 --- a/core/chains/legacyevm/chain.go +++ b/core/chains/legacyevm/chain.go @@ -421,7 +421,6 @@ func (c *chain) listNodeStatuses(start, end int) ([]types.NodeStatus, int, error for _, n := range nodes[start:end] { var ( nodeState string - exists bool ) toml, err := gotoml.Marshal(n) if err != nil { @@ -430,10 +429,11 @@ func (c *chain) listNodeStatuses(start, end int) ([]types.NodeStatus, int, error if states == nil { nodeState = "Unknown" } else { - nodeState, exists = states[*n.Name] - if !exists { - // The node is in the DB and the chain is enabled but it's not running - nodeState = "NotLoaded" + // The node is in the DB and the chain is enabled but it's not running + nodeState = "NotLoaded" + s, exists := states[*n.Name] + if exists { + nodeState = s.String() } } stats = append(stats, types.NodeStatus{ diff --git a/core/internal/cltest/cltest.go b/core/internal/cltest/cltest.go index 508dde86a05..345c6b9b3db 100644 --- a/core/internal/cltest/cltest.go +++ b/core/internal/cltest/cltest.go @@ -466,8 +466,9 @@ func NewEthMocks(t testing.TB) *evmclimocks.Client { func NewEthMocksWithStartupAssertions(t testing.TB) *evmclimocks.Client { testutils.SkipShort(t, "long test") c := NewEthMocks(t) + chHead := make(<-chan *evmtypes.Head) c.On("Dial", mock.Anything).Maybe().Return(nil) - c.On("SubscribeNewHead", mock.Anything, mock.Anything).Maybe().Return(EmptyMockSubscription(t), nil) + c.On("SubscribeNewHead", mock.Anything).Maybe().Return(chHead, EmptyMockSubscription(t), nil) c.On("SendTransaction", mock.Anything, mock.Anything).Maybe().Return(nil) c.On("HeadByNumber", mock.Anything, mock.Anything).Maybe().Return(Head(0), nil) c.On("ConfiguredChainID").Maybe().Return(&FixtureChainID) @@ -488,8 +489,9 @@ func NewEthMocksWithStartupAssertions(t testing.TB) *evmclimocks.Client { func NewEthMocksWithTransactionsOnBlocksAssertions(t testing.TB) *evmclimocks.Client { testutils.SkipShort(t, "long test") c := NewEthMocks(t) + chHead := make(<-chan *evmtypes.Head) c.On("Dial", mock.Anything).Maybe().Return(nil) - c.On("SubscribeNewHead", mock.Anything, mock.Anything).Maybe().Return(EmptyMockSubscription(t), nil) + c.On("SubscribeNewHead", mock.Anything).Maybe().Return(chHead, EmptyMockSubscription(t), nil) c.On("SendTransaction", mock.Anything, mock.Anything).Maybe().Return(nil) c.On("SendTransactionReturnCode", mock.Anything, mock.Anything, mock.Anything).Maybe().Return(client.Successful, nil) // Construct chain @@ -1249,7 +1251,7 @@ func MockApplicationEthCalls(t *testing.T, app *TestApplication, ethClient *evmc // Start ethClient.On("Dial", mock.Anything).Return(nil) - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(sub, nil).Maybe() + ethClient.On("SubscribeNewHead", mock.Anything).Return(make(<-chan *evmtypes.Head), sub, nil).Maybe() ethClient.On("ConfiguredChainID", mock.Anything).Return(evmtest.MustGetDefaultChainID(t, app.GetConfig().EVMConfigs()), nil) ethClient.On("PendingNonceAt", mock.Anything, mock.Anything).Return(uint64(0), nil).Maybe() ethClient.On("HeadByNumber", mock.Anything, mock.Anything).Return(nil, nil).Maybe() diff --git a/core/internal/features/features_test.go b/core/internal/features/features_test.go index 046f21b7f7d..60cb98748ac 100644 --- a/core/internal/features/features_test.go +++ b/core/internal/features/features_test.go @@ -1294,14 +1294,14 @@ func TestIntegration_BlockHistoryEstimator(t *testing.T) { h42 := evmtypes.Head{Hash: b42.Hash, ParentHash: h41.Hash, Number: 42, EVMChainID: evmChainID} mockEth := &evmtestutils.MockEth{EthClient: ethClient} - ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything). + ethClient.On("SubscribeNewHead", mock.Anything). Return( - func(ctx context.Context, ch chan<- *evmtypes.Head) ethereum.Subscription { + func(ctx context.Context) (<-chan *evmtypes.Head, ethereum.Subscription, error) { + ch := make(chan *evmtypes.Head) sub := mockEth.NewSub(t) chchNewHeads <- evmtestutils.NewRawSub(ch, sub.Err()) - return sub + return ch, sub, nil }, - func(ctx context.Context, ch chan<- *evmtypes.Head) error { return nil }, ) // Nonce syncer ethClient.On("PendingNonceAt", mock.Anything, mock.Anything).Maybe().Return(uint64(0), nil)