diff --git a/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go b/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go index e9e3493b37..5571311660 100644 --- a/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go +++ b/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go @@ -84,6 +84,7 @@ type ExecutionReportingPlugin struct { offchainConfig ccipconfig.ExecOffchainConfig cachedSourceFeeTokens cache.AutoSync[[]common.Address] cachedDestTokens cache.AutoSync[cache.CachedTokens] + cachedTokenPools cache.AutoSync[map[common.Address]common.Address] customTokenPoolFactory func(ctx context.Context, poolAddress common.Address, bind bind.ContractBackend) (custom_token_pool.CustomTokenPoolInterface, error) } @@ -131,6 +132,8 @@ func (rf *ExecutionReportingPluginFactory) NewReportingPlugin(config types.Repor cachedSourceFeeTokens := cache.NewCachedFeeTokens(rf.config.sourceLP, rf.config.sourcePriceRegistry, int64(offchainConfig.SourceFinalityDepth)) cachedDestTokens := cache.NewCachedSupportedTokens(rf.config.destLP, rf.config.offRamp, priceRegistry, int64(offchainConfig.DestOptimisticConfirmations)) + + cachedTokenPools := cache.NewTokenPools(rf.config.lggr, rf.config.destLP, rf.config.offRamp, int64(offchainConfig.DestOptimisticConfirmations), 5) rf.config.lggr.Infow("Starting exec plugin", "offchainConfig", offchainConfig, "onchainConfig", onchainConfig) @@ -147,6 +150,7 @@ func (rf *ExecutionReportingPluginFactory) NewReportingPlugin(config types.Repor offchainConfig: offchainConfig, cachedDestTokens: cachedDestTokens, cachedSourceFeeTokens: cachedSourceFeeTokens, + cachedTokenPools: cachedTokenPools, customTokenPoolFactory: func(ctx context.Context, poolAddress common.Address, contractBackend bind.ContractBackend) (custom_token_pool.CustomTokenPoolInterface, error) { return custom_token_pool.NewCustomTokenPool(poolAddress, contractBackend) }, @@ -394,12 +398,17 @@ func (r *ExecutionReportingPlugin) destPoolRateLimits(ctx context.Context, commi } } + tokenPools, err := r.cachedTokenPools.Get(ctx) + if err != nil { + return nil, fmt.Errorf("get cached token pools: %w", err) + } + res := make(map[common.Address]*big.Int, len(dstTokens)) for dstToken := range dstTokens { - poolAddress, err := r.config.offRamp.GetPoolByDestToken(&bind.CallOpts{Context: ctx}, dstToken) - if err != nil { - return nil, fmt.Errorf("get pool by dest token (%s): %w", dstToken, err) + poolAddress, exists := tokenPools[dstToken] + if !exists { + return nil, fmt.Errorf("pool for token '%s' does not exist", dstToken) } tokenPool, err := r.customTokenPoolFactory(ctx, poolAddress, r.config.destClient) diff --git a/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go b/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go index 30ee04add9..fb02daa03c 100644 --- a/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go +++ b/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go @@ -57,6 +57,7 @@ func TestExecutionReportingPlugin_Observation(t *testing.T) { unexpiredReports []ccipdata.Event[commit_store.CommitStoreReportAccepted] sendRequests []ccipdata.Event[ccipdata.EVM2EVMMessage] executedSeqNums []uint64 + tokenPoolsMapping map[common.Address]common.Address blessedRoots map[[32]byte]bool senderNonce uint64 rateLimiterState evm_2_evm_offramp.RateLimiterTokenBucket @@ -88,7 +89,8 @@ func TestExecutionReportingPlugin_Observation(t *testing.T) { rateLimiterState: evm_2_evm_offramp.RateLimiterTokenBucket{ IsEnabled: false, }, - senderNonce: 9, + tokenPoolsMapping: map[common.Address]common.Address{}, + senderNonce: 9, sendRequests: []ccipdata.Event[ccipdata.EVM2EVMMessage]{ { Data: ccipdata.EVM2EVMMessage{SequenceNumber: 10}, @@ -161,6 +163,10 @@ func TestExecutionReportingPlugin_Observation(t *testing.T) { p.destPriceRegistry = priceRegistry p.config.sourcePriceRegistry = priceRegistry + cachedTokenPools := cache.NewMockAutoSync[map[common.Address]common.Address](t) + cachedTokenPools.On("Get", ctx).Return(tc.tokenPoolsMapping, nil).Maybe() + p.cachedTokenPools = cachedTokenPools + sourceFeeTokens := cache.NewMockAutoSync[[]common.Address](t) sourceFeeTokens.On("Get", ctx).Return([]common.Address{}, nil).Maybe() p.cachedSourceFeeTokens = sourceFeeTokens @@ -765,6 +771,7 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { sourceToDestToken map[common.Address]common.Address destPools map[common.Address]common.Address poolRateLimits map[common.Address]custom_token_pool.RateLimiterTokenBucket + destPoolsCacheErr error expRateLimits map[common.Address]*big.Int expErr bool @@ -838,6 +845,20 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { }, expErr: false, }, + { + name: "dest pool cache error", + tokenAmounts: []evm_2_evm_offramp.ClientEVMTokenAmount{{Token: tk1}}, + sourceToDestToken: map[common.Address]common.Address{tk1: tk1dest}, + destPoolsCacheErr: errors.New("some random error"), + expErr: true, + }, + { + name: "pool for token not found", + tokenAmounts: []evm_2_evm_offramp.ClientEVMTokenAmount{{Token: tk1}}, + sourceToDestToken: map[common.Address]common.Address{tk1: tk1dest}, + destPools: map[common.Address]common.Address{}, + expErr: true, + }, } ctx := testutils.Context(t) @@ -847,6 +868,10 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { p := &ExecutionReportingPlugin{} p.lggr = lggr + tokenPoolsCache := cache.NewMockAutoSync[map[common.Address]common.Address](t) + tokenPoolsCache.On("Get", ctx).Return(tc.destPools, tc.destPoolsCacheErr).Maybe() + p.cachedTokenPools = tokenPoolsCache + offRamp, offRampAddr := testhelpers.NewFakeOffRamp(t) offRamp.SetTokenPools(tc.destPools) p.config.offRamp = offRamp diff --git a/core/services/ocr2/plugins/ccip/internal/cache/tokenpool.go b/core/services/ocr2/plugins/ccip/internal/cache/tokenpool.go new file mode 100644 index 0000000000..8a72c8d357 --- /dev/null +++ b/core/services/ocr2/plugins/ccip/internal/cache/tokenpool.go @@ -0,0 +1,92 @@ +package cache + +import ( + "context" + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "golang.org/x/sync/errgroup" + + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_offramp" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" +) + +func NewTokenPools( + lggr logger.Logger, + lp logpoller.LogPoller, + offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface, + optimisticConfirmations int64, + numWorkers int, +) *CachedChain[map[common.Address]common.Address] { + return &CachedChain[map[common.Address]common.Address]{ + observedEvents: []common.Hash{ + abihelpers.EventSignatures.PoolAdded, + abihelpers.EventSignatures.PoolRemoved, + }, + logPoller: lp, + address: []common.Address{offRamp.Address()}, + optimisticConfirmations: optimisticConfirmations, + lock: &sync.RWMutex{}, + value: make(map[common.Address]common.Address), + lastChangeBlock: 0, + origin: newTokenPoolsOrigin(lggr, offRamp, numWorkers), + } +} + +func newTokenPoolsOrigin( + lggr logger.Logger, + offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface, + numWorkers int) *tokenPools { + return &tokenPools{ + lggr: lggr, + offRamp: offRamp, + numWorkers: numWorkers, + } +} + +type tokenPools struct { + lggr logger.Logger + offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface + numWorkers int +} + +func (t *tokenPools) Copy(value map[common.Address]common.Address) map[common.Address]common.Address { + return copyMap(value) +} + +func (t *tokenPools) CallOrigin(ctx context.Context) (map[common.Address]common.Address, error) { + destTokens, err := t.offRamp.GetDestinationTokens(&bind.CallOpts{Context: ctx}) + if err != nil { + return nil, err + } + + eg := new(errgroup.Group) + eg.SetLimit(t.numWorkers) + var mu sync.Mutex + + mapping := make(map[common.Address]common.Address, len(destTokens)) + for _, token := range destTokens { + token := token + eg.Go(func() error { + poolAddress, err := t.offRamp.GetPoolByDestToken(&bind.CallOpts{Context: ctx}, token) + if err != nil { + return fmt.Errorf("get token pool for token '%s': %w", token, err) + } + + mu.Lock() + mapping[token] = poolAddress + mu.Unlock() + return nil + }) + } + + if err := eg.Wait(); err != nil { + return nil, err + } + + return mapping, nil +} diff --git a/core/services/ocr2/plugins/ccip/internal/cache/tokenpool_test.go b/core/services/ocr2/plugins/ccip/internal/cache/tokenpool_test.go new file mode 100644 index 0000000000..bb6e0d5674 --- /dev/null +++ b/core/services/ocr2/plugins/ccip/internal/cache/tokenpool_test.go @@ -0,0 +1,127 @@ +package cache + +import ( + "math/rand" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/testhelpers" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +func TestNewTokenPools(t *testing.T) { + ctx := testutils.Context(t) + + tk1src := utils.RandomAddress() + tk1dst := utils.RandomAddress() + tk1pool := utils.RandomAddress() + + tk2src := utils.RandomAddress() + tk2dst := utils.RandomAddress() + tk2pool := utils.RandomAddress() + + testCases := []struct { + name string + sourceToDestTokens map[common.Address]common.Address // offramp + feeTokens []common.Address // price registry + tokenToPool map[common.Address]common.Address // offramp + expRes map[common.Address]common.Address + expErr bool + }{ + { + name: "no tokens", + sourceToDestTokens: map[common.Address]common.Address{}, + feeTokens: []common.Address{}, + tokenToPool: map[common.Address]common.Address{}, + expRes: map[common.Address]common.Address{}, + expErr: false, + }, + { + name: "happy flow", + sourceToDestTokens: map[common.Address]common.Address{ + tk1src: tk1dst, + tk2src: tk2dst, + }, + feeTokens: []common.Address{tk1dst, tk2dst}, + tokenToPool: map[common.Address]common.Address{ + tk1dst: tk1pool, + tk2dst: tk2pool, + }, + expRes: map[common.Address]common.Address{ + tk1dst: tk1pool, + tk2dst: tk2pool, + }, + expErr: false, + }, + { + name: "token pool not found", + sourceToDestTokens: map[common.Address]common.Address{ + tk1src: tk1dst, + }, + feeTokens: []common.Address{tk1dst}, + tokenToPool: map[common.Address]common.Address{}, + expErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockLp := mocks.NewLogPoller(t) + mockLp.On("LatestBlock", mock.Anything).Return(int64(100), nil) + + offRamp, _ := testhelpers.NewFakeOffRamp(t) + offRamp.SetSourceToDestTokens(tc.sourceToDestTokens) + offRamp.SetTokenPools(tc.tokenToPool) + + priceReg, _ := testhelpers.NewFakePriceRegistry(t) + priceReg.SetFeeTokens(tc.feeTokens) + + c := NewTokenPools(logger.TestLogger(t), mockLp, offRamp, 0, 5) + + res, err := c.Get(ctx) + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, len(tc.expRes), len(res)) + for k, v := range tc.expRes { + assert.Equal(t, v, res[k]) + } + }) + } +} + +func Test_tokenPools_CallOrigin_concurrency(t *testing.T) { + numDestTokens := rand.Intn(500) + numWorkers := rand.Intn(500) + + sourceToDestTokens := make(map[common.Address]common.Address, numDestTokens) + tokenToPool := make(map[common.Address]common.Address) + for i := 0; i < numDestTokens; i++ { + sourceToken := utils.RandomAddress() + destToken := utils.RandomAddress() + destPool := utils.RandomAddress() + sourceToDestTokens[sourceToken] = destToken + tokenToPool[destToken] = destPool + } + + offRamp, _ := testhelpers.NewFakeOffRamp(t) + offRamp.SetSourceToDestTokens(sourceToDestTokens) + offRamp.SetTokenPools(tokenToPool) + + origin := newTokenPoolsOrigin(logger.TestLogger(t), offRamp, numWorkers) + res, err := origin.CallOrigin(testutils.Context(t)) + assert.NoError(t, err) + + assert.Equal(t, len(tokenToPool), len(res)) + for k, v := range tokenToPool { + assert.Equal(t, v, res[k]) + } +} diff --git a/core/services/ocr2/plugins/ccip/internal/cache/tokens.go b/core/services/ocr2/plugins/ccip/internal/cache/tokens.go index 5a53cf964b..536895b090 100644 --- a/core/services/ocr2/plugins/ccip/internal/cache/tokens.go +++ b/core/services/ocr2/plugins/ccip/internal/cache/tokens.go @@ -206,25 +206,12 @@ func (t *tokenToDecimals) Copy(value map[common.Address]uint8) map[common.Addres // CallOrigin Generates the token to decimal mapping for dest tokens and fee tokens. // NOTE: this queries token decimals n times, where n is the number of tokens whose decimals are not already cached. func (t *tokenToDecimals) CallOrigin(ctx context.Context) (map[common.Address]uint8, error) { - mapping := make(map[common.Address]uint8) - - destTokens, err := t.offRamp.GetDestinationTokens(&bind.CallOpts{Context: ctx}) + destTokens, err := getDestinationAndFeeTokens(ctx, t.offRamp, t.priceRegistry) if err != nil { return nil, err } - feeTokens, err := t.priceRegistry.GetFeeTokens(&bind.CallOpts{Context: ctx}) - if err != nil { - return nil, err - } - - // In case if a fee token is not an offramp dest token, we still want to update its decimals and price - for _, feeToken := range feeTokens { - if !slices.Contains(destTokens, feeToken) { - destTokens = append(destTokens, feeToken) - } - } - + mapping := make(map[common.Address]uint8, len(destTokens)) for _, token := range destTokens { if decimals, exists := t.getCachedDecimals(token); exists { mapping[token] = decimals @@ -247,6 +234,26 @@ func (t *tokenToDecimals) CallOrigin(ctx context.Context) (map[common.Address]ui return mapping, nil } +func getDestinationAndFeeTokens(ctx context.Context, offRamp evm_2_evm_offramp.EVM2EVMOffRampInterface, priceRegistry price_registry.PriceRegistryInterface) ([]common.Address, error) { + destTokens, err := offRamp.GetDestinationTokens(&bind.CallOpts{Context: ctx}) + if err != nil { + return nil, err + } + + feeTokens, err := priceRegistry.GetFeeTokens(&bind.CallOpts{Context: ctx}) + if err != nil { + return nil, err + } + + for _, feeToken := range feeTokens { + if !slices.Contains(destTokens, feeToken) { + destTokens = append(destTokens, feeToken) + } + } + + return destTokens, nil +} + func (t *tokenToDecimals) getCachedDecimals(token common.Address) (uint8, bool) { rawVal, exists := t.tokenDecimals.Load(token.String()) if !exists { diff --git a/core/services/ocr2/plugins/ccip/testhelpers/offramp.go b/core/services/ocr2/plugins/ccip/testhelpers/offramp.go index 16b572b5c0..4fc68db15a 100644 --- a/core/services/ocr2/plugins/ccip/testhelpers/offramp.go +++ b/core/services/ocr2/plugins/ccip/testhelpers/offramp.go @@ -51,7 +51,13 @@ func (o *FakeOffRamp) SetSenderNonces(senderNonces map[common.Address]uint64) { } func (o *FakeOffRamp) GetPoolByDestToken(opts *bind.CallOpts, destToken common.Address) (common.Address, error) { - return getOffRampVal(o, func(o *FakeOffRamp) (common.Address, error) { return o.tokenToPool[destToken], nil }) + return getOffRampVal(o, func(o *FakeOffRamp) (common.Address, error) { + addr, exists := o.tokenToPool[destToken] + if !exists { + return common.Address{}, errors.New("not found") + } + return addr, nil + }) } func (o *FakeOffRamp) SetTokenPools(tokenToPool map[common.Address]common.Address) { @@ -92,6 +98,16 @@ func (o *FakeOffRamp) GetDestinationToken(opts *bind.CallOpts, sourceToken commo }) } +func (o *FakeOffRamp) GetDestinationTokens(opts *bind.CallOpts) ([]common.Address, error) { + return getOffRampVal(o, func(o *FakeOffRamp) ([]common.Address, error) { + tokens := make([]common.Address, 0, len(o.sourceToDestTokens)) + for _, dst := range o.sourceToDestTokens { + tokens = append(tokens, dst) + } + return tokens, nil + }) +} + func getOffRampVal[T any](o *FakeOffRamp, getter func(o *FakeOffRamp) (T, error)) (T, error) { o.mu.RLock() defer o.mu.RUnlock() diff --git a/core/services/ocr2/plugins/ccip/testhelpers/priceregistry.go b/core/services/ocr2/plugins/ccip/testhelpers/priceregistry.go index 5b127e685c..9f0d9b5fce 100644 --- a/core/services/ocr2/plugins/ccip/testhelpers/priceregistry.go +++ b/core/services/ocr2/plugins/ccip/testhelpers/priceregistry.go @@ -16,6 +16,7 @@ type FakePriceRegistry struct { *mock_contracts.PriceRegistryInterface tokenPrices []price_registry.InternalTimestampedPackedUint224 + feeTokens []common.Address mu sync.RWMutex } @@ -39,6 +40,14 @@ func (p *FakePriceRegistry) GetTokenPrices(opts *bind.CallOpts, tokens []common. }) } +func (p *FakePriceRegistry) SetFeeTokens(tokens []common.Address) { + setPriceRegistryVal(p, func(p *FakePriceRegistry) { p.feeTokens = tokens }) +} + +func (p *FakePriceRegistry) GetFeeTokens(opts *bind.CallOpts) ([]common.Address, error) { + return getPriceRegistryVal(p, func(p *FakePriceRegistry) ([]common.Address, error) { return p.feeTokens, nil }) +} + func getPriceRegistryVal[T any](p *FakePriceRegistry, getter func(p *FakePriceRegistry) (T, error)) (T, error) { p.mu.RLock() defer p.mu.RUnlock()