diff --git a/common/client/poller.go b/common/client/poller.go index d6080722c5c..eeb6c3af576 100644 --- a/common/client/poller.go +++ b/common/client/poller.go @@ -2,7 +2,6 @@ package client import ( "context" - "sync" "time" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -15,83 +14,80 @@ import ( // and delivers the result to a channel. It is used by multinode to poll // for new heads and implements the Subscription interface. type Poller[T any] struct { - services.StateMachine + services.Service + eng *services.Engine + pollingInterval time.Duration pollingFunc func(ctx context.Context) (T, error) pollingTimeout time.Duration - logger logger.Logger channel chan<- T errCh chan error - - stopCh services.StopChan - wg sync.WaitGroup } // NewPoller creates a new Poller instance and returns a channel to receive the polled data func NewPoller[ T any, -](pollingInterval time.Duration, pollingFunc func(ctx context.Context) (T, error), pollingTimeout time.Duration, logger logger.Logger) (Poller[T], <-chan T) { +](pollingInterval time.Duration, pollingFunc func(ctx context.Context) (T, error), pollingTimeout time.Duration, lggr logger.Logger) (Poller[T], <-chan T) { channel := make(chan T) - return Poller[T]{ + p := Poller[T]{ pollingInterval: pollingInterval, pollingFunc: pollingFunc, pollingTimeout: pollingTimeout, channel: channel, - logger: logger, errCh: make(chan error), - stopCh: make(chan struct{}), - }, channel + } + p.Service, p.eng = services.Config{ + Name: "Poller", + Start: p.start, + Close: p.close, + }.NewServiceEngine(lggr) + return p, channel } var _ types.Subscription = &Poller[any]{} -func (p *Poller[T]) Start() error { - return p.StartOnce("Poller", func() error { - p.wg.Add(1) - go p.pollingLoop() - return nil - }) +func (p *Poller[T]) start(ctx context.Context) error { + p.eng.Go(p.pollingLoop) + return nil } // Unsubscribe cancels the sending of events to the data channel func (p *Poller[T]) Unsubscribe() { - _ = p.StopOnce("Poller", func() error { - close(p.stopCh) - p.wg.Wait() - close(p.errCh) - close(p.channel) - return nil - }) + _ = p.Close() +} + +func (p *Poller[T]) close() error { + close(p.errCh) + close(p.channel) + return nil } func (p *Poller[T]) Err() <-chan error { return p.errCh } -func (p *Poller[T]) pollingLoop() { - defer p.wg.Done() - +func (p *Poller[T]) pollingLoop(ctx context.Context) { ticker := time.NewTicker(p.pollingInterval) defer ticker.Stop() for { select { - case <-p.stopCh: + case <-ctx.Done(): return case <-ticker.C: // Set polling timeout - pollingCtx, cancelPolling := p.stopCh.CtxCancel(context.WithTimeout(context.Background(), p.pollingTimeout)) + pollingCtx, cancelPolling := context.WithTimeout(ctx, p.pollingTimeout) // Execute polling function result, err := p.pollingFunc(pollingCtx) cancelPolling() if err != nil { - p.logger.Warnf("polling error: %v", err) + p.eng.Warnf("polling error: %v", err) continue } // Send result to channel or block if channel is full select { case p.channel <- result: - case <-p.stopCh: + case <-ctx.Done(): return } } diff --git a/common/client/poller_test.go b/common/client/poller_test.go index 91af579307b..930b101751c 100644 --- a/common/client/poller_test.go +++ b/common/client/poller_test.go @@ -20,20 +20,22 @@ func Test_Poller(t *testing.T) { lggr := logger.Test(t) t.Run("Test multiple start", func(t *testing.T) { + ctx := tests.Context(t) pollFunc := func(ctx context.Context) (Head, error) { return nil, nil } poller, _ := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr) - err := poller.Start() + err := poller.Start(ctx) require.NoError(t, err) - err = poller.Start() + err = poller.Start(ctx) require.Error(t, err) poller.Unsubscribe() }) t.Run("Test polling for heads", func(t *testing.T) { + ctx := tests.Context(t) // Mock polling function that returns a new value every time it's called var pollNumber int pollLock := sync.Mutex{} @@ -50,7 +52,7 @@ func Test_Poller(t *testing.T) { // Create poller and start to receive data poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr) - require.NoError(t, poller.Start()) + require.NoError(t, poller.Start(ctx)) defer poller.Unsubscribe() // Receive updates from the poller @@ -63,6 +65,7 @@ func Test_Poller(t *testing.T) { }) t.Run("Test polling errors", func(t *testing.T) { + ctx := tests.Context(t) // Mock polling function that returns an error var pollNumber int pollLock := sync.Mutex{} @@ -77,7 +80,7 @@ func Test_Poller(t *testing.T) { // Create poller and subscribe to receive data poller, _ := NewPoller[Head](time.Millisecond, pollFunc, time.Second, olggr) - require.NoError(t, poller.Start()) + require.NoError(t, poller.Start(ctx)) defer poller.Unsubscribe() // Ensure that all errors were logged as expected @@ -94,6 +97,7 @@ func Test_Poller(t *testing.T) { }) t.Run("Test polling timeout", func(t *testing.T) { + ctx := tests.Context(t) pollFunc := func(ctx context.Context) (Head, error) { if <-ctx.Done(); true { return nil, ctx.Err() @@ -108,7 +112,7 @@ func Test_Poller(t *testing.T) { // Create poller and subscribe to receive data poller, _ := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, olggr) - require.NoError(t, poller.Start()) + require.NoError(t, poller.Start(ctx)) defer poller.Unsubscribe() // Ensure that timeout errors were logged as expected @@ -119,6 +123,7 @@ func Test_Poller(t *testing.T) { }) t.Run("Test unsubscribe during polling", func(t *testing.T) { + ctx := tests.Context(t) wait := make(chan struct{}) closeOnce := sync.OnceFunc(func() { close(wait) }) pollFunc := func(ctx context.Context) (Head, error) { @@ -137,7 +142,7 @@ func Test_Poller(t *testing.T) { // Create poller and subscribe to receive data poller, _ := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, olggr) - require.NoError(t, poller.Start()) + require.NoError(t, poller.Start(ctx)) // Unsubscribe while blocked in polling function <-wait @@ -167,8 +172,9 @@ func Test_Poller_Unsubscribe(t *testing.T) { } t.Run("Test multiple unsubscribe", func(t *testing.T) { + ctx := tests.Context(t) poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr) - err := poller.Start() + err := poller.Start(ctx) require.NoError(t, err) <-channel @@ -177,8 +183,9 @@ func Test_Poller_Unsubscribe(t *testing.T) { }) t.Run("Read channel after unsubscribe", func(t *testing.T) { + ctx := tests.Context(t) poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr) - err := poller.Start() + err := poller.Start(ctx) require.NoError(t, err) poller.Unsubscribe() diff --git a/core/capabilities/remote/target/endtoend_test.go b/core/capabilities/remote/target/endtoend_test.go index 31bdc83e266..9abb6e99d5c 100644 --- a/core/capabilities/remote/target/endtoend_test.go +++ b/core/capabilities/remote/target/endtoend_test.go @@ -276,67 +276,47 @@ func testRemoteTarget(ctx context.Context, t *testing.T, underlying commoncap.Ta } type testAsyncMessageBroker struct { - services.StateMachine - t *testing.T + services.Service + eng *services.Engine + t *testing.T nodes map[p2ptypes.PeerID]remotetypes.Receiver sendCh chan *remotetypes.MessageBody - - stopCh services.StopChan - wg sync.WaitGroup -} - -func (a *testAsyncMessageBroker) HealthReport() map[string]error { - return nil -} - -func (a *testAsyncMessageBroker) Name() string { - return "testAsyncMessageBroker" } func newTestAsyncMessageBroker(t *testing.T, sendChBufferSize int) *testAsyncMessageBroker { - return &testAsyncMessageBroker{ + b := &testAsyncMessageBroker{ t: t, nodes: make(map[p2ptypes.PeerID]remotetypes.Receiver), - stopCh: make(services.StopChan), sendCh: make(chan *remotetypes.MessageBody, sendChBufferSize), } -} - -func (a *testAsyncMessageBroker) Start(ctx context.Context) error { - return a.StartOnce("testAsyncMessageBroker", func() error { - a.wg.Add(1) - go func() { - defer a.wg.Done() - - for { - select { - case <-a.stopCh: - return - case msg := <-a.sendCh: - receiverId := toPeerID(msg.Receiver) - - receiver, ok := a.nodes[receiverId] - if !ok { - panic("server not found for peer id") - } - - receiver.Receive(tests.Context(a.t), msg) + b.Service, b.eng = services.Config{ + Name: "testAsyncMessageBroker", + Start: b.start, + }.NewServiceEngine(logger.TestLogger(t)) + return b +} + +func (a *testAsyncMessageBroker) start(ctx context.Context) error { + a.eng.Go(func(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case msg := <-a.sendCh: + receiverId := toPeerID(msg.Receiver) + + receiver, ok := a.nodes[receiverId] + if !ok { + panic("server not found for peer id") } - } - }() - return nil - }) -} - -func (a *testAsyncMessageBroker) Close() error { - return a.StopOnce("testAsyncMessageBroker", func() error { - close(a.stopCh) - a.wg.Wait() - return nil + receiver.Receive(tests.Context(a.t), msg) + } + } }) + return nil } func (a *testAsyncMessageBroker) NewDispatcherForNode(nodePeerID p2ptypes.PeerID) remotetypes.Dispatcher { diff --git a/core/chains/evm/client/rpc_client.go b/core/chains/evm/client/rpc_client.go index 07aa86fc450..c9e172326e6 100644 --- a/core/chains/evm/client/rpc_client.go +++ b/core/chains/evm/client/rpc_client.go @@ -546,14 +546,14 @@ func (r *rpcClient) SubscribeToHeads(ctx context.Context) (ch <-chan *evmtypes.H return channel, forwarder, err } -func (r *rpcClient) SubscribeToFinalizedHeads(_ context.Context) (<-chan *evmtypes.Head, commontypes.Subscription, error) { +func (r *rpcClient) SubscribeToFinalizedHeads(ctx context.Context) (<-chan *evmtypes.Head, commontypes.Subscription, error) { interval := r.finalizedBlockPollInterval if interval == 0 { return nil, nil, errors.New("FinalizedBlockPollInterval is 0") } timeout := interval poller, channel := commonclient.NewPoller[*evmtypes.Head](interval, r.LatestFinalizedBlock, timeout, r.rpcLog) - if err := poller.Start(); err != nil { + if err := poller.Start(ctx); err != nil { return nil, nil, err } return channel, &poller, nil diff --git a/core/chains/evm/forwarders/forwarder_manager.go b/core/chains/evm/forwarders/forwarder_manager.go index 6436988ba82..0e9f8781c35 100644 --- a/core/chains/evm/forwarders/forwarder_manager.go +++ b/core/chains/evm/forwarders/forwarder_manager.go @@ -33,7 +33,9 @@ type Config interface { } type FwdMgr struct { - services.StateMachine + services.Service + eng *services.Engine + ORM ORM evmClient evmclient.Client cfg Config @@ -48,61 +50,51 @@ type FwdMgr struct { authRcvr authorized_receiver.AuthorizedReceiverInterface offchainAgg offchain_aggregator_wrapper.OffchainAggregatorInterface - stopCh services.StopChan - cacheMu sync.RWMutex - wg sync.WaitGroup } -func NewFwdMgr(ds sqlutil.DataSource, client evmclient.Client, logpoller evmlogpoller.LogPoller, l logger.Logger, cfg Config) *FwdMgr { - lggr := logger.Sugared(logger.Named(l, "EVMForwarderManager")) - fwdMgr := FwdMgr{ - logger: lggr, +func NewFwdMgr(ds sqlutil.DataSource, client evmclient.Client, logpoller evmlogpoller.LogPoller, lggr logger.Logger, cfg Config) *FwdMgr { + fm := FwdMgr{ cfg: cfg, evmClient: client, ORM: NewORM(ds), logpoller: logpoller, sendersCache: make(map[common.Address][]common.Address), } - fwdMgr.stopCh = make(chan struct{}) - return &fwdMgr -} - -func (f *FwdMgr) Name() string { - return f.logger.Name() + fm.Service, fm.eng = services.Config{ + Name: "ForwarderManager", + Start: fm.start, + }.NewServiceEngine(lggr) + fm.logger = logger.Sugared(fm.eng) + return &fm } -// Start starts Forwarder Manager. -func (f *FwdMgr) Start(ctx context.Context) error { - return f.StartOnce("EVMForwarderManager", func() error { - f.logger.Debug("Initializing EVM forwarder manager") - chainId := f.evmClient.ConfiguredChainID() +func (f *FwdMgr) start(ctx context.Context) error { + chainId := f.evmClient.ConfiguredChainID() - fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*chainId)) - if err != nil { - return pkgerrors.Wrapf(err, "Failed to retrieve forwarders for chain %d", chainId) - } - if len(fwdrs) != 0 { - f.initForwardersCache(ctx, fwdrs) - if err = f.subscribeForwardersLogs(ctx, fwdrs); err != nil { - return err - } + fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*chainId)) + if err != nil { + return pkgerrors.Wrapf(err, "Failed to retrieve forwarders for chain %d", chainId) + } + if len(fwdrs) != 0 { + f.initForwardersCache(ctx, fwdrs) + if err = f.subscribeForwardersLogs(ctx, fwdrs); err != nil { + return err } + } - f.authRcvr, err = authorized_receiver.NewAuthorizedReceiver(common.Address{}, f.evmClient) - if err != nil { - return pkgerrors.Wrap(err, "Failed to init AuthorizedReceiver") - } + f.authRcvr, err = authorized_receiver.NewAuthorizedReceiver(common.Address{}, f.evmClient) + if err != nil { + return pkgerrors.Wrap(err, "Failed to init AuthorizedReceiver") + } - f.offchainAgg, err = offchain_aggregator_wrapper.NewOffchainAggregator(common.Address{}, f.evmClient) - if err != nil { - return pkgerrors.Wrap(err, "Failed to init OffchainAggregator") - } + f.offchainAgg, err = offchain_aggregator_wrapper.NewOffchainAggregator(common.Address{}, f.evmClient) + if err != nil { + return pkgerrors.Wrap(err, "Failed to init OffchainAggregator") + } - f.wg.Add(1) - go f.runLoop() - return nil - }) + f.eng.Go(f.runLoop) + return nil } func FilterName(addr common.Address) string { @@ -176,7 +168,7 @@ func (f *FwdMgr) ConvertPayload(dest common.Address, origPayload []byte) ([]byte if err != nil { f.logger.AssumptionViolationw("Forwarder encoding failed, this should never happen", "err", err, "to", dest, "payload", origPayload) - f.SvcErrBuffer.Append(err) + f.eng.EmitHealthErr(err) } } return databytes, nil @@ -269,11 +261,7 @@ func (f *FwdMgr) getCachedSenders(addr common.Address) ([]common.Address, bool) return addrs, ok } -func (f *FwdMgr) runLoop() { - defer f.wg.Done() - ctx, cancel := f.stopCh.NewCtx() - defer cancel() - +func (f *FwdMgr) runLoop(ctx context.Context) { ticker := services.NewTicker(time.Minute) defer ticker.Stop() @@ -353,16 +341,3 @@ func (f *FwdMgr) collectAddresses() (addrs []common.Address) { } return } - -// Stop cancels all outgoings calls and stops internal ticker loop. -func (f *FwdMgr) Close() error { - return f.StopOnce("EVMForwarderManager", func() (err error) { - close(f.stopCh) - f.wg.Wait() - return nil - }) -} - -func (f *FwdMgr) HealthReport() map[string]error { - return map[string]error{f.Name(): f.Healthy()} -} diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_0_0/commit_store.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_0_0/commit_store.go index d0559b800e0..f6e957746e6 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_0_0/commit_store.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_0_0/commit_store.go @@ -2,6 +2,7 @@ package v1_0_0 import ( "context" + "errors" "fmt" "math/big" "sync" @@ -11,7 +12,7 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/pkg/errors" + "golang.org/x/exp/maps" "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -90,7 +91,7 @@ func encodeCommitReport(commitReportArgs abi.Arguments, report cciptypes.CommitS var usdPerUnitGas = big.NewInt(0) var destChainSelector = uint64(0) if len(report.GasPrices) > 1 { - return []byte{}, errors.Errorf("CommitStore V1_0_0 can only accept 1 gas price, received: %d", len(report.GasPrices)) + return []byte{}, fmt.Errorf("CommitStore V1_0_0 can only accept 1 gas price, received: %d", len(report.GasPrices)) } if len(report.GasPrices) > 0 { usdPerUnitGas = report.GasPrices[0].Value @@ -133,7 +134,7 @@ func DecodeCommitReport(commitReportArgs abi.Arguments, report []byte) (cciptype MerkleRoot [32]byte `json:"merkleRoot"` }) if !ok { - return cciptypes.CommitStoreReport{}, errors.Errorf("invalid commit report got %T", unpacked[0]) + return cciptypes.CommitStoreReport{}, fmt.Errorf("invalid commit report got %T", unpacked[0]) } var tokenPriceUpdates []cciptypes.TokenPrice @@ -382,7 +383,7 @@ func (c *CommitStore) GetLatestPriceEpochAndRound(ctx context.Context) (uint64, } func (c *CommitStore) IsDestChainHealthy(context.Context) (bool, error) { - if err := c.lp.Healthy(); err != nil { + if err := errors.Join(maps.Values(c.lp.HealthReport())...); err != nil { return false, nil } return true, nil diff --git a/plugins/medianpoc/plugin.go b/plugins/medianpoc/plugin.go index 41580afe499..e6ada1bec84 100644 --- a/plugins/medianpoc/plugin.go +++ b/plugins/medianpoc/plugin.go @@ -40,7 +40,7 @@ func (e *PipelineNotFoundError) Error() string { } func (p *Plugin) NewValidationService(ctx context.Context) (core.ValidationService, error) { - s := &reportingPluginValidationService{lggr: p.Logger} + s := &reportingPluginValidationService{Service: services.Config{Name: "ValidationService"}.NewService(p.Logger)} p.SubService(s) return s, nil } @@ -81,7 +81,10 @@ func (p *Plugin) NewReportingPluginFactory( if err != nil { return nil, err } - s := &reportingPluginFactoryService{lggr: p.Logger, ReportingPluginFactory: f} + s := &reportingPluginFactoryService{ + Service: services.Config{Name: "ReportingPluginFactory"}.NewService(p.Logger), + ReportingPluginFactory: f, + } p.SubService(s) return s, nil } @@ -157,28 +160,12 @@ func (p *Plugin) newFactory(ctx context.Context, config core.ReportingPluginServ } type reportingPluginFactoryService struct { - services.StateMachine - lggr logger.Logger + services.Service ocrtypes.ReportingPluginFactory } -func (r *reportingPluginFactoryService) Name() string { return r.lggr.Name() } - -func (r *reportingPluginFactoryService) Start(ctx context.Context) error { - return r.StartOnce("ReportingPluginFactory", func() error { return nil }) -} - -func (r *reportingPluginFactoryService) Close() error { - return r.StopOnce("ReportingPluginFactory", func() error { return nil }) -} - -func (r *reportingPluginFactoryService) HealthReport() map[string]error { - return map[string]error{r.Name(): r.Healthy()} -} - type reportingPluginValidationService struct { - services.StateMachine - lggr logger.Logger + services.Service } func (r *reportingPluginValidationService) ValidateConfig(ctx context.Context, config map[string]interface{}) error { @@ -196,16 +183,3 @@ func (r *reportingPluginValidationService) ValidateConfig(ctx context.Context, c return nil } -func (r *reportingPluginValidationService) Name() string { return r.lggr.Name() } - -func (r *reportingPluginValidationService) Start(ctx context.Context) error { - return r.StartOnce("ValidationService", func() error { return nil }) -} - -func (r *reportingPluginValidationService) Close() error { - return r.StopOnce("ValidationService", func() error { return nil }) -} - -func (r *reportingPluginValidationService) HealthReport() map[string]error { - return map[string]error{r.Name(): r.Healthy()} -}