diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index 7eaa6a3ae00..13c5b97e20a 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -25,6 +25,7 @@ import ( ocr2keepers20runner "github.com/smartcontractkit/chainlink-automation/pkg/v2/runner" ocr2keepers21config "github.com/smartcontractkit/chainlink-automation/pkg/v3/config" ocr2keepers21 "github.com/smartcontractkit/chainlink-automation/pkg/v3/plugin" + "github.com/smartcontractkit/chainlink/v2/core/config/env" "github.com/smartcontractkit/chainlink-vrf/altbn_128" @@ -36,6 +37,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/loop/reportingplugins" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipcommit" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipexec" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/rebalancer" @@ -316,13 +318,13 @@ func (d *Delegate) cleanupEVM(jb job.Job, q pg.Queryer, relayID relay.ID) error d.lggr.Errorw("failed to derive ocr2keeper filter names from spec", "err", err, "spec", spec) } case types.CCIPCommit: - err = ccipcommit.UnregisterCommitPluginLpFilters(context.Background(), d.lggr, jb, d.legacyChains, pg.WithQueryer(q)) + err = ccipcommit.UnregisterCommitPluginLpFilters(d.lggr, jb, d.legacyChains, pg.WithQueryer(q)) if err != nil { d.lggr.Errorw("failed to unregister ccip commit plugin filters", "err", err, "spec", spec) } return nil case types.CCIPExecution: - err = ccipexec.UnregisterExecPluginLpFilters(context.Background(), d.lggr, jb, d.legacyChains, pg.WithQueryer(q)) + err = ccipexec.UnregisterExecPluginLpFilters(d.lggr, jb, d.legacyChains, pg.WithQueryer(q)) if err != nil { d.lggr.Errorw("failed to unregister ccip exec plugin filters", "err", err, "spec", spec) } @@ -1604,7 +1606,7 @@ func (d *Delegate) newServicesCCIPExecution(ctx context.Context, lggr logger.Sug logError := func(msg string) { lggr.ErrorIf(d.jobORM.RecordError(jb.ID, msg), "unable to record error") } - return ccipexec.NewExecutionServices(ctx, lggr, jb, d.legacyChains, d.isNewlyCreatedJob, oracleArgsNoPlugin, logError, qopts...) + return ccipexec.NewExecutionServices(lggr, jb, d.legacyChains, d.isNewlyCreatedJob, oracleArgsNoPlugin, logError, qopts...) } func (d *Delegate) newServicesRebalancer(ctx context.Context, lggr logger.SugaredLogger, jb job.Job, bootstrapPeers []commontypes.BootstrapperLocator, kb ocr2key.KeyBundle, ocrDB *db, lc ocrtypes.LocalConfig, qopts ...pg.QOpt) ([]job.ServiceCtx, error) { diff --git a/core/services/ocr2/plugins/ccip/ccipcommit/initializers.go b/core/services/ocr2/plugins/ccip/ccipcommit/initializers.go index 17a76a4cad1..c1166a02a1d 100644 --- a/core/services/ocr2/plugins/ccip/ccipcommit/initializers.go +++ b/core/services/ocr2/plugins/ccip/ccipcommit/initializers.go @@ -87,7 +87,7 @@ func CommitReportToEthTxMeta(typ ccipconfig.ContractType, ver semver.Version) (f // https://github.com/smartcontractkit/ccip/blob/68e2197472fb017dd4e5630d21e7878d58bc2a44/core/services/feeds/service.go#L716 // TODO once that transaction is broken up, we should be able to simply rely on oracle.Close() to cleanup the filters. // Until then we have to deterministically reload the readers from the spec (and thus their filters) and close them. -func UnregisterCommitPluginLpFilters(ctx context.Context, lggr logger.Logger, jb job.Job, chainSet legacyevm.LegacyChainContainer, qopts ...pg.QOpt) error { +func UnregisterCommitPluginLpFilters(lggr logger.Logger, jb job.Job, chainSet legacyevm.LegacyChainContainer, qopts ...pg.QOpt) error { params, err := extractJobSpecParams(jb, chainSet) if err != nil { return err diff --git a/core/services/ocr2/plugins/ccip/ccipcommit/initializers_test.go b/core/services/ocr2/plugins/ccip/ccipcommit/initializers_test.go index b9c6704d4c1..d42e1987f92 100644 --- a/core/services/ocr2/plugins/ccip/ccipcommit/initializers_test.go +++ b/core/services/ocr2/plugins/ccip/ccipcommit/initializers_test.go @@ -1,7 +1,6 @@ package ccipcommit import ( - "context" "fmt" "strconv" "testing" @@ -66,7 +65,7 @@ func TestGetCommitPluginFilterNamesFromSpec(t *testing.T) { } } - err := UnregisterCommitPluginLpFilters(context.Background(), lggr, job.Job{OCR2OracleSpec: tc.spec}, chainSet) + err := UnregisterCommitPluginLpFilters(lggr, job.Job{OCR2OracleSpec: tc.spec}, chainSet) if tc.expectingErr { assert.Error(t, err) } else { diff --git a/core/services/ocr2/plugins/ccip/ccipexec/initializers.go b/core/services/ocr2/plugins/ccip/ccipexec/initializers.go index acd0f1b10aa..45a21dfcb5b 100644 --- a/core/services/ocr2/plugins/ccip/ccipexec/initializers.go +++ b/core/services/ocr2/plugins/ccip/ccipexec/initializers.go @@ -46,8 +46,8 @@ import ( const numTokenDataWorkers = 5 -func NewExecutionServices(ctx context.Context, lggr logger.Logger, jb job.Job, chainSet legacyevm.LegacyChainContainer, new bool, argsNoPlugin libocr2.OCR2OracleArgs, logError func(string), qopts ...pg.QOpt) ([]job.ServiceCtx, error) { - execPluginConfig, backfillArgs, chainHealthcheck, err := jobSpecToExecPluginConfig(ctx, lggr, jb, chainSet, qopts...) +func NewExecutionServices(lggr logger.Logger, jb job.Job, chainSet legacyevm.LegacyChainContainer, new bool, argsNoPlugin libocr2.OCR2OracleArgs, logError func(string), qopts ...pg.QOpt) ([]job.ServiceCtx, error) { + execPluginConfig, backfillArgs, chainHealthcheck, tokenWorker, err := jobSpecToExecPluginConfig(lggr, jb, chainSet, qopts...) if err != nil { return nil, err } @@ -74,18 +74,20 @@ func NewExecutionServices(ctx context.Context, lggr logger.Logger, jb job.Job, c job.NewServiceAdapter(oracle), ), chainHealthcheck, + tokenWorker, }, nil } return []job.ServiceCtx{ job.NewServiceAdapter(oracle), chainHealthcheck, + tokenWorker, }, nil } // UnregisterExecPluginLpFilters unregisters all the registered filters for both source and dest chains. // See comment in UnregisterCommitPluginLpFilters // It MUST mirror the filters registered in NewExecutionServices. -func UnregisterExecPluginLpFilters(ctx context.Context, lggr logger.Logger, jb job.Job, chainSet legacyevm.LegacyChainContainer, qopts ...pg.QOpt) error { +func UnregisterExecPluginLpFilters(lggr logger.Logger, jb job.Job, chainSet legacyevm.LegacyChainContainer, qopts ...pg.QOpt) error { params, err := extractJobSpecParams(lggr, jb, chainSet, false, qopts...) if err != nil { return err @@ -158,10 +160,10 @@ func initTokenDataProviders(lggr logger.Logger, jobID string, pluginConfig ccipc return tokenDataProviders, nil } -func jobSpecToExecPluginConfig(ctx context.Context, lggr logger.Logger, jb job.Job, chainSet legacyevm.LegacyChainContainer, qopts ...pg.QOpt) (*ExecutionPluginStaticConfig, *ccipcommon.BackfillArgs, *cache.ObservedChainHealthcheck, error) { +func jobSpecToExecPluginConfig(lggr logger.Logger, jb job.Job, chainSet legacyevm.LegacyChainContainer, qopts ...pg.QOpt) (*ExecutionPluginStaticConfig, *ccipcommon.BackfillArgs, *cache.ObservedChainHealthcheck, *tokendata.BackgroundWorker, error) { params, err := extractJobSpecParams(lggr, jb, chainSet, true, qopts...) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } lggr.Infow("Initializing exec plugin", @@ -177,39 +179,39 @@ func jobSpecToExecPluginConfig(ctx context.Context, lggr logger.Logger, jb job.J sourceChainName, destChainName, err := ccipconfig.ResolveChainNames(sourceChainID, destChainID) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } execLggr := lggr.Named("CCIPExecution").With("sourceChain", sourceChainName, "destChain", destChainName) onRampReader, err := factory.NewOnRampReader(execLggr, versionFinder, params.offRampConfig.SourceChainSelector, params.offRampConfig.ChainSelector, params.offRampConfig.OnRamp, params.sourceChain.LogPoller(), params.sourceChain.Client(), qopts...) if err != nil { - return nil, nil, nil, errors.Wrap(err, "create onramp reader") + return nil, nil, nil, nil, errors.Wrap(err, "create onramp reader") } dynamicOnRampConfig, err := onRampReader.GetDynamicConfig() if err != nil { - return nil, nil, nil, errors.Wrap(err, "get onramp dynamic config") + return nil, nil, nil, nil, errors.Wrap(err, "get onramp dynamic config") } routerAddr, err := ccipcalc.GenericAddrToEvm(dynamicOnRampConfig.Router) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } sourceRouter, err := router.NewRouter(routerAddr, params.sourceChain.Client()) if err != nil { - return nil, nil, nil, errors.Wrap(err, "failed loading source router") + return nil, nil, nil, nil, errors.Wrap(err, "failed loading source router") } sourceWrappedNative, err := sourceRouter.GetWrappedNative(&bind.CallOpts{}) if err != nil { - return nil, nil, nil, errors.Wrap(err, "could not get source native token") + return nil, nil, nil, nil, errors.Wrap(err, "could not get source native token") } commitStoreReader, err := factory.NewCommitStoreReader(lggr, versionFinder, params.offRampConfig.CommitStore, params.destChain.Client(), params.destChain.LogPoller(), params.sourceChain.GasEstimator(), params.sourceChain.Config().EVM().GasEstimator().PriceMax().ToInt(), qopts...) if err != nil { - return nil, nil, nil, errors.Wrap(err, "could not load commitStoreReader reader") + return nil, nil, nil, nil, errors.Wrap(err, "could not load commitStoreReader reader") } tokenDataProviders, err := initTokenDataProviders(lggr, jobIDToString(jb.ID), params.pluginConfig, params.sourceChain.LogPoller(), qopts...) if err != nil { - return nil, nil, nil, errors.Wrap(err, "could not get token data providers") + return nil, nil, nil, nil, errors.Wrap(err, "could not get token data providers") } // Prom wrappers @@ -220,11 +222,11 @@ func jobSpecToExecPluginConfig(ctx context.Context, lggr logger.Logger, jb job.J destChainSelector, err := chainselectors.SelectorFromChainId(uint64(destChainID)) if err != nil { - return nil, nil, nil, fmt.Errorf("get chain %d selector: %w", destChainID, err) + return nil, nil, nil, nil, fmt.Errorf("get chain %d selector: %w", destChainID, err) } sourceChainSelector, err := chainselectors.SelectorFromChainId(uint64(sourceChainID)) if err != nil { - return nil, nil, nil, fmt.Errorf("get chain %d selector: %w", sourceChainID, err) + return nil, nil, nil, nil, fmt.Errorf("get chain %d selector: %w", sourceChainID, err) } execLggr.Infow("Initialized exec plugin", @@ -238,7 +240,7 @@ func jobSpecToExecPluginConfig(ctx context.Context, lggr logger.Logger, jb job.J tokenPoolBatchedReader, err := batchreader.NewEVMTokenPoolBatchedReader(execLggr, sourceChainSelector, offRampReader.Address(), batchCaller) if err != nil { - return nil, nil, nil, fmt.Errorf("new token pool batched reader: %w", err) + return nil, nil, nil, nil, fmt.Errorf("new token pool batched reader: %w", err) } chainHealthcheck := cache.NewObservedChainHealthCheck( @@ -259,6 +261,12 @@ func jobSpecToExecPluginConfig(ctx context.Context, lggr logger.Logger, jb job.J params.offRampConfig.OnRamp, ) + tokenBackgroundWorker := tokendata.NewBackgroundWorker( + tokenDataProviders, + numTokenDataWorkers, + 5*time.Second, + offRampReader.OnchainConfig().PermissionLessExecutionThresholdSeconds, + ) return &ExecutionPluginStaticConfig{ lggr: execLggr, onRampReader: onRampReader, @@ -269,15 +277,9 @@ func jobSpecToExecPluginConfig(ctx context.Context, lggr logger.Logger, jb job.J destChainSelector: destChainSelector, priceRegistryProvider: ccipdataprovider.NewEvmPriceRegistry(params.destChain.LogPoller(), params.destChain.Client(), execLggr, ccip.ExecPluginLabel), tokenPoolBatchedReader: tokenPoolBatchedReader, - tokenDataWorker: tokendata.NewBackgroundWorker( - ctx, - tokenDataProviders, - numTokenDataWorkers, - 5*time.Second, - offRampReader.OnchainConfig().PermissionLessExecutionThresholdSeconds, - ), - metricsCollector: metricsCollector, - chainHealthcheck: chainHealthcheck, + tokenDataWorker: tokenBackgroundWorker, + metricsCollector: metricsCollector, + chainHealthcheck: chainHealthcheck, }, &ccipcommon.BackfillArgs{ SourceLP: params.sourceChain.LogPoller(), DestLP: params.destChain.LogPoller(), @@ -285,6 +287,7 @@ func jobSpecToExecPluginConfig(ctx context.Context, lggr logger.Logger, jb job.J DestStartBlock: params.pluginConfig.DestStartBlock, }, chainHealthcheck, + tokenBackgroundWorker, nil } diff --git a/core/services/ocr2/plugins/ccip/ccipexec/initializers_test.go b/core/services/ocr2/plugins/ccip/ccipexec/initializers_test.go index 07d53fc983b..5729d95d66b 100644 --- a/core/services/ocr2/plugins/ccip/ccipexec/initializers_test.go +++ b/core/services/ocr2/plugins/ccip/ccipexec/initializers_test.go @@ -1,7 +1,6 @@ package ccipexec import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -48,7 +47,7 @@ func TestGetExecutionPluginFilterNamesFromSpec(t *testing.T) { for _, tc := range testCases { chainSet := &legacyEvmORMMocks.LegacyChainContainer{} t.Run(tc.description, func(t *testing.T) { - err := UnregisterExecPluginLpFilters(context.Background(), logger.TestLogger(t), job.Job{OCR2OracleSpec: tc.spec}, chainSet) + err := UnregisterExecPluginLpFilters(logger.TestLogger(t), job.Job{OCR2OracleSpec: tc.spec}, chainSet) if tc.expectingErr { assert.Error(t, err) } else { diff --git a/core/services/ocr2/plugins/ccip/ccipexec/ocr2_test.go b/core/services/ocr2/plugins/ccip/ccipexec/ocr2_test.go index a13a1fec452..578d73f9bcf 100644 --- a/core/services/ocr2/plugins/ccip/ccipexec/ocr2_test.go +++ b/core/services/ocr2/plugins/ccip/ccipexec/ocr2_test.go @@ -139,7 +139,7 @@ func TestExecutionReportingPlugin_Observation(t *testing.T) { p.inflightReports.reports = tc.inflightReports p.lggr = logger.TestLogger(t) p.tokenDataWorker = tokendata.NewBackgroundWorker( - ctx, make(map[cciptypes.Address]tokendata.Reader), 10, 5*time.Second, time.Hour) + make(map[cciptypes.Address]tokendata.Reader), 10, 5*time.Second, time.Hour) p.metricsCollector = ccip.NoopMetricsCollector commitStoreReader := ccipdatamocks.NewCommitStoreReader(t) @@ -668,8 +668,6 @@ func TestExecutionReportingPlugin_buildBatch(t *testing.T) { }, } - ctx := testutils.Context(t) - for _, tc := range tt { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -685,7 +683,7 @@ func TestExecutionReportingPlugin_buildBatch(t *testing.T) { mockOffRampReader.On("GetSenderNonce", mock.Anything, sender1).Return(uint64(0), nil).Maybe() plugin := ExecutionReportingPlugin{ - tokenDataWorker: tokendata.NewBackgroundWorker(ctx, map[cciptypes.Address]tokendata.Reader{}, 10, 5*time.Second, time.Hour), + tokenDataWorker: tokendata.NewBackgroundWorker(map[cciptypes.Address]tokendata.Reader{}, 10, 5*time.Second, time.Hour), offRampReader: mockOffRampReader, destWrappedNative: destNative, offchainConfig: cciptypes.ExecOffchainConfig{ diff --git a/core/services/ocr2/plugins/ccip/internal/cache/chain_health.go b/core/services/ocr2/plugins/ccip/internal/cache/chain_health.go index de2d61db537..2530d479464 100644 --- a/core/services/ocr2/plugins/ccip/internal/cache/chain_health.go +++ b/core/services/ocr2/plugins/ccip/internal/cache/chain_health.go @@ -90,7 +90,7 @@ func NewChainHealthcheck(lggr logger.Logger, onRamp ccipdata.OnRampReader, commi func newChainHealthcheckWithCustomEviction(lggr logger.Logger, onRamp ccipdata.OnRampReader, commitStore ccipdata.CommitStoreReader, globalStatusDuration time.Duration, rmnStatusRefreshInterval time.Duration) *chainHealthcheck { ctx, cancel := context.WithCancel(context.Background()) - ch := &chainHealthcheck{ + return &chainHealthcheck{ cache: cache.New(rmnStatusRefreshInterval, 0), rmnStatusKey: rmnStatusKey, globalStatusKey: globalStatusKey, @@ -105,7 +105,6 @@ func newChainHealthcheckWithCustomEviction(lggr logger.Logger, onRamp ccipdata.O backgroundCtx: ctx, backgroundCancel: cancel, } - return ch } type rmnResponse struct { @@ -164,9 +163,9 @@ func (c *chainHealthcheck) Close() error { } func (c *chainHealthcheck) run() { - defer c.wg.Done() ticker := time.NewTicker(c.rmnStatusRefreshInterval) go func() { + defer c.wg.Done() // Refresh the RMN state immediately after starting the background refresher _, _ = c.refresh(c.backgroundCtx) diff --git a/core/services/ocr2/plugins/ccip/tokendata/bgworker.go b/core/services/ocr2/plugins/ccip/tokendata/bgworker.go index 7dc09debbde..6a7cacfc323 100644 --- a/core/services/ocr2/plugins/ccip/tokendata/bgworker.go +++ b/core/services/ocr2/plugins/ccip/tokendata/bgworker.go @@ -3,9 +3,14 @@ package tokendata import ( "context" "fmt" + "strconv" "sync" "time" + "github.com/patrickmn/go-cache" + "github.com/smartcontractkit/chainlink-common/pkg/services" + + "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/cciptypes" ) @@ -16,6 +21,7 @@ type msgResult struct { } type Worker interface { + job.ServiceCtx // AddJobsFromMsgs will include the provided msgs for background processing. AddJobsFromMsgs(ctx context.Context, msgs []cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta) @@ -30,14 +36,65 @@ type BackgroundWorker struct { tokenDataReaders map[cciptypes.Address]Reader numWorkers int jobsChan chan cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta - resultsCache *resultsCache + resultsCache *cache.Cache timeoutDur time.Duration + + services.StateMachine + wg *sync.WaitGroup + backgroundCtx context.Context + backgroundCancel context.CancelFunc +} + +func NewBackgroundWorker( + tokenDataReaders map[cciptypes.Address]Reader, + numWorkers int, + timeoutDur time.Duration, + expirationDur time.Duration, +) *BackgroundWorker { + if expirationDur == 0 { + expirationDur = 24 * time.Hour + } + + ctx, cancel := context.WithCancel(context.Background()) + return &BackgroundWorker{ + tokenDataReaders: tokenDataReaders, + numWorkers: numWorkers, + jobsChan: make(chan cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta, numWorkers*100), + resultsCache: cache.New(expirationDur, expirationDur/2), + timeoutDur: timeoutDur, + + wg: new(sync.WaitGroup), + backgroundCtx: ctx, + backgroundCancel: cancel, + } +} + +func (w *BackgroundWorker) Start(context.Context) error { + return w.StateMachine.StartOnce("Token BackgroundWorker", func() error { + for i := 0; i < w.numWorkers; i++ { + w.wg.Add(1) + w.run() + } + return nil + }) +} + +func (w *BackgroundWorker) Close() error { + return w.StateMachine.StopOnce("Token BackgroundWorker", func() error { + w.backgroundCancel() + w.wg.Wait() + return nil + }) } func (w *BackgroundWorker) AddJobsFromMsgs(ctx context.Context, msgs []cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta) { + w.wg.Add(1) go func() { + defer w.wg.Done() for _, msg := range msgs { select { + case <-w.backgroundCtx.Done(): + return case <-ctx.Done(): return default: @@ -73,49 +130,25 @@ func (w *BackgroundWorker) GetMsgTokenData(ctx context.Context, msg cciptypes.EV return tokenDatas, nil } -func NewBackgroundWorker( - ctx context.Context, - tokenDataReaders map[cciptypes.Address]Reader, - numWorkers int, - timeoutDur time.Duration, - expirationDur time.Duration, -) *BackgroundWorker { - if expirationDur == 0 { - expirationDur = 24 * time.Hour - } - - w := &BackgroundWorker{ - tokenDataReaders: tokenDataReaders, - numWorkers: numWorkers, - jobsChan: make(chan cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta, numWorkers*100), - resultsCache: newResultsCache(ctx, expirationDur, expirationDur/2), - timeoutDur: timeoutDur, - } - - w.spawnWorkers(ctx) - return w -} - -func (w *BackgroundWorker) spawnWorkers(ctx context.Context) { - for i := 0; i < w.numWorkers; i++ { - go func() { - for { - select { - case <-ctx.Done(): - return - case msg := <-w.jobsChan: - w.workOnMsg(ctx, msg) - } +func (w *BackgroundWorker) run() { + go func() { + defer w.wg.Done() + for { + select { + case <-w.backgroundCtx.Done(): + return + case msg := <-w.jobsChan: + w.workOnMsg(w.backgroundCtx, msg) } - }() - } + } + }() } func (w *BackgroundWorker) workOnMsg(ctx context.Context, msg cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta) { results := make([]msgResult, 0, len(msg.TokenAmounts)) cachedTokenData := make(map[int]msgResult) // tokenAmount index -> token data - if cachedData, exists := w.resultsCache.get(msg.SequenceNumber); exists { + if cachedData, exists := w.getFromCache(msg.SequenceNumber); exists { for _, r := range cachedData { cachedTokenData[r.TokenAmountIndex] = r } @@ -145,11 +178,11 @@ func (w *BackgroundWorker) workOnMsg(ctx context.Context, msg cciptypes.EVM2EVMO }) } - w.resultsCache.add(msg.SequenceNumber, results) + w.resultsCache.Set(strconv.FormatUint(msg.SequenceNumber, 10), results, cache.DefaultExpiration) } func (w *BackgroundWorker) getMsgTokenData(ctx context.Context, seqNum uint64) ([]msgResult, error) { - if msgTokenData, exists := w.resultsCache.get(seqNum); exists { + if msgTokenData, exists := w.getFromCache(seqNum); exists { return msgTokenData, nil } @@ -163,75 +196,17 @@ func (w *BackgroundWorker) getMsgTokenData(ctx context.Context, seqNum uint64) ( case <-ctx.Done(): return nil, context.DeadlineExceeded case <-tick.C: - if msgTokenData, exists := w.resultsCache.get(seqNum); exists { + if msgTokenData, exists := w.getFromCache(seqNum); exists { return msgTokenData, nil } } } } -type resultsCache struct { - expirationDuration time.Duration - expiresAt map[uint64]time.Time - results map[uint64][]msgResult - resultsMu *sync.RWMutex -} - -func newResultsCache(ctx context.Context, expirationDuration, cleanupInterval time.Duration) *resultsCache { - c := &resultsCache{ - expirationDuration: expirationDuration, - expiresAt: make(map[uint64]time.Time), - results: make(map[uint64][]msgResult), - resultsMu: &sync.RWMutex{}, - } - - ticker := time.NewTicker(cleanupInterval) - go func() { - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - c.cleanExpiredItems() - } - } - }() - - return c -} - -func (c *resultsCache) add(msgSeqNum uint64, results []msgResult) { - c.resultsMu.Lock() - defer c.resultsMu.Unlock() - c.results[msgSeqNum] = results - c.expiresAt[msgSeqNum] = time.Now().Add(c.expirationDuration) -} - -func (c *resultsCache) get(msgSeqNum uint64) ([]msgResult, bool) { - c.resultsMu.RLock() - defer c.resultsMu.RUnlock() - v, exists := c.results[msgSeqNum] - return v, exists -} - -func (c *resultsCache) cleanExpiredItems() { - c.resultsMu.RLock() - expiredKeys := make([]uint64, 0, len(c.expiresAt)) - for seqNum, expiresAt := range c.expiresAt { - if expiresAt.Before(time.Now()) { - expiredKeys = append(expiredKeys, seqNum) - } - } - c.resultsMu.RUnlock() - - if len(expiredKeys) == 0 { - return - } - - c.resultsMu.Lock() - for _, seqNum := range expiredKeys { - delete(c.results, seqNum) - delete(c.expiresAt, seqNum) +func (w *BackgroundWorker) getFromCache(seqNum uint64) ([]msgResult, bool) { + rawResult, found := w.resultsCache.Get(strconv.FormatUint(seqNum, 10)) + if !found { + return nil, false } - c.resultsMu.Unlock() + return rawResult.([]msgResult), true } diff --git a/core/services/ocr2/plugins/ccip/tokendata/bgworker_test.go b/core/services/ocr2/plugins/ccip/tokendata/bgworker_test.go index f703417128e..c55548ce6df 100644 --- a/core/services/ocr2/plugins/ccip/tokendata/bgworker_test.go +++ b/core/services/ocr2/plugins/ccip/tokendata/bgworker_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" @@ -43,7 +44,8 @@ func TestBackgroundWorker(t *testing.T) { readerLatency := rand.Intn(maxReaderLatencyMS) delays[tokens[i]] = time.Duration(readerLatency) * time.Millisecond } - w := tokendata.NewBackgroundWorker(ctx, tokenDataReaders, numWorkers, 5*time.Second, time.Hour) + w := tokendata.NewBackgroundWorker(tokenDataReaders, numWorkers, 5*time.Second, time.Hour) + require.NoError(t, w.Start(ctx)) msgs := make([]cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta, numMessages) for i := range msgs { @@ -90,6 +92,8 @@ func TestBackgroundWorker(t *testing.T) { assert.Equal(t, tokenData[msg.TokenAmounts[0].Token], b[0]) } assert.True(t, time.Since(tStart) < 200*time.Millisecond) + + require.NoError(t, w.Close()) } func TestBackgroundWorker_RetryOnErrors(t *testing.T) { @@ -101,10 +105,11 @@ func TestBackgroundWorker_RetryOnErrors(t *testing.T) { rdr1 := tokendata.NewMockReader(t) rdr2 := tokendata.NewMockReader(t) - w := tokendata.NewBackgroundWorker(ctx, map[cciptypes.Address]tokendata.Reader{ + w := tokendata.NewBackgroundWorker(map[cciptypes.Address]tokendata.Reader{ tk1: rdr1, tk2: rdr2, }, 10, 5*time.Second, time.Hour) + require.NoError(t, w.Start(ctx)) msgs := []cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta{ {EVM2EVMMessage: cciptypes.EVM2EVMMessage{ @@ -155,6 +160,8 @@ func TestBackgroundWorker_RetryOnErrors(t *testing.T) { tokenData, err = w.GetMsgTokenData(ctx, msgs[1]) assert.NoError(t, err) assert.Equal(t, []byte("some other data"), tokenData[0]) + + require.NoError(t, w.Close()) } func TestBackgroundWorker_Timeout(t *testing.T) { @@ -167,7 +174,8 @@ func TestBackgroundWorker_Timeout(t *testing.T) { rdr2 := tokendata.NewMockReader(t) w := tokendata.NewBackgroundWorker( - ctx, map[cciptypes.Address]tokendata.Reader{tk1: rdr1, tk2: rdr2}, 10, 5*time.Second, time.Hour) + map[cciptypes.Address]tokendata.Reader{tk1: rdr1, tk2: rdr2}, 10, 5*time.Second, time.Hour) + require.NoError(t, w.Start(ctx)) ctx, cf := context.WithTimeout(ctx, 500*time.Millisecond) defer cf() @@ -176,4 +184,5 @@ func TestBackgroundWorker_Timeout(t *testing.T) { EVM2EVMMessage: cciptypes.EVM2EVMMessage{SequenceNumber: 1}}, ) assert.Error(t, err) + require.NoError(t, w.Close()) } diff --git a/core/services/ocr2/plugins/ccip/tokendata/bgworker_unit_test.go b/core/services/ocr2/plugins/ccip/tokendata/bgworker_unit_test.go deleted file mode 100644 index d3f41cff410..00000000000 --- a/core/services/ocr2/plugins/ccip/tokendata/bgworker_unit_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package tokendata - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func Test_newResultsCache(t *testing.T) { - ctx := context.Background() - - t.Run("add and get", func(t *testing.T) { - c := newResultsCache(ctx, time.Hour, time.Hour) - c.add(123, []msgResult{{}, {}, {}}) - v, exists := c.get(123) - assert.True(t, exists) - assert.Equal(t, []msgResult{{}, {}, {}}, v) - }) - - t.Run("expired", func(t *testing.T) { - c := newResultsCache(ctx, time.Millisecond, time.Millisecond) - c.add(123, []msgResult{{}, {}, {}}) - time.Sleep(10 * time.Millisecond) - _, exists := c.get(123) - assert.False(t, exists) - }) - -}