diff --git a/.changeset/blue-camels-begin.md b/.changeset/blue-camels-begin.md new file mode 100644 index 00000000000..3ad57286e91 --- /dev/null +++ b/.changeset/blue-camels-begin.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +enforce proper result indexing on pipeline results #breaking_change diff --git a/.changeset/blue-camels-promise.md b/.changeset/blue-camels-promise.md new file mode 100644 index 00000000000..48d7fd47565 --- /dev/null +++ b/.changeset/blue-camels-promise.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +#changed Remove ClientErrors interface from common diff --git a/.changeset/silent-jars-relax.md b/.changeset/silent-jars-relax.md new file mode 100644 index 00000000000..3b076da1226 --- /dev/null +++ b/.changeset/silent-jars-relax.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +#internal [Keystone] EVM encoder support for tuples diff --git a/.changeset/tame-mice-give.md b/.changeset/tame-mice-give.md new file mode 100644 index 00000000000..7cd59b154ad --- /dev/null +++ b/.changeset/tame-mice-give.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +#db_update Add ON DELETE CASCADE to workflow tables diff --git a/.changeset/wild-berries-cry.md b/.changeset/wild-berries-cry.md new file mode 100644 index 00000000000..196de1a124e --- /dev/null +++ b/.changeset/wild-berries-cry.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +#db_update Add name to workflow spec. Add unique constraint to (owner,name) for workflow spec diff --git a/.changeset/young-candles-brush.md b/.changeset/young-candles-brush.md new file mode 100644 index 00000000000..5d10eaabf80 --- /dev/null +++ b/.changeset/young-candles-brush.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +#bugfix allow ChainType to be set to xdai diff --git a/.github/workflows/automation-nightly-tests.yml b/.github/workflows/automation-nightly-tests.yml index f25700f3155..ae0f3526e99 100644 --- a/.github/workflows/automation-nightly-tests.yml +++ b/.github/workflows/automation-nightly-tests.yml @@ -129,6 +129,7 @@ jobs: cl_image_tag: 'latest' aws_registries: ${{ secrets.QA_AWS_ACCOUNT_NUMBER }} artifacts_location: ./integration-tests/${{ matrix.tests.suite }}/logs + artifacts_name: testcontainers-logs-${{ matrix.tests.name }} publish_check_name: Automation Results ${{ matrix.tests.name }} token: ${{ secrets.GITHUB_TOKEN }} go_mod_path: ./integration-tests/go.mod @@ -139,7 +140,7 @@ jobs: uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 if: failure() with: - name: test-log-${{ matrix.tests.name }} + name: gotest-logs-${{ matrix.tests.name }} path: /tmp/gotest.log retention-days: 7 continue-on-error: true @@ -155,6 +156,9 @@ jobs: this-job-name: Automation ${{ matrix.tests.name }} Test test-results-file: '{"testType":"go","filePath":"/tmp/gotest.log"}' continue-on-error: true + - name: Print failed test summary + if: always() + uses: smartcontractkit/chainlink-github-actions/chainlink-testing-framework/show-test-summary@b49a9d04744b0237908831730f8553f26d73a94b # v2.3.17 test-notify: name: Start Slack Thread diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ad04a9d808c..9e944ead9bf 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1214,7 +1214,7 @@ jobs: QA_AWS_REGION: ${{ secrets.QA_AWS_REGION }} QA_AWS_ROLE_TO_ASSUME: ${{ secrets.QA_AWS_ROLE_TO_ASSUME }} QA_KUBECONFIG: ${{ secrets.QA_KUBECONFIG }} - - name: Pull Artfacts + - name: Pull Artifacts if: needs.changes.outputs.src == 'true' || github.event_name == 'workflow_dispatch' run: | IMAGE_NAME=${{ secrets.QA_AWS_ACCOUNT_NUMBER }}.dkr.ecr.${{ secrets.QA_AWS_REGION }}.amazonaws.com/chainlink-solana-tests:${{ needs.get_solana_sha.outputs.sha }} @@ -1232,12 +1232,20 @@ jobs: docker rm "$CONTAINER_ID" - name: Install Solana CLI # required for ensuring the local test validator is configured correctly run: ./scripts/install-solana-ci.sh + - name: Install gauntlet + run: | + yarn --cwd ./gauntlet install --frozen-lockfile + yarn --cwd ./gauntlet build + yarn --cwd ./gauntlet gauntlet - name: Generate config overrides run: | # https://github.com/smartcontractkit/chainlink-testing-framework/blob/main/config/README.md cat << EOF > config.toml [ChainlinkImage] image="${{ env.CHAINLINK_IMAGE }}" version="${{ inputs.evm-ref || github.sha }}" + [Common] + user="${{ github.actor }}" + internal_docker_repo = "${{ secrets.QA_AWS_ACCOUNT_NUMBER }}.dkr.ecr.${{ secrets.QA_AWS_REGION }}.amazonaws.com" EOF # shellcheck disable=SC2002 BASE64_CONFIG_OVERRIDE=$(cat config.toml | base64 -w 0) diff --git a/.golangci.yml b/.golangci.yml index 96a7de282e0..3834400ba67 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -15,6 +15,7 @@ linters: - noctx - depguard - whitespace + - containedctx linters-settings: exhaustive: default-signifies-exhaustive: true diff --git a/common/client/node.go b/common/client/node.go index 6450b086f10..1d0a799321b 100644 --- a/common/client/node.go +++ b/common/client/node.go @@ -9,8 +9,6 @@ import ( "sync" "time" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" - "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -47,7 +45,6 @@ type NodeConfig interface { SyncThreshold() uint32 NodeIsSyncingEnabled() bool FinalizedBlockPollInterval() time.Duration - Errors() config.ClientErrors } type ChainConfig interface { @@ -106,10 +103,7 @@ type node[ stateLatestTotalDifficulty *big.Int stateLatestFinalizedBlockNumber int64 - // nodeCtx is the node lifetime's context - nodeCtx context.Context - // cancelNodeCtx cancels nodeCtx when stopping the node - cancelNodeCtx context.CancelFunc + stopCh services.StopChan // wg waits for subsidiary goroutines wg sync.WaitGroup @@ -148,7 +142,7 @@ func NewNode[ if httpuri != nil { n.http = httpuri } - n.nodeCtx, n.cancelNodeCtx = context.WithCancel(context.Background()) + n.stopCh = make(services.StopChan) lggr = logger.Named(lggr, "Node") lggr = logger.With(lggr, "nodeTier", Primary.String(), @@ -205,7 +199,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) close() error { n.stateMu.Lock() defer n.stateMu.Unlock() - n.cancelNodeCtx() + close(n.stopCh) n.state = nodeStateClosed return nil } diff --git a/common/client/node_lifecycle.go b/common/client/node_lifecycle.go index fa6397580c8..5947774e202 100644 --- a/common/client/node_lifecycle.go +++ b/common/client/node_lifecycle.go @@ -79,6 +79,8 @@ const rpcSubscriptionMethodNewHeads = "newHeads" // Should only be run ONCE per node, after a successful Dial func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -100,7 +102,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { lggr.Tracew("Alive loop starting", "nodeState", n.State()) headsC := make(chan HEAD) - sub, err := n.rpc.Subscribe(n.nodeCtx, headsC, rpcSubscriptionMethodNewHeads) + sub, err := n.rpc.Subscribe(ctx, headsC, rpcSubscriptionMethodNewHeads) if err != nil { lggr.Errorw("Initial subscribe for heads failed", "nodeState", n.State()) n.declareUnreachable() @@ -151,15 +153,16 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-pollCh: - var version string promPoolRPCNodePolls.WithLabelValues(n.chainID.String(), n.name).Inc() lggr.Tracew("Polling for version", "nodeState", n.State(), "pollFailures", pollFailures) - ctx, cancel := context.WithTimeout(n.nodeCtx, pollInterval) - version, err := n.RPC().ClientVersion(ctx) - cancel() + version, err := func(ctx context.Context) (string, error) { + ctx, cancel := context.WithTimeout(ctx, pollInterval) + defer cancel() + return n.RPC().ClientVersion(ctx) + }(ctx) if err != nil { // prevent overflow if pollFailures < math.MaxUint32 { @@ -240,9 +243,11 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { n.declareOutOfSync(func(num int64, td *big.Int) bool { return num < highestReceivedBlockNumber }) return case <-pollFinalizedHeadCh: - ctx, cancel := context.WithTimeout(n.nodeCtx, n.nodePoolCfg.FinalizedBlockPollInterval()) - latestFinalized, err := n.RPC().LatestFinalizedBlock(ctx) - cancel() + latestFinalized, err := func(ctx context.Context) (HEAD, error) { + ctx, cancel := context.WithTimeout(ctx, n.nodePoolCfg.FinalizedBlockPollInterval()) + defer cancel() + return n.RPC().LatestFinalizedBlock(ctx) + }(ctx) if err != nil { lggr.Warnw("Failed to fetch latest finalized block", "err", err) continue @@ -300,6 +305,8 @@ const ( // outOfSyncLoop takes an OutOfSync node and waits until isOutOfSync returns false to go back to live status func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -319,7 +326,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td lggr.Debugw("Trying to revive out-of-sync RPC node", "nodeState", n.State()) // Need to redial since out-of-sync nodes are automatically disconnected - state := n.createVerifiedConn(n.nodeCtx, lggr) + state := n.createVerifiedConn(ctx, lggr) if state != nodeStateAlive { n.declareState(state) return @@ -328,7 +335,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td lggr.Tracew("Successfully subscribed to heads feed on out-of-sync RPC node", "nodeState", n.State()) ch := make(chan HEAD) - sub, err := n.rpc.Subscribe(n.nodeCtx, ch, rpcSubscriptionMethodNewHeads) + sub, err := n.rpc.Subscribe(ctx, ch, rpcSubscriptionMethodNewHeads) if err != nil { lggr.Errorw("Failed to subscribe heads on out-of-sync RPC node", "nodeState", n.State(), "err", err) n.declareUnreachable() @@ -338,7 +345,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case head, open := <-ch: if !open { @@ -372,6 +379,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -394,12 +403,12 @@ func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(dialRetryBackoff.Duration()): lggr.Tracew("Trying to re-dial RPC node", "nodeState", n.State()) - err := n.rpc.Dial(n.nodeCtx) + err := n.rpc.Dial(ctx) if err != nil { lggr.Errorw(fmt.Sprintf("Failed to redial RPC node; still unreachable: %v", err), "err", err, "nodeState", n.State()) continue @@ -407,7 +416,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { n.setState(nodeStateDialed) - state := n.verifyConn(n.nodeCtx, lggr) + state := n.verifyConn(ctx, lggr) switch state { case nodeStateUnreachable: n.setState(nodeStateUnreachable) @@ -425,6 +434,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -443,7 +454,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { lggr := logger.Named(n.lfcLog, "InvalidChainID") // Need to redial since invalid chain ID nodes are automatically disconnected - state := n.createVerifiedConn(n.nodeCtx, lggr) + state := n.createVerifiedConn(ctx, lggr) if state != nodeStateInvalidChainID { n.declareState(state) return @@ -455,10 +466,10 @@ func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(chainIDRecheckBackoff.Duration()): - state := n.verifyConn(n.nodeCtx, lggr) + state := n.verifyConn(ctx, lggr) switch state { case nodeStateInvalidChainID: continue @@ -475,6 +486,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { func (n *node[CHAIN_ID, HEAD, RPC]) syncingLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -493,7 +506,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) syncingLoop() { lggr := logger.Sugared(logger.Named(n.lfcLog, "Syncing")) lggr.Debugw(fmt.Sprintf("Periodically re-checking RPC node %s with syncing status", n.String()), "nodeState", n.State()) // Need to redial since syncing nodes are automatically disconnected - state := n.createVerifiedConn(n.nodeCtx, lggr) + state := n.createVerifiedConn(ctx, lggr) if state != nodeStateSyncing { n.declareState(state) return @@ -503,11 +516,11 @@ func (n *node[CHAIN_ID, HEAD, RPC]) syncingLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(recheckBackoff.Duration()): lggr.Tracew("Trying to recheck if the node is still syncing", "nodeState", n.State()) - isSyncing, err := n.rpc.IsSyncing(n.nodeCtx) + isSyncing, err := n.rpc.IsSyncing(ctx) if err != nil { lggr.Errorw("Unexpected error while verifying RPC node synchronization status", "err", err, "nodeState", n.State()) n.declareUnreachable() diff --git a/common/client/node_test.go b/common/client/node_test.go index 85c96145740..a97f26555a9 100644 --- a/common/client/node_test.go +++ b/common/client/node_test.go @@ -9,7 +9,6 @@ import ( clientMocks "github.com/smartcontractkit/chainlink/v2/common/client/mocks" "github.com/smartcontractkit/chainlink/v2/common/types" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" ) type testNodeConfig struct { @@ -19,7 +18,6 @@ type testNodeConfig struct { syncThreshold uint32 nodeIsSyncingEnabled bool finalizedBlockPollInterval time.Duration - errors config.ClientErrors } func (n testNodeConfig) PollFailureThreshold() uint32 { @@ -46,10 +44,6 @@ func (n testNodeConfig) FinalizedBlockPollInterval() time.Duration { return n.finalizedBlockPollInterval } -func (n testNodeConfig) Errors() config.ClientErrors { - return n.errors -} - type testNode struct { *node[types.ID, Head, NodeClient[types.ID, Head]] } diff --git a/common/config/chaintype.go b/common/config/chaintype.go index 906b57415e3..e29cd1a337e 100644 --- a/common/config/chaintype.go +++ b/common/config/chaintype.go @@ -5,10 +5,8 @@ import ( "strings" ) -// ChainType denotes the chain or network to work with type ChainType string -// nolint const ( ChainArbitrum ChainType = "arbitrum" ChainCelo ChainType = "celo" @@ -18,12 +16,104 @@ const ( ChainOptimismBedrock ChainType = "optimismBedrock" ChainScroll ChainType = "scroll" ChainWeMix ChainType = "wemix" - ChainXDai ChainType = "xdai" // Deprecated: use ChainGnosis instead ChainXLayer ChainType = "xlayer" ChainZkEvm ChainType = "zkevm" ChainZkSync ChainType = "zksync" ) +// IsL2 returns true if this chain is a Layer 2 chain. Notably: +// - the block numbers used for log searching are different from calling block.number +// - gas bumping is not supported, since there is no tx mempool +func (c ChainType) IsL2() bool { + switch c { + case ChainArbitrum, ChainMetis: + return true + default: + return false + } +} + +func (c ChainType) IsValid() bool { + switch c { + case "", ChainArbitrum, ChainCelo, ChainGnosis, ChainKroma, ChainMetis, ChainOptimismBedrock, ChainScroll, ChainWeMix, ChainXLayer, ChainZkSync: + return true + } + return false +} + +func ChainTypeFromSlug(slug string) ChainType { + switch slug { + case "arbitrum": + return ChainArbitrum + case "celo": + return ChainCelo + case "gnosis", "xdai": + return ChainGnosis + case "kroma": + return ChainKroma + case "metis": + return ChainMetis + case "optimismBedrock": + return ChainOptimismBedrock + case "scroll": + return ChainScroll + case "wemix": + return ChainWeMix + case "xlayer": + return ChainXLayer + case "zksync": + return ChainZkSync + default: + return ChainType(slug) + } +} + +type ChainTypeConfig struct { + value ChainType + slug string +} + +func NewChainTypeConfig(slug string) *ChainTypeConfig { + return &ChainTypeConfig{ + value: ChainTypeFromSlug(slug), + slug: slug, + } +} + +func (c *ChainTypeConfig) MarshalText() ([]byte, error) { + if c == nil { + return nil, nil + } + return []byte(c.slug), nil +} + +func (c *ChainTypeConfig) UnmarshalText(b []byte) error { + c.slug = string(b) + c.value = ChainTypeFromSlug(c.slug) + return nil +} + +func (c *ChainTypeConfig) Slug() string { + if c == nil { + return "" + } + return c.slug +} + +func (c *ChainTypeConfig) ChainType() ChainType { + if c == nil { + return "" + } + return c.value +} + +func (c *ChainTypeConfig) String() string { + if c == nil { + return "" + } + return string(c.value) +} + var ErrInvalidChainType = fmt.Errorf("must be one of %s or omitted", strings.Join([]string{ string(ChainArbitrum), string(ChainCelo), @@ -37,24 +127,3 @@ var ErrInvalidChainType = fmt.Errorf("must be one of %s or omitted", strings.Joi string(ChainZkEvm), string(ChainZkSync), }, ", ")) - -// IsValid returns true if the ChainType value is known or empty. -func (c ChainType) IsValid() bool { - switch c { - case "", ChainArbitrum, ChainCelo, ChainGnosis, ChainKroma, ChainMetis, ChainOptimismBedrock, ChainScroll, ChainWeMix, ChainXDai, ChainXLayer, ChainZkEvm, ChainZkSync: - return true - } - return false -} - -// IsL2 returns true if this chain is a Layer 2 chain. Notably: -// - the block numbers used for log searching are different from calling block.number -// - gas bumping is not supported, since there is no tx mempool -func (c ChainType) IsL2() bool { - switch c { - case ChainArbitrum, ChainMetis: - return true - default: - return false - } -} diff --git a/common/txmgr/confirmer.go b/common/txmgr/confirmer.go index d35172895f3..294e922c1c0 100644 --- a/common/txmgr/confirmer.go +++ b/common/txmgr/confirmer.go @@ -134,8 +134,7 @@ type Confirmer[ enabledAddresses []ADDR mb *mailbox.Mailbox[HEAD] - ctx context.Context - ctxCancel context.CancelFunc + stopCh services.StopChan wg sync.WaitGroup initSync sync.Mutex isStarted bool @@ -213,7 +212,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) sta ec.lggr.Debugf("Confirmer: failed to load the last purged block num for enabled addresses. Process can continue as normal but purge rate limiting may be affected.") } - ec.ctx, ec.ctxCancel = context.WithCancel(context.Background()) + ec.stopCh = make(chan struct{}) ec.wg = sync.WaitGroup{} ec.wg.Add(1) go ec.runLoop() @@ -234,7 +233,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) clo if !ec.isStarted { return fmt.Errorf("Confirmer is not started: %w", services.ErrAlreadyStopped) } - ec.ctxCancel() + close(ec.stopCh) ec.wg.Wait() ec.isStarted = false return nil @@ -254,23 +253,25 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Hea func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() { defer ec.wg.Done() + ctx, cancel := ec.stopCh.NewCtx() + defer cancel() for { select { case <-ec.mb.Notify(): for { - if ec.ctx.Err() != nil { + if ctx.Err() != nil { return } head, exists := ec.mb.Retrieve() if !exists { break } - if err := ec.ProcessHead(ec.ctx, head); err != nil { + if err := ec.ProcessHead(ctx, head); err != nil { ec.lggr.Errorw("Error processing head", "err", err) continue } } - case <-ec.ctx.Done(): + case <-ctx.Done(): return } } @@ -1028,7 +1029,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Ens for _, etx := range etxs { if !hasReceiptInLongestChain(*etx, head) { - if err := ec.markForRebroadcast(*etx, head); err != nil { + if err := ec.markForRebroadcast(ctx, *etx, head); err != nil { return fmt.Errorf("markForRebroadcast failed for etx %v: %w", etx.ID, err) } } @@ -1080,7 +1081,7 @@ func hasReceiptInLongestChain[ } } -func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) markForRebroadcast(etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], head types.Head[BLOCK_HASH]) error { +func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) markForRebroadcast(ctx context.Context, etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], head types.Head[BLOCK_HASH]) error { if len(etx.TxAttempts) == 0 { return fmt.Errorf("invariant violation: expected tx %v to have at least one attempt", etx.ID) } @@ -1115,7 +1116,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) mar ec.lggr.Infow(fmt.Sprintf("Re-org detected. Rebroadcasting transaction %s which may have been re-org'd out of the main chain", attempt.Hash.String()), logValues...) // Put it back in progress and delete all receipts (they do not apply to the new chain) - if err := ec.txStore.UpdateTxForRebroadcast(ec.ctx, etx, attempt); err != nil { + if err := ec.txStore.UpdateTxForRebroadcast(ctx, etx, attempt); err != nil { return fmt.Errorf("markForRebroadcast failed: %w", err) } diff --git a/common/txmgr/mocks/tx_manager.go b/common/txmgr/mocks/tx_manager.go index a3e8c489314..974fd455903 100644 --- a/common/txmgr/mocks/tx_manager.go +++ b/common/txmgr/mocks/tx_manager.go @@ -273,9 +273,9 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTx return r0, r1 } -// GetForwarderForEOA provides a mock function with given fields: eoa -func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(eoa ADDR) (ADDR, error) { - ret := _m.Called(eoa) +// GetForwarderForEOA provides a mock function with given fields: ctx, eoa +func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(ctx context.Context, eoa ADDR) (ADDR, error) { + ret := _m.Called(ctx, eoa) if len(ret) == 0 { panic("no return value specified for GetForwarderForEOA") @@ -283,17 +283,17 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR) (ADDR, error)); ok { - return rf(eoa) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) (ADDR, error)); ok { + return rf(ctx, eoa) } - if rf, ok := ret.Get(0).(func(ADDR) ADDR); ok { - r0 = rf(eoa) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) ADDR); ok { + r0 = rf(ctx, eoa) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR) error); ok { - r1 = rf(eoa) + if rf, ok := ret.Get(1).(func(context.Context, ADDR) error); ok { + r1 = rf(ctx, eoa) } else { r1 = ret.Error(1) } @@ -301,9 +301,9 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor return r0, r1 } -// GetForwarderForEOAOCR2Feeds provides a mock function with given fields: eoa, ocr2AggregatorID -func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(eoa ADDR, ocr2AggregatorID ADDR) (ADDR, error) { - ret := _m.Called(eoa, ocr2AggregatorID) +// GetForwarderForEOAOCR2Feeds provides a mock function with given fields: ctx, eoa, ocr2AggregatorID +func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa ADDR, ocr2AggregatorID ADDR) (ADDR, error) { + ret := _m.Called(ctx, eoa, ocr2AggregatorID) if len(ret) == 0 { panic("no return value specified for GetForwarderForEOAOCR2Feeds") @@ -311,17 +311,17 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR, ADDR) (ADDR, error)); ok { - return rf(eoa, ocr2AggregatorID) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) (ADDR, error)); ok { + return rf(ctx, eoa, ocr2AggregatorID) } - if rf, ok := ret.Get(0).(func(ADDR, ADDR) ADDR); ok { - r0 = rf(eoa, ocr2AggregatorID) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) ADDR); ok { + r0 = rf(ctx, eoa, ocr2AggregatorID) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR, ADDR) error); ok { - r1 = rf(eoa, ocr2AggregatorID) + if rf, ok := ret.Get(1).(func(context.Context, ADDR, ADDR) error); ok { + r1 = rf(ctx, eoa, ocr2AggregatorID) } else { r1 = ret.Error(1) } diff --git a/common/txmgr/resender.go b/common/txmgr/resender.go index b752ec63f13..8483b7a0264 100644 --- a/common/txmgr/resender.go +++ b/common/txmgr/resender.go @@ -8,6 +8,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/chains/label" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink/v2/common/client" @@ -56,8 +57,7 @@ type Resender[ logger logger.SugaredLogger lastAlertTimestamps map[string]time.Time - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan chDone chan struct{} } @@ -83,7 +83,6 @@ func NewResender[ panic("Resender requires a non-zero threshold") } // todo: add context to txStore https://smartcontract-it.atlassian.net/browse/BCI-1585 - ctx, cancel := context.WithCancel(context.Background()) return &Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]{ txStore, client, @@ -95,8 +94,7 @@ func NewResender[ txConfig, logger.Sugared(logger.Named(lggr, "Resender")), make(map[string]time.Time), - ctx, - cancel, + make(chan struct{}), make(chan struct{}), } } @@ -109,14 +107,16 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(ctx // Stop is a comment which satisfies the linter func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Stop() { - er.cancel() + close(er.stopCh) <-er.chDone } func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() { defer close(er.chDone) + ctx, cancel := er.stopCh.NewCtx() + defer cancel() - if err := er.resendUnconfirmed(er.ctx); err != nil { + if err := er.resendUnconfirmed(ctx); err != nil { er.logger.Warnw("Failed to resend unconfirmed transactions", "err", err) } @@ -124,10 +124,10 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() defer ticker.Stop() for { select { - case <-er.ctx.Done(): + case <-ctx.Done(): return case <-ticker.C: - if err := er.resendUnconfirmed(er.ctx); err != nil { + if err := er.resendUnconfirmed(ctx); err != nil { er.logger.Warnw("Failed to resend unconfirmed transactions", "err", err) } } @@ -135,6 +135,9 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() } func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) resendUnconfirmed(ctx context.Context) error { + var cancel func() + ctx, cancel = er.stopCh.Ctx(ctx) + defer cancel() resendAddresses, err := er.ks.EnabledAddressesForChain(ctx, er.chainID) if err != nil { return fmt.Errorf("Resender failed getting enabled keys for chain %s: %w", er.chainID.String(), err) @@ -147,7 +150,7 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) resendUnco for _, k := range resendAddresses { var attempts []txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] - attempts, err = er.txStore.FindTxAttemptsRequiringResend(er.ctx, olderThan, maxInFlightTransactions, er.chainID, k) + attempts, err = er.txStore.FindTxAttemptsRequiringResend(ctx, olderThan, maxInFlightTransactions, er.chainID, k) if err != nil { return fmt.Errorf("failed to FindTxAttemptsRequiringResend: %w", err) } @@ -165,13 +168,13 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) resendUnco er.logger.Infow(fmt.Sprintf("Re-sending %d unconfirmed transactions that were last sent over %s ago. These transactions are taking longer than usual to be mined. %s", len(allAttempts), ageThreshold, label.NodeConnectivityProblemWarning), "n", len(allAttempts)) batchSize := int(er.config.RPCDefaultBatchSize()) - ctx, cancel := context.WithTimeout(er.ctx, batchSendTransactionTimeout) - defer cancel() - txErrTypes, _, broadcastTime, txIDs, err := er.client.BatchSendTransactions(ctx, allAttempts, batchSize, er.logger) + batchCtx, batchCancel := context.WithTimeout(ctx, batchSendTransactionTimeout) + defer batchCancel() + txErrTypes, _, broadcastTime, txIDs, err := er.client.BatchSendTransactions(batchCtx, allAttempts, batchSize, er.logger) // update broadcast times before checking additional errors if len(txIDs) > 0 { - if updateErr := er.txStore.UpdateBroadcastAts(er.ctx, broadcastTime, txIDs); updateErr != nil { + if updateErr := er.txStore.UpdateBroadcastAts(ctx, broadcastTime, txIDs); updateErr != nil { err = errors.Join(err, fmt.Errorf("failed to update broadcast time: %w", updateErr)) } } diff --git a/common/txmgr/test_helpers.go b/common/txmgr/test_helpers.go index dbc07861ffe..3051e0985d8 100644 --- a/common/txmgr/test_helpers.go +++ b/common/txmgr/test_helpers.go @@ -35,7 +35,9 @@ func (eb *Broadcaster[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) XXXT } func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestStartInternal() error { - return ec.startInternal(ec.ctx) + ctx, cancel := ec.stopCh.NewCtx() + defer cancel() + return ec.startInternal(ctx) } func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestCloseInternal() error { @@ -43,7 +45,9 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXX } func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestResendUnconfirmed() error { - return er.resendUnconfirmed(er.ctx) + ctx, cancel := er.stopCh.NewCtx() + defer cancel() + return er.resendUnconfirmed(ctx) } func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestAbandon(addr ADDR) (err error) { diff --git a/common/txmgr/txmgr.go b/common/txmgr/txmgr.go index 1c8b59a55cc..44b518fdaab 100644 --- a/common/txmgr/txmgr.go +++ b/common/txmgr/txmgr.go @@ -46,8 +46,8 @@ type TxManager[ services.Service Trigger(addr ADDR) CreateTransaction(ctx context.Context, txRequest txmgrtypes.TxRequest[ADDR, TX_HASH]) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) - GetForwarderForEOA(eoa ADDR) (forwarder ADDR, err error) - GetForwarderForEOAOCR2Feeds(eoa, ocr2AggregatorID ADDR) (forwarder ADDR, err error) + GetForwarderForEOA(ctx context.Context, eoa ADDR) (forwarder ADDR, err error) + GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa, ocr2AggregatorID ADDR) (forwarder ADDR, err error) RegisterResumeCallback(fn ResumeCallback) SendNativeToken(ctx context.Context, chainID CHAIN_ID, from, to ADDR, value big.Int, gasLimit uint64) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) Reset(addr ADDR, abandon bool) error @@ -546,20 +546,20 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) CreateTran } // Calls forwarderMgr to get a proper forwarder for a given EOA. -func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOA(eoa ADDR) (forwarder ADDR, err error) { +func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOA(ctx context.Context, eoa ADDR) (forwarder ADDR, err error) { if !b.txConfig.ForwardersEnabled() { return forwarder, fmt.Errorf("forwarding is not enabled, to enable set Transactions.ForwardersEnabled =true") } - forwarder, err = b.fwdMgr.ForwarderFor(eoa) + forwarder, err = b.fwdMgr.ForwarderFor(ctx, eoa) return } // GetForwarderForEOAOCR2Feeds calls forwarderMgr to get a proper forwarder for a given EOA and checks if its set as a transmitter on the OCR2Aggregator contract. -func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) { +func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) { if !b.txConfig.ForwardersEnabled() { return forwarder, fmt.Errorf("forwarding is not enabled, to enable set Transactions.ForwardersEnabled =true") } - forwarder, err = b.fwdMgr.ForwarderForOCR2Feeds(eoa, ocr2Aggregator) + forwarder, err = b.fwdMgr.ForwarderForOCR2Feeds(ctx, eoa, ocr2Aggregator) return } @@ -656,10 +656,10 @@ func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) Tri func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) CreateTransaction(ctx context.Context, txRequest txmgrtypes.TxRequest[ADDR, TX_HASH]) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) { return etx, errors.New(n.ErrMsg) } -func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(addr ADDR) (fwdr ADDR, err error) { +func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(ctx context.Context, addr ADDR) (fwdr ADDR, err error) { return fwdr, err } -func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(_, _ ADDR) (fwdr ADDR, err error) { +func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(ctx context.Context, _, _ ADDR) (fwdr ADDR, err error) { return fwdr, err } diff --git a/common/txmgr/types/forwarder_manager.go b/common/txmgr/types/forwarder_manager.go index 3e51ffb1524..6acb491a1fb 100644 --- a/common/txmgr/types/forwarder_manager.go +++ b/common/txmgr/types/forwarder_manager.go @@ -1,15 +1,18 @@ package types import ( + "context" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink/v2/common/types" ) //go:generate mockery --quiet --name ForwarderManager --output ./mocks/ --case=underscore type ForwarderManager[ADDR types.Hashable] interface { services.Service - ForwarderFor(addr ADDR) (forwarder ADDR, err error) - ForwarderForOCR2Feeds(eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) + ForwarderFor(ctx context.Context, addr ADDR) (forwarder ADDR, err error) + ForwarderForOCR2Feeds(ctx context.Context, eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) // Converts payload to be forwarder-friendly ConvertPayload(dest ADDR, origPayload []byte) ([]byte, error) } diff --git a/common/txmgr/types/mocks/forwarder_manager.go b/common/txmgr/types/mocks/forwarder_manager.go index 1021e776e9d..b2cf9bc9d35 100644 --- a/common/txmgr/types/mocks/forwarder_manager.go +++ b/common/txmgr/types/mocks/forwarder_manager.go @@ -63,9 +63,9 @@ func (_m *ForwarderManager[ADDR]) ConvertPayload(dest ADDR, origPayload []byte) return r0, r1 } -// ForwarderFor provides a mock function with given fields: addr -func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { - ret := _m.Called(addr) +// ForwarderFor provides a mock function with given fields: ctx, addr +func (_m *ForwarderManager[ADDR]) ForwarderFor(ctx context.Context, addr ADDR) (ADDR, error) { + ret := _m.Called(ctx, addr) if len(ret) == 0 { panic("no return value specified for ForwarderFor") @@ -73,17 +73,17 @@ func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR) (ADDR, error)); ok { - return rf(addr) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) (ADDR, error)); ok { + return rf(ctx, addr) } - if rf, ok := ret.Get(0).(func(ADDR) ADDR); ok { - r0 = rf(addr) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) ADDR); ok { + r0 = rf(ctx, addr) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR) error); ok { - r1 = rf(addr) + if rf, ok := ret.Get(1).(func(context.Context, ADDR) error); ok { + r1 = rf(ctx, addr) } else { r1 = ret.Error(1) } @@ -91,9 +91,9 @@ func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { return r0, r1 } -// ForwarderForOCR2Feeds provides a mock function with given fields: eoa, ocr2Aggregator -func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(eoa ADDR, ocr2Aggregator ADDR) (ADDR, error) { - ret := _m.Called(eoa, ocr2Aggregator) +// ForwarderForOCR2Feeds provides a mock function with given fields: ctx, eoa, ocr2Aggregator +func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(ctx context.Context, eoa ADDR, ocr2Aggregator ADDR) (ADDR, error) { + ret := _m.Called(ctx, eoa, ocr2Aggregator) if len(ret) == 0 { panic("no return value specified for ForwarderForOCR2Feeds") @@ -101,17 +101,17 @@ func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(eoa ADDR, ocr2Aggregator var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR, ADDR) (ADDR, error)); ok { - return rf(eoa, ocr2Aggregator) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) (ADDR, error)); ok { + return rf(ctx, eoa, ocr2Aggregator) } - if rf, ok := ret.Get(0).(func(ADDR, ADDR) ADDR); ok { - r0 = rf(eoa, ocr2Aggregator) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) ADDR); ok { + r0 = rf(ctx, eoa, ocr2Aggregator) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR, ADDR) error); ok { - r1 = rf(eoa, ocr2Aggregator) + if rf, ok := ret.Get(1).(func(context.Context, ADDR, ADDR) error); ok { + r1 = rf(ctx, eoa, ocr2Aggregator) } else { r1 = ret.Error(1) } diff --git a/core/chains/evm/abi/selector_parser.go b/core/chains/evm/abi/selector_parser.go index 30e687ba33a..329ed6eb181 100644 --- a/core/chains/evm/abi/selector_parser.go +++ b/core/chains/evm/abi/selector_parser.go @@ -1,5 +1,5 @@ -// Sourced from https://github.com/ethereum/go-ethereum/blob/fe91d476ba3e29316b6dc99b6efd4a571481d888/accounts/abi/selector_parser.go#L126 -// Modified assembleArgs to retain argument names +// Originally sourced from https://github.com/ethereum/go-ethereum/blob/fe91d476ba3e29316b6dc99b6efd4a571481d888/accounts/abi/selector_parser.go#L126 +// Modified to suppor parsing selectors with argument names // Copyright 2022 The go-ethereum Authors // This file is part of the go-ethereum library. @@ -83,55 +83,24 @@ func parseElementaryType(unescapedSelector string) (string, string, error) { return parsedType, rest, nil } -func parseCompositeType(unescapedSelector string) ([]interface{}, string, error) { +func parseCompositeType(unescapedSelector string) ([]abi.ArgumentMarshaling, string, error) { if len(unescapedSelector) == 0 || unescapedSelector[0] != '(' { return nil, "", fmt.Errorf("expected '(', got %c", unescapedSelector[0]) } - parsedType, rest, err := parseType(unescapedSelector[1:]) - if err != nil { - return nil, "", fmt.Errorf("failed to parse type: %v", err) - } - result := []interface{}{parsedType} - for len(rest) > 0 && rest[0] != ')' { - parsedType, rest, err = parseType(rest[1:]) - if err != nil { - return nil, "", fmt.Errorf("failed to parse type: %v", err) - } - result = append(result, parsedType) - } - if len(rest) == 0 || rest[0] != ')' { - return nil, "", fmt.Errorf("expected ')', got '%s'", rest) - } - if len(rest) >= 3 && rest[1] == '[' && rest[2] == ']' { - return append(result, "[]"), rest[3:], nil - } - return result, rest[1:], nil -} - -func parseType(unescapedSelector string) (interface{}, string, error) { - if len(unescapedSelector) == 0 { - return nil, "", errors.New("empty type") - } - if unescapedSelector[0] == '(' { - return parseCompositeType(unescapedSelector) - } - return parseElementaryType(unescapedSelector) -} - -func parseArgs(unescapedSelector string) ([]abi.ArgumentMarshaling, error) { - if len(unescapedSelector) == 0 || unescapedSelector[0] != '(' { - return nil, fmt.Errorf("expected '(', got %c", unescapedSelector[0]) - } + rest := unescapedSelector[1:] // skip over the opening `(` result := []abi.ArgumentMarshaling{} - rest := unescapedSelector[1:] var parsedType any var err error + i := 0 for len(rest) > 0 && rest[0] != ')' { - // parse method name - var name string - name, rest, err = parseIdentifier(rest[:]) + // skip any leading whitespace + for rest[0] == ' ' { + rest = rest[1:] + } + + parsedType, rest, err = parseType(rest[0:]) if err != nil { - return nil, fmt.Errorf("failed to parse name: %v", err) + return nil, "", fmt.Errorf("failed to parse type: %v", err) } // skip whitespace between name and identifier @@ -139,40 +108,58 @@ func parseArgs(unescapedSelector string) ([]abi.ArgumentMarshaling, error) { rest = rest[1:] } - // parse type - parsedType, rest, err = parseType(rest[:]) - if err != nil { - return nil, fmt.Errorf("failed to parse type: %v", err) + name := fmt.Sprintf("name%d", i) + // if we're at a delimiter the parameter is unnamed + if !(rest[0] == ',' || rest[0] == ')') { + // attempt to parse name + name, rest, err = parseIdentifier(rest[:]) + if err != nil { + return nil, "", fmt.Errorf("failed to parse name: %v", err) + } } arg, err := assembleArg(name, parsedType) if err != nil { - return nil, fmt.Errorf("failed to parse type: %v", err) + return nil, "", fmt.Errorf("failed to parse type: %v", err) } result = append(result, arg) + i++ + // skip trailing whitespace, consume comma for rest[0] == ' ' || rest[0] == ',' { rest = rest[1:] } } if len(rest) == 0 || rest[0] != ')' { - return nil, fmt.Errorf("expected ')', got '%s'", rest) + return nil, "", fmt.Errorf("expected ')', got '%s'", rest) } - if len(rest) > 1 { - return nil, fmt.Errorf("failed to parse selector '%s': unexpected string '%s'", unescapedSelector, rest) + if len(rest) >= 3 && rest[1] == '[' && rest[2] == ']' { + // emits a sentinel value that later gets removed when assembling + array, err := assembleArg("", "[]") + if err != nil { + panic("unreachable") + } + return append(result, array), rest[3:], nil } - return result, nil + return result, rest[1:], nil +} + +// type-name rule +func parseType(unescapedSelector string) (interface{}, string, error) { + if len(unescapedSelector) == 0 { + return nil, "", errors.New("empty type") + } + if unescapedSelector[0] == '(' { + return parseCompositeType(unescapedSelector) + } + return parseElementaryType(unescapedSelector) } func assembleArg(name string, arg any) (abi.ArgumentMarshaling, error) { if s, ok := arg.(string); ok { return abi.ArgumentMarshaling{Name: name, Type: s, InternalType: s, Components: nil, Indexed: false}, nil - } else if components, ok := arg.([]interface{}); ok { - subArgs, err := assembleArgs(components) - if err != nil { - return abi.ArgumentMarshaling{}, fmt.Errorf("failed to assemble components: %v", err) - } + } else if subArgs, ok := arg.([]abi.ArgumentMarshaling); ok { tupleType := "tuple" if len(subArgs) != 0 && subArgs[len(subArgs)-1].Type == "[]" { subArgs = subArgs[:len(subArgs)-1] @@ -183,20 +170,6 @@ func assembleArg(name string, arg any) (abi.ArgumentMarshaling, error) { return abi.ArgumentMarshaling{}, fmt.Errorf("failed to assemble args: unexpected type %T", arg) } -func assembleArgs(args []interface{}) ([]abi.ArgumentMarshaling, error) { - arguments := make([]abi.ArgumentMarshaling, 0) - for i, arg := range args { - // generate dummy name to avoid unmarshal issues - name := fmt.Sprintf("name%d", i) - arg, err := assembleArg(name, arg) - if err != nil { - return nil, err - } - arguments = append(arguments, arg) - } - return arguments, nil -} - // ParseSelector converts a method selector into a struct that can be JSON encoded // and consumed by other functions in this package. // Note, although uppercase letters are not part of the ABI spec, this function @@ -204,46 +177,19 @@ func assembleArgs(args []interface{}) ([]abi.ArgumentMarshaling, error) { func ParseSelector(unescapedSelector string) (abi.SelectorMarshaling, error) { name, rest, err := parseIdentifier(unescapedSelector) if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': %v", unescapedSelector, err) + return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector identifier '%s': %v", unescapedSelector, err) } - args := []interface{}{} + args := []abi.ArgumentMarshaling{} if len(rest) >= 2 && rest[0] == '(' && rest[1] == ')' { rest = rest[2:] } else { args, rest, err = parseCompositeType(rest) if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': %v", unescapedSelector, err) + return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector args '%s': %v", unescapedSelector, err) } } if len(rest) > 0 { return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': unexpected string '%s'", unescapedSelector, rest) } - - // Reassemble the fake ABI and construct the JSON - fakeArgs, err := assembleArgs(args) - if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector: %v", err) - } - - return abi.SelectorMarshaling{Name: name, Type: "function", Inputs: fakeArgs}, nil -} - -// ParseSelector converts a method selector into a struct that can be JSON encoded -// and consumed by other functions in this package. -// Note, although uppercase letters are not part of the ABI spec, this function -// still accepts it as the general format is valid. -func ParseSignature(unescapedSelector string) (abi.SelectorMarshaling, error) { - name, rest, err := parseIdentifier(unescapedSelector) - if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': %v", unescapedSelector, err) - } - args := []abi.ArgumentMarshaling{} - if len(rest) < 2 || rest[0] != '(' || rest[1] != ')' { - args, err = parseArgs(rest) - if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': %v", unescapedSelector, err) - } - } - return abi.SelectorMarshaling{Name: name, Type: "function", Inputs: args}, nil } diff --git a/core/chains/evm/abi/selector_parser_test.go b/core/chains/evm/abi/selector_parser_test.go index caae3744678..8ef37a193f8 100644 --- a/core/chains/evm/abi/selector_parser_test.go +++ b/core/chains/evm/abi/selector_parser_test.go @@ -25,6 +25,8 @@ import ( "testing" "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParseSelector(t *testing.T) { @@ -83,7 +85,7 @@ func TestParseSelector(t *testing.T) { } } -func TestParseSignature(t *testing.T) { +func TestParseSelectorWithNames(t *testing.T) { t.Parallel() mkType := func(name string, typeOrComponents interface{}) abi.ArgumentMarshaling { if typeName, ok := typeOrComponents.(string); ok { @@ -102,13 +104,14 @@ func TestParseSignature(t *testing.T) { args []abi.ArgumentMarshaling }{ {"noargs()", "noargs", []abi.ArgumentMarshaling{}}, - {"simple(a uint256, b uint256, c uint256)", "simple", []abi.ArgumentMarshaling{mkType("a", "uint256"), mkType("b", "uint256"), mkType("c", "uint256")}}, - {"other(foo uint256, bar address)", "other", []abi.ArgumentMarshaling{mkType("foo", "uint256"), mkType("bar", "address")}}, - {"withArray(a uint256[], b address[2], c uint8[4][][5])", "withArray", []abi.ArgumentMarshaling{mkType("a", "uint256[]"), mkType("b", "address[2]"), mkType("c", "uint8[4][][5]")}}, - {"singleNest(d bytes32, e uint8, f (uint256,uint256), g address)", "singleNest", []abi.ArgumentMarshaling{mkType("d", "bytes32"), mkType("e", "uint8"), mkType("f", []abi.ArgumentMarshaling{mkType("name0", "uint256"), mkType("name1", "uint256")}), mkType("g", "address")}}, + {"simple(uint256 a , uint256 b, uint256 c)", "simple", []abi.ArgumentMarshaling{mkType("a", "uint256"), mkType("b", "uint256"), mkType("c", "uint256")}}, + {"other(uint256 foo, address bar )", "other", []abi.ArgumentMarshaling{mkType("foo", "uint256"), mkType("bar", "address")}}, + {"withArray(uint256[] a, address[2] b, uint8[4][][5] c)", "withArray", []abi.ArgumentMarshaling{mkType("a", "uint256[]"), mkType("b", "address[2]"), mkType("c", "uint8[4][][5]")}}, + {"singleNest(bytes32 d, uint8 e, (uint256,uint256) f, address g)", "singleNest", []abi.ArgumentMarshaling{mkType("d", "bytes32"), mkType("e", "uint8"), mkType("f", []abi.ArgumentMarshaling{mkType("name0", "uint256"), mkType("name1", "uint256")}), mkType("g", "address")}}, + {"singleNest(bytes32 d, uint8 e, (uint256 first, uint256 second ) f, address g)", "singleNest", []abi.ArgumentMarshaling{mkType("d", "bytes32"), mkType("e", "uint8"), mkType("f", []abi.ArgumentMarshaling{mkType("first", "uint256"), mkType("second", "uint256")}), mkType("g", "address")}}, } for i, tt := range tests { - selector, err := ParseSignature(tt.input) + selector, err := ParseSelector(tt.input) if err != nil { t.Errorf("test %d: failed to parse selector '%v': %v", i, tt.input, err) } @@ -124,3 +127,35 @@ func TestParseSignature(t *testing.T) { } } } + +func TestParseSelectorErrors(t *testing.T) { + type errorTestCases struct { + description string + input string + expectedError string + } + + for _, scenario := range []errorTestCases{ + { + description: "invalid name", + input: "123()", + expectedError: "failed to parse selector identifier '123()': invalid token start: 1", + }, + { + description: "missing closing parenthesis", + input: "noargs(", + expectedError: "failed to parse selector args 'noargs(': expected ')', got ''", + }, + { + description: "missing opening parenthesis", + input: "noargs)", + expectedError: "failed to parse selector args 'noargs)': expected '(', got )", + }, + } { + t.Run(scenario.description, func(t *testing.T) { + _, err := ParseSelector(scenario.input) + require.Error(t, err) + assert.Equal(t, scenario.expectedError, err.Error()) + }) + } +} diff --git a/core/chains/evm/client/chain_client.go b/core/chains/evm/client/chain_client.go index 3ee10a600da..04b1ff29387 100644 --- a/core/chains/evm/client/chain_client.go +++ b/core/chains/evm/client/chain_client.go @@ -5,8 +5,6 @@ import ( "math/big" "time" - evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" - "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -14,10 +12,10 @@ import ( commonassets "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/logger" - commonclient "github.com/smartcontractkit/chainlink/v2/common/client" "github.com/smartcontractkit/chainlink/v2/common/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" + evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) diff --git a/core/chains/evm/client/config_builder.go b/core/chains/evm/client/config_builder.go index d78a981b881..9817879b579 100644 --- a/core/chains/evm/client/config_builder.go +++ b/core/chains/evm/client/config_builder.go @@ -8,6 +8,7 @@ import ( "go.uber.org/multierr" commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink/v2/common/config" commonclient "github.com/smartcontractkit/chainlink/v2/common/client" evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" @@ -55,7 +56,7 @@ func NewClientConfigs( chainConfig := &evmconfig.EVMConfig{ C: &toml.EVMConfig{ Chain: toml.Chain{ - ChainType: &chainType, + ChainType: config.NewChainTypeConfig(chainType), FinalityDepth: finalityDepth, FinalityTagEnabled: finalityTagEnabled, NoNewHeadsThreshold: commonconfig.MustNewDuration(noNewHeadsThreshold), diff --git a/core/chains/evm/client/errors.go b/core/chains/evm/client/errors.go index b7e0d9317eb..bcc8ff961f0 100644 --- a/core/chains/evm/client/errors.go +++ b/core/chains/evm/client/errors.go @@ -6,6 +6,8 @@ import ( "fmt" "regexp" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" pkgerrors "github.com/pkg/errors" @@ -13,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" commonclient "github.com/smartcontractkit/chainlink/v2/common/client" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/label" ) diff --git a/core/chains/evm/client/node.go b/core/chains/evm/client/node.go index 92b7a8301e5..968cb34b9fe 100644 --- a/core/chains/evm/client/node.go +++ b/core/chains/evm/client/node.go @@ -166,10 +166,8 @@ type node struct { // this node. Closing and replacing should be serialized through // stateMu since it can happen on state transitions as well as node Close. chStopInFlight chan struct{} - // nodeCtx is the node lifetime's context - nodeCtx context.Context - // cancelNodeCtx cancels nodeCtx when stopping the node - cancelNodeCtx context.CancelFunc + + stopCh services.StopChan // wg waits for subsidiary goroutines wg sync.WaitGroup @@ -196,7 +194,7 @@ func NewNode(nodeCfg config.NodePool, noNewHeadsThreshold time.Duration, lggr lo n.http = &rawclient{uri: *httpuri} } n.chStopInFlight = make(chan struct{}) - n.nodeCtx, n.cancelNodeCtx = context.WithCancel(context.Background()) + n.stopCh = make(chan struct{}) lggr = logger.Named(lggr, "Node") lggr = logger.With(lggr, "nodeTier", "primary", @@ -367,7 +365,7 @@ func (n *node) Close() error { n.stateMu.Lock() defer n.stateMu.Unlock() - n.cancelNodeCtx() + close(n.stopCh) n.cancelInflightRequests() n.state = NodeStateClosed return nil diff --git a/core/chains/evm/client/node_lifecycle.go b/core/chains/evm/client/node_lifecycle.go index 41add532222..c18c8032009 100644 --- a/core/chains/evm/client/node_lifecycle.go +++ b/core/chains/evm/client/node_lifecycle.go @@ -75,6 +75,8 @@ const ( // Should only be run ONCE per node, after a successful Dial func (n *node) aliveLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -96,7 +98,7 @@ func (n *node) aliveLoop() { lggr.Tracew("Alive loop starting", "nodeState", n.State()) headsC := make(chan *evmtypes.Head) - sub, err := n.EthSubscribe(n.nodeCtx, headsC, "newHeads") + sub, err := n.EthSubscribe(ctx, headsC, "newHeads") if err != nil { lggr.Errorw("Initial subscribe for heads failed", "nodeState", n.State()) n.declareUnreachable() @@ -137,18 +139,19 @@ func (n *node) aliveLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-pollCh: - var version string promEVMPoolRPCNodePolls.WithLabelValues(n.chainID.String(), n.name).Inc() lggr.Tracew("Polling for version", "nodeState", n.State(), "pollFailures", pollFailures) - ctx, cancel := context.WithTimeout(n.nodeCtx, pollInterval) - ctx, cancel2 := n.makeQueryCtx(ctx) - err := n.CallContext(ctx, &version, "web3_clientVersion") - cancel2() - cancel() - if err != nil { + var version string + if err := func(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, pollInterval) + defer cancel() + ctx, cancel2 := n.makeQueryCtx(ctx) + defer cancel2() + return n.CallContext(ctx, &version, "web3_clientVersion") + }(ctx); err != nil { // prevent overflow if pollFailures < math.MaxUint32 { promEVMPoolRPCNodePollsFailed.WithLabelValues(n.chainID.String(), n.name).Inc() @@ -262,6 +265,8 @@ const ( // outOfSyncLoop takes an OutOfSync node and waits until isOutOfSync returns false to go back to live status func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -281,14 +286,14 @@ func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { lggr.Debugw("Trying to revive out-of-sync RPC node", "nodeState", n.State()) // Need to redial since out-of-sync nodes are automatically disconnected - if err := n.dial(n.nodeCtx); err != nil { + if err := n.dial(ctx); err != nil { lggr.Errorw("Failed to dial out-of-sync RPC node", "nodeState", n.State()) n.declareUnreachable() return } // Manually re-verify since out-of-sync nodes are automatically disconnected - if err := n.verify(n.nodeCtx); err != nil { + if err := n.verify(ctx); err != nil { lggr.Errorw(fmt.Sprintf("Failed to verify out-of-sync RPC node: %v", err), "err", err) n.declareInvalidChainID() return @@ -297,7 +302,7 @@ func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { lggr.Tracew("Successfully subscribed to heads feed on out-of-sync RPC node", "nodeState", n.State()) ch := make(chan *evmtypes.Head) - subCtx, cancel := n.makeQueryCtx(n.nodeCtx) + subCtx, cancel := n.makeQueryCtx(ctx) // raw call here to bypass node state checking sub, err := n.ws.rpc.EthSubscribe(subCtx, ch, "newHeads") cancel() @@ -310,7 +315,7 @@ func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case head, open := <-ch: if !open { @@ -344,6 +349,8 @@ func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { func (n *node) unreachableLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -366,12 +373,12 @@ func (n *node) unreachableLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(dialRetryBackoff.Duration()): lggr.Tracew("Trying to re-dial RPC node", "nodeState", n.State()) - err := n.dial(n.nodeCtx) + err := n.dial(ctx) if err != nil { lggr.Errorw(fmt.Sprintf("Failed to redial RPC node; still unreachable: %v", err), "err", err, "nodeState", n.State()) continue @@ -379,7 +386,7 @@ func (n *node) unreachableLoop() { n.setState(NodeStateDialed) - err = n.verify(n.nodeCtx) + err = n.verify(ctx) if pkgerrors.Is(err, errInvalidChainID) { lggr.Errorw("Failed to redial RPC node; remote endpoint returned the wrong chain ID", "err", err) @@ -400,6 +407,8 @@ func (n *node) unreachableLoop() { func (n *node) invalidChainIDLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -422,10 +431,10 @@ func (n *node) invalidChainIDLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(chainIDRecheckBackoff.Duration()): - err := n.verify(n.nodeCtx) + err := n.verify(ctx) if pkgerrors.Is(err, errInvalidChainID) { lggr.Errorw("Failed to verify RPC node; remote endpoint returned the wrong chain ID", "err", err) continue diff --git a/core/chains/evm/client/pool_test.go b/core/chains/evm/client/pool_test.go index 5f614b7ed24..5a2c13130d3 100644 --- a/core/chains/evm/client/pool_test.go +++ b/core/chains/evm/client/pool_test.go @@ -169,7 +169,6 @@ func TestPool_Dial(t *testing.T) { if err == nil { t.Cleanup(func() { assert.NoError(t, p.Close()) }) } - assert.True(t, p.ChainType().IsValid()) assert.False(t, p.ChainType().IsL2()) if test.errStr != "" { require.Error(t, err) @@ -333,7 +332,6 @@ func TestUnit_Pool_BatchCallContextAll(t *testing.T) { p := evmclient.NewPool(logger.Test(t), defaultConfig.NodeSelectionMode(), defaultConfig.LeaseDuration(), time.Second*0, nodes, sendonlys, &cltest.FixtureChainID, "") - assert.True(t, p.ChainType().IsValid()) assert.False(t, p.ChainType().IsL2()) require.NoError(t, p.BatchCallContextAll(ctx, b)) } diff --git a/core/chains/evm/config/chain_scoped.go b/core/chains/evm/config/chain_scoped.go index 8f94fef09f4..17d4120ddf6 100644 --- a/core/chains/evm/config/chain_scoped.go +++ b/core/chains/evm/config/chain_scoped.go @@ -128,7 +128,7 @@ func (e *EVMConfig) ChainType() commonconfig.ChainType { if e.C.ChainType == nil { return "" } - return commonconfig.ChainType(*e.C.ChainType) + return e.C.ChainType.ChainType() } func (e *EVMConfig) ChainID() *big.Int { diff --git a/core/chains/evm/config/chain_scoped_node_pool.go b/core/chains/evm/config/chain_scoped_node_pool.go index 7f071e5506f..50269366829 100644 --- a/core/chains/evm/config/chain_scoped_node_pool.go +++ b/core/chains/evm/config/chain_scoped_node_pool.go @@ -38,4 +38,6 @@ func (n *NodePoolConfig) FinalizedBlockPollInterval() time.Duration { return n.C.FinalizedBlockPollInterval.Duration() } -func (n *NodePoolConfig) Errors() ClientErrors { return &clientErrorsConfig{c: n.C.Errors} } +func (n *NodePoolConfig) Errors() ClientErrors { + return &clientErrorsConfig{c: n.C.Errors} +} diff --git a/core/chains/evm/config/config_test.go b/core/chains/evm/config/config_test.go index 9553f59ad61..ddf9817958d 100644 --- a/core/chains/evm/config/config_test.go +++ b/core/chains/evm/config/config_test.go @@ -406,7 +406,7 @@ func Test_chainScopedConfig_Validate(t *testing.T) { t.Run("arbitrum-estimator", func(t *testing.T) { t.Run("custom", func(t *testing.T) { cfg := configWithChains(t, 0, &toml.Chain{ - ChainType: ptr(string(commonconfig.ChainArbitrum)), + ChainType: commonconfig.NewChainTypeConfig(string(commonconfig.ChainArbitrum)), GasEstimator: toml.GasEstimator{ Mode: ptr("BlockHistory"), }, @@ -437,7 +437,7 @@ func Test_chainScopedConfig_Validate(t *testing.T) { t.Run("optimism-estimator", func(t *testing.T) { t.Run("custom", func(t *testing.T) { cfg := configWithChains(t, 0, &toml.Chain{ - ChainType: ptr(string(commonconfig.ChainOptimismBedrock)), + ChainType: commonconfig.NewChainTypeConfig(string(commonconfig.ChainOptimismBedrock)), GasEstimator: toml.GasEstimator{ Mode: ptr("BlockHistory"), }, diff --git a/core/chains/evm/config/toml/config.go b/core/chains/evm/config/toml/config.go index b747dc641fd..38fa54f521f 100644 --- a/core/chains/evm/config/toml/config.go +++ b/core/chains/evm/config/toml/config.go @@ -294,18 +294,14 @@ func (c *EVMConfig) ValidateConfig() (err error) { } else if c.ChainID.String() == "" { err = multierr.Append(err, commonconfig.ErrEmpty{Name: "ChainID", Msg: "required for all chains"}) } else if must, ok := ChainTypeForID(c.ChainID); ok { // known chain id - if c.ChainType == nil && must != "" { - err = multierr.Append(err, commonconfig.ErrMissing{Name: "ChainType", - Msg: fmt.Sprintf("only %q can be used with this chain id", must)}) - } else if c.ChainType != nil && *c.ChainType != string(must) { - if *c.ChainType == "" { - err = multierr.Append(err, commonconfig.ErrEmpty{Name: "ChainType", - Msg: fmt.Sprintf("only %q can be used with this chain id", must)}) - } else if must == "" { - err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: *c.ChainType, + // Check if the parsed value matched the expected value + is := c.ChainType.ChainType() + if is != must { + if must == "" { + err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: c.ChainType.ChainType(), Msg: "must not be set with this chain id"}) } else { - err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: *c.ChainType, + err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: c.ChainType.ChainType(), Msg: fmt.Sprintf("only %q can be used with this chain id", must)}) } } @@ -345,7 +341,7 @@ type Chain struct { AutoCreateKey *bool BlockBackfillDepth *uint32 BlockBackfillSkip *bool - ChainType *string + ChainType *config.ChainTypeConfig FinalityDepth *uint32 FinalityTagEnabled *bool FlagsContractAddress *types.EIP55Address @@ -375,12 +371,8 @@ type Chain struct { } func (c *Chain) ValidateConfig() (err error) { - var chainType config.ChainType - if c.ChainType != nil { - chainType = config.ChainType(*c.ChainType) - } - if !chainType.IsValid() { - err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: *c.ChainType, + if !c.ChainType.ChainType().IsValid() { + err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: c.ChainType.ChainType(), Msg: config.ErrInvalidChainType.Error()}) } diff --git a/core/chains/evm/config/toml/defaults.go b/core/chains/evm/config/toml/defaults.go index 951246eeb22..622ac132e13 100644 --- a/core/chains/evm/config/toml/defaults.go +++ b/core/chains/evm/config/toml/defaults.go @@ -94,10 +94,7 @@ func Defaults(chainID *big.Big, with ...*Chain) Chain { func ChainTypeForID(chainID *big.Big) (config.ChainType, bool) { s := chainID.String() if d, ok := defaults[s]; ok { - if d.ChainType == nil { - return "", true - } - return config.ChainType(*d.ChainType), true + return d.ChainType.ChainType(), true } return "", false } diff --git a/core/chains/evm/forwarders/forwarder_manager.go b/core/chains/evm/forwarders/forwarder_manager.go index 9505cdfbbbf..fca35274708 100644 --- a/core/chains/evm/forwarders/forwarder_manager.go +++ b/core/chains/evm/forwarders/forwarder_manager.go @@ -49,8 +49,7 @@ type FwdMgr struct { authRcvr authorized_receiver.AuthorizedReceiverInterface offchainAgg offchain_aggregator_wrapper.OffchainAggregatorInterface - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan cacheMu sync.RWMutex wg sync.WaitGroup @@ -66,7 +65,7 @@ func NewFwdMgr(ds sqlutil.DataSource, client evmclient.Client, logpoller evmlogp logpoller: logpoller, sendersCache: make(map[common.Address][]common.Address), } - fwdMgr.ctx, fwdMgr.cancel = context.WithCancel(context.Background()) + fwdMgr.stopCh = make(chan struct{}) return &fwdMgr } @@ -86,7 +85,7 @@ func (f *FwdMgr) Start(ctx context.Context) error { } if len(fwdrs) != 0 { f.initForwardersCache(ctx, fwdrs) - if err = f.subscribeForwardersLogs(fwdrs); err != nil { + if err = f.subscribeForwardersLogs(ctx, fwdrs); err != nil { return err } } @@ -111,15 +110,15 @@ func FilterName(addr common.Address) string { return evmlogpoller.FilterName("ForwarderManager AuthorizedSendersChanged", addr.String()) } -func (f *FwdMgr) ForwarderFor(addr common.Address) (forwarder common.Address, err error) { +func (f *FwdMgr) ForwarderFor(ctx context.Context, addr common.Address) (forwarder common.Address, err error) { // Gets forwarders for current chain. - fwdrs, err := f.ORM.FindForwardersByChain(f.ctx, big.Big(*f.evmClient.ConfiguredChainID())) + fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*f.evmClient.ConfiguredChainID())) if err != nil { return common.Address{}, err } for _, fwdr := range fwdrs { - eoas, err := f.getContractSenders(fwdr.Address) + eoas, err := f.getContractSenders(ctx, fwdr.Address) if err != nil { f.logger.Errorw("Failed to get forwarder senders", "forwarder", fwdr.Address, "err", err) continue @@ -133,8 +132,8 @@ func (f *FwdMgr) ForwarderFor(addr common.Address) (forwarder common.Address, er return common.Address{}, pkgerrors.Errorf("Cannot find forwarder for given EOA") } -func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forwarder common.Address, err error) { - fwdrs, err := f.ORM.FindForwardersByChain(f.ctx, big.Big(*f.evmClient.ConfiguredChainID())) +func (f *FwdMgr) ForwarderForOCR2Feeds(ctx context.Context, eoa, ocr2Aggregator common.Address) (forwarder common.Address, err error) { + fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*f.evmClient.ConfiguredChainID())) if err != nil { return common.Address{}, err } @@ -144,7 +143,7 @@ func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forw return common.Address{}, err } - transmitters, err := offchainAggregator.GetTransmitters(&bind.CallOpts{Context: f.ctx}) + transmitters, err := offchainAggregator.GetTransmitters(&bind.CallOpts{Context: ctx}) if err != nil { return common.Address{}, pkgerrors.Errorf("failed to get ocr2 aggregator transmitters: %s", err.Error()) } @@ -155,7 +154,7 @@ func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forw continue } - eoas, err := f.getContractSenders(fwdr.Address) + eoas, err := f.getContractSenders(ctx, fwdr.Address) if err != nil { f.logger.Errorw("Failed to get forwarder senders", "forwarder", fwdr.Address, "err", err) continue @@ -191,16 +190,16 @@ func (f *FwdMgr) getForwardedPayload(dest common.Address, origPayload []byte) ([ return dataBytes, nil } -func (f *FwdMgr) getContractSenders(addr common.Address) ([]common.Address, error) { +func (f *FwdMgr) getContractSenders(ctx context.Context, addr common.Address) ([]common.Address, error) { if senders, ok := f.getCachedSenders(addr); ok { return senders, nil } - senders, err := f.getAuthorizedSenders(f.ctx, addr) + senders, err := f.getAuthorizedSenders(ctx, addr) if err != nil { return nil, pkgerrors.Wrapf(err, "Failed to call getAuthorizedSenders on %s", addr) } f.setCachedSenders(addr, senders) - if err = f.subscribeSendersChangedLogs(addr); err != nil { + if err = f.subscribeSendersChangedLogs(ctx, addr); err != nil { return nil, err } return senders, nil @@ -230,23 +229,23 @@ func (f *FwdMgr) initForwardersCache(ctx context.Context, fwdrs []Forwarder) { } } -func (f *FwdMgr) subscribeForwardersLogs(fwdrs []Forwarder) error { +func (f *FwdMgr) subscribeForwardersLogs(ctx context.Context, fwdrs []Forwarder) error { for _, fwdr := range fwdrs { - if err := f.subscribeSendersChangedLogs(fwdr.Address); err != nil { + if err := f.subscribeSendersChangedLogs(ctx, fwdr.Address); err != nil { return err } } return nil } -func (f *FwdMgr) subscribeSendersChangedLogs(addr common.Address) error { +func (f *FwdMgr) subscribeSendersChangedLogs(ctx context.Context, addr common.Address) error { if err := f.logpoller.Ready(); err != nil { f.logger.Warnw("Unable to subscribe to AuthorizedSendersChanged logs", "forwarder", addr, "err", err) return nil } err := f.logpoller.RegisterFilter( - f.ctx, + ctx, evmlogpoller.Filter{ Name: FilterName(addr), EventSigs: []common.Hash{authChangedTopic}, @@ -270,8 +269,10 @@ func (f *FwdMgr) getCachedSenders(addr common.Address) ([]common.Address, bool) func (f *FwdMgr) runLoop() { defer f.wg.Done() - tick := time.After(0) + ctx, cancel := f.stopCh.NewCtx() + defer cancel() + tick := time.After(0) for ; ; tick = time.After(utils.WithJitter(time.Minute)) { select { case <-tick: @@ -287,7 +288,7 @@ func (f *FwdMgr) runLoop() { } logs, err := f.logpoller.LatestLogEventSigsAddrsWithConfs( - f.ctx, + ctx, f.latestBlock, []common.Hash{authChangedTopic}, addrs, @@ -308,7 +309,7 @@ func (f *FwdMgr) runLoop() { } } - case <-f.ctx.Done(): + case <-ctx.Done(): return } } @@ -352,7 +353,7 @@ func (f *FwdMgr) collectAddresses() (addrs []common.Address) { // Stop cancels all outgoings calls and stops internal ticker loop. func (f *FwdMgr) Close() error { return f.StopOnce("EVMForwarderManager", func() (err error) { - f.cancel() + close(f.stopCh) f.wg.Wait() return nil }) diff --git a/core/chains/evm/forwarders/forwarder_manager_test.go b/core/chains/evm/forwarders/forwarder_manager_test.go index 993efacac4a..020446aa547 100644 --- a/core/chains/evm/forwarders/forwarder_manager_test.go +++ b/core/chains/evm/forwarders/forwarder_manager_test.go @@ -18,6 +18,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/testhelpers" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" @@ -86,7 +87,7 @@ func TestFwdMgr_MaybeForwardTransaction(t *testing.T) { require.Equal(t, lst[0].Address, forwarderAddr) require.NoError(t, fwdMgr.Start(testutils.Context(t))) - addr, err := fwdMgr.ForwarderFor(owner.From) + addr, err := fwdMgr.ForwarderFor(ctx, owner.From) require.NoError(t, err) require.Equal(t, addr.String(), forwarderAddr.String()) err = fwdMgr.Close() @@ -148,7 +149,7 @@ func TestFwdMgr_AccountUnauthorizedToForward_SkipsForwarding(t *testing.T) { err = fwdMgr.Start(testutils.Context(t)) require.NoError(t, err) - addr, err := fwdMgr.ForwarderFor(owner.From) + addr, err := fwdMgr.ForwarderFor(ctx, owner.From) require.ErrorContains(t, err, "Cannot find forwarder for given EOA") require.True(t, utils.IsZero(addr)) err = fwdMgr.Close() @@ -214,7 +215,7 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { fwdMgr = forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM()) require.NoError(t, fwdMgr.Start(testutils.Context(t))) // cannot find forwarder because it isn't authorized nor added as a transmitter - addr, err := fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + addr, err := fwdMgr.ForwarderForOCR2Feeds(ctx, owner.From, ocr2Address) require.ErrorContains(t, err, "Cannot find forwarder for given EOA") require.True(t, utils.IsZero(addr)) @@ -227,7 +228,7 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { require.Equal(t, owner.From, authorizedSenders[0]) // cannot find forwarder because it isn't added as a transmitter - addr, err = fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + addr, err = fwdMgr.ForwarderForOCR2Feeds(ctx, owner.From, ocr2Address) require.ErrorContains(t, err, "Cannot find forwarder for given EOA") require.True(t, utils.IsZero(addr)) @@ -251,7 +252,7 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { // create new fwd to have an empty cache that has to fetch authorized forwarders from log poller fwdMgr = forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM()) require.NoError(t, fwdMgr.Start(testutils.Context(t))) - addr, err = fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + addr, err = fwdMgr.ForwarderForOCR2Feeds(ctx, owner.From, ocr2Address) require.NoError(t, err, "forwarder should be valid and found because it is both authorized and set as a transmitter") require.Equal(t, forwarderAddr, addr) require.NoError(t, fwdMgr.Close()) diff --git a/core/chains/evm/gas/block_history_estimator.go b/core/chains/evm/gas/block_history_estimator.go index 82f1c46fff8..f6d15f7aff7 100644 --- a/core/chains/evm/gas/block_history_estimator.go +++ b/core/chains/evm/gas/block_history_estimator.go @@ -104,13 +104,12 @@ type BlockHistoryEstimator struct { bhConfig BlockHistoryConfig // NOTE: it is assumed that blocks will be kept sorted by // block number ascending - blocks []evmtypes.Block - blocksMu sync.RWMutex - size int64 - mb *mailbox.Mailbox[*evmtypes.Head] - wg *sync.WaitGroup - ctx context.Context - ctxCancel context.CancelFunc + blocks []evmtypes.Block + blocksMu sync.RWMutex + size int64 + mb *mailbox.Mailbox[*evmtypes.Head] + wg *sync.WaitGroup + stopCh services.StopChan gasPrice *assets.Wei tipCap *assets.Wei @@ -128,9 +127,7 @@ type BlockHistoryEstimator struct { // for new heads and updates the base gas price dynamically based on the // configured percentile of gas prices in that block func NewBlockHistoryEstimator(lggr logger.Logger, ethClient feeEstimatorClient, cfg chainConfig, eCfg estimatorGasEstimatorConfig, bhCfg BlockHistoryConfig, chainID *big.Int, l1Oracle rollups.L1Oracle) EvmEstimator { - ctx, cancel := context.WithCancel(context.Background()) - - b := &BlockHistoryEstimator{ + return &BlockHistoryEstimator{ ethClient: ethClient, chainID: chainID, config: cfg, @@ -138,16 +135,13 @@ func NewBlockHistoryEstimator(lggr logger.Logger, ethClient feeEstimatorClient, bhConfig: bhCfg, blocks: make([]evmtypes.Block, 0), // Must have enough blocks for both estimator and connectivity checker - size: int64(mathutil.Max(bhCfg.BlockHistorySize(), bhCfg.CheckInclusionBlocks())), - mb: mailbox.NewSingle[*evmtypes.Head](), - wg: new(sync.WaitGroup), - ctx: ctx, - ctxCancel: cancel, - logger: logger.Sugared(logger.Named(lggr, "BlockHistoryEstimator")), - l1Oracle: l1Oracle, + size: int64(mathutil.Max(bhCfg.BlockHistorySize(), bhCfg.CheckInclusionBlocks())), + mb: mailbox.NewSingle[*evmtypes.Head](), + wg: new(sync.WaitGroup), + stopCh: make(chan struct{}), + logger: logger.Sugared(logger.Named(lggr, "BlockHistoryEstimator")), + l1Oracle: l1Oracle, } - - return b } // OnNewLongestChain recalculates and sets global gas price if a sampled new head comes @@ -240,7 +234,7 @@ func (b *BlockHistoryEstimator) L1Oracle() rollups.L1Oracle { func (b *BlockHistoryEstimator) Close() error { return b.StopOnce("BlockHistoryEstimator", func() error { - b.ctxCancel() + close(b.stopCh) b.wg.Wait() return nil }) @@ -482,9 +476,12 @@ func (b *BlockHistoryEstimator) BumpDynamicFee(_ context.Context, originalFee Dy func (b *BlockHistoryEstimator) runLoop() { defer b.wg.Done() + ctx, cancel := b.stopCh.NewCtx() + defer cancel() + for { select { - case <-b.ctx.Done(): + case <-ctx.Done(): return case <-b.mb.Notify(): head, exists := b.mb.Retrieve() @@ -492,7 +489,7 @@ func (b *BlockHistoryEstimator) runLoop() { b.logger.Debug("No head to retrieve") continue } - b.FetchBlocksAndRecalculate(b.ctx, head) + b.FetchBlocksAndRecalculate(ctx, head) } } } diff --git a/core/chains/evm/gas/block_history_estimator_test.go b/core/chains/evm/gas/block_history_estimator_test.go index 730bfcab7e1..b38cd069c69 100644 --- a/core/chains/evm/gas/block_history_estimator_test.go +++ b/core/chains/evm/gas/block_history_estimator_test.go @@ -994,11 +994,6 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { bhe.Recalculate(testutils.Head(0)) require.Equal(t, assets.NewWeiI(80), gas.GetGasPrice(bhe)) - // Same for xDai (deprecated) - cfg.ChainTypeF = string(config.ChainXDai) - bhe.Recalculate(testutils.Head(0)) - require.Equal(t, assets.NewWeiI(80), gas.GetGasPrice(bhe)) - // And for X Layer cfg.ChainTypeF = string(config.ChainXLayer) bhe.Recalculate(testutils.Head(0)) diff --git a/core/chains/evm/gas/chain_specific.go b/core/chains/evm/gas/chain_specific.go index 694411f164b..f9985a6fafc 100644 --- a/core/chains/evm/gas/chain_specific.go +++ b/core/chains/evm/gas/chain_specific.go @@ -9,7 +9,7 @@ import ( // chainSpecificIsUsable allows for additional logic specific to a particular // Config that determines whether a transaction should be used for gas estimation func chainSpecificIsUsable(tx evmtypes.Transaction, baseFee *assets.Wei, chainType config.ChainType, minGasPriceWei *assets.Wei) bool { - if chainType == config.ChainGnosis || chainType == config.ChainXDai || chainType == config.ChainXLayer { + if chainType == config.ChainGnosis || chainType == config.ChainXLayer { // GasPrice 0 on most chains is great since it indicates cheap/free transactions. // However, Gnosis and XLayer reserve a special type of "bridge" transaction with 0 gas // price that is always processed at top priority. Ordinary transactions diff --git a/core/chains/evm/logpoller/log_poller.go b/core/chains/evm/logpoller/log_poller.go index e964b86221c..26978b18d48 100644 --- a/core/chains/evm/logpoller/log_poller.go +++ b/core/chains/evm/logpoller/log_poller.go @@ -119,8 +119,7 @@ type logPoller struct { replayStart chan int64 replayComplete chan error - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan wg sync.WaitGroup // This flag is raised whenever the log poller detects that the chain's finality has been violated. // It can happen when reorg is deeper than the latest finalized block that LogPoller saw in a previous PollAndSave tick. @@ -152,10 +151,8 @@ type Opts struct { // How fast that can be done depends largely on network speed and DB, but even for the fastest // support chain, polygon, which has 2s block times, we need RPCs roughly with <= 500ms latency func NewLogPoller(orm ORM, ec Client, lggr logger.Logger, opts Opts) *logPoller { - ctx, cancel := context.WithCancel(context.Background()) return &logPoller{ - ctx: ctx, - cancel: cancel, + stopCh: make(chan struct{}), ec: ec, orm: orm, lggr: logger.Sugared(logger.Named(lggr, "LogPoller")), @@ -465,21 +462,23 @@ func (lp *logPoller) savedFinalizedBlockNumber(ctx context.Context) (int64, erro } func (lp *logPoller) recvReplayComplete() { + defer lp.wg.Done() err := <-lp.replayComplete if err != nil { lp.lggr.Error(err) } - lp.wg.Done() } // Asynchronous wrapper for Replay() func (lp *logPoller) ReplayAsync(fromBlock int64) { lp.wg.Add(1) go func() { - if err := lp.Replay(lp.ctx, fromBlock); err != nil { + defer lp.wg.Done() + ctx, cancel := lp.stopCh.NewCtx() + defer cancel() + if err := lp.Replay(ctx, fromBlock); err != nil { lp.lggr.Error(err) } - lp.wg.Done() }() } @@ -498,7 +497,7 @@ func (lp *logPoller) Close() error { case lp.replayComplete <- ErrLogPollerShutdown: default: } - lp.cancel() + close(lp.stopCh) lp.wg.Wait() return nil }) @@ -535,10 +534,10 @@ func (lp *logPoller) GetReplayFromBlock(ctx context.Context, requested int64) (i return mathutil.Min(requested, lastProcessed.BlockNumber), nil } -func (lp *logPoller) loadFilters() error { +func (lp *logPoller) loadFilters(ctx context.Context) error { lp.filterMu.Lock() defer lp.filterMu.Unlock() - filters, err := lp.orm.LoadFilters(lp.ctx) + filters, err := lp.orm.LoadFilters(ctx) if err != nil { return pkgerrors.Wrapf(err, "Failed to load initial filters from db, retrying") @@ -551,6 +550,8 @@ func (lp *logPoller) loadFilters() error { func (lp *logPoller) run() { defer lp.wg.Done() + ctx, cancel := lp.stopCh.NewCtx() + defer cancel() logPollTick := time.After(0) // stagger these somewhat, so they don't all run back-to-back backupLogPollTick := time.After(100 * time.Millisecond) @@ -558,14 +559,14 @@ func (lp *logPoller) run() { for { select { - case <-lp.ctx.Done(): + case <-ctx.Done(): return case fromBlockReq := <-lp.replayStart: - lp.handleReplayRequest(fromBlockReq, filtersLoaded) + lp.handleReplayRequest(ctx, fromBlockReq, filtersLoaded) case <-logPollTick: logPollTick = time.After(utils.WithJitter(lp.pollPeriod)) if !filtersLoaded { - if err := lp.loadFilters(); err != nil { + if err := lp.loadFilters(ctx); err != nil { lp.lggr.Errorw("Failed loading filters in main logpoller loop, retrying later", "err", err) continue } @@ -574,7 +575,7 @@ func (lp *logPoller) run() { // Always start from the latest block in the db. var start int64 - lastProcessed, err := lp.orm.SelectLatestBlock(lp.ctx) + lastProcessed, err := lp.orm.SelectLatestBlock(ctx) if err != nil { if !pkgerrors.Is(err, sql.ErrNoRows) { // Assume transient db reading issue, retry forever. @@ -583,7 +584,7 @@ func (lp *logPoller) run() { } // Otherwise this is the first poll _ever_ on a new chain. // Only safe thing to do is to start at the first finalized block. - latestBlock, latestFinalizedBlockNumber, err := lp.latestBlocks(lp.ctx) + latestBlock, latestFinalizedBlockNumber, err := lp.latestBlocks(ctx) if err != nil { lp.lggr.Warnw("Unable to get latest for first poll", "err", err) continue @@ -600,7 +601,7 @@ func (lp *logPoller) run() { } else { start = lastProcessed.BlockNumber + 1 } - lp.PollAndSaveLogs(lp.ctx, start) + lp.PollAndSaveLogs(ctx, start) case <-backupLogPollTick: if lp.backupPollerBlockDelay == 0 { continue // backup poller is disabled @@ -618,13 +619,15 @@ func (lp *logPoller) run() { lp.lggr.Warnw("Backup log poller ran before filters loaded, skipping") continue } - lp.BackupPollAndSaveLogs(lp.ctx) + lp.BackupPollAndSaveLogs(ctx) } } } func (lp *logPoller) backgroundWorkerRun() { defer lp.wg.Done() + ctx, cancel := lp.stopCh.NewCtx() + defer cancel() // Avoid putting too much pressure on the database by staggering the pruning of old blocks and logs. // Usually, node after restart will have some work to boot the plugins and other services. @@ -634,11 +637,11 @@ func (lp *logPoller) backgroundWorkerRun() { for { select { - case <-lp.ctx.Done(): + case <-ctx.Done(): return case <-blockPruneTick: blockPruneTick = time.After(utils.WithJitter(lp.pollPeriod * 1000)) - if allRemoved, err := lp.PruneOldBlocks(lp.ctx); err != nil { + if allRemoved, err := lp.PruneOldBlocks(ctx); err != nil { lp.lggr.Errorw("Unable to prune old blocks", "err", err) } else if !allRemoved { // Tick faster when cleanup can't keep up with the pace of new blocks @@ -646,7 +649,7 @@ func (lp *logPoller) backgroundWorkerRun() { } case <-logPruneTick: logPruneTick = time.After(utils.WithJitter(lp.pollPeriod * 2401)) // = 7^5 avoids common factors with 1000 - if allRemoved, err := lp.PruneExpiredLogs(lp.ctx); err != nil { + if allRemoved, err := lp.PruneExpiredLogs(ctx); err != nil { lp.lggr.Errorw("Unable to prune expired logs", "err", err) } else if !allRemoved { // Tick faster when cleanup can't keep up with the pace of new logs @@ -656,26 +659,26 @@ func (lp *logPoller) backgroundWorkerRun() { } } -func (lp *logPoller) handleReplayRequest(fromBlockReq int64, filtersLoaded bool) { - fromBlock, err := lp.GetReplayFromBlock(lp.ctx, fromBlockReq) +func (lp *logPoller) handleReplayRequest(ctx context.Context, fromBlockReq int64, filtersLoaded bool) { + fromBlock, err := lp.GetReplayFromBlock(ctx, fromBlockReq) if err == nil { if !filtersLoaded { lp.lggr.Warnw("Received replayReq before filters loaded", "fromBlock", fromBlock, "requested", fromBlockReq) - if err = lp.loadFilters(); err != nil { + if err = lp.loadFilters(ctx); err != nil { lp.lggr.Errorw("Failed loading filters during Replay", "err", err, "fromBlock", fromBlock) } } if err == nil { // Serially process replay requests. lp.lggr.Infow("Executing replay", "fromBlock", fromBlock, "requested", fromBlockReq) - lp.PollAndSaveLogs(lp.ctx, fromBlock) + lp.PollAndSaveLogs(ctx, fromBlock) lp.lggr.Infow("Executing replay finished", "fromBlock", fromBlock, "requested", fromBlockReq) } } else { lp.lggr.Errorw("Error executing replay, could not get fromBlock", "err", err) } select { - case <-lp.ctx.Done(): + case <-ctx.Done(): // We're shutting down, notify client and exit select { case lp.replayComplete <- ErrReplayRequestAborted: diff --git a/core/chains/evm/logpoller/log_poller_internal_test.go b/core/chains/evm/logpoller/log_poller_internal_test.go index b7dbb074568..bc295105874 100644 --- a/core/chains/evm/logpoller/log_poller_internal_test.go +++ b/core/chains/evm/logpoller/log_poller_internal_test.go @@ -280,7 +280,6 @@ func TestLogPoller_Replay(t *testing.T) { chainID := testutils.FixtureChainID db := pgtest.NewSqlxDB(t) orm := NewORM(chainID, db, lggr) - ctx := testutils.Context(t) head := evmtypes.Head{Number: 4} events := []common.Hash{EmitterABI.Events["Log1"].ID} @@ -312,18 +311,21 @@ func TestLogPoller_Replay(t *testing.T) { } lp := NewLogPoller(orm, ec, lggr, lpOpts) - // process 1 log in block 3 - lp.PollAndSaveLogs(ctx, 4) - latest, err := lp.LatestBlock(ctx) - require.NoError(t, err) - require.Equal(t, int64(4), latest.BlockNumber) - require.Equal(t, int64(1), latest.FinalizedBlockNumber) + { + ctx := testutils.Context(t) + // process 1 log in block 3 + lp.PollAndSaveLogs(ctx, 4) + latest, err := lp.LatestBlock(ctx) + require.NoError(t, err) + require.Equal(t, int64(4), latest.BlockNumber) + require.Equal(t, int64(1), latest.FinalizedBlockNumber) + } t.Run("abort before replayStart received", func(t *testing.T) { // Replay() should abort immediately if caller's context is cancelled before request signal is read cancelCtx, cancel := context.WithCancel(testutils.Context(t)) cancel() - err = lp.Replay(cancelCtx, 3) + err := lp.Replay(cancelCtx, 3) assert.ErrorIs(t, err, ErrReplayRequestAborted) }) @@ -338,6 +340,7 @@ func TestLogPoller_Replay(t *testing.T) { // Replay() should return error code received from replayComplete t.Run("returns error code on replay complete", func(t *testing.T) { + ctx := testutils.Context(t) ec.On("FilterLogs", mock.Anything, mock.Anything).Return([]types.Log{log1}, nil).Once() mockBatchCallContext(t, ec) anyErr := pkgerrors.New("any error") @@ -368,6 +371,7 @@ func TestLogPoller_Replay(t *testing.T) { // Main lp.run() loop shouldn't get stuck if client aborts t.Run("client abort doesnt hang run loop", func(t *testing.T) { + ctx := testutils.Context(t) lp.backupPollerNextBlock = 0 pass := make(chan struct{}) @@ -420,6 +424,7 @@ func TestLogPoller_Replay(t *testing.T) { // run() should abort if log poller shuts down while replay is in progress t.Run("shutdown during replay", func(t *testing.T) { + ctx := testutils.Context(t) lp.backupPollerNextBlock = 0 pass := make(chan struct{}) @@ -438,8 +443,15 @@ func TestLogPoller_Replay(t *testing.T) { }() }) ec.On("FilterLogs", mock.Anything, mock.Anything).Once().Return([]types.Log{log1}, nil).Run(func(args mock.Arguments) { - lp.cancel() - close(pass) + go func() { + assert.NoError(t, lp.Close()) + + // prevent double close + lp.reset() + assert.NoError(t, lp.Start(ctx)) + + close(pass) + }() }) ec.On("FilterLogs", mock.Anything, mock.Anything).Return([]types.Log{log1}, nil) @@ -468,6 +480,7 @@ func TestLogPoller_Replay(t *testing.T) { }) t.Run("ReplayAsync error", func(t *testing.T) { + ctx := testutils.Context(t) t.Cleanup(lp.reset) servicetest.Run(t, lp) head = evmtypes.Head{Number: 4} @@ -481,7 +494,7 @@ func TestLogPoller_Replay(t *testing.T) { select { case lp.replayComplete <- anyErr: time.Sleep(2 * time.Second) - case <-lp.ctx.Done(): + case <-ctx.Done(): t.Error("timed out waiting to send replaceComplete") } require.Equal(t, 1, observedLogs.Len()) @@ -489,6 +502,7 @@ func TestLogPoller_Replay(t *testing.T) { }) t.Run("run regular replay when there are not blocks in db", func(t *testing.T) { + ctx := testutils.Context(t) err := lp.orm.DeleteLogsAndBlocksAfter(ctx, 0) require.NoError(t, err) @@ -497,6 +511,7 @@ func TestLogPoller_Replay(t *testing.T) { }) t.Run("run only backfill when everything is finalized", func(t *testing.T) { + ctx := testutils.Context(t) err := lp.orm.DeleteLogsAndBlocksAfter(ctx, 0) require.NoError(t, err) @@ -513,7 +528,7 @@ func TestLogPoller_Replay(t *testing.T) { func (lp *logPoller) reset() { lp.StateMachine = services.StateMachine{} - lp.ctx, lp.cancel = context.WithCancel(context.Background()) + lp.stopCh = make(chan struct{}) } func Test_latestBlockAndFinalityDepth(t *testing.T) { diff --git a/core/chains/evm/txmgr/client.go b/core/chains/evm/txmgr/client.go index 8ba3c841289..661a180af50 100644 --- a/core/chains/evm/txmgr/client.go +++ b/core/chains/evm/txmgr/client.go @@ -18,7 +18,7 @@ import ( commonclient "github.com/smartcontractkit/chainlink/v2/common/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" - evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) @@ -27,10 +27,10 @@ var _ TxmClient = (*evmTxmClient)(nil) type evmTxmClient struct { client client.Client - clientErrors evmconfig.ClientErrors + clientErrors config.ClientErrors } -func NewEvmTxmClient(c client.Client, clientErrors evmconfig.ClientErrors) *evmTxmClient { +func NewEvmTxmClient(c client.Client, clientErrors config.ClientErrors) *evmTxmClient { return &evmTxmClient{client: c, clientErrors: clientErrors} } diff --git a/core/chains/evm/txmgr/evm_tx_store.go b/core/chains/evm/txmgr/evm_tx_store.go index 2e4c85c879d..b0ed0b41284 100644 --- a/core/chains/evm/txmgr/evm_tx_store.go +++ b/core/chains/evm/txmgr/evm_tx_store.go @@ -19,6 +19,7 @@ import ( nullv4 "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/null" @@ -77,10 +78,9 @@ type TestEvmTxStore interface { } type evmTxStore struct { - q sqlutil.DataSource - logger logger.SugaredLogger - ctx context.Context - ctxCancel context.CancelFunc + q sqlutil.DataSource + logger logger.SugaredLogger + stopCh services.StopChan } var _ EvmTxStore = (*evmTxStore)(nil) @@ -349,12 +349,10 @@ func NewTxStore( lggr logger.Logger, ) *evmTxStore { namedLogger := logger.Named(lggr, "TxmStore") - ctx, cancel := context.WithCancel(context.Background()) return &evmTxStore{ - q: db, - logger: logger.Sugared(namedLogger), - ctx: ctx, - ctxCancel: cancel, + q: db, + logger: logger.Sugared(namedLogger), + stopCh: make(chan struct{}), } } @@ -365,7 +363,7 @@ RETURNING *; ` func (o *evmTxStore) Close() { - o.ctxCancel() + close(o.stopCh) } func (o *evmTxStore) preloadTxAttempts(ctx context.Context, txs []Tx) error { @@ -402,7 +400,7 @@ func (o *evmTxStore) preloadTxAttempts(ctx context.Context, txs []Tx) error { func (o *evmTxStore) PreloadTxes(ctx context.Context, attempts []TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() return o.preloadTxesAtomic(ctx, attempts) } @@ -577,7 +575,7 @@ func (o *evmTxStore) InsertReceipt(ctx context.Context, receipt *evmtypes.Receip func (o *evmTxStore) GetFatalTransactions(ctx context.Context) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { stmt := `SELECT * FROM evm.txes WHERE state = 'fatal_error'` @@ -655,7 +653,7 @@ func (o *evmTxStore) LoadTxesAttempts(ctx context.Context, etxs []*Tx) error { func (o *evmTxStore) LoadTxAttempts(ctx context.Context, etx *Tx) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() return o.loadTxAttemptsAtomic(ctx, etx) } @@ -720,7 +718,7 @@ func loadConfirmedAttemptsReceipts(ctx context.Context, q sqlutil.DataSource, at // eth_tx that was last sent before or at the given time (up to limit) func (o *evmTxStore) FindTxAttemptsRequiringResend(ctx context.Context, olderThan time.Time, maxInFlightTransactions uint32, chainID *big.Int, address common.Address) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var limit null.Uint32 if maxInFlightTransactions > 0 { @@ -744,7 +742,7 @@ LIMIT $4 func (o *evmTxStore) UpdateBroadcastAts(ctx context.Context, now time.Time, etxIDs []int64) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() // Deliberately do nothing on NULL broadcast_at because that indicates the // tx has been moved into a state where broadcast_at is not relevant, e.g. @@ -762,7 +760,7 @@ func (o *evmTxStore) UpdateBroadcastAts(ctx context.Context, now time.Time, etxI // the attempt is already broadcast it _must_ have been before this head. func (o *evmTxStore) SetBroadcastBeforeBlockNum(ctx context.Context, blockNum int64, chainID *big.Int) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() _, err := o.q.ExecContext(ctx, `UPDATE evm.tx_attempts @@ -777,7 +775,7 @@ AND evm.txes.id = evm.tx_attempts.eth_tx_id AND evm.txes.evm_chain_id = $2`, func (o *evmTxStore) FindTxAttemptsConfirmedMissingReceipt(ctx context.Context, chainID *big.Int) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbAttempts []DbEthTxAttempt err = o.q.SelectContext(ctx, &dbAttempts, @@ -796,7 +794,7 @@ func (o *evmTxStore) FindTxAttemptsConfirmedMissingReceipt(ctx context.Context, func (o *evmTxStore) UpdateTxsUnconfirmed(ctx context.Context, ids []int64) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() _, err := o.q.ExecContext(ctx, `UPDATE evm.txes SET state='unconfirmed' WHERE id = ANY($1)`, pq.Array(ids)) @@ -808,7 +806,7 @@ func (o *evmTxStore) UpdateTxsUnconfirmed(ctx context.Context, ids []int64) erro func (o *evmTxStore) FindTxAttemptsRequiringReceiptFetch(ctx context.Context, chainID *big.Int) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbAttempts []DbEthTxAttempt @@ -862,7 +860,7 @@ func (o *evmTxStore) FindTxsByStateAndFromAddresses(ctx context.Context, address func (o *evmTxStore) SaveFetchedReceipts(ctx context.Context, r []*evmtypes.Receipt, state txmgrtypes.TxState, errorMsg *string, chainID *big.Int) (err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() receipts := toOnchainReceipt(r) if len(receipts) == 0 { @@ -966,7 +964,7 @@ func (o *evmTxStore) SaveFetchedReceipts(ctx context.Context, r []*evmtypes.Rece // attempts are below the finality depth from current head. func (o *evmTxStore) MarkAllConfirmedMissingReceipt(ctx context.Context, chainID *big.Int) (err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() res, err := o.q.ExecContext(ctx, ` UPDATE evm.txes @@ -997,7 +995,7 @@ WHERE state = 'unconfirmed' func (o *evmTxStore) GetInProgressTxAttempts(ctx context.Context, address common.Address, chainID *big.Int) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbAttempts []DbEthTxAttempt @@ -1021,7 +1019,7 @@ func (o *evmTxStore) FindTxesPendingCallback(ctx context.Context, blockNum int64 var rs []dbReceiptPlus var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.q.SelectContext(ctx, &rs, ` SELECT evm.txes.pipeline_task_run_id, evm.receipts.receipt, COALESCE((evm.txes.meta->>'FailOnRevert')::boolean, false) "FailOnRevert" FROM evm.txes @@ -1040,7 +1038,7 @@ func (o *evmTxStore) FindTxesPendingCallback(ctx context.Context, blockNum int64 // Update tx to mark that its callback has been signaled func (o *evmTxStore) UpdateTxCallbackCompleted(ctx context.Context, pipelineTaskRunId uuid.UUID, chainId *big.Int) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() _, err := o.q.ExecContext(ctx, `UPDATE evm.txes SET callback_completed = TRUE WHERE pipeline_task_run_id = $1 AND evm_chain_id = $2`, pipelineTaskRunId, chainId.String()) if err != nil { @@ -1051,7 +1049,7 @@ func (o *evmTxStore) UpdateTxCallbackCompleted(ctx context.Context, pipelineTask func (o *evmTxStore) FindLatestSequence(ctx context.Context, fromAddress common.Address, chainId *big.Int) (nonce evmtypes.Nonce, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() sql := `SELECT nonce FROM evm.txes WHERE from_address = $1 AND evm_chain_id = $2 AND nonce IS NOT NULL ORDER BY nonce DESC LIMIT 1` err = o.q.GetContext(ctx, &nonce, sql, fromAddress, chainId.String()) @@ -1061,7 +1059,7 @@ func (o *evmTxStore) FindLatestSequence(ctx context.Context, fromAddress common. // FindTxWithIdempotencyKey returns any broadcast ethtx with the given idempotencyKey and chainID func (o *evmTxStore) FindTxWithIdempotencyKey(ctx context.Context, idempotencyKey string, chainID *big.Int) (etx *Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtx DbEthTx err = o.q.GetContext(ctx, &dbEtx, `SELECT * FROM evm.txes WHERE idempotency_key = $1 and evm_chain_id = $2`, idempotencyKey, chainID.String()) @@ -1079,7 +1077,7 @@ func (o *evmTxStore) FindTxWithIdempotencyKey(ctx context.Context, idempotencyKe // FindTxWithSequence returns any broadcast ethtx with the given nonce func (o *evmTxStore) FindTxWithSequence(ctx context.Context, fromAddress common.Address, nonce evmtypes.Nonce) (etx *Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() etx = new(Tx) err = o.Transact(ctx, true, func(orm *evmTxStore) error { @@ -1128,7 +1126,7 @@ AND evm.tx_attempts.eth_tx_id = $1 func (o *evmTxStore) UpdateTxForRebroadcast(ctx context.Context, etx Tx, etxAttempt TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() return o.Transact(ctx, false, func(orm *evmTxStore) error { if err := deleteEthReceipts(ctx, orm, etx.ID); err != nil { @@ -1143,7 +1141,7 @@ func (o *evmTxStore) UpdateTxForRebroadcast(ctx context.Context, etx Tx, etxAtte func (o *evmTxStore) FindTransactionsConfirmedInBlockRange(ctx context.Context, highBlockNumber, lowBlockNumber int64, chainID *big.Int) (etxs []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtxs []DbEthTx @@ -1170,7 +1168,7 @@ ORDER BY nonce ASC func (o *evmTxStore) FindEarliestUnconfirmedBroadcastTime(ctx context.Context, chainID *big.Int) (broadcastAt nullv4.Time, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { if err = orm.q.QueryRowxContext(ctx, `SELECT min(initial_broadcast_at) FROM evm.txes WHERE state = 'unconfirmed' AND evm_chain_id = $1`, chainID.String()).Scan(&broadcastAt); err != nil { @@ -1183,7 +1181,7 @@ func (o *evmTxStore) FindEarliestUnconfirmedBroadcastTime(ctx context.Context, c func (o *evmTxStore) FindEarliestUnconfirmedTxAttemptBlock(ctx context.Context, chainID *big.Int) (earliestUnconfirmedTxBlock nullv4.Int, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { err = orm.q.QueryRowxContext(ctx, ` @@ -1201,7 +1199,7 @@ AND evm_chain_id = $1`, chainID.String()).Scan(&earliestUnconfirmedTxBlock) func (o *evmTxStore) IsTxFinalized(ctx context.Context, blockHeight int64, txID int64, chainID *big.Int) (finalized bool, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var count int32 @@ -1234,7 +1232,7 @@ func (o *evmTxStore) saveAttemptWithNewState(ctx context.Context, attempt TxAtte func (o *evmTxStore) SaveInsufficientFundsAttempt(ctx context.Context, timeout time.Duration, attempt *TxAttempt, broadcastAt time.Time) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if !(attempt.State == txmgrtypes.TxAttemptInProgress || attempt.State == txmgrtypes.TxAttemptInsufficientFunds) { return errors.New("expected state to be either in_progress or insufficient_eth") @@ -1257,14 +1255,14 @@ func (o *evmTxStore) saveSentAttempt(ctx context.Context, timeout time.Duration, func (o *evmTxStore) SaveSentAttempt(ctx context.Context, timeout time.Duration, attempt *TxAttempt, broadcastAt time.Time) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() return o.saveSentAttempt(ctx, timeout, attempt, broadcastAt) } func (o *evmTxStore) SaveConfirmedMissingReceiptAttempt(ctx context.Context, timeout time.Duration, attempt *TxAttempt, broadcastAt time.Time) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err := o.Transact(ctx, false, func(orm *evmTxStore) error { if err := orm.saveSentAttempt(ctx, timeout, attempt, broadcastAt); err != nil { @@ -1280,7 +1278,7 @@ func (o *evmTxStore) SaveConfirmedMissingReceiptAttempt(ctx context.Context, tim func (o *evmTxStore) DeleteInProgressAttempt(ctx context.Context, attempt TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if attempt.State != txmgrtypes.TxAttemptInProgress { return errors.New("DeleteInProgressAttempt: expected attempt state to be in_progress") @@ -1295,7 +1293,7 @@ func (o *evmTxStore) DeleteInProgressAttempt(ctx context.Context, attempt TxAtte // SaveInProgressAttempt inserts or updates an attempt func (o *evmTxStore) SaveInProgressAttempt(ctx context.Context, attempt *TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if attempt.State != txmgrtypes.TxAttemptInProgress { return errors.New("SaveInProgressAttempt failed: attempt state must be in_progress") @@ -1329,7 +1327,7 @@ func (o *evmTxStore) SaveInProgressAttempt(ctx context.Context, attempt *TxAttem func (o *evmTxStore) GetAbandonedTransactionsByBatch(ctx context.Context, chainID *big.Int, enabledAddrs []common.Address, offset, limit uint) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var enabledAddrsBytea [][]byte @@ -1353,7 +1351,7 @@ func (o *evmTxStore) GetAbandonedTransactionsByBatch(ctx context.Context, chainI func (o *evmTxStore) GetTxByID(ctx context.Context, id int64) (txe *Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { @@ -1388,7 +1386,7 @@ func (o *evmTxStore) FindTxsRequiringGasBump(ctx context.Context, address common return } var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { stmt := ` @@ -1415,7 +1413,7 @@ ORDER BY nonce ASC // block func (o *evmTxStore) FindTxsRequiringResubmissionDueToInsufficientFunds(ctx context.Context, address common.Address, chainID *big.Int) (etxs []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtxs []DbEthTx @@ -1445,7 +1443,7 @@ ORDER BY nonce ASC // receipt and thus cannot pass on any transaction hash func (o *evmTxStore) MarkOldTxesMissingReceiptAsErrored(ctx context.Context, blockNum int64, finalityDepth uint32, chainID *big.Int) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() // cutoffBlockNum is a block height // Any 'confirmed_missing_receipt' eth_tx with all attempts older than this block height will be marked as errored @@ -1538,7 +1536,7 @@ GROUP BY e.id func (o *evmTxStore) SaveReplacementInProgressAttempt(ctx context.Context, oldAttempt TxAttempt, replacementAttempt *TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if oldAttempt.State != txmgrtypes.TxAttemptInProgress || replacementAttempt.State != txmgrtypes.TxAttemptInProgress { return errors.New("expected attempts to be in_progress") @@ -1565,7 +1563,7 @@ func (o *evmTxStore) SaveReplacementInProgressAttempt(ctx context.Context, oldAt // Finds earliest saved transaction that has yet to be broadcast from the given address func (o *evmTxStore) FindNextUnstartedTransactionFromAddress(ctx context.Context, fromAddress common.Address, chainID *big.Int) (*Tx, error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtx DbEthTx err := o.q.GetContext(ctx, &dbEtx, `SELECT * FROM evm.txes WHERE from_address = $1 AND state = 'unstarted' AND evm_chain_id = $2 ORDER BY value ASC, created_at ASC, id ASC`, fromAddress, chainID.String()) @@ -1580,7 +1578,7 @@ func (o *evmTxStore) FindNextUnstartedTransactionFromAddress(ctx context.Context func (o *evmTxStore) UpdateTxFatalError(ctx context.Context, etx *Tx) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if etx.State != txmgr.TxInProgress && etx.State != txmgr.TxUnstarted { return pkgerrors.Errorf("can only transition to fatal_error from in_progress or unstarted, transaction is currently %s", etx.State) @@ -1607,7 +1605,7 @@ func (o *evmTxStore) UpdateTxFatalError(ctx context.Context, etx *Tx) error { // Updates eth attempt from in_progress to broadcast. Also updates the eth tx to unconfirmed. func (o *evmTxStore) UpdateTxAttemptInProgressToBroadcast(ctx context.Context, etx *Tx, attempt TxAttempt, NewAttemptState txmgrtypes.TxAttemptState) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if etx.BroadcastAt == nil { return errors.New("unconfirmed transaction must have broadcast_at time") @@ -1645,7 +1643,7 @@ func (o *evmTxStore) UpdateTxAttemptInProgressToBroadcast(ctx context.Context, e // Updates eth tx from unstarted to in_progress and inserts in_progress eth attempt func (o *evmTxStore) UpdateTxUnstartedToInProgress(ctx context.Context, etx *Tx, attempt *TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if etx.Sequence == nil { return errors.New("in_progress transaction must have nonce") @@ -1717,7 +1715,7 @@ func (o *evmTxStore) UpdateTxUnstartedToInProgress(ctx context.Context, etx *Tx, // It may or may not have been broadcast to an eth node. func (o *evmTxStore) GetTxInProgress(ctx context.Context, fromAddress common.Address) (etx *Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() etx = new(Tx) if err != nil { @@ -1748,7 +1746,7 @@ func (o *evmTxStore) GetTxInProgress(ctx context.Context, fromAddress common.Add func (o *evmTxStore) HasInProgressTransaction(ctx context.Context, account common.Address, chainID *big.Int) (exists bool, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.q.GetContext(ctx, &exists, `SELECT EXISTS(SELECT 1 FROM evm.txes WHERE state = 'in_progress' AND from_address = $1 AND evm_chain_id = $2)`, account, chainID.String()) return exists, pkgerrors.Wrap(err, "hasInProgressTransaction failed") @@ -1756,7 +1754,7 @@ func (o *evmTxStore) HasInProgressTransaction(ctx context.Context, account commo func (o *evmTxStore) countTransactionsWithState(ctx context.Context, fromAddress common.Address, state txmgrtypes.TxState, chainID *big.Int) (count uint32, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.q.GetContext(ctx, &count, `SELECT count(*) FROM evm.txes WHERE from_address = $1 AND state = $2 AND evm_chain_id = $3`, fromAddress, state, chainID.String()) @@ -1771,7 +1769,7 @@ func (o *evmTxStore) CountUnconfirmedTransactions(ctx context.Context, fromAddre // CountTransactionsByState returns the number of transactions with any fromAddress in the given state func (o *evmTxStore) CountTransactionsByState(ctx context.Context, state txmgrtypes.TxState, chainID *big.Int) (count uint32, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.q.GetContext(ctx, &count, `SELECT count(*) FROM evm.txes WHERE state = $1 AND evm_chain_id = $2`, state, chainID.String()) @@ -1788,7 +1786,7 @@ func (o *evmTxStore) CountUnstartedTransactions(ctx context.Context, fromAddress func (o *evmTxStore) CheckTxQueueCapacity(ctx context.Context, fromAddress common.Address, maxQueuedTransactions uint64, chainID *big.Int) (err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if maxQueuedTransactions == 0 { return nil @@ -1808,7 +1806,7 @@ func (o *evmTxStore) CheckTxQueueCapacity(ctx context.Context, fromAddress commo func (o *evmTxStore) CreateTransaction(ctx context.Context, txRequest TxRequest, chainID *big.Int) (tx Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtx DbEthTx err = o.Transact(ctx, false, func(orm *evmTxStore) error { @@ -1842,7 +1840,7 @@ RETURNING "txes".* func (o *evmTxStore) PruneUnstartedTxQueue(ctx context.Context, queueSize uint32, subject uuid.UUID) (ids []int64, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, false, func(orm *evmTxStore) error { err := orm.q.SelectContext(ctx, &ids, ` @@ -1870,7 +1868,7 @@ id < ( func (o *evmTxStore) ReapTxHistory(ctx context.Context, minBlockNumberToKeep int64, timeThreshold time.Time, chainID *big.Int) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() // Delete old confirmed evm.txes @@ -1929,7 +1927,7 @@ AND evm_chain_id = $2`, timeThreshold, chainID.String()) func (o *evmTxStore) Abandon(ctx context.Context, chainID *big.Int, addr common.Address) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() _, err := o.q.ExecContext(ctx, `UPDATE evm.txes SET state='fatal_error', nonce = NULL, error = 'abandoned' WHERE state IN ('unconfirmed', 'in_progress', 'unstarted') AND evm_chain_id = $1 AND from_address = $2`, chainID.String(), addr) return err @@ -1938,7 +1936,7 @@ func (o *evmTxStore) Abandon(ctx context.Context, chainID *big.Int, addr common. // Find transactions by a field in the TxMeta blob and transaction states func (o *evmTxStore) FindTxesByMetaFieldAndStates(ctx context.Context, metaField string, metaValue string, states []txmgrtypes.TxState, chainID *big.Int) ([]*Tx, error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtxs []DbEthTx sql := fmt.Sprintf("SELECT * FROM evm.txes WHERE evm_chain_id = $1 AND meta->>'%s' = $2 AND state = ANY($3)", metaField) @@ -1951,7 +1949,7 @@ func (o *evmTxStore) FindTxesByMetaFieldAndStates(ctx context.Context, metaField // Find transactions with a non-null TxMeta field that was provided by transaction states func (o *evmTxStore) FindTxesWithMetaFieldByStates(ctx context.Context, metaField string, states []txmgrtypes.TxState, chainID *big.Int) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtxs []DbEthTx sql := fmt.Sprintf("SELECT * FROM evm.txes WHERE meta->'%s' IS NOT NULL AND state = ANY($1) AND evm_chain_id = $2", metaField) @@ -1964,7 +1962,7 @@ func (o *evmTxStore) FindTxesWithMetaFieldByStates(ctx context.Context, metaFiel // Find transactions with a non-null TxMeta field that was provided and a receipt block number greater than or equal to the one provided func (o *evmTxStore) FindTxesWithMetaFieldByReceiptBlockNum(ctx context.Context, metaField string, blockNum int64, chainID *big.Int) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtxs []DbEthTx sql := fmt.Sprintf("SELECT et.* FROM evm.txes et JOIN evm.tx_attempts eta on et.id = eta.eth_tx_id JOIN evm.receipts er on eta.hash = er.tx_hash WHERE et.meta->'%s' IS NOT NULL AND er.block_number >= $1 AND et.evm_chain_id = $2", metaField) @@ -1977,7 +1975,7 @@ func (o *evmTxStore) FindTxesWithMetaFieldByReceiptBlockNum(ctx context.Context, // Find transactions loaded with transaction attempts and receipts by transaction IDs and states func (o *evmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []txmgrtypes.TxState, chainID *big.Int) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtxs []DbEthTx @@ -2000,7 +1998,7 @@ func (o *evmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Co // For testing only, get all txes in the DB func (o *evmTxStore) GetAllTxes(ctx context.Context) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtxs []DbEthTx sql := "SELECT * FROM evm.txes" @@ -2013,7 +2011,7 @@ func (o *evmTxStore) GetAllTxes(ctx context.Context) (txes []*Tx, err error) { // For testing only, get all tx attempts in the DB func (o *evmTxStore) GetAllTxAttempts(ctx context.Context) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbAttempts []DbEthTxAttempt sql := "SELECT * FROM evm.tx_attempts" @@ -2024,7 +2022,7 @@ func (o *evmTxStore) GetAllTxAttempts(ctx context.Context) (attempts []TxAttempt func (o *evmTxStore) CountTxesByStateAndSubject(ctx context.Context, state txmgrtypes.TxState, subject uuid.UUID) (count int, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() sql := "SELECT COUNT(*) FROM evm.txes WHERE state = $1 AND subject = $2" err = o.q.GetContext(ctx, &count, sql, state, subject) @@ -2033,7 +2031,7 @@ func (o *evmTxStore) CountTxesByStateAndSubject(ctx context.Context, state txmgr func (o *evmTxStore) FindTxesByFromAddressAndState(ctx context.Context, fromAddress common.Address, state string) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() sql := "SELECT * FROM evm.txes WHERE from_address = $1 AND state = $2" var dbEtxs []DbEthTx @@ -2045,23 +2043,9 @@ func (o *evmTxStore) FindTxesByFromAddressAndState(ctx context.Context, fromAddr func (o *evmTxStore) UpdateTxAttemptBroadcastBeforeBlockNum(ctx context.Context, id int64, blockNum uint) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() sql := "UPDATE evm.tx_attempts SET broadcast_before_block_num = $1 WHERE eth_tx_id = $2" _, err := o.q.ExecContext(ctx, sql, blockNum, id) return err } - -// Returns a context that contains the values of the provided context, -// and which is canceled when either the provided context or TxStore parent context is canceled. -func (o *evmTxStore) mergeContexts(ctx context.Context) (context.Context, context.CancelFunc) { - var cancel context.CancelCauseFunc - ctx, cancel = context.WithCancelCause(ctx) - stop := context.AfterFunc(o.ctx, func() { - cancel(context.Cause(o.ctx)) - }) - return ctx, func() { - stop() - cancel(context.Canceled) - } -} diff --git a/core/cmd/shell_test.go b/core/cmd/shell_test.go index 1e3b93851f3..6ecdc4a34de 100644 --- a/core/cmd/shell_test.go +++ b/core/cmd/shell_test.go @@ -18,9 +18,7 @@ import ( "github.com/urfave/cli" commoncfg "github.com/smartcontractkit/chainlink-common/pkg/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" - "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" "github.com/smartcontractkit/chainlink/v2/core/cmd" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" @@ -359,20 +357,20 @@ func TestSetupSolanaRelayer(t *testing.T) { // config 3 chains but only enable 2 => should only be 2 relayer nEnabledChains := 2 tConfig := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { - c.Solana = solana.TOMLConfigs{ - &solana.TOMLConfig{ + c.Solana = solcfg.TOMLConfigs{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("solana-id-1"), Enabled: ptr(true), Chain: solcfg.Chain{}, Nodes: []*solcfg.Node{}, }, - &solana.TOMLConfig{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("solana-id-2"), Enabled: ptr(true), Chain: solcfg.Chain{}, Nodes: []*solcfg.Node{}, }, - &solana.TOMLConfig{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("disabled-solana-id-1"), Enabled: ptr(false), Chain: solcfg.Chain{}, @@ -382,8 +380,8 @@ func TestSetupSolanaRelayer(t *testing.T) { }) t2Config := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { - c.Solana = solana.TOMLConfigs{ - &solana.TOMLConfig{ + c.Solana = solcfg.TOMLConfigs{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("solana-id-1"), Enabled: ptr(true), Chain: solcfg.Chain{}, @@ -420,14 +418,14 @@ func TestSetupSolanaRelayer(t *testing.T) { // test that duplicate enabled chains is an error when duplicateConfig := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { - c.Solana = solana.TOMLConfigs{ - &solana.TOMLConfig{ + c.Solana = solcfg.TOMLConfigs{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("dupe"), Enabled: ptr(true), Chain: solcfg.Chain{}, Nodes: []*solcfg.Node{}, }, - &solana.TOMLConfig{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("dupe"), Enabled: ptr(true), Chain: solcfg.Chain{}, @@ -478,21 +476,21 @@ func TestSetupStarkNetRelayer(t *testing.T) { ChainID: ptr[string]("starknet-id-1"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, &stkcfg.TOMLConfig{ ChainID: ptr[string]("starknet-id-2"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, &stkcfg.TOMLConfig{ ChainID: ptr[string]("disabled-starknet-id-1"), Enabled: ptr(false), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, } @@ -504,7 +502,7 @@ func TestSetupStarkNetRelayer(t *testing.T) { ChainID: ptr[string]("starknet-id-3"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, } @@ -542,14 +540,14 @@ func TestSetupStarkNetRelayer(t *testing.T) { ChainID: ptr[string]("dupe"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, &stkcfg.TOMLConfig{ ChainID: ptr[string]("dupe"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, } diff --git a/core/cmd/solana_chains_commands_test.go b/core/cmd/solana_chains_commands_test.go index 88bc8049247..e374ba11c65 100644 --- a/core/cmd/solana_chains_commands_test.go +++ b/core/cmd/solana_chains_commands_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink/v2/core/cmd" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/solanatest" @@ -16,7 +16,7 @@ func TestShell_IndexSolanaChains(t *testing.T) { t.Parallel() id := solanatest.RandomChainID() - cfg := solana.TOMLConfig{ + cfg := solcfg.TOMLConfig{ ChainID: &id, Enabled: ptr(true), } diff --git a/core/cmd/solana_node_commands_test.go b/core/cmd/solana_node_commands_test.go index ebe9502d1fa..adc699de79b 100644 --- a/core/cmd/solana_node_commands_test.go +++ b/core/cmd/solana_node_commands_test.go @@ -12,14 +12,13 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/config" solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" "github.com/smartcontractkit/chainlink/v2/core/cmd" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/solanatest" "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" ) -func solanaStartNewApplication(t *testing.T, cfgs ...*solana.TOMLConfig) *cltest.TestApplication { +func solanaStartNewApplication(t *testing.T, cfgs ...*solcfg.TOMLConfig) *cltest.TestApplication { for i := range cfgs { cfgs[i].SetDefaults() } @@ -41,9 +40,9 @@ func TestShell_IndexSolanaNodes(t *testing.T) { Name: ptr("second"), URL: config.MustParseURL("https://solana2.example"), } - chain := solana.TOMLConfig{ + chain := solcfg.TOMLConfig{ ChainID: &id, - Nodes: solana.SolanaNodes{&node1, &node2}, + Nodes: solcfg.Nodes{&node1, &node2}, } app := solanaStartNewApplication(t, &chain) client, r := app.NewShellAndRenderer() diff --git a/core/cmd/solana_transaction_commands_test.go b/core/cmd/solana_transaction_commands_test.go index c26bd89ab94..79a5513f190 100644 --- a/core/cmd/solana_transaction_commands_test.go +++ b/core/cmd/solana_transaction_commands_test.go @@ -16,7 +16,6 @@ import ( "github.com/urfave/cli" "github.com/smartcontractkit/chainlink-common/pkg/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" solanaClient "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" @@ -32,9 +31,9 @@ func TestShell_SolanaSendSol(t *testing.T) { Name: ptr(t.Name()), URL: config.MustParseURL(url), } - cfg := solana.TOMLConfig{ + cfg := solcfg.TOMLConfig{ ChainID: &chainID, - Nodes: solana.SolanaNodes{&node}, + Nodes: solcfg.Nodes{&node}, Enabled: ptr(true), } app := solanaStartNewApplication(t, &cfg) diff --git a/core/config/docs/docs_test.go b/core/config/docs/docs_test.go index 5bd91a7d93c..fd59edbab6a 100644 --- a/core/config/docs/docs_test.go +++ b/core/config/docs/docs_test.go @@ -11,10 +11,11 @@ import ( "github.com/stretchr/testify/require" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" "github.com/smartcontractkit/chainlink-common/pkg/config" + commonconfig "github.com/smartcontractkit/chainlink/v2/common/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" evmcfg "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -46,7 +47,7 @@ func TestDoc(t *testing.T) { fallbackDefaults := evmcfg.Defaults(nil) docDefaults := defaults.EVM[0].Chain - require.Equal(t, "", *docDefaults.ChainType) + require.Equal(t, commonconfig.ChainType(""), docDefaults.ChainType.ChainType()) docDefaults.ChainType = nil // clean up KeySpecific as a special case @@ -102,7 +103,7 @@ func TestDoc(t *testing.T) { }) t.Run("Solana", func(t *testing.T) { - var fallbackDefaults solana.TOMLConfig + var fallbackDefaults solcfg.TOMLConfig fallbackDefaults.SetDefaults() assertTOML(t, fallbackDefaults.Chain, defaults.Solana[0].Chain) diff --git a/core/scripts/go.mod b/core/scripts/go.mod index 18840a04cad..cea4ca21242 100644 --- a/core/scripts/go.mod +++ b/core/scripts/go.mod @@ -259,7 +259,7 @@ require ( github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240508101745-af1ed7bc8a69 // indirect github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 // indirect github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab // indirect - github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 // indirect + github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc // indirect github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 // indirect github.com/smartcontractkit/tdh2/go/ocr2/decryptionplugin v0.0.0-20230906073235-9e478e5e19f1 // indirect github.com/smartcontractkit/tdh2/go/tdh2 v0.0.0-20230906073235-9e478e5e19f1 // indirect diff --git a/core/scripts/go.sum b/core/scripts/go.sum index 48cff213690..cde855e2d2f 100644 --- a/core/scripts/go.sum +++ b/core/scripts/go.sum @@ -1193,8 +1193,8 @@ github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab/go.mod h1:RPUY7r8GxgzXxS1ijtU1P/fpJomOXztXgUbEziNmbCA= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 h1:f3W82k9V/XA6ZP/VQVJcGMVR6CrL3pQrPJSwyQWVFys= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83/go.mod h1:RdAtOeBUWq2zByw2kEbwPlXaPIb7YlaDOmnn+nVUBJI= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc h1:ZqgatXFWsJR/hkvm2mKAta6ivXZqTw7542Iz9ucoOq0= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc/go.mod h1:sR0dMjjpvvEpX3qH8DPRANauPkbO9jgUUGYK95xjLRU= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 h1:ssh/w3oXWu+C6bE88GuFRC1+0Bx/4ihsbc80XMLrl2k= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69/go.mod h1:VsfjhvWgjxqWja4q+FlXEtX5lu8BSxn10xRo6gi948g= github.com/smartcontractkit/chainlink-vrf v0.0.0-20240222010609-cd67d123c772 h1:LQmRsrzzaYYN3wEU1l5tWiccznhvbyGnu2N+wHSXZAo= diff --git a/core/services/blockhashstore/delegate.go b/core/services/blockhashstore/delegate.go index 2181084aeec..172dbafc4a4 100644 --- a/core/services/blockhashstore/delegate.go +++ b/core/services/blockhashstore/delegate.go @@ -215,29 +215,32 @@ type service struct { pollPeriod time.Duration runTimeout time.Duration logger logger.Logger - parentCtx context.Context - cancel context.CancelFunc + stopCh services.StopChan } // Start the BHS feeder service, satisfying the job.Service interface. func (s *service) Start(context.Context) error { return s.StartOnce("BHS Feeder Service", func() error { s.logger.Infow("Starting BHS feeder") - ticker := time.NewTicker(utils.WithJitter(s.pollPeriod)) - s.parentCtx, s.cancel = context.WithCancel(context.Background()) + s.stopCh = make(chan struct{}) s.wg.Add(2) go func() { defer s.wg.Done() - s.feeder.StartHeartbeats(s.parentCtx, &realTimer{}) + ctx, cancel := s.stopCh.NewCtx() + defer cancel() + s.feeder.StartHeartbeats(ctx, &realTimer{}) }() go func() { defer s.wg.Done() + ctx, cancel := s.stopCh.NewCtx() + defer cancel() + ticker := time.NewTicker(utils.WithJitter(s.pollPeriod)) defer ticker.Stop() for { select { case <-ticker.C: - s.runFeeder() - case <-s.parentCtx.Done(): + s.runFeeder(ctx) + case <-ctx.Done(): return } } @@ -250,15 +253,15 @@ func (s *service) Start(context.Context) error { func (s *service) Close() error { return s.StopOnce("BHS Feeder Service", func() error { s.logger.Infow("Stopping BHS feeder") - s.cancel() + close(s.stopCh) s.wg.Wait() return nil }) } -func (s *service) runFeeder() { +func (s *service) runFeeder(ctx context.Context) { s.logger.Debugw("Running BHS feeder") - ctx, cancel := context.WithTimeout(s.parentCtx, s.runTimeout) + ctx, cancel := context.WithTimeout(ctx, s.runTimeout) defer cancel() err := s.feeder.Run(ctx) if err == nil { diff --git a/core/services/blockheaderfeeder/delegate.go b/core/services/blockheaderfeeder/delegate.go index 36d1d1cf895..830c2e23377 100644 --- a/core/services/blockheaderfeeder/delegate.go +++ b/core/services/blockheaderfeeder/delegate.go @@ -229,24 +229,25 @@ type service struct { pollPeriod time.Duration runTimeout time.Duration logger logger.Logger - parentCtx context.Context - cancel context.CancelFunc + stopCh services.StopChan } // Start the BHS feeder service, satisfying the job.Service interface. func (s *service) Start(context.Context) error { return s.StartOnce("Block Header Feeder Service", func() error { s.logger.Infow("Starting BlockHeaderFeeder") - ticker := time.NewTicker(utils.WithJitter(s.pollPeriod)) - s.parentCtx, s.cancel = context.WithCancel(context.Background()) + s.stopCh = make(chan struct{}) go func() { defer close(s.done) + ctx, cancel := s.stopCh.NewCtx() + defer cancel() + ticker := time.NewTicker(utils.WithJitter(s.pollPeriod)) defer ticker.Stop() for { select { case <-ticker.C: - s.runFeeder() - case <-s.parentCtx.Done(): + s.runFeeder(ctx) + case <-ctx.Done(): return } } @@ -259,15 +260,15 @@ func (s *service) Start(context.Context) error { func (s *service) Close() error { return s.StopOnce("Block Header Feeder Service", func() error { s.logger.Infow("Stopping BlockHeaderFeeder") - s.cancel() + close(s.stopCh) <-s.done return nil }) } -func (s *service) runFeeder() { +func (s *service) runFeeder(ctx context.Context) { s.logger.Debugw("Running BlockHeaderFeeder") - ctx, cancel := context.WithTimeout(s.parentCtx, s.runTimeout) + ctx, cancel := context.WithTimeout(ctx, s.runTimeout) defer cancel() err := s.feeder.Run(ctx) if err == nil { diff --git a/core/services/chainlink/config.go b/core/services/chainlink/config.go index b77a54f39a8..d0d25a5e461 100644 --- a/core/services/chainlink/config.go +++ b/core/services/chainlink/config.go @@ -9,10 +9,9 @@ import ( gotoml "github.com/pelletier/go-toml/v2" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" - commoncfg "github.com/smartcontractkit/chainlink/v2/common/config" evmcfg "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" "github.com/smartcontractkit/chainlink/v2/core/config/docs" "github.com/smartcontractkit/chainlink/v2/core/config/env" @@ -39,7 +38,7 @@ type Config struct { Cosmos coscfg.TOMLConfigs `toml:",omitempty"` - Solana solana.TOMLConfigs `toml:",omitempty"` + Solana solcfg.TOMLConfigs `toml:",omitempty"` Starknet stkcfg.TOMLConfigs `toml:",omitempty"` } @@ -79,10 +78,10 @@ func (c *Config) valueWarnings() (err error) { func (c *Config) deprecationWarnings() (err error) { // ChainType xdai is deprecated and has been renamed to gnosis for _, evm := range c.EVM { - if evm.ChainType != nil && *evm.ChainType == string(commoncfg.ChainXDai) { + if evm.ChainType != nil && evm.ChainType.Slug() == "xdai" { err = multierr.Append(err, config.ErrInvalid{ Name: "EVM.ChainType", - Value: *evm.ChainType, + Value: evm.ChainType.Slug(), Msg: "deprecated and will be removed in v2.13.0, use 'gnosis' instead", }) } @@ -122,7 +121,7 @@ func (c *Config) setDefaults() { for i := range c.Solana { if c.Solana[i] == nil { - c.Solana[i] = new(solana.TOMLConfig) + c.Solana[i] = new(solcfg.TOMLConfig) } c.Solana[i].Chain.SetDefaults() } diff --git a/core/services/chainlink/config_general.go b/core/services/chainlink/config_general.go index cae01c01cb7..ce34cc47e47 100644 --- a/core/services/chainlink/config_general.go +++ b/core/services/chainlink/config_general.go @@ -14,7 +14,7 @@ import ( "go.uber.org/zap/zapcore" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" starknet "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" @@ -201,7 +201,7 @@ func (g *generalConfig) CosmosConfigs() coscfg.TOMLConfigs { return g.c.Cosmos } -func (g *generalConfig) SolanaConfigs() solana.TOMLConfigs { +func (g *generalConfig) SolanaConfigs() solcfg.TOMLConfigs { return g.c.Solana } diff --git a/core/services/chainlink/config_test.go b/core/services/chainlink/config_test.go index 41075944c5f..2aa1d26c326 100644 --- a/core/services/chainlink/config_test.go +++ b/core/services/chainlink/config_test.go @@ -22,12 +22,11 @@ import ( commoncfg "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/utils/hex" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" commonconfig "github.com/smartcontractkit/chainlink/v2/common/config" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" evmcfg "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -160,7 +159,7 @@ var ( {Name: ptr("secondary"), TendermintURL: commoncfg.MustParseURL("http://bombay.cosmos.com")}, }}, }, - Solana: []*solana.TOMLConfig{ + Solana: []*solcfg.TOMLConfig{ { ChainID: ptr("mainnet"), Chain: solcfg.Chain{ @@ -495,7 +494,7 @@ func TestConfig_Marshal(t *testing.T) { }, BlockBackfillDepth: ptr[uint32](100), BlockBackfillSkip: ptr(true), - ChainType: ptr("Optimism"), + ChainType: commonconfig.NewChainTypeConfig("Optimism"), FinalityDepth: ptr[uint32](42), FinalityTagEnabled: ptr[bool](false), FlagsContractAddress: mustAddress("0xae4E781a6218A8031764928E88d457937A954fC3"), @@ -635,7 +634,7 @@ func TestConfig_Marshal(t *testing.T) { }, }}, } - full.Solana = []*solana.TOMLConfig{ + full.Solana = []*solcfg.TOMLConfig{ { ChainID: ptr("mainnet"), Enabled: ptr(false), @@ -1565,7 +1564,7 @@ func TestConfig_setDefaults(t *testing.T) { var c Config c.EVM = evmcfg.EVMConfigs{{ChainID: ubig.NewI(99999133712345)}} c.Cosmos = coscfg.TOMLConfigs{{ChainID: ptr("unknown cosmos chain")}} - c.Solana = solana.TOMLConfigs{{ChainID: ptr("unknown solana chain")}} + c.Solana = solcfg.TOMLConfigs{{ChainID: ptr("unknown solana chain")}} c.Starknet = stkcfg.TOMLConfigs{{ChainID: ptr("unknown starknet chain")}} c.setDefaults() if s, err := c.TOMLString(); assert.NoError(t, err) { @@ -1645,7 +1644,7 @@ func TestConfig_warnings(t *testing.T) { { name: "Value warning - ChainType=xdai is deprecated", config: Config{ - EVM: evmcfg.EVMConfigs{{Chain: evmcfg.Chain{ChainType: ptr(string(commonconfig.ChainXDai))}}}, + EVM: evmcfg.EVMConfigs{{Chain: evmcfg.Chain{ChainType: commonconfig.NewChainTypeConfig("xdai")}}}, }, expectedErrors: []string{"EVM.ChainType: invalid value (xdai): deprecated and will be removed in v2.13.0, use 'gnosis' instead"}, }, diff --git a/core/services/chainlink/mocks/general_config.go b/core/services/chainlink/mocks/general_config.go index a86753a59e3..c7e224f4f2a 100644 --- a/core/services/chainlink/mocks/general_config.go +++ b/core/services/chainlink/mocks/general_config.go @@ -10,7 +10,7 @@ import ( mock "github.com/stretchr/testify/mock" - solana "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solanaconfig "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" time "time" @@ -616,19 +616,19 @@ func (_m *GeneralConfig) ShutdownGracePeriod() time.Duration { } // SolanaConfigs provides a mock function with given fields: -func (_m *GeneralConfig) SolanaConfigs() solana.TOMLConfigs { +func (_m *GeneralConfig) SolanaConfigs() solanaconfig.TOMLConfigs { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for SolanaConfigs") } - var r0 solana.TOMLConfigs - if rf, ok := ret.Get(0).(func() solana.TOMLConfigs); ok { + var r0 solanaconfig.TOMLConfigs + if rf, ok := ret.Get(0).(func() solanaconfig.TOMLConfigs); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(solana.TOMLConfigs) + r0 = ret.Get(0).(solanaconfig.TOMLConfigs) } } diff --git a/core/services/chainlink/relayer_chain_interoperators_test.go b/core/services/chainlink/relayer_chain_interoperators_test.go index c6183cc1a34..c2baa1edcde 100644 --- a/core/services/chainlink/relayer_chain_interoperators_test.go +++ b/core/services/chainlink/relayer_chain_interoperators_test.go @@ -18,7 +18,6 @@ import ( solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" @@ -75,8 +74,8 @@ func TestCoreRelayerChainInteroperators(t *testing.T) { Nodes: evmcfg.EVMNodes{&node2_1}, }) - c.Solana = solana.TOMLConfigs{ - &solana.TOMLConfig{ + c.Solana = solcfg.TOMLConfigs{ + &solcfg.TOMLConfig{ ChainID: &solanaChainID1, Enabled: ptr(true), Chain: solcfg.Chain{}, @@ -85,7 +84,7 @@ func TestCoreRelayerChainInteroperators(t *testing.T) { URL: ((*commonconfig.URL)(commonconfig.MustParseURL("http://localhost:8547").URL())), }}, }, - &solana.TOMLConfig{ + &solcfg.TOMLConfig{ ChainID: &solanaChainID2, Enabled: ptr(true), Chain: solcfg.Chain{}, diff --git a/core/services/chainlink/relayer_factory.go b/core/services/chainlink/relayer_factory.go index 2aaeb253c0a..bcdb08b8026 100644 --- a/core/services/chainlink/relayer_factory.go +++ b/core/services/chainlink/relayer_factory.go @@ -14,7 +14,7 @@ import ( "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana" - pkgsolana "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" pkgstarknet "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink" starkchain "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/chain" "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" @@ -92,10 +92,10 @@ func (r *RelayerFactory) NewEVM(ctx context.Context, config EVMFactoryConfig) (m type SolanaFactoryConfig struct { Keystore keystore.Solana - solana.TOMLConfigs + solcfg.TOMLConfigs } -func (r *RelayerFactory) NewSolana(ks keystore.Solana, chainCfgs solana.TOMLConfigs) (map[types.RelayID]loop.Relayer, error) { +func (r *RelayerFactory) NewSolana(ks keystore.Solana, chainCfgs solcfg.TOMLConfigs) (map[types.RelayID]loop.Relayer, error) { solanaRelayers := make(map[types.RelayID]loop.Relayer) var ( solLggr = r.Logger.Named("Solana") @@ -123,7 +123,7 @@ func (r *RelayerFactory) NewSolana(ks keystore.Solana, chainCfgs solana.TOMLConf if cmdName := env.SolanaPlugin.Cmd.Get(); cmdName != "" { // setup the solana relayer to be a LOOP cfgTOML, err := toml.Marshal(struct { - Solana solana.TOMLConfig + Solana solcfg.TOMLConfig }{Solana: *chainCfg}) if err != nil { @@ -154,7 +154,7 @@ func (r *RelayerFactory) NewSolana(ks keystore.Solana, chainCfgs solana.TOMLConf if err != nil { return nil, err } - solanaRelayers[relayID] = relay.NewServerAdapter(pkgsolana.NewRelayer(lggr, chain), chain) + solanaRelayers[relayID] = relay.NewServerAdapter(solana.NewRelayer(lggr, chain), chain) } } return solanaRelayers, nil diff --git a/core/services/chainlink/types.go b/core/services/chainlink/types.go index 72cad694167..4c7550142a2 100644 --- a/core/services/chainlink/types.go +++ b/core/services/chainlink/types.go @@ -2,7 +2,7 @@ package chainlink import ( coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" @@ -15,7 +15,7 @@ type GeneralConfig interface { config.AppConfig toml.HasEVMConfigs CosmosConfigs() coscfg.TOMLConfigs - SolanaConfigs() solana.TOMLConfigs + SolanaConfigs() solcfg.TOMLConfigs StarknetConfigs() stkcfg.TOMLConfigs // ConfigTOML returns both the user provided and effective configuration as TOML. ConfigTOML() (user, effective string) diff --git a/core/services/feeds/connection_manager.go b/core/services/feeds/connection_manager.go index 6339339ab7c..d388bc0899f 100644 --- a/core/services/feeds/connection_manager.go +++ b/core/services/feeds/connection_manager.go @@ -1,7 +1,6 @@ package feeds import ( - "context" "crypto/ed25519" "sync" @@ -10,6 +9,7 @@ import ( "github.com/smartcontractkit/wsrpc" "github.com/smartcontractkit/wsrpc/connectivity" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/recovery" pb "github.com/smartcontractkit/chainlink/v2/core/services/feeds/proto" @@ -35,10 +35,7 @@ type connectionsManager struct { } type connection struct { - // ctx allows us to cancel any connections which are currently blocking - // while waiting to establish a connection to FMS. - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan connected bool client pb.FeedsManagerClient @@ -81,11 +78,8 @@ type ConnectOpts struct { // Eventually when FMS does come back up, wsrpc will establish the connection // without any interaction on behalf of the node operator. func (mgr *connectionsManager) Connect(opts ConnectOpts) { - ctx, cancel := context.WithCancel(context.Background()) - conn := &connection{ - ctx: ctx, - cancel: cancel, + stopCh: make(chan struct{}), connected: false, } @@ -96,11 +90,13 @@ func (mgr *connectionsManager) Connect(opts ConnectOpts) { mgr.mu.Unlock() go recovery.WrapRecover(mgr.lggr, func() { + ctx, cancel := conn.stopCh.NewCtx() + defer cancel() defer mgr.wgClosed.Done() mgr.lggr.Infow("Connecting to Feeds Manager...", "feedsManagerID", opts.FeedsManagerID) - clientConn, err := wsrpc.DialWithContext(conn.ctx, opts.URI, + clientConn, err := wsrpc.DialWithContext(ctx, opts.URI, wsrpc.WithTransportCreds(opts.Privkey, ed25519.PublicKey(opts.Pubkey)), wsrpc.WithBlock(), wsrpc.WithLogger(mgr.lggr), @@ -108,7 +104,7 @@ func (mgr *connectionsManager) Connect(opts ConnectOpts) { if err != nil { // We only want to log if there was an error that did not occur // from a context cancel. - if conn.ctx.Err() == nil { + if ctx.Err() == nil { mgr.lggr.Warnf("Error connecting to Feeds Manager server: %v", err) } else { mgr.lggr.Infof("Closing wsrpc websocket connection: %v", err) @@ -139,7 +135,7 @@ func (mgr *connectionsManager) Connect(opts ConnectOpts) { for { s := clientConn.GetState() - clientConn.WaitForStateChange(conn.ctx, s) + clientConn.WaitForStateChange(ctx, s) s = clientConn.GetState() @@ -155,7 +151,7 @@ func (mgr *connectionsManager) Connect(opts ConnectOpts) { }() // Wait for close - <-conn.ctx.Done() + <-ctx.Done() }) } @@ -169,7 +165,7 @@ func (mgr *connectionsManager) Disconnect(id int64) error { return errors.New("feeds manager is not connected") } - conn.cancel() + close(conn.stopCh) delete(mgr.connections, id) mgr.lggr.Infow("Disconnected Feeds Manager", "feedsManagerID", id) @@ -181,7 +177,7 @@ func (mgr *connectionsManager) Disconnect(id int64) error { func (mgr *connectionsManager) Close() { mgr.mu.Lock() for _, conn := range mgr.connections { - conn.cancel() + close(conn.stopCh) } mgr.mu.Unlock() diff --git a/core/services/feeds/service_test.go b/core/services/feeds/service_test.go index 43d75f712a0..b8cd590a402 100644 --- a/core/services/feeds/service_test.go +++ b/core/services/feeds/service_test.go @@ -643,6 +643,7 @@ func Test_Service_ProposeJob(t *testing.T) { // variables for workflow spec wfID = "15c631d295ef5e32deb99a10ee6804bc4af1385568f9b3363f6552ac6dbb2cef" wfOwner = "00000000000000000000000000000000000000aa" + wfName = "my-workflow" specYaml = ` triggers: - id: "a-trigger" @@ -666,7 +667,7 @@ targets: inputs: consensus_output: $(a-consensus.outputs) ` - wfSpec = testspecs.GenerateWorkflowSpec(wfID, wfOwner, specYaml).Toml() + wfSpec = testspecs.GenerateWorkflowSpec(wfID, wfOwner, wfName, specYaml).Toml() proposalIDWF = int64(11) remoteUUIDWF = uuid.New() argsWF = &feeds.ProposeJobArgs{ diff --git a/core/services/functions/listener.go b/core/services/functions/listener.go index d2033ff74de..bfcbf10f692 100644 --- a/core/services/functions/listener.go +++ b/core/services/functions/listener.go @@ -130,9 +130,7 @@ type functionsListener struct { job job.Job bridgeAccessor BridgeAccessor shutdownWaitGroup sync.WaitGroup - serviceContext context.Context - serviceCancel context.CancelFunc - chStop chan struct{} + chStop services.StopChan pluginORM ORM pluginConfig config.PluginConfig s4Storage s4.Storage @@ -186,12 +184,10 @@ func NewFunctionsListener( // Start complies with job.Service func (l *functionsListener) Start(context.Context) error { return l.StartOnce("FunctionsListener", func() error { - l.serviceContext, l.serviceCancel = context.WithCancel(context.Background()) - switch l.pluginConfig.ContractVersion { case 1: l.shutdownWaitGroup.Add(1) - go l.processOracleEventsV1(l.serviceContext) + go l.processOracleEventsV1() default: return fmt.Errorf("unsupported contract version: %d", l.pluginConfig.ContractVersion) } @@ -213,15 +209,16 @@ func (l *functionsListener) Start(context.Context) error { // Close complies with job.Service func (l *functionsListener) Close() error { return l.StopOnce("FunctionsListener", func() error { - l.serviceCancel() close(l.chStop) l.shutdownWaitGroup.Wait() return nil }) } -func (l *functionsListener) processOracleEventsV1(ctx context.Context) { +func (l *functionsListener) processOracleEventsV1() { defer l.shutdownWaitGroup.Done() + ctx, cancel := l.chStop.NewCtx() + defer cancel() freqMillis := l.pluginConfig.ListenerEventsCheckFrequencyMillis if freqMillis == 0 { l.logger.Errorw("ListenerEventsCheckFrequencyMillis must set to more than 0 in PluginConfig") @@ -255,11 +252,17 @@ func (l *functionsListener) processOracleEventsV1(ctx context.Context) { } func (l *functionsListener) getNewHandlerContext() (context.Context, context.CancelFunc) { + ctx, cancel := l.chStop.NewCtx() timeoutSec := l.pluginConfig.ListenerEventHandlerTimeoutSec if timeoutSec == 0 { - return context.WithCancel(l.serviceContext) + return ctx, cancel + } + var cancel2 func() + ctx, cancel2 = context.WithTimeout(ctx, time.Duration(timeoutSec)*time.Second) + return ctx, func() { + cancel2() + cancel() } - return context.WithTimeout(l.serviceContext, time.Duration(timeoutSec)*time.Second) } func (l *functionsListener) setError(ctx context.Context, requestId RequestID, errType ErrType, errBytes []byte) { diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist.go b/core/services/gateway/handlers/functions/allowlist/allowlist.go index f0fe5c8c829..2a27f51471a 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist.go @@ -210,33 +210,23 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b return errors.Wrap(err, "unexpected error during functions_allow_list.NewTermsOfServiceAllowList") } - var allowedSenderList []common.Address - typeAndVersion, err := tosContract.TypeAndVersion(&bind.CallOpts{ - Pending: false, - BlockNumber: blockNum, - Context: ctx, - }) - if err != nil { - return errors.Wrap(err, "failed to fetch the tos contract type and version") - } - - currentVersion, err := ExtractContractVersion(typeAndVersion) + currentVersion, err := fetchTosCurrentVersion(ctx, tosContract, blockNum) if err != nil { - return fmt.Errorf("failed to extract version: %w", err) + return fmt.Errorf("failed to fetch tos current version: %w", err) } if semver.Compare(tosContractMinBatchProcessingVersion, currentVersion) <= 0 { - err = a.syncBlockedSenders(ctx, tosContract, blockNum) + err = a.updateAllowedSendersInBatches(ctx, tosContract, blockNum) if err != nil { - return errors.Wrap(err, "failed to sync the stored allowed and blocked senders") + return errors.Wrap(err, "failed to get allowed senders in rage") } - allowedSenderList, err = a.getAllowedSendersBatched(ctx, tosContract, blockNum) + err := a.syncBlockedSenders(ctx, tosContract, blockNum) if err != nil { - return errors.Wrap(err, "failed to get allowed senders in rage") + return errors.Wrap(err, "failed to sync the stored allowed and blocked senders") } } else { - allowedSenderList, err = tosContract.GetAllAllowedSenders(&bind.CallOpts{ + allowedSenderList, err := tosContract.GetAllAllowedSenders(&bind.CallOpts{ Pending: false, BlockNumber: blockNum, Context: ctx, @@ -254,50 +244,108 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b if err != nil { a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) } + + a.update(allowedSenderList) } - a.update(allowedSenderList) return nil } -func (a *onchainAllowlist) getAllowedSendersBatched(ctx context.Context, tosContract *functions_allow_list.TermsOfServiceAllowList, blockNum *big.Int) ([]common.Address, error) { - allowedSenderList := make([]common.Address, 0) - count, err := tosContract.GetAllowedSendersCount(&bind.CallOpts{ +// updateAllowedSendersInBatches will update the node's inmemory state and the orm layer representing the allowlist. +// it will get the current node's in memory allowlist and start fetching and adding from the tos contract in batches. +// the iteration order will give priority to new allowed senders, if new addresses are added while iterating over the batches +// an extra step will be executed to keep this up to date. +func (a *onchainAllowlist) updateAllowedSendersInBatches(ctx context.Context, tosContract functions_allow_list.TermsOfServiceAllowListInterface, blockNum *big.Int) error { + // currentAllowedSenderList will be the starting point from which we will be adding the new allowed senders + currentAllowedSenderList := make(map[common.Address]struct{}, 0) + if cal := a.allowlist.Load(); cal != nil { + for k := range *cal { + currentAllowedSenderList[k] = struct{}{} + } + } + + currentAllowedSenderCount, err := tosContract.GetAllowedSendersCount(&bind.CallOpts{ Pending: false, BlockNumber: blockNum, Context: ctx, }) if err != nil { - return nil, errors.Wrap(err, "unexpected error during functions_allow_list.GetAllowedSendersCount") + return errors.Wrap(err, "unexpected error during functions_allow_list.GetAllowedSendersCount") } throttleTicker := time.NewTicker(time.Duration(a.config.FetchingDelayInRangeSec) * time.Second) - for idxStart := uint64(0); idxStart < count; idxStart += uint64(a.config.OnchainAllowlistBatchSize) { - <-throttleTicker.C - idxEnd := idxStart + uint64(a.config.OnchainAllowlistBatchSize) - if idxEnd >= count { - idxEnd = count - 1 + for i := int64(currentAllowedSenderCount); i > 0; i -= int64(a.config.OnchainAllowlistBatchSize) { + <-throttleTicker.C + var idxStart uint64 + if uint64(i) > uint64(a.config.OnchainAllowlistBatchSize) { + idxStart = uint64(i) - uint64(a.config.OnchainAllowlistBatchSize) } - allowedSendersBatch, err := tosContract.GetAllowedSendersInRange(&bind.CallOpts{ + idxEnd := uint64(i) - 1 + + // before continuing we evaluate if the size of the list changed, if that happens we trigger an extra step + // getting the latest added addresses from the list + updatedAllowedSenderCount, err := tosContract.GetAllowedSendersCount(&bind.CallOpts{ Pending: false, BlockNumber: blockNum, Context: ctx, - }, idxStart, idxEnd) + }) if err != nil { - return nil, errors.Wrap(err, "error calling GetAllowedSendersInRange") + return errors.Wrap(err, "unexpected error while fetching the updated functions_allow_list.GetAllowedSendersCount") + } + + if updatedAllowedSenderCount > currentAllowedSenderCount { + lastBatchIdxStart := currentAllowedSenderCount + lastBatchIdxEnd := updatedAllowedSenderCount - 1 + currentAllowedSenderCount = updatedAllowedSenderCount + + err = a.updateAllowedSendersBatch(ctx, tosContract, blockNum, lastBatchIdxStart, lastBatchIdxEnd, currentAllowedSenderList) + if err != nil { + return err + } } - allowedSenderList = append(allowedSenderList, allowedSendersBatch...) - err = a.orm.CreateAllowedSenders(ctx, allowedSendersBatch) + err = a.updateAllowedSendersBatch(ctx, tosContract, blockNum, idxStart, idxEnd, currentAllowedSenderList) if err != nil { - a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) + return err } } throttleTicker.Stop() - return allowedSenderList, nil + return nil +} + +func (a *onchainAllowlist) updateAllowedSendersBatch( + ctx context.Context, + tosContract functions_allow_list.TermsOfServiceAllowListInterface, + blockNum *big.Int, + idxStart uint64, + idxEnd uint64, + currentAllowedSenderList map[common.Address]struct{}, +) error { + allowedSendersBatch, err := tosContract.GetAllowedSendersInRange(&bind.CallOpts{ + Pending: false, + BlockNumber: blockNum, + Context: ctx, + }, idxStart, idxEnd) + if err != nil { + return errors.Wrap(err, "error calling GetAllowedSendersInRange") + } + + // add the fetched batch to the currentAllowedSenderList and replace the existing allowlist + for _, addr := range allowedSendersBatch { + currentAllowedSenderList[addr] = struct{}{} + } + a.allowlist.Store(¤tAllowedSenderList) + a.lggr.Infow("allowlist updated in batches successfully", "len", len(currentAllowedSenderList)) + + // persist each batch to the underalying orm layer + err = a.orm.CreateAllowedSenders(ctx, allowedSendersBatch) + if err != nil { + a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) + } + return nil } // syncBlockedSenders fetches the list of blocked addresses from the contract in batches @@ -370,6 +418,19 @@ func (a *onchainAllowlist) loadStoredAllowedSenderList(ctx context.Context) { a.update(allowedList) } +func fetchTosCurrentVersion(ctx context.Context, tosContract *functions_allow_list.TermsOfServiceAllowList, blockNum *big.Int) (string, error) { + typeAndVersion, err := tosContract.TypeAndVersion(&bind.CallOpts{ + Pending: false, + BlockNumber: blockNum, + Context: ctx, + }) + if err != nil { + return "", errors.Wrap(err, "failed to fetch the tos contract type and version") + } + + return ExtractContractVersion(typeAndVersion) +} + func ExtractContractVersion(str string) (string, error) { pattern := `v(\d+).(\d+).(\d+)` re := regexp.MustCompile(pattern) diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist_internal_test.go b/core/services/gateway/handlers/functions/allowlist/allowlist_internal_test.go new file mode 100644 index 00000000000..966db032636 --- /dev/null +++ b/core/services/gateway/handlers/functions/allowlist/allowlist_internal_test.go @@ -0,0 +1,216 @@ +package allowlist + +import ( + "context" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/functions/generated/functions_allow_list" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/logger" + amocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/allowlist/mocks" +) + +func TestUpdateAllowedSendersInBatches(t *testing.T) { + t.Run("OK-simple_update_in_batches", func(t *testing.T) { + ctx := context.Background() + config := OnchainAllowlistConfig{ + ContractAddress: testutils.NewAddress(), + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + StoredAllowlistBatchSize: 2, + OnchainAllowlistBatchSize: 10, + FetchingDelayInRangeSec: 1, + } + + // allowlistSize defines how big the mocked allowlist will be + allowlistSize := 53 + // allowlist represents the actual allowlist the tos contract will return + allowlist := make([]common.Address, 0, allowlistSize) + // expectedAllowlist will be used to compare the actual status with what we actually want + expectedAllowlist := make(map[common.Address]struct{}, 0) + + // we load both the expectedAllowlist and the allowlist the contract will return with some new addresses + for i := 0; i < allowlistSize; i++ { + addr := testutils.NewAddress() + allowlist = append(allowlist, addr) + expectedAllowlist[addr] = struct{}{} + } + + tosContract := NewTosContractMock(allowlist) + + // with the orm mock we can validate the actual order in which the allowlist is fetched giving priority to newest addresses + orm := amocks.NewORM(t) + firstCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[43:53]).Times(1).Return(nil) + secondCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[33:43]).Times(1).Return(nil).NotBefore(firstCall) + thirdCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[23:33]).Times(1).Return(nil).NotBefore(secondCall) + forthCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[13:23]).Times(1).Return(nil).NotBefore(thirdCall) + fifthCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[3:13]).Times(1).Return(nil).NotBefore(forthCall) + orm.On("CreateAllowedSenders", context.Background(), allowlist[0:3]).Times(1).Return(nil).NotBefore(fifthCall) + + onchainAllowlist := &onchainAllowlist{ + config: config, + orm: orm, + blockConfirmations: big.NewInt(int64(config.BlockConfirmations)), + lggr: logger.TestLogger(t).Named("OnchainAllowlist"), + stopCh: make(services.StopChan), + } + + // we set the onchain allowlist to an empty state before updating it in batches + emptyMap := make(map[common.Address]struct{}) + onchainAllowlist.allowlist.Store(&emptyMap) + + err := onchainAllowlist.updateAllowedSendersInBatches(ctx, tosContract, big.NewInt(0)) + require.NoError(t, err) + + currentAllowlist := onchainAllowlist.allowlist.Load() + require.Equal(t, &expectedAllowlist, currentAllowlist) + }) + + t.Run("OK-new_address_added_while_updating_in_batches", func(t *testing.T) { + ctx := context.Background() + config := OnchainAllowlistConfig{ + ContractAddress: testutils.NewAddress(), + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + StoredAllowlistBatchSize: 2, + OnchainAllowlistBatchSize: 10, + FetchingDelayInRangeSec: 1, + } + + // allowlistSize defines how big the initial mocked allowlist will be + allowlistSize := 50 + // allowlist represents the actual allowlist the tos contract will return + allowlist := make([]common.Address, 0) + // expectedAllowlist will be used to compare the actual status with what we actually want + expectedAllowlist := make(map[common.Address]struct{}, 0) + + // we load both the expectedAllowlist and the allowlist the contract will return with some new addresses + for i := 0; i < allowlistSize; i++ { + addr := testutils.NewAddress() + allowlist = append(allowlist, addr) + expectedAllowlist[addr] = struct{}{} + } + + tosContract := NewTosContractMock(allowlist) + + // with the orm mock we can validate the actual order in which the allowlist is fetched giving priority to newest addresses + orm := amocks.NewORM(t) + firstCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[40:50]).Times(1).Run(func(args mock.Arguments) { + // after the first call we update the tosContract by adding a new address + addr := testutils.NewAddress() + allowlist = append(allowlist, addr) + expectedAllowlist[addr] = struct{}{} + *tosContract = *NewTosContractMock(allowlist) + }).Return(nil) + + // this is the extra step that will fetch the new address we want to validate + extraStepCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[50:51]).Times(1).Return(nil).NotBefore(firstCall) + + secondCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[30:40]).Times(1).Return(nil).NotBefore(extraStepCall) + thirdCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[20:30]).Times(1).Return(nil).NotBefore(secondCall) + forthCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[10:20]).Times(1).Return(nil).NotBefore(thirdCall) + orm.On("CreateAllowedSenders", context.Background(), allowlist[0:10]).Times(1).Return(nil).NotBefore(forthCall) + + onchainAllowlist := &onchainAllowlist{ + config: config, + orm: orm, + blockConfirmations: big.NewInt(int64(config.BlockConfirmations)), + lggr: logger.TestLogger(t).Named("OnchainAllowlist"), + stopCh: make(services.StopChan), + } + + // we set the onchain allowlist to an empty state before updating it in batches + emptyMap := make(map[common.Address]struct{}) + onchainAllowlist.allowlist.Store(&emptyMap) + + err := onchainAllowlist.updateAllowedSendersInBatches(ctx, tosContract, big.NewInt(0)) + require.NoError(t, err) + + currentAllowlist := onchainAllowlist.allowlist.Load() + require.Equal(t, &expectedAllowlist, currentAllowlist) + }) + + t.Run("OK-allowlist_size_smaller_than_batchsize", func(t *testing.T) { + ctx := context.Background() + config := OnchainAllowlistConfig{ + ContractAddress: testutils.NewAddress(), + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + StoredAllowlistBatchSize: 2, + OnchainAllowlistBatchSize: 100, + FetchingDelayInRangeSec: 1, + } + + // allowlistSize defines how big the mocked allowlist will be + allowlistSize := 50 + // allowlist represents the actual allowlist the tos contract will return + allowlist := make([]common.Address, 0, allowlistSize) + // expectedAllowlist will be used to compare the actual status with what we actually want + expectedAllowlist := make(map[common.Address]struct{}, 0) + + // we load both the expectedAllowlist and the allowlist the contract will return with some new addresses + for i := 0; i < allowlistSize; i++ { + addr := testutils.NewAddress() + allowlist = append(allowlist, addr) + expectedAllowlist[addr] = struct{}{} + } + + tosContract := NewTosContractMock(allowlist) + + // with the orm mock we can validate the actual order in which the allowlist is fetched giving priority to newest addresses + orm := amocks.NewORM(t) + orm.On("CreateAllowedSenders", context.Background(), allowlist[0:50]).Times(1).Return(nil) + + onchainAllowlist := &onchainAllowlist{ + config: config, + orm: orm, + blockConfirmations: big.NewInt(int64(config.BlockConfirmations)), + lggr: logger.TestLogger(t).Named("OnchainAllowlist"), + stopCh: make(services.StopChan), + } + + // we set the onchain allowlist to an empty state before updating it in batches + emptyMap := make(map[common.Address]struct{}) + onchainAllowlist.allowlist.Store(&emptyMap) + + err := onchainAllowlist.updateAllowedSendersInBatches(ctx, tosContract, big.NewInt(0)) + require.NoError(t, err) + + currentAllowlist := onchainAllowlist.allowlist.Load() + require.Equal(t, &expectedAllowlist, currentAllowlist) + }) +} + +type tosContractMock struct { + functions_allow_list.TermsOfServiceAllowListInterface + + onchainAllowlist []common.Address +} + +func NewTosContractMock(onchainAllowlist []common.Address) *tosContractMock { + return &tosContractMock{ + onchainAllowlist: onchainAllowlist, + } +} + +func (t *tosContractMock) GetAllowedSendersCount(opts *bind.CallOpts) (uint64, error) { + return uint64(len(t.onchainAllowlist)), nil +} + +func (t *tosContractMock) GetAllowedSendersInRange(opts *bind.CallOpts, allowedSenderIdxStart uint64, allowedSenderIdxEnd uint64) ([]common.Address, error) { + // we replicate the onchain behaviour of including start and end indexes + return t.onchainAllowlist[allowedSenderIdxStart : allowedSenderIdxEnd+1], nil +} diff --git a/core/services/job/job_orm_test.go b/core/services/job/job_orm_test.go index 3c7d5a7afa5..c13cf9da4b1 100644 --- a/core/services/job/job_orm_test.go +++ b/core/services/job/job_orm_test.go @@ -10,10 +10,12 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/google/uuid" "github.com/lib/pq" + "github.com/pelletier/go-toml/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/guregu/null.v4" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" @@ -1803,6 +1805,187 @@ func Test_CountPipelineRunsByJobID(t *testing.T) { }) } +func Test_ORM_FindJobByWorkflow(t *testing.T) { + type fields struct { + ds sqlutil.DataSource + } + type args struct { + spec job.WorkflowSpec + before func(t *testing.T, o job.ORM, s job.WorkflowSpec) int32 + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "wf not job found", + fields: fields{ + ds: pgtest.NewSqlxDB(t), + }, + args: args{ + // before is nil, so no job is inserted + spec: job.WorkflowSpec{ + ID: 1, + WorkflowID: "workflow 1", + Workflow: "abcd", + WorkflowOwner: "me", + WorkflowName: "myworkflow", + }, + }, + wantErr: true, + }, + { + name: "wf job found", + fields: fields{ + ds: pgtest.NewSqlxDB(t), + }, + args: args{ + spec: job.WorkflowSpec{ + ID: 1, + WorkflowID: "workflow 2", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "myworkflow", + }, + before: mustInsertWFJob, + }, + wantErr: false, + }, + + { + name: "wf wrong name", + fields: fields{ + ds: pgtest.NewSqlxDB(t), + }, + args: args{ + spec: job.WorkflowSpec{ + ID: 1, + WorkflowID: "workflow 3", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "wf3", + }, + before: func(t *testing.T, o job.ORM, s job.WorkflowSpec) int32 { + s.WorkflowName = "notmyworkflow" + return mustInsertWFJob(t, o, s) + }, + }, + wantErr: true, + }, + { + name: "wf wrong owner", + fields: fields{ + ds: pgtest.NewSqlxDB(t), + }, + args: args{ + spec: job.WorkflowSpec{ + ID: 1, + WorkflowID: "workflow 4", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "wf4", + }, + before: func(t *testing.T, o job.ORM, s job.WorkflowSpec) int32 { + s.WorkflowOwner = "not me" + return mustInsertWFJob(t, o, s) + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ks := cltest.NewKeyStore(t, tt.fields.ds) + pipelineORM := pipeline.NewORM(tt.fields.ds, logger.TestLogger(t), configtest.NewTestGeneralConfig(t).JobPipeline().MaxSuccessfulRuns()) + bridgesORM := bridges.NewORM(tt.fields.ds) + o := NewTestORM(t, tt.fields.ds, pipelineORM, bridgesORM, ks) + var wantJobID int32 + if tt.args.before != nil { + wantJobID = tt.args.before(t, o, tt.args.spec) + } + ctx := testutils.Context(t) + gotJ, err := o.FindJobIDByWorkflow(ctx, tt.args.spec) + if (err != nil) != tt.wantErr { + t.Errorf("orm.FindJobByWorkflow() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + assert.Equal(t, wantJobID, gotJ, "mismatch job id") + } + }) + } + + t.Run("multiple jobs", func(t *testing.T) { + db := pgtest.NewSqlxDB(t) + o := NewTestORM(t, + db, + pipeline.NewORM(db, + logger.TestLogger(t), + configtest.NewTestGeneralConfig(t).JobPipeline().MaxSuccessfulRuns()), + bridges.NewORM(db), + cltest.NewKeyStore(t, db)) + ctx := testutils.Context(t) + s1 := job.WorkflowSpec{ + WorkflowID: "workflowid", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "a_common_name", + } + wantJobID1 := mustInsertWFJob(t, o, s1) + + s2 := job.WorkflowSpec{ + WorkflowID: "another workflowid", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "another workflow name", + } + wantJobID2 := mustInsertWFJob(t, o, s2) + + s3 := job.WorkflowSpec{ + WorkflowID: "xworkflowid", + Workflow: "anything", + WorkflowOwner: "someone else", + WorkflowName: "a_common_name", + } + wantJobID3 := mustInsertWFJob(t, o, s3) + + expectedIDs := []int32{wantJobID1, wantJobID2, wantJobID3} + for i, s := range []job.WorkflowSpec{s1, s2, s3} { + gotJ, err := o.FindJobIDByWorkflow(ctx, s) + require.NoError(t, err) + assert.Equal(t, expectedIDs[i], gotJ, "mismatch job id case %d, spec %v", i, s) + j, err := o.FindJob(ctx, expectedIDs[i]) + require.NoError(t, err) + assert.NotNil(t, j) + t.Logf("found job %v", j) + assert.EqualValues(t, j.WorkflowSpec.Workflow, s.Workflow) + assert.EqualValues(t, j.WorkflowSpec.WorkflowID, s.WorkflowID) + assert.EqualValues(t, j.WorkflowSpec.WorkflowOwner, s.WorkflowOwner) + assert.EqualValues(t, j.WorkflowSpec.WorkflowName, s.WorkflowName) + } + }) +} + +func mustInsertWFJob(t *testing.T, orm job.ORM, s job.WorkflowSpec) int32 { + t.Helper() + ctx := testutils.Context(t) + _, err := toml.Marshal(s.Workflow) + require.NoError(t, err) + j := job.Job{ + Type: job.Workflow, + WorkflowSpec: &s, + ExternalJobID: uuid.New(), + Name: null.StringFrom(s.WorkflowOwner + "_" + s.WorkflowName), + SchemaVersion: 1, + } + err = orm.CreateJob(ctx, &j) + require.NoError(t, err) + return j.ID +} + func mustInsertPipelineRun(t *testing.T, orm pipeline.ORM, j job.Job) pipeline.Run { t.Helper() ctx := testutils.Context(t) diff --git a/core/services/job/mocks/orm.go b/core/services/job/mocks/orm.go index ec60137de93..e8911b25af3 100644 --- a/core/services/job/mocks/orm.go +++ b/core/services/job/mocks/orm.go @@ -248,6 +248,34 @@ func (_m *ORM) FindJobIDByAddress(ctx context.Context, address types.EIP55Addres return r0, r1 } +// FindJobIDByWorkflow provides a mock function with given fields: ctx, spec +func (_m *ORM) FindJobIDByWorkflow(ctx context.Context, spec job.WorkflowSpec) (int32, error) { + ret := _m.Called(ctx, spec) + + if len(ret) == 0 { + panic("no return value specified for FindJobIDByWorkflow") + } + + var r0 int32 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, job.WorkflowSpec) (int32, error)); ok { + return rf(ctx, spec) + } + if rf, ok := ret.Get(0).(func(context.Context, job.WorkflowSpec) int32); ok { + r0 = rf(ctx, spec) + } else { + r0 = ret.Get(0).(int32) + } + + if rf, ok := ret.Get(1).(func(context.Context, job.WorkflowSpec) error); ok { + r1 = rf(ctx, spec) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindJobIDsWithBridge provides a mock function with given fields: ctx, name func (_m *ORM) FindJobIDsWithBridge(ctx context.Context, name string) ([]int32, error) { ret := _m.Called(ctx, name) diff --git a/core/services/job/models.go b/core/services/job/models.go index 578e9e079b8..9601df2e02d 100644 --- a/core/services/job/models.go +++ b/core/services/job/models.go @@ -845,6 +845,7 @@ type WorkflowSpec struct { WorkflowID string `toml:"workflowId"` Workflow string `toml:"workflow"` WorkflowOwner string `toml:"workflowOwner"` + WorkflowName string `toml:"workflowName"` CreatedAt time.Time `toml:"-"` UpdatedAt time.Time `toml:"-"` } @@ -863,5 +864,9 @@ func (w *WorkflowSpec) Validate() error { return fmt.Errorf("incorrect length for owner %s: expected %d, got %d", w.WorkflowOwner, workflowOwnerLen, len(w.WorkflowOwner)) } + if w.WorkflowName == "" { + return fmt.Errorf("workflow name is required") + } + return nil } diff --git a/core/services/job/orm.go b/core/services/job/orm.go index d54d6fba522..71a4ebebb1e 100644 --- a/core/services/job/orm.go +++ b/core/services/job/orm.go @@ -78,6 +78,8 @@ type ORM interface { DataSource() sqlutil.DataSource WithDataSource(source sqlutil.DataSource) ORM + + FindJobIDByWorkflow(ctx context.Context, spec WorkflowSpec) (int32, error) } type ORMConfig interface { @@ -395,8 +397,8 @@ func (o *orm) CreateJob(ctx context.Context, jb *Job) error { case Stream: // 'stream' type has no associated spec, nothing to do here case Workflow: - sql := `INSERT INTO workflow_specs (workflow, workflow_id, workflow_owner, created_at, updated_at) - VALUES (:workflow, :workflow_id, :workflow_owner, NOW(), NOW()) + sql := `INSERT INTO workflow_specs (workflow, workflow_id, workflow_owner, workflow_name, created_at, updated_at) + VALUES (:workflow, :workflow_id, :workflow_owner, :workflow_name, NOW(), NOW()) RETURNING id;` specID, err := tx.prepareQuerySpecID(ctx, sql, jb.WorkflowSpec) if err != nil { @@ -1043,6 +1045,23 @@ func (o *orm) FindJobIDsWithBridge(ctx context.Context, name string) (jids []int return } +func (o *orm) FindJobIDByWorkflow(ctx context.Context, spec WorkflowSpec) (jobID int32, err error) { + stmt := ` +SELECT jobs.id FROM jobs +INNER JOIN workflow_specs ws on jobs.workflow_spec_id = ws.id AND ws.workflow_owner = $1 AND ws.workflow_name = $2 +` + err = o.ds.GetContext(ctx, &jobID, stmt, spec.WorkflowOwner, spec.WorkflowName) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + err = fmt.Errorf("error searching for job by workflow (owner,name) ('%s','%s'): %w", spec.WorkflowOwner, spec.WorkflowName, err) + } + err = fmt.Errorf("FindJobIDByWorkflow failed: %w", err) + return + } + + return +} + // PipelineRunsByJobsIDs returns pipeline runs for multiple jobs, not preloading data func (o *orm) PipelineRunsByJobsIDs(ctx context.Context, ids []int32) (runs []pipeline.Run, err error) { err = o.transact(ctx, false, func(tx *orm) error { diff --git a/core/services/keeper/delegate.go b/core/services/keeper/delegate.go index 71a0c5c43a9..c9d189b30c5 100644 --- a/core/services/keeper/delegate.go +++ b/core/services/keeper/delegate.go @@ -93,7 +93,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services // In the case of forwarding, the keeper address is the forwarder contract deployed onchain between EOA and Registry. effectiveKeeperAddress := spec.KeeperSpec.FromAddress.Address() if spec.ForwardingAllowed { - fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(spec.KeeperSpec.FromAddress.Address()) + fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(ctx, spec.KeeperSpec.FromAddress.Address()) if fwderr == nil { effectiveKeeperAddress = fwdrAddress } else { diff --git a/core/services/keeper/integration_test.go b/core/services/keeper/integration_test.go index 9e4cf5f9041..cbbe89b3f21 100644 --- a/core/services/keeper/integration_test.go +++ b/core/services/keeper/integration_test.go @@ -417,7 +417,7 @@ func TestKeeperForwarderEthIntegration(t *testing.T) { _, err = forwarderORM.CreateForwarder(ctx, fwdrAddress, chainID) require.NoError(t, err) - addr, err := app.GetRelayers().LegacyEVMChains().Slice()[0].TxManager().GetForwarderForEOA(nodeAddress) + addr, err := app.GetRelayers().LegacyEVMChains().Slice()[0].TxManager().GetForwarderForEOA(ctx, nodeAddress) require.NoError(t, err) require.Equal(t, addr, fwdrAddress) diff --git a/core/services/ocr/config_overrider.go b/core/services/ocr/config_overrider.go index 435efa437c7..067c06f58ce 100644 --- a/core/services/ocr/config_overrider.go +++ b/core/services/ocr/config_overrider.go @@ -31,9 +31,8 @@ type ConfigOverriderImpl struct { DeltaCFromAddress time.Duration // Start/Stop lifecycle - ctx context.Context - ctxCancel context.CancelFunc - chDone chan struct{} + chStop services.StopChan + chDone chan struct{} mu sync.RWMutex } @@ -63,7 +62,6 @@ func NewConfigOverriderImpl( addressSeconds := addressBig.Mod(addressBig, big.NewInt(jitterSeconds)).Uint64() deltaC := cfg.DeltaCOverride() + time.Duration(addressSeconds)*time.Second - ctx, cancel := context.WithCancel(context.Background()) co := ConfigOverriderImpl{ services.StateMachine{}, logger, @@ -73,8 +71,7 @@ func NewConfigOverriderImpl( time.Now(), InitialHibernationStatus, deltaC, - ctx, - cancel, + make(chan struct{}), make(chan struct{}), sync.RWMutex{}, } @@ -96,7 +93,7 @@ func (c *ConfigOverriderImpl) Start(context.Context) error { func (c *ConfigOverriderImpl) Close() error { return c.StopOnce("OCRContractTracker", func() error { - c.ctxCancel() + close(c.chStop) <-c.chDone return nil }) @@ -104,11 +101,13 @@ func (c *ConfigOverriderImpl) Close() error { func (c *ConfigOverriderImpl) eventLoop() { defer close(c.chDone) + ctx, cancel := c.chStop.NewCtx() + defer cancel() c.pollTicker.Resume() defer c.pollTicker.Destroy() for { select { - case <-c.ctx.Done(): + case <-ctx.Done(): return case <-c.pollTicker.Ticks(): if err := c.updateFlagsStatus(); err != nil { diff --git a/core/services/ocr/contract_tracker.go b/core/services/ocr/contract_tracker.go index e1f51c588cd..34852bbe74b 100644 --- a/core/services/ocr/contract_tracker.go +++ b/core/services/ocr/contract_tracker.go @@ -400,7 +400,7 @@ func (t *OCRContractTracker) LatestBlockHeight(ctx context.Context) (blockheight // care about the block height; we have no way of getting the L1 block // height anyway return 0, nil - case "", config.ChainArbitrum, config.ChainCelo, config.ChainGnosis, config.ChainKroma, config.ChainOptimismBedrock, config.ChainScroll, config.ChainWeMix, config.ChainXDai, config.ChainXLayer, config.ChainZkEvm, config.ChainZkSync: + case "", config.ChainArbitrum, config.ChainCelo, config.ChainGnosis, config.ChainKroma, config.ChainOptimismBedrock, config.ChainScroll, config.ChainWeMix, config.ChainXLayer, config.ChainZkEvm, config.ChainZkSync: // continue } latestBlockHeight := t.getLatestBlockHeight() diff --git a/core/services/ocr/delegate.go b/core/services/ocr/delegate.go index e748823ad71..a47e7ec9e7d 100644 --- a/core/services/ocr/delegate.go +++ b/core/services/ocr/delegate.go @@ -216,7 +216,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] // In the case of forwarding, the transmitter address is the forwarder contract deployed onchain between EOA and OCR contract. effectiveTransmitterAddress := concreteSpec.TransmitterAddress.Address() if jb.ForwardingAllowed { - fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(effectiveTransmitterAddress) + fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(ctx, effectiveTransmitterAddress) if fwderr == nil { effectiveTransmitterAddress = fwdrAddress } else { diff --git a/core/services/ocr/validate_test.go b/core/services/ocr/validate_test.go index 6e68559d09d..59213c7168e 100644 --- a/core/services/ocr/validate_test.go +++ b/core/services/ocr/validate_test.go @@ -27,6 +27,58 @@ func TestValidateOracleSpec(t *testing.T) { overrides func(c *chainlink.Config, s *chainlink.Secrets) assertion func(t *testing.T, os job.Job, err error) }{ + { + name: "invalid result sorting index", + toml: ` +ds1 [type=memo value=10000.1234]; +ds2 [type=memo value=100]; + +div_by_ds2 [type=divide divisor="$(ds2)"]; + +ds1 -> div_by_ds2 -> answer1; + +answer1 [type=multiply times=10000 index=-1]; +`, + assertion: func(t *testing.T, os job.Job, err error) { + require.Error(t, err) + }, + }, + { + name: "duplicate sorting indexes not allowed", + toml: ` +ds1 [type=memo value=10000.1234]; +ds2 [type=memo value=100]; + +div_by_ds2 [type=divide divisor="$(ds2)"]; + +ds1 -> div_by_ds2 -> answer1; +ds1 -> div_by_ds2 -> answer2; + +answer1 [type=multiply times=10000 index=0]; +answer2 [type=multiply times=10000 index=0]; +`, + assertion: func(t *testing.T, os job.Job, err error) { + require.Error(t, err) + }, + }, + { + name: "invalid result sorting index", + toml: ` +type = "offchainreporting" +schemaVersion = 1 +contractAddress = "0x613a38AC1659769640aaE063C651F48E0250454C" +isBootstrapPeer = false +observationSource = """ +ds1 [type=bridge name=voter_turnout]; +ds1_parse [type=jsonparse path="one,two"]; +ds1_multiply [type=multiply times=1.23]; +ds1 -> ds1_parse -> ds1_multiply -> answer1; +answer1 [type=median index=-1]; +"""`, + assertion: func(t *testing.T, os job.Job, err error) { + require.Error(t, err) + }, + }, { name: "minimal non-bootstrap oracle spec", toml: ` diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index 8ea43582126..350cbc8d593 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -382,7 +382,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi if err2 != nil { return nil, fmt.Errorf("ServicesForSpec: could not get EVM chain %s: %w", rid.ChainID, err2) } - effectiveTransmitterID, err2 = GetEVMEffectiveTransmitterID(&jb, chain, lggr) + effectiveTransmitterID, err2 = GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) if err2 != nil { return nil, fmt.Errorf("ServicesForSpec failed to get evm transmitterID: %w", err2) } @@ -470,7 +470,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi } } -func GetEVMEffectiveTransmitterID(jb *job.Job, chain legacyevm.Chain, lggr logger.SugaredLogger) (string, error) { +func GetEVMEffectiveTransmitterID(ctx context.Context, jb *job.Job, chain legacyevm.Chain, lggr logger.SugaredLogger) (string, error) { spec := jb.OCR2OracleSpec if spec.PluginType == types.Mercury || spec.PluginType == types.LLO { return spec.TransmitterID.String, nil @@ -501,9 +501,9 @@ func GetEVMEffectiveTransmitterID(jb *job.Job, chain legacyevm.Chain, lggr logge var effectiveTransmitterID common.Address // Median forwarders need special handling because of OCR2Aggregator transmitters whitelist. if spec.PluginType == types.Median { - effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOAOCR2Feeds(common.HexToAddress(spec.TransmitterID.String), common.HexToAddress(spec.ContractID)) + effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOAOCR2Feeds(ctx, common.HexToAddress(spec.TransmitterID.String), common.HexToAddress(spec.ContractID)) } else { - effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOA(common.HexToAddress(spec.TransmitterID.String)) + effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOA(ctx, common.HexToAddress(spec.TransmitterID.String)) } if err == nil { diff --git a/core/services/ocr2/delegate_test.go b/core/services/ocr2/delegate_test.go index bc5c2df2bbe..1e4be66c7d1 100644 --- a/core/services/ocr2/delegate_test.go +++ b/core/services/ocr2/delegate_test.go @@ -5,10 +5,12 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink-common/pkg/types" + evmcfg "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" txmmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr/mocks" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" @@ -27,7 +29,6 @@ import ( ) func TestGetEVMEffectiveTransmitterID(t *testing.T) { - ctx := testutils.Context(t) customChainID := big.New(testutils.NewRandomEVMChainID()) config := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -41,7 +42,7 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { }) db := pgtest.NewSqlxDB(t) keyStore := cltest.NewKeyStore(t, db) - require.NoError(t, keyStore.OCR2().Add(ctx, cltest.DefaultOCR2Key)) + require.NoError(t, keyStore.OCR2().Add(testutils.Context(t), cltest.DefaultOCR2Key)) lggr := logger.TestLogger(t) txManager := txmmocks.NewMockEvmTxManager(t) @@ -67,7 +68,7 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = tc.sendingKeys jb.ForwardingAllowed = tc.forwardingEnabled - args := []interface{}{tc.getForwarderForEOAArg} + args := []interface{}{mock.Anything, tc.getForwarderForEOAArg} getForwarderMethodName := "GetForwarderForEOA" if tc.pluginType == types.Median { getForwarderMethodName = "GetForwarderForEOAOCR2Feeds" @@ -144,13 +145,14 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { } t.Run("when sending keys are not defined, the first one should be set to transmitterID", func(t *testing.T) { + ctx := testutils.Context(t) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) jb.OCR2OracleSpec.TransmitterID = null.StringFrom("some transmitterID string") jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = nil chain, err := legacyChains.Get(customChainID.String()) require.NoError(t, err) - effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(&jb, chain, lggr) + effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) require.NoError(t, err) require.Equal(t, "some transmitterID string", effectiveTransmitterID) require.Equal(t, []string{"some transmitterID string"}, jb.OCR2OracleSpec.RelayConfig["sendingKeys"].([]string)) @@ -158,13 +160,14 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + ctx := testutils.Context(t) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) setTestCase(&jb, tc, txManager) chain, err := legacyChains.Get(customChainID.String()) require.NoError(t, err) - effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(&jb, chain, lggr) + effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) if tc.expectedError { require.Error(t, err) } else { @@ -180,13 +183,14 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { } t.Run("when forwarders are enabled and chain retrieval fails, error should be handled", func(t *testing.T) { + ctx := testutils.Context(t) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) jb.ForwardingAllowed = true jb.OCR2OracleSpec.TransmitterID = null.StringFrom("0x7e57000000000000000000000000000000000001") chain, err := legacyChains.Get("not an id") require.Error(t, err) - _, err = ocr2.GetEVMEffectiveTransmitterID(&jb, chain, lggr) + _, err = ocr2.GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) require.Error(t, err) }) } diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry.go index da4dd17d96f..56c200f9b13 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry.go @@ -102,10 +102,12 @@ func NewEVMRegistryService(addr common.Address, client legacyevm.Chain, lggr log enc: EVMAutomationEncoder20{}, } - r.ctx, r.cancel = context.WithCancel(context.Background()) + r.stopCh = make(chan struct{}) r.reInit = time.NewTimer(reInitializationDelay) - if err := r.registerEvents(client.ID().Uint64(), addr); err != nil { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() + if err := r.registerEvents(ctx, client.ID().Uint64(), addr); err != nil { return nil, fmt.Errorf("logPoller error while registering automation events: %w", err) } @@ -152,8 +154,7 @@ type EvmRegistry struct { mu sync.RWMutex txHashes map[string]bool lastPollBlock int64 - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan active map[string]activeUpkeep headFunc func(ocr2keepers.BlockKey) runState int @@ -209,8 +210,10 @@ func (r *EvmRegistry) Start(_ context.Context) error { defer r.mu.Unlock() // initialize the upkeep keys; if the reInit timer returns, do it again { - go func(cx context.Context, tmr *time.Timer, lggr logger.Logger, f func() error) { - err := f() + go func(tmr *time.Timer, lggr logger.Logger, f func(context.Context) error) { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() + err := f(ctx) if err != nil { lggr.Errorf("failed to initialize upkeeps", err) } @@ -218,53 +221,57 @@ func (r *EvmRegistry) Start(_ context.Context) error { for { select { case <-tmr.C: - err = f() + err = f(ctx) if err != nil { lggr.Errorf("failed to re-initialize upkeeps", err) } tmr.Reset(reInitializationDelay) - case <-cx.Done(): + case <-ctx.Done(): return } } - }(r.ctx, r.reInit, r.lggr, r.initialize) + }(r.reInit, r.lggr, r.initialize) } // start polling logs on an interval { - go func(cx context.Context, lggr logger.Logger, f func() error) { + go func(lggr logger.Logger, f func(context.Context) error) { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() ticker := time.NewTicker(time.Second) for { select { case <-ticker.C: - err := f() + err := f(ctx) if err != nil { lggr.Errorf("failed to poll logs for upkeeps", err) } - case <-cx.Done(): + case <-ctx.Done(): ticker.Stop() return } } - }(r.ctx, r.lggr, r.pollLogs) + }(r.lggr, r.pollLogs) } // run process to process logs from log channel { - go func(cx context.Context, ch chan logpoller.Log, lggr logger.Logger, f func(logpoller.Log) error) { + go func(ch chan logpoller.Log, lggr logger.Logger, f func(context.Context, logpoller.Log) error) { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() for { select { case l := <-ch: - err := f(l) + err := f(ctx, l) if err != nil { lggr.Errorf("failed to process log for upkeep", err) } - case <-cx.Done(): + case <-ctx.Done(): return } } - }(r.ctx, r.chLog, r.lggr, r.processUpkeepStateLog) + }(r.chLog, r.lggr, r.processUpkeepStateLog) } r.runState = 1 @@ -276,7 +283,7 @@ func (r *EvmRegistry) Close() error { return r.sync.StopOnce("AutomationRegistry", func() error { r.mu.Lock() defer r.mu.Unlock() - r.cancel() + close(r.stopCh) r.runState = 0 r.runError = nil return nil @@ -303,8 +310,8 @@ func (r *EvmRegistry) HealthReport() map[string]error { return map[string]error{r.Name(): r.sync.Healthy()} } -func (r *EvmRegistry) initialize() error { - startupCtx, cancel := context.WithTimeout(r.ctx, reInitializationDelay) +func (r *EvmRegistry) initialize(ctx context.Context) error { + startupCtx, cancel := context.WithTimeout(ctx, reInitializationDelay) defer cancel() idMap := make(map[string]activeUpkeep) @@ -345,12 +352,12 @@ func (r *EvmRegistry) initialize() error { return nil } -func (r *EvmRegistry) pollLogs() error { +func (r *EvmRegistry) pollLogs(ctx context.Context) error { var latest int64 var end logpoller.LogPollerBlock var err error - if end, err = r.poller.LatestBlock(r.ctx); err != nil { + if end, err = r.poller.LatestBlock(ctx); err != nil { return fmt.Errorf("%w: %s", ErrHeadNotAvailable, err) } @@ -367,7 +374,7 @@ func (r *EvmRegistry) pollLogs() error { { var logs []logpoller.Log if logs, err = r.poller.LogsWithSigs( - r.ctx, + ctx, end.BlockNumber-logEventLookback, end.BlockNumber, upkeepStateEvents, @@ -388,17 +395,17 @@ func UpkeepFilterName(addr common.Address) string { return logpoller.FilterName("EvmRegistry - Upkeep events for", addr.String()) } -func (r *EvmRegistry) registerEvents(chainID uint64, addr common.Address) error { +func (r *EvmRegistry) registerEvents(ctx context.Context, chainID uint64, addr common.Address) error { // Add log filters for the log poller so that it can poll and find the logs that // we need - return r.poller.RegisterFilter(r.ctx, logpoller.Filter{ + return r.poller.RegisterFilter(ctx, logpoller.Filter{ Name: UpkeepFilterName(addr), EventSigs: append(upkeepStateEvents, upkeepActiveEvents...), Addresses: []common.Address{addr}, }) } -func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { +func (r *EvmRegistry) processUpkeepStateLog(ctx context.Context, l logpoller.Log) error { hash := l.TxHash.String() if _, ok := r.txHashes[hash]; ok { return nil @@ -414,22 +421,22 @@ func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { switch l := abilog.(type) { case *keeper_registry_wrapper2_0.KeeperRegistryUpkeepRegistered: r.lggr.Debugf("KeeperRegistryUpkeepRegistered log detected for upkeep ID %s in transaction %s", l.Id.String(), hash) - r.addToActive(l.Id, false) + r.addToActive(ctx, l.Id, false) case *keeper_registry_wrapper2_0.KeeperRegistryUpkeepReceived: r.lggr.Debugf("KeeperRegistryUpkeepReceived log detected for upkeep ID %s in transaction %s", l.Id.String(), hash) - r.addToActive(l.Id, false) + r.addToActive(ctx, l.Id, false) case *keeper_registry_wrapper2_0.KeeperRegistryUpkeepUnpaused: r.lggr.Debugf("KeeperRegistryUpkeepUnpaused log detected for upkeep ID %s in transaction %s", l.Id.String(), hash) - r.addToActive(l.Id, false) + r.addToActive(ctx, l.Id, false) case *keeper_registry_wrapper2_0.KeeperRegistryUpkeepGasLimitSet: r.lggr.Debugf("KeeperRegistryUpkeepGasLimitSet log detected for upkeep ID %s in transaction %s", l.Id.String(), hash) - r.addToActive(l.Id, true) + r.addToActive(ctx, l.Id, true) } return nil } -func (r *EvmRegistry) addToActive(id *big.Int, force bool) { +func (r *EvmRegistry) addToActive(ctx context.Context, id *big.Int, force bool) { r.mu.Lock() defer r.mu.Unlock() @@ -438,7 +445,7 @@ func (r *EvmRegistry) addToActive(id *big.Int, force bool) { } if _, ok := r.active[id.String()]; !ok || force { - actives, err := r.getUpkeepConfigs(r.ctx, []*big.Int{id}) + actives, err := r.getUpkeepConfigs(ctx, []*big.Int{id}) if err != nil { r.lggr.Errorf("failed to get upkeep configs during adding active upkeep: %w", err) return diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry_test.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry_test.go index 592563f0b04..8100980dd6b 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry_test.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry_test.go @@ -186,6 +186,7 @@ func TestPollLogs(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { + ctx := testutils.Context(t) mp := new(mocks.LogPoller) if test.LatestBlock != nil { @@ -205,7 +206,7 @@ func TestPollLogs(t *testing.T) { chLog: make(chan logpoller.Log, 10), } - err := rg.pollLogs() + err := rg.pollLogs(ctx) assert.Equal(t, test.ExpectedLastPoll, rg.lastPollBlock) if test.ExpectedErr != nil { diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry.go index 082318518a5..6bab073b9b3 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry.go @@ -98,7 +98,7 @@ func NewEvmRegistry( hc := http.DefaultClient return &EvmRegistry{ - ctx: context.Background(), + stopCh: make(chan struct{}), threadCtrl: utils.NewThreadControl(), lggr: lggr.Named(RegistryServiceName), poller: client.LogPoller(), @@ -190,7 +190,7 @@ type EvmRegistry struct { logProcessed map[string]bool active ActiveUpkeepList lastPollBlock int64 - ctx context.Context + stopCh services.StopChan headFunc func(ocr2keepers.BlockKey) mercury *MercuryConfig hc HttpClient @@ -207,13 +207,13 @@ func (r *EvmRegistry) Name() string { func (r *EvmRegistry) Start(ctx context.Context) error { return r.StartOnce(RegistryServiceName, func() error { - if err := r.registerEvents(r.chainID, r.addr); err != nil { + if err := r.registerEvents(ctx, r.chainID, r.addr); err != nil { return fmt.Errorf("logPoller error while registering automation events: %w", err) } r.threadCtrl.Go(func(ctx context.Context) { lggr := r.lggr.With("where", "upkeeps_referesh") - err := r.refreshActiveUpkeeps() + err := r.refreshActiveUpkeeps(ctx) if err != nil { lggr.Errorf("failed to initialize upkeeps", err) } @@ -224,7 +224,7 @@ func (r *EvmRegistry) Start(ctx context.Context) error { for { select { case <-ticker.C: - err = r.refreshActiveUpkeeps() + err = r.refreshActiveUpkeeps(ctx) if err != nil { lggr.Errorf("failed to refresh upkeeps", err) } @@ -242,7 +242,7 @@ func (r *EvmRegistry) Start(ctx context.Context) error { for { select { case <-ticker.C: - err := r.pollUpkeepStateLogs() + err := r.pollUpkeepStateLogs(ctx) if err != nil { lggr.Errorf("failed to poll logs for upkeeps", err) } @@ -259,7 +259,7 @@ func (r *EvmRegistry) Start(ctx context.Context) error { for { select { case l := <-ch: - err := r.processUpkeepStateLog(l) + err := r.processUpkeepStateLog(ctx, l) if err != nil { lggr.Errorf("failed to process log for upkeep", err) } @@ -284,9 +284,9 @@ func (r *EvmRegistry) HealthReport() map[string]error { return map[string]error{RegistryServiceName: r.Healthy()} } -func (r *EvmRegistry) refreshActiveUpkeeps() error { +func (r *EvmRegistry) refreshActiveUpkeeps(ctx context.Context) error { // Allow for max timeout of refreshInterval - ctx, cancel := context.WithTimeout(r.ctx, refreshInterval) + ctx, cancel := context.WithTimeout(ctx, refreshInterval) defer cancel() r.lggr.Debugf("Refreshing active upkeeps list") @@ -311,17 +311,17 @@ func (r *EvmRegistry) refreshActiveUpkeeps() error { } } - _, err = r.logEventProvider.RefreshActiveUpkeeps(r.ctx, logTriggerIDs...) + _, err = r.logEventProvider.RefreshActiveUpkeeps(ctx, logTriggerIDs...) if err != nil { return fmt.Errorf("failed to refresh active upkeep ids in log event provider: %w", err) } // Try to refersh log trigger config for all log upkeeps - return r.refreshLogTriggerUpkeeps(logTriggerIDs) + return r.refreshLogTriggerUpkeeps(ctx, logTriggerIDs) } // refreshLogTriggerUpkeeps refreshes the active upkeep ids for log trigger upkeeps -func (r *EvmRegistry) refreshLogTriggerUpkeeps(ids []*big.Int) error { +func (r *EvmRegistry) refreshLogTriggerUpkeeps(ctx context.Context, ids []*big.Int) error { var err error for i := 0; i < len(ids); i += logTriggerRefreshBatchSize { end := i + logTriggerRefreshBatchSize @@ -330,7 +330,7 @@ func (r *EvmRegistry) refreshLogTriggerUpkeeps(ids []*big.Int) error { } idBatch := ids[i:end] - if batchErr := r.refreshLogTriggerUpkeepsBatch(idBatch); batchErr != nil { + if batchErr := r.refreshLogTriggerUpkeepsBatch(ctx, idBatch); batchErr != nil { multierr.AppendInto(&err, batchErr) } @@ -340,17 +340,17 @@ func (r *EvmRegistry) refreshLogTriggerUpkeeps(ids []*big.Int) error { return err } -func (r *EvmRegistry) refreshLogTriggerUpkeepsBatch(logTriggerIDs []*big.Int) error { +func (r *EvmRegistry) refreshLogTriggerUpkeepsBatch(ctx context.Context, logTriggerIDs []*big.Int) error { var logTriggerHashes []common.Hash for _, id := range logTriggerIDs { logTriggerHashes = append(logTriggerHashes, common.BigToHash(id)) } - unpausedLogs, err := r.poller.IndexedLogs(r.ctx, ac.IAutomationV21PlusCommonUpkeepUnpaused{}.Topic(), r.addr, 1, logTriggerHashes, evmtypes.Confirmations(r.finalityDepth)) + unpausedLogs, err := r.poller.IndexedLogs(ctx, ac.IAutomationV21PlusCommonUpkeepUnpaused{}.Topic(), r.addr, 1, logTriggerHashes, evmtypes.Confirmations(r.finalityDepth)) if err != nil { return err } - configSetLogs, err := r.poller.IndexedLogs(r.ctx, ac.IAutomationV21PlusCommonUpkeepTriggerConfigSet{}.Topic(), r.addr, 1, logTriggerHashes, evmtypes.Confirmations(r.finalityDepth)) + configSetLogs, err := r.poller.IndexedLogs(ctx, ac.IAutomationV21PlusCommonUpkeepTriggerConfigSet{}.Topic(), r.addr, 1, logTriggerHashes, evmtypes.Confirmations(r.finalityDepth)) if err != nil { return err } @@ -400,7 +400,7 @@ func (r *EvmRegistry) refreshLogTriggerUpkeepsBatch(logTriggerIDs []*big.Int) er if unpausedBlockNumbers[id.String()] > logBlock { logBlock = unpausedBlockNumbers[id.String()] } - if err := r.updateTriggerConfig(id, config, logBlock); err != nil { + if err := r.updateTriggerConfig(ctx, id, config, logBlock); err != nil { merr = goerrors.Join(merr, fmt.Errorf("failed to update trigger config for upkeep id %s: %w", id.String(), err)) } } @@ -408,12 +408,12 @@ func (r *EvmRegistry) refreshLogTriggerUpkeepsBatch(logTriggerIDs []*big.Int) er return merr } -func (r *EvmRegistry) pollUpkeepStateLogs() error { +func (r *EvmRegistry) pollUpkeepStateLogs(ctx context.Context) error { var latest int64 var end logpoller.LogPollerBlock var err error - if end, err = r.poller.LatestBlock(r.ctx); err != nil { + if end, err = r.poller.LatestBlock(ctx); err != nil { return fmt.Errorf("%w: %s", ErrHeadNotAvailable, err) } @@ -429,7 +429,7 @@ func (r *EvmRegistry) pollUpkeepStateLogs() error { var logs []logpoller.Log if logs, err = r.poller.LogsWithSigs( - r.ctx, + ctx, end.BlockNumber-logEventLookback, end.BlockNumber, upkeepStateEvents, @@ -445,7 +445,7 @@ func (r *EvmRegistry) pollUpkeepStateLogs() error { return nil } -func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { +func (r *EvmRegistry) processUpkeepStateLog(ctx context.Context, l logpoller.Log) error { lid := fmt.Sprintf("%s%d", l.TxHash.String(), l.LogIndex) r.mu.Lock() if _, ok := r.logProcessed[lid]; ok { @@ -465,16 +465,16 @@ func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { switch l := abilog.(type) { case *ac.IAutomationV21PlusCommonUpkeepPaused: r.lggr.Debugf("KeeperRegistryUpkeepPaused log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) - r.removeFromActive(r.ctx, l.Id) + r.removeFromActive(ctx, l.Id) case *ac.IAutomationV21PlusCommonUpkeepCanceled: r.lggr.Debugf("KeeperRegistryUpkeepCanceled log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) - r.removeFromActive(r.ctx, l.Id) + r.removeFromActive(ctx, l.Id) case *ac.IAutomationV21PlusCommonUpkeepMigrated: r.lggr.Debugf("AutomationV2CommonUpkeepMigrated log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) - r.removeFromActive(r.ctx, l.Id) + r.removeFromActive(ctx, l.Id) case *ac.IAutomationV21PlusCommonUpkeepTriggerConfigSet: r.lggr.Debugf("KeeperRegistryUpkeepTriggerConfigSet log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) - if err := r.updateTriggerConfig(l.Id, l.TriggerConfig, rawLog.BlockNumber); err != nil { + if err := r.updateTriggerConfig(ctx, l.Id, l.TriggerConfig, rawLog.BlockNumber); err != nil { r.lggr.Warnf("failed to update trigger config upon AutomationV2CommonUpkeepTriggerConfigSet for upkeep ID %s: %s", l.Id.String(), err) } case *ac.IAutomationV21PlusCommonUpkeepRegistered: @@ -483,19 +483,19 @@ func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { trigger := core.GetUpkeepType(*uid) r.lggr.Debugf("KeeperRegistryUpkeepRegistered log detected for upkeep ID %s (trigger=%d) in transaction %s", l.Id.String(), trigger, txHash) r.active.Add(l.Id) - if err := r.updateTriggerConfig(l.Id, nil, rawLog.BlockNumber); err != nil { + if err := r.updateTriggerConfig(ctx, l.Id, nil, rawLog.BlockNumber); err != nil { r.lggr.Warnf("failed to update trigger config upon AutomationV2CommonUpkeepRegistered for upkeep ID %s: %s", err) } case *ac.IAutomationV21PlusCommonUpkeepReceived: r.lggr.Debugf("KeeperRegistryUpkeepReceived log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) r.active.Add(l.Id) - if err := r.updateTriggerConfig(l.Id, nil, rawLog.BlockNumber); err != nil { + if err := r.updateTriggerConfig(ctx, l.Id, nil, rawLog.BlockNumber); err != nil { r.lggr.Warnf("failed to update trigger config upon AutomationV2CommonUpkeepReceived for upkeep ID %s: %s", err) } case *ac.IAutomationV21PlusCommonUpkeepUnpaused: r.lggr.Debugf("KeeperRegistryUpkeepUnpaused log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) r.active.Add(l.Id) - if err := r.updateTriggerConfig(l.Id, nil, rawLog.BlockNumber); err != nil { + if err := r.updateTriggerConfig(ctx, l.Id, nil, rawLog.BlockNumber); err != nil { r.lggr.Warnf("failed to update trigger config upon AutomationV2CommonUpkeepUnpaused for upkeep ID %s: %s", err) } default: @@ -510,9 +510,9 @@ func RegistryUpkeepFilterName(addr common.Address) string { } // registerEvents registers upkeep state events from keeper registry on log poller -func (r *EvmRegistry) registerEvents(_ uint64, addr common.Address) error { +func (r *EvmRegistry) registerEvents(ctx context.Context, _ uint64, addr common.Address) error { // Add log filters for the log poller so that it can poll and find the logs that we need - return r.poller.RegisterFilter(r.ctx, logpoller.Filter{ + return r.poller.RegisterFilter(ctx, logpoller.Filter{ Name: RegistryUpkeepFilterName(addr), EventSigs: upkeepStateEvents, Addresses: []common.Address{addr}, @@ -591,13 +591,13 @@ func (r *EvmRegistry) getLatestIDsFromContract(ctx context.Context) ([]*big.Int, } // updateTriggerConfig updates the trigger config for an upkeep. it will re-register a filter for this upkeep. -func (r *EvmRegistry) updateTriggerConfig(id *big.Int, cfg []byte, logBlock uint64) error { +func (r *EvmRegistry) updateTriggerConfig(ctx context.Context, id *big.Int, cfg []byte, logBlock uint64) error { uid := &ocr2keepers.UpkeepIdentifier{} uid.FromBigInt(id) switch core.GetUpkeepType(*uid) { case types2.LogTrigger: if len(cfg) == 0 { - fetched, err := r.fetchTriggerConfig(id) + fetched, err := r.fetchTriggerConfig(ctx, id) if err != nil { return errors.Wrap(err, "failed to fetch log upkeep config") } @@ -609,7 +609,7 @@ func (r *EvmRegistry) updateTriggerConfig(id *big.Int, cfg []byte, logBlock uint r.lggr.Warnw("failed to unpack log upkeep config", "upkeepID", id.String(), "err", err) return nil } - if err := r.logEventProvider.RegisterFilter(r.ctx, logprovider.FilterOptions{ + if err := r.logEventProvider.RegisterFilter(ctx, logprovider.FilterOptions{ TriggerConfig: logprovider.LogTriggerConfig(parsed), UpkeepID: id, UpdateBlock: logBlock, @@ -623,8 +623,8 @@ func (r *EvmRegistry) updateTriggerConfig(id *big.Int, cfg []byte, logBlock uint } // fetchTriggerConfig fetches trigger config in raw bytes for an upkeep. -func (r *EvmRegistry) fetchTriggerConfig(id *big.Int) ([]byte, error) { - opts := r.buildCallOpts(r.ctx, nil) +func (r *EvmRegistry) fetchTriggerConfig(ctx context.Context, id *big.Int) ([]byte, error) { + opts := r.buildCallOpts(ctx, nil) cfg, err := r.registry.GetUpkeepTriggerConfig(opts, id) if err != nil { r.lggr.Warnw("failed to get trigger config", "err", err) @@ -634,8 +634,8 @@ func (r *EvmRegistry) fetchTriggerConfig(id *big.Int) ([]byte, error) { } // fetchUpkeepOffchainConfig fetches upkeep offchain config in raw bytes for an upkeep. -func (r *EvmRegistry) fetchUpkeepOffchainConfig(id *big.Int) ([]byte, error) { - opts := r.buildCallOpts(r.ctx, nil) +func (r *EvmRegistry) fetchUpkeepOffchainConfig(ctx context.Context, id *big.Int) ([]byte, error) { + opts := r.buildCallOpts(ctx, nil) ui, err := r.registry.GetUpkeep(opts, id) if err != nil { return []byte{}, err diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline.go index 491099496cb..5294530140b 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline.go @@ -96,8 +96,8 @@ func (r *EvmRegistry) getBlockAndUpkeepId(upkeepID ocr2keepers.UpkeepIdentifier, return block, common.BytesToHash(trigger.BlockHash[:]), upkeepID.BigInt() } -func (r *EvmRegistry) getBlockHash(blockNumber *big.Int) (common.Hash, error) { - blocks, err := r.poller.GetBlocksRange(r.ctx, []uint64{blockNumber.Uint64()}) +func (r *EvmRegistry) getBlockHash(ctx context.Context, blockNumber *big.Int) (common.Hash, error) { + blocks, err := r.poller.GetBlocksRange(ctx, []uint64{blockNumber.Uint64()}) if err != nil { return [32]byte{}, err } @@ -109,7 +109,7 @@ func (r *EvmRegistry) getBlockHash(blockNumber *big.Int) (common.Hash, error) { } // verifyCheckBlock checks that the check block and hash are valid, returns the pipeline execution state and retryable -func (r *EvmRegistry) verifyCheckBlock(_ context.Context, checkBlock, upkeepId *big.Int, checkHash common.Hash) (state encoding.PipelineExecutionState, retryable bool) { +func (r *EvmRegistry) verifyCheckBlock(ctx context.Context, checkBlock, upkeepId *big.Int, checkHash common.Hash) (state encoding.PipelineExecutionState, retryable bool) { // verify check block number and hash are valid h, ok := r.bs.queryBlocksMap(checkBlock.Int64()) // if this block number/hash combo exists in block subscriber, this check block and hash still exist on chain and are valid @@ -119,7 +119,7 @@ func (r *EvmRegistry) verifyCheckBlock(_ context.Context, checkBlock, upkeepId * return encoding.NoPipelineError, false } r.lggr.Warnf("check block %s does not exist in block subscriber or hash does not match for upkeepId %s. this may be caused by block subscriber outdated due to re-org, querying eth client to confirm", checkBlock, upkeepId) - b, err := r.getBlockHash(checkBlock) + b, err := r.getBlockHash(ctx, checkBlock) if err != nil { r.lggr.Warnf("failed to query block %s: %s", checkBlock, err.Error()) return encoding.RpcFlakyFailure, true @@ -132,7 +132,7 @@ func (r *EvmRegistry) verifyCheckBlock(_ context.Context, checkBlock, upkeepId * } // verifyLogExists checks that the log still exists on chain, returns failure reason, pipeline error, and retryable -func (r *EvmRegistry) verifyLogExists(upkeepId *big.Int, p ocr2keepers.UpkeepPayload) (encoding.UpkeepFailureReason, encoding.PipelineExecutionState, bool) { +func (r *EvmRegistry) verifyLogExists(ctx context.Context, upkeepId *big.Int, p ocr2keepers.UpkeepPayload) (encoding.UpkeepFailureReason, encoding.PipelineExecutionState, bool) { logBlockNumber := int64(p.Trigger.LogTriggerExtension.BlockNumber) logBlockHash := common.BytesToHash(p.Trigger.LogTriggerExtension.BlockHash[:]) checkBlockHash := common.BytesToHash(p.Trigger.BlockHash[:]) @@ -158,7 +158,7 @@ func (r *EvmRegistry) verifyLogExists(upkeepId *big.Int, p ocr2keepers.UpkeepPay r.lggr.Debugf("log block not provided, querying eth client for tx hash %s for upkeepId %s", hexutil.Encode(p.Trigger.LogTriggerExtension.TxHash[:]), upkeepId) } // query eth client as a fallback - bn, bh, err := core.GetTxBlock(r.ctx, r.client, p.Trigger.LogTriggerExtension.TxHash) + bn, bh, err := core.GetTxBlock(ctx, r.client, p.Trigger.LogTriggerExtension.TxHash) if err != nil { // primitive way of checking errors if strings.Contains(err.Error(), "missing required field") || strings.Contains(err.Error(), "not found") { @@ -202,7 +202,7 @@ func (r *EvmRegistry) checkUpkeeps(ctx context.Context, payloads []ocr2keepers.U uid.FromBigInt(upkeepId) switch core.GetUpkeepType(*uid) { case types.LogTrigger: - reason, state, retryable := r.verifyLogExists(upkeepId, p) + reason, state, retryable := r.verifyLogExists(ctx, upkeepId, p) if reason != encoding.UpkeepFailureReasonNone || state != encoding.NoPipelineError { results[i] = encoding.GetIneligibleCheckResultWithoutPerformData(p, reason, state, retryable) continue @@ -306,7 +306,7 @@ func (r *EvmRegistry) simulatePerformUpkeeps(ctx context.Context, checkResults [ block, _, upkeepId := r.getBlockAndUpkeepId(cr.UpkeepID, cr.Trigger) - oc, err := r.fetchUpkeepOffchainConfig(upkeepId) + oc, err := r.fetchUpkeepOffchainConfig(ctx, upkeepId) if err != nil { // this is mostly caused by RPC flakiness r.lggr.Errorw("failed get offchain config, gas price check will be disabled", "err", err, "upkeepId", upkeepId, "block", block) diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline_test.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline_test.go index f1b9cc66ae4..6f8785fda78 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline_test.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline_test.go @@ -346,13 +346,13 @@ func TestRegistry_VerifyLogExists(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + ctx := testutils.Context(t) bs := &BlockSubscriber{ blocks: tc.blocks, } e := &EvmRegistry{ lggr: lggr, bs: bs, - ctx: testutils.Context(t), } if tc.makeEthCall { @@ -370,7 +370,7 @@ func TestRegistry_VerifyLogExists(t *testing.T) { e.client = client } - reason, state, retryable := e.verifyLogExists(tc.upkeepId, tc.payload) + reason, state, retryable := e.verifyLogExists(ctx, tc.upkeepId, tc.payload) assert.Equal(t, tc.reason, reason) assert.Equal(t, tc.state, state) assert.Equal(t, tc.retryable, retryable) diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_test.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_test.go index 34b3f822dbf..ab530f877ae 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_test.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_test.go @@ -16,7 +16,9 @@ import ( types2 "github.com/smartcontractkit/chainlink-automation/pkg/v3/types" "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" types3 "github.com/smartcontractkit/chainlink/v2/core/chains/evm/headtracker/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" @@ -186,6 +188,7 @@ func TestPollLogs(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { + ctx := testutils.Context(t) mp := new(mocks.LogPoller) if test.LatestBlock != nil { @@ -205,7 +208,7 @@ func TestPollLogs(t *testing.T) { chLog: make(chan logpoller.Log, 10), } - err := rg.pollUpkeepStateLogs() + err := rg.pollUpkeepStateLogs(ctx) assert.Equal(t, test.ExpectedLastPoll, rg.lastPollBlock) if test.ExpectedErr != nil { @@ -542,6 +545,7 @@ func TestRegistry_refreshLogTriggerUpkeeps(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { + ctx := tests.Context(t) lggr := logger.TestLogger(t) var hb types3.HeadBroadcaster var lp logpoller.LogPoller @@ -559,7 +563,7 @@ func TestRegistry_refreshLogTriggerUpkeeps(t *testing.T) { lggr: lggr, } - err := registry.refreshLogTriggerUpkeeps(tc.ids) + err := registry.refreshLogTriggerUpkeeps(ctx, tc.ids) if tc.expectsErr { assert.Error(t, err) assert.Equal(t, err.Error(), tc.wantErr.Error()) diff --git a/core/services/ocr2/plugins/ocr2keeper/integration_test.go b/core/services/ocr2/plugins/ocr2keeper/integration_test.go index 29e56460c36..1c9c6025391 100644 --- a/core/services/ocr2/plugins/ocr2keeper/integration_test.go +++ b/core/services/ocr2/plugins/ocr2keeper/integration_test.go @@ -427,6 +427,7 @@ func setupForwarderForNode( backend *backends.SimulatedBackend, recipient common.Address, linkAddr common.Address) common.Address { + ctx := testutils.Context(t) faddr, _, authorizedForwarder, err := authorized_forwarder.DeployAuthorizedForwarder(caller, backend, linkAddr, caller.From, recipient, []byte{}) require.NoError(t, err) @@ -438,12 +439,12 @@ func setupForwarderForNode( // add forwarder address to be tracked in db forwarderORM := forwarders.NewORM(app.GetDB()) chainID := ubig.Big(*backend.Blockchain().Config().ChainID) - _, err = forwarderORM.CreateForwarder(testutils.Context(t), faddr, chainID) + _, err = forwarderORM.CreateForwarder(ctx, faddr, chainID) require.NoError(t, err) chain, err := app.GetRelayers().LegacyEVMChains().Get((*big.Int)(&chainID).String()) require.NoError(t, err) - fwdr, err := chain.TxManager().GetForwarderForEOA(recipient) + fwdr, err := chain.TxManager().GetForwarderForEOA(ctx, recipient) require.NoError(t, err) require.Equal(t, faddr, fwdr) diff --git a/core/services/ocrcommon/block_translator.go b/core/services/ocrcommon/block_translator.go index ddae8c51d98..7bce661e692 100644 --- a/core/services/ocrcommon/block_translator.go +++ b/core/services/ocrcommon/block_translator.go @@ -21,7 +21,7 @@ func NewBlockTranslator(cfg Config, client evmclient.Client, lggr logger.Logger) switch cfg.ChainType() { case config.ChainArbitrum: return NewArbitrumBlockTranslator(client, lggr) - case "", config.ChainCelo, config.ChainGnosis, config.ChainKroma, config.ChainMetis, config.ChainOptimismBedrock, config.ChainScroll, config.ChainWeMix, config.ChainXDai, config.ChainXLayer, config.ChainZkEvm, config.ChainZkSync: + case "", config.ChainCelo, config.ChainGnosis, config.ChainKroma, config.ChainMetis, config.ChainOptimismBedrock, config.ChainScroll, config.ChainWeMix, config.ChainXLayer, config.ChainZkEvm, config.ChainZkSync: fallthrough default: return &l1BlockTranslator{} diff --git a/core/services/pipeline/common.go b/core/services/pipeline/common.go index a0fc28c6862..5d843b8b918 100644 --- a/core/services/pipeline/common.go +++ b/core/services/pipeline/common.go @@ -415,9 +415,11 @@ func UnmarshalTaskFromMap(taskType TaskType, taskMap interface{}, ID int, dotID return nil, pkgerrors.Errorf(`unknown task type: "%v"`, taskType) } + metadata := mapstructure.Metadata{} decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ Result: task, WeaklyTypedInput: true, + Metadata: &metadata, DecodeHook: mapstructure.ComposeDecodeHookFunc( mapstructure.StringToTimeDurationHookFunc(), func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { @@ -441,6 +443,23 @@ func UnmarshalTaskFromMap(taskType TaskType, taskMap interface{}, ID int, dotID if err != nil { return nil, err } + + // valid explicit index values are 0-based + for _, key := range metadata.Keys { + if key == "index" { + if task.OutputIndex() < 0 { + return nil, errors.New("result sorting indexes should start with 0") + } + } + } + + // the 'unset' value should be -1 to allow explicit indexes to be 0-based + for _, key := range metadata.Unset { + if key == "index" { + task.Base().Index = -1 + } + } + return task, nil } diff --git a/core/services/pipeline/graph.go b/core/services/pipeline/graph.go index c3914e698c6..12bec2bc8b9 100644 --- a/core/services/pipeline/graph.go +++ b/core/services/pipeline/graph.go @@ -235,6 +235,8 @@ func Parse(text string) (*Pipeline, error) { // we need a temporary mapping of graph.IDs to positional ids after toposort ids := make(map[int64]int) + resultIdxs := make(map[int32]struct{}) + // use the new ordering as the id so that we can easily reproduce the original toposort for id, node := range nodes { node, is := node.(*GraphNode) @@ -251,6 +253,15 @@ func Parse(text string) (*Pipeline, error) { return nil, err } + if task.OutputIndex() > 0 { + _, exists := resultIdxs[task.OutputIndex()] + if exists { + return nil, errors.New("duplicate sorting indexes detected") + } + + resultIdxs[task.OutputIndex()] = struct{}{} + } + // re-link the edges for inputs := g.To(node.ID()); inputs.Next(); { isImplicitEdge := g.IsImplicitEdge(inputs.Node().ID(), node.ID()) diff --git a/core/services/pipeline/graph_test.go b/core/services/pipeline/graph_test.go index b3960bb1f46..c6248a38e24 100644 --- a/core/services/pipeline/graph_test.go +++ b/core/services/pipeline/graph_test.go @@ -171,27 +171,27 @@ func TestGraph_TasksInDependencyOrder(t *testing.T) { "ds1_multiply", []pipeline.TaskDependency{{PropagateResult: true, InputTask: pipeline.Task(ds1_parse)}}, []pipeline.Task{answer1}, - 0) + -1) ds2_multiply.BaseTask = pipeline.NewBaseTask( 5, "ds2_multiply", []pipeline.TaskDependency{{PropagateResult: true, InputTask: pipeline.Task(ds2_parse)}}, []pipeline.Task{answer1}, - 0) + -1) ds1_parse.BaseTask = pipeline.NewBaseTask( 1, "ds1_parse", []pipeline.TaskDependency{{PropagateResult: true, InputTask: pipeline.Task(ds1)}}, []pipeline.Task{ds1_multiply}, - 0) + -1) ds2_parse.BaseTask = pipeline.NewBaseTask( 4, "ds2_parse", []pipeline.TaskDependency{{PropagateResult: true, InputTask: pipeline.Task(ds2)}}, []pipeline.Task{ds2_multiply}, - 0) - ds1.BaseTask = pipeline.NewBaseTask(0, "ds1", nil, []pipeline.Task{ds1_parse}, 0) - ds2.BaseTask = pipeline.NewBaseTask(3, "ds2", nil, []pipeline.Task{ds2_parse}, 0) + -1) + ds1.BaseTask = pipeline.NewBaseTask(0, "ds1", nil, []pipeline.Task{ds1_parse}, -1) + ds2.BaseTask = pipeline.NewBaseTask(3, "ds2", nil, []pipeline.Task{ds2_parse}, -1) for i, task := range p.Tasks { // Make sure inputs appear before the task, and outputs don't diff --git a/core/services/pipeline/helpers_test.go b/core/services/pipeline/helpers_test.go index 97d81f56f74..7068209aa18 100644 --- a/core/services/pipeline/helpers_test.go +++ b/core/services/pipeline/helpers_test.go @@ -1,6 +1,7 @@ package pipeline import ( + "context" "net/http" "github.com/google/uuid" @@ -64,4 +65,4 @@ func (t *ETHTxTask) HelperSetDependencies(legacyChains legacyevm.LegacyChainCont t.jobType = jobType } -func (o *orm) Prune(pipelineSpecID int32) { o.prune(o.ds, pipelineSpecID) } +func (o *orm) Prune(ctx context.Context, pipelineSpecID int32) { o.prune(ctx, o.ds, pipelineSpecID) } diff --git a/core/services/pipeline/orm.go b/core/services/pipeline/orm.go index 06774e06e99..266b605ed42 100644 --- a/core/services/pipeline/orm.go +++ b/core/services/pipeline/orm.go @@ -108,22 +108,19 @@ type orm struct { lggr logger.Logger maxSuccessfulRuns uint64 // jobID => count - pm sync.Map - wg sync.WaitGroup - ctx context.Context - cncl context.CancelFunc + pm sync.Map + wg sync.WaitGroup + stopCh services.StopChan } var _ ORM = (*orm)(nil) func NewORM(ds sqlutil.DataSource, lggr logger.Logger, jobPipelineMaxSuccessfulRuns uint64) *orm { - ctx, cancel := context.WithCancel(context.Background()) return &orm{ ds: ds, lggr: lggr.Named("PipelineORM"), maxSuccessfulRuns: jobPipelineMaxSuccessfulRuns, - ctx: ctx, - cncl: cancel, + stopCh: make(chan struct{}), } } @@ -142,7 +139,7 @@ func (o *orm) Start(_ context.Context) error { func (o *orm) Close() error { return o.StopOnce("PipelineORM", func() error { - o.cncl() + close(o.stopCh) o.wg.Wait() return nil }) @@ -177,13 +174,11 @@ func (o *orm) DataSource() sqlutil.DataSource { return o.ds } func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { return o.withDataSource(ds) } func (o *orm) withDataSource(ds sqlutil.DataSource) *orm { - ctx, cancel := context.WithCancel(context.Background()) return &orm{ ds: ds, lggr: o.lggr, maxSuccessfulRuns: o.maxSuccessfulRuns, - ctx: ctx, - cncl: cancel, + stopCh: make(chan struct{}), } } @@ -231,7 +226,7 @@ func (o *orm) CreateRun(ctx context.Context, run *Run) (err error) { // InsertRun inserts a run into the database func (o *orm) InsertRun(ctx context.Context, run *Run) error { if run.Status() == RunStatusCompleted { - defer o.prune(o.ds, run.PruningKey) + defer o.prune(ctx, o.ds, run.PruningKey) } query, args, err := o.ds.BindNamed(`INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) VALUES (:pipeline_spec_id, :pruning_key, :meta, :all_errors, :fatal_errors, :inputs, :outputs, :created_at, :finished_at, :state) @@ -287,7 +282,7 @@ func (o *orm) StoreRun(ctx context.Context, run *Run) (restart bool, err error) return fmt.Errorf("failed to update pipeline run %d to %s: %w", run.ID, run.State, err) } } else { - defer o.prune(tx.ds, run.PruningKey) + defer o.prune(ctx, tx.ds, run.PruningKey) // Simply finish the run, no need to do any sort of locking if run.Outputs.Val == nil || len(run.FatalErrors)+len(run.AllErrors) == 0 { return fmt.Errorf("run must have both Outputs and Errors, got Outputs: %#v, FatalErrors: %#v, AllErrors: %#v", run.Outputs.Val, run.FatalErrors, run.AllErrors) @@ -401,7 +396,7 @@ RETURNING id defer func() { for pruningKey := range pruningKeysm { - o.prune(tx.ds, pruningKey) + o.prune(ctx, tx.ds, pruningKey) } }() @@ -510,7 +505,7 @@ func (o *orm) insertFinishedRun(ctx context.Context, run *Run, saveSuccessfulTas return nil } - defer o.prune(o.ds, run.PruningKey) + defer o.prune(ctx, o.ds, run.PruningKey) sql = ` INSERT INTO pipeline_task_runs (pipeline_run_id, id, type, index, output, error, dot_id, created_at, finished_at) VALUES (:pipeline_run_id, :id, :type, :index, :output, :error, :dot_id, :created_at, :finished_at);` @@ -709,13 +704,13 @@ const syncLimit = 1000 // // Note this does not guarantee the pipeline_runs table is kept to exactly the // max length, rather that it doesn't excessively larger than it. -func (o *orm) prune(tx sqlutil.DataSource, jobID int32) { +func (o *orm) prune(ctx context.Context, tx sqlutil.DataSource, jobID int32) { if jobID == 0 { o.lggr.Panic("expected a non-zero job ID") } // For small maxSuccessfulRuns its fast enough to prune every time if o.maxSuccessfulRuns < syncLimit { - o.withDataSource(tx).execPrune(o.ctx, jobID) + o.withDataSource(tx).execPrune(ctx, jobID) return } // for large maxSuccessfulRuns we do it async on a sampled basis @@ -725,15 +720,15 @@ func (o *orm) prune(tx sqlutil.DataSource, jobID int32) { if val%every == 0 { ok := o.IfStarted(func() { o.wg.Add(1) - go func() { + go func(ctx context.Context) { o.lggr.Debugw("Pruning runs", "jobID", jobID, "count", val, "every", every, "maxSuccessfulRuns", o.maxSuccessfulRuns) defer o.wg.Done() - ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(o.ctx), time.Minute) + ctx, cancel := o.stopCh.CtxCancel(context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute)) defer cancel() // Must not use tx here since it could be stale by the time we execute async. o.execPrune(ctx, jobID) - }() + }(context.WithoutCancel(ctx)) // don't propagate cancellation }) if !ok { o.lggr.Warnw("Cannot prune: ORM is not running", "jobID", jobID) @@ -743,7 +738,7 @@ func (o *orm) prune(tx sqlutil.DataSource, jobID int32) { } func (o *orm) execPrune(ctx context.Context, jobID int32) { - res, err := o.ds.ExecContext(o.ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( + res, err := o.ds.ExecContext(ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( SELECT id FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 ORDER BY id DESC diff --git a/core/services/pipeline/orm_test.go b/core/services/pipeline/orm_test.go index 8c99635c8d1..877aa9e4aa5 100644 --- a/core/services/pipeline/orm_test.go +++ b/core/services/pipeline/orm_test.go @@ -16,6 +16,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/hex" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" @@ -862,7 +863,8 @@ func Test_Prune(t *testing.T) { jobID := ps1.ID t.Run("when there are no runs to prune, does nothing", func(t *testing.T) { - porm.Prune(jobID) + ctx := tests.Context(t) + porm.Prune(ctx, jobID) // no error logs; it did nothing assert.Empty(t, observed.All()) @@ -898,7 +900,7 @@ func Test_Prune(t *testing.T) { cltest.MustInsertPipelineRunWithStatus(t, db, ps2.ID, pipeline.RunStatusSuspended, jobID2) } - porm.Prune(jobID2) + porm.Prune(tests.Context(t), jobID2) cnt := pgtest.MustCount(t, db, "SELECT count(*) FROM pipeline_runs WHERE pipeline_spec_id = $1 AND state = $2", ps1.ID, pipeline.RunStatusCompleted) assert.Equal(t, cnt, 20) diff --git a/core/services/pipeline/task.divide_test.go b/core/services/pipeline/task.divide_test.go index 3c2f57c7a07..8eb8e4de063 100644 --- a/core/services/pipeline/task.divide_test.go +++ b/core/services/pipeline/task.divide_test.go @@ -3,6 +3,7 @@ package pipeline_test import ( "fmt" "math" + "reflect" "testing" "github.com/pkg/errors" @@ -198,19 +199,17 @@ func TestDivideTask_Overflow(t *testing.T) { } func TestDivide_Example(t *testing.T) { - testutils.SkipFlakey(t, "BCF-3236") t.Parallel() dag := ` -ds1 [type=memo value=10000.1234] +ds1 [type=memo value=10000.1234]; +ds2 [type=memo value=100]; -ds2 [type=memo value=100] +div_by_ds2 [type=divide divisor="$(ds2)"]; +multiply [type=multiply times=10000 index=0]; -div_by_ds2 [type=divide divisor="$(ds2)"] +ds1 -> div_by_ds2 -> multiply; -multiply [type=multiply times=10000 index=0] - -ds1->div_by_ds2->multiply; ` db := pgtest.NewSqlxDB(t) @@ -223,12 +222,14 @@ ds1->div_by_ds2->multiply; lggr := logger.TestLogger(t) _, trrs, err := r.ExecuteRun(testutils.Context(t), spec, vars, lggr) - require.NoError(t, err) + require.NoError(t, err) require.Len(t, trrs, 4) finalResult := trrs[3] - assert.Nil(t, finalResult.Result.Error) + require.NoError(t, finalResult.Result.Error) + require.Equal(t, reflect.TypeOf(decimal.Decimal{}), reflect.TypeOf(finalResult.Result.Value)) + assert.Equal(t, "1000012.34", finalResult.Result.Value.(decimal.Decimal).String()) } diff --git a/core/services/pipeline/task.eth_tx.go b/core/services/pipeline/task.eth_tx.go index 354651acbb4..964591cacd2 100644 --- a/core/services/pipeline/task.eth_tx.go +++ b/core/services/pipeline/task.eth_tx.go @@ -140,7 +140,7 @@ func (t *ETHTxTask) Run(ctx context.Context, lggr logger.Logger, vars Vars, inpu var forwarderAddress common.Address if t.forwardingAllowed { var fwderr error - forwarderAddress, fwderr = chain.TxManager().GetForwarderForEOA(fromAddr) + forwarderAddress, fwderr = chain.TxManager().GetForwarderForEOA(ctx, fromAddr) if fwderr != nil { lggr.Warnw("Skipping forwarding for job, will fallback to default behavior", "err", fwderr) } diff --git a/core/services/relay/evm/cap_encoder.go b/core/services/relay/evm/cap_encoder.go index 00fa3bc8773..e0e3a2cf0f5 100644 --- a/core/services/relay/evm/cap_encoder.go +++ b/core/services/relay/evm/cap_encoder.go @@ -35,7 +35,7 @@ func NewEVMEncoder(config *values.Map) (consensustypes.Encoder, error) { if !ok { return nil, fmt.Errorf("expected %s to be a string", abiConfigFieldName) } - selector, err := abiutil.ParseSignature("inner(" + selectorStr + ")") + selector, err := abiutil.ParseSelector("inner(" + selectorStr + ")") if err != nil { return nil, err } diff --git a/core/services/relay/evm/cap_encoder_test.go b/core/services/relay/evm/cap_encoder_test.go index 10a19fd962b..8c56fb9075a 100644 --- a/core/services/relay/evm/cap_encoder_test.go +++ b/core/services/relay/evm/cap_encoder_test.go @@ -2,6 +2,7 @@ package evm_test import ( "encoding/hex" + "math/big" "testing" "github.com/stretchr/testify/assert" @@ -27,9 +28,9 @@ var ( wrongLength = "8d4e66" ) -func TestEVMEncoder(t *testing.T) { +func TestEVMEncoder_SingleField(t *testing.T) { config := map[string]any{ - "abi": "mercury_reports bytes[]", + "abi": "bytes[] Full_reports", } wrapped, err := values.NewMap(config) require.NoError(t, err) @@ -38,7 +39,7 @@ func TestEVMEncoder(t *testing.T) { // output of a DF2.0 aggregator + metadata fields appended by OCR input := map[string]any{ - "mercury_reports": []any{reportA, reportB}, + "Full_reports": []any{reportA, reportB}, consensustypes.WorkflowIDFieldName: workflowID, consensustypes.ExecutionIDFieldName: executionID, } @@ -48,28 +49,156 @@ func TestEVMEncoder(t *testing.T) { require.NoError(t, err) expected := - // start of the outer tuple ((user_fields), workflow_id, workflow_execution_id) + // start of the outer tuple workflowID + donID + executionID + workflowOwnerID + // start of the inner tuple (user_fields) - "0000000000000000000000000000000000000000000000000000000000000020" + // offset of mercury_reports array - "0000000000000000000000000000000000000000000000000000000000000002" + // length of mercury_reports array + "0000000000000000000000000000000000000000000000000000000000000020" + // offset of Full_reports array + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Full_reports array "0000000000000000000000000000000000000000000000000000000000000040" + // offset of reportA "0000000000000000000000000000000000000000000000000000000000000080" + // offset of reportB "0000000000000000000000000000000000000000000000000000000000000003" + // length of reportA "0102030000000000000000000000000000000000000000000000000000000000" + // reportA "0000000000000000000000000000000000000000000000000000000000000004" + // length of reportB "aabbccdd00000000000000000000000000000000000000000000000000000000" // reportB - // end of the inner tuple (user_fields) + + require.Equal(t, expected, hex.EncodeToString(encoded)) +} + +func TestEVMEncoder_TwoFields(t *testing.T) { + config := map[string]any{ + "abi": "uint256[] Prices, uint32[] Timestamps", + } + wrapped, err := values.NewMap(config) + require.NoError(t, err) + enc, err := evm.NewEVMEncoder(wrapped) + require.NoError(t, err) + + // output of a DF2.0 aggregator + metadata fields appended by OCR + input := map[string]any{ + "Prices": []any{big.NewInt(234), big.NewInt(456)}, + "Timestamps": []any{int64(111), int64(222)}, + consensustypes.WorkflowIDFieldName: workflowID, + consensustypes.ExecutionIDFieldName: executionID, + } + wrapped, err = values.NewMap(input) + require.NoError(t, err) + encoded, err := enc.Encode(testutils.Context(t), *wrapped) + require.NoError(t, err) + + expected := + // start of the outer tuple + workflowID + + donID + + executionID + + workflowOwnerID + + // start of the inner tuple (user_fields) + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of Prices array + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset of Timestamps array + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Prices array + "00000000000000000000000000000000000000000000000000000000000000ea" + // Prices[0] + "00000000000000000000000000000000000000000000000000000000000001c8" + // Prices[1] + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Timestamps array + "000000000000000000000000000000000000000000000000000000000000006f" + // Timestamps[0] + "00000000000000000000000000000000000000000000000000000000000000de" // Timestamps[1] + + require.Equal(t, expected, hex.EncodeToString(encoded)) +} + +func TestEVMEncoder_Tuple(t *testing.T) { + config := map[string]any{ + "abi": "(uint256[] Prices, uint32[] Timestamps) Elem", + } + wrapped, err := values.NewMap(config) + require.NoError(t, err) + enc, err := evm.NewEVMEncoder(wrapped) + require.NoError(t, err) + + // output of a DF2.0 aggregator + metadata fields appended by OCR + input := map[string]any{ + "Elem": map[string]any{ + "Prices": []any{big.NewInt(234), big.NewInt(456)}, + "Timestamps": []any{int64(111), int64(222)}, + }, + consensustypes.WorkflowIDFieldName: workflowID, + consensustypes.ExecutionIDFieldName: executionID, + } + wrapped, err = values.NewMap(input) + require.NoError(t, err) + encoded, err := enc.Encode(testutils.Context(t), *wrapped) + require.NoError(t, err) + + expected := + // start of the outer tuple + workflowID + + donID + + executionID + + workflowOwnerID + + // start of the inner tuple (user_fields) + "0000000000000000000000000000000000000000000000000000000000000020" + // offset of Elem tuple + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of Prices array + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset of Timestamps array + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Prices array + "00000000000000000000000000000000000000000000000000000000000000ea" + // Prices[0] = 234 + "00000000000000000000000000000000000000000000000000000000000001c8" + // Prices[1] = 456 + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Timestamps array + "000000000000000000000000000000000000000000000000000000000000006f" + // Timestamps[0] = 111 + "00000000000000000000000000000000000000000000000000000000000000de" // Timestamps[1] = 222 + + require.Equal(t, expected, hex.EncodeToString(encoded)) +} + +func TestEVMEncoder_ListOfTuples(t *testing.T) { + config := map[string]any{ + "abi": "(uint256 Price, uint32 Timestamp)[] Elems", + } + wrapped, err := values.NewMap(config) + require.NoError(t, err) + enc, err := evm.NewEVMEncoder(wrapped) + require.NoError(t, err) + + // output of a DF2.0 aggregator + metadata fields appended by OCR + input := map[string]any{ + "Elems": []any{ + map[string]any{ + "Price": big.NewInt(234), + "Timestamp": int64(111), + }, + map[string]any{ + "Price": big.NewInt(456), + "Timestamp": int64(222), + }, + }, + consensustypes.WorkflowIDFieldName: workflowID, + consensustypes.ExecutionIDFieldName: executionID, + } + wrapped, err = values.NewMap(input) + require.NoError(t, err) + encoded, err := enc.Encode(testutils.Context(t), *wrapped) + require.NoError(t, err) + + expected := + // start of the outer tuple + workflowID + + donID + + executionID + + workflowOwnerID + + // start of the inner tuple (user_fields) + "0000000000000000000000000000000000000000000000000000000000000020" + // offset of Elem list + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Elem list + "00000000000000000000000000000000000000000000000000000000000000ea" + // Elem[0].Price = 234 + "000000000000000000000000000000000000000000000000000000000000006f" + // Elem[0].Timestamp = 111 + "00000000000000000000000000000000000000000000000000000000000001c8" + // Elem[1].Price = 456 + "00000000000000000000000000000000000000000000000000000000000000de" // Elem[1].Timestamp = 222 require.Equal(t, expected, hex.EncodeToString(encoded)) } func TestEVMEncoder_InvalidIDs(t *testing.T) { config := map[string]any{ - "abi": "mercury_reports bytes[]", + "abi": "bytes[] Full_reports", } wrapped, err := values.NewMap(config) require.NoError(t, err) @@ -79,7 +208,7 @@ func TestEVMEncoder_InvalidIDs(t *testing.T) { // output of a DF2.0 aggregator + metadata fields appended by OCR // using an invalid ID input := map[string]any{ - "mercury_reports": []any{reportA, reportB}, + "Full_reports": []any{reportA, reportB}, consensustypes.WorkflowIDFieldName: invalidID, consensustypes.ExecutionIDFieldName: executionID, } @@ -90,7 +219,7 @@ func TestEVMEncoder_InvalidIDs(t *testing.T) { // using valid hex string of wrong length input = map[string]any{ - "mercury_reports": []any{reportA, reportB}, + "full_reports": []any{reportA, reportB}, consensustypes.WorkflowIDFieldName: wrongLength, consensustypes.ExecutionIDFieldName: executionID, } diff --git a/core/services/relay/evm/codec_test.go b/core/services/relay/evm/codec_test.go index 0597560aaec..0773b274107 100644 --- a/core/services/relay/evm/codec_test.go +++ b/core/services/relay/evm/codec_test.go @@ -2,6 +2,7 @@ package evm_test import ( "encoding/json" + "math/big" "testing" "github.com/ethereum/go-ethereum/accounts/abi" @@ -68,13 +69,79 @@ func TestCodec_SimpleEncode(t *testing.T) { require.NoError(t, err) expected := "0000000000000000000000000000000000000000000000000000000000000006" + // int32(6) - "0000000000000000000000000000000000000000000000000000000000000040" + // total bytes occupied by the string (64) + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of the beginning of second value (64 bytes) "0000000000000000000000000000000000000000000000000000000000000007" + // length of the string (7 chars) "6162636465666700000000000000000000000000000000000000000000000000" // actual string require.Equal(t, expected, hexutil.Encode(result)[2:]) } +func TestCodec_EncodeTuple(t *testing.T) { + codecName := "my_codec" + input := map[string]any{ + "Report": int32(6), + "Nested": map[string]any{ + "Meta": "abcdefg", + "Count": int32(14), + "Other": "12334", + }, + } + evmEncoderConfig := `[{"Name":"Report","Type":"int32"},{"Name":"Nested","Type":"tuple","Components":[{"Name":"Other","Type":"string"},{"Name":"Count","Type":"int32"},{"Name":"Meta","Type":"string"}]}]` + + codecConfig := types.CodecConfig{Configs: map[string]types.ChainCodecConfig{ + codecName: {TypeABI: evmEncoderConfig}, + }} + c, err := evm.NewCodec(codecConfig) + require.NoError(t, err) + + result, err := c.Encode(testutils.Context(t), input, codecName) + require.NoError(t, err) + expected := + "0000000000000000000000000000000000000000000000000000000000000006" + // Report integer (=6) + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of the first dynamic value (tuple, 64 bytes) + "0000000000000000000000000000000000000000000000000000000000000060" + // offset of the first nested dynamic value (string, 96 bytes) + "000000000000000000000000000000000000000000000000000000000000000e" + // "Count" integer (=14) + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset of the second nested dynamic value (string, 160 bytes) + "0000000000000000000000000000000000000000000000000000000000000005" + // length of the "Meta" string (5 chars) + "3132333334000000000000000000000000000000000000000000000000000000" + // "Other" string (="12334") + "0000000000000000000000000000000000000000000000000000000000000007" + // length of the "Other" string (7 chars) + "6162636465666700000000000000000000000000000000000000000000000000" // "Meta" string (="abcdefg") + + require.Equal(t, expected, hexutil.Encode(result)[2:]) +} + +func TestCodec_EncodeTupleWithLists(t *testing.T) { + codecName := "my_codec" + input := map[string]any{ + "Elem": map[string]any{ + "Prices": []any{big.NewInt(234), big.NewInt(456)}, + "Timestamps": []any{int64(111), int64(222)}, + }, + } + evmEncoderConfig := `[{"Name":"Elem","Type":"tuple","InternalType":"tuple","Components":[{"Name":"Prices","Type":"uint256[]","InternalType":"uint256[]","Components":null,"Indexed":false},{"Name":"Timestamps","Type":"uint32[]","InternalType":"uint32[]","Components":null,"Indexed":false}],"Indexed":false}]` + + codecConfig := types.CodecConfig{Configs: map[string]types.ChainCodecConfig{ + codecName: {TypeABI: evmEncoderConfig}, + }} + c, err := evm.NewCodec(codecConfig) + require.NoError(t, err) + + result, err := c.Encode(testutils.Context(t), input, codecName) + require.NoError(t, err) + expected := + "0000000000000000000000000000000000000000000000000000000000000020" + // offset of Elem tuple + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of Prices array + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset of Timestamps array + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Prices array + "00000000000000000000000000000000000000000000000000000000000000ea" + // Prices[0] = 234 + "00000000000000000000000000000000000000000000000000000000000001c8" + // Prices[1] = 456 + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Timestamps array + "000000000000000000000000000000000000000000000000000000000000006f" + // Timestamps[0] = 111 + "00000000000000000000000000000000000000000000000000000000000000de" // Timestamps[1] = 222 + + require.Equal(t, expected, hexutil.Encode(result)[2:]) +} + type codecInterfaceTester struct{} func (it *codecInterfaceTester) Setup(_ *testing.T) {} diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index 97241044206..5a0ccffaf71 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -439,8 +439,7 @@ type configWatcher struct { chain legacyevm.Chain runReplay bool fromBlock uint64 - replayCtx context.Context - replayCancel context.CancelFunc + stopCh services.StopChan wg sync.WaitGroup } @@ -452,7 +451,6 @@ func newConfigWatcher(lggr logger.Logger, fromBlock uint64, runReplay bool, ) *configWatcher { - replayCtx, replayCancel := context.WithCancel(context.Background()) return &configWatcher{ lggr: lggr.Named("ConfigWatcher").Named(contractAddress.String()), contractAddress: contractAddress, @@ -461,8 +459,7 @@ func newConfigWatcher(lggr logger.Logger, chain: chain, runReplay: runReplay, fromBlock: fromBlock, - replayCtx: replayCtx, - replayCancel: replayCancel, + stopCh: make(chan struct{}), } } @@ -477,8 +474,10 @@ func (c *configWatcher) Start(ctx context.Context) error { c.wg.Add(1) go func() { defer c.wg.Done() + ctx, cancel := c.stopCh.NewCtx() + defer cancel() c.lggr.Infow("starting replay for config", "fromBlock", c.fromBlock) - if err := c.configPoller.Replay(c.replayCtx, int64(c.fromBlock)); err != nil { + if err := c.configPoller.Replay(ctx, int64(c.fromBlock)); err != nil { c.lggr.Errorf("error replaying for config", "err", err) } else { c.lggr.Infow("completed replaying for config", "fromBlock", c.fromBlock) @@ -492,7 +491,7 @@ func (c *configWatcher) Start(ctx context.Context) error { func (c *configWatcher) Close() error { return c.StopOnce(fmt.Sprintf("configWatcher %x", c.contractAddress), func() error { - c.replayCancel() + close(c.stopCh) c.wg.Wait() return c.configPoller.Close() }) diff --git a/core/services/relay/evm/request_round_tracker.go b/core/services/relay/evm/request_round_tracker.go index fe6b6826eb2..7cf13775693 100644 --- a/core/services/relay/evm/request_round_tracker.go +++ b/core/services/relay/evm/request_round_tracker.go @@ -37,8 +37,7 @@ type RequestRoundTracker struct { blockTranslator ocrcommon.BlockTranslator // Start/Stop lifecycle - ctx context.Context - ctxCancel context.CancelFunc + stopCh services.StopChan unsubscribeLogs func() // LatestRoundRequested @@ -58,7 +57,6 @@ func NewRequestRoundTracker( odb RequestRoundDB, chain ocrcommon.Config, ) (o *RequestRoundTracker) { - ctx, cancel := context.WithCancel(context.Background()) return &RequestRoundTracker{ ethClient: ethClient, contract: contract, @@ -69,8 +67,7 @@ func NewRequestRoundTracker( odb: odb, ds: ds, blockTranslator: ocrcommon.NewBlockTranslator(chain, ethClient, lggr), - ctx: ctx, - ctxCancel: cancel, + stopCh: make(chan struct{}), } } @@ -98,7 +95,7 @@ func (t *RequestRoundTracker) Start(ctx context.Context) error { // Close should be called after teardown of the OCR job relying on this tracker func (t *RequestRoundTracker) Close() error { return t.StopOnce("RequestRoundTracker", func() error { - t.ctxCancel() + close(t.stopCh) t.unsubscribeLogs() return nil }) diff --git a/core/services/vrf/v2/listener_v2_log_listener_test.go b/core/services/vrf/v2/listener_v2_log_listener_test.go index a393aec3ee3..5b827a5291d 100644 --- a/core/services/vrf/v2/listener_v2_log_listener_test.go +++ b/core/services/vrf/v2/listener_v2_log_listener_test.go @@ -1,7 +1,6 @@ package v2 import ( - "context" "fmt" "math/big" "strings" @@ -20,6 +19,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -56,7 +56,6 @@ type vrfLogPollerListenerTH struct { EthDB ethdb.Database Db *sqlx.DB Listener *listenerV2 - Ctx context.Context } func setupVRFLogPollerListenerTH(t *testing.T, @@ -173,7 +172,6 @@ func setupVRFLogPollerListenerTH(t *testing.T, EthDB: ethDB, Db: db, Listener: listener, - Ctx: ctx, } mockChainUpdateFn(chain, th) return th @@ -189,6 +187,7 @@ func setupVRFLogPollerListenerTH(t *testing.T, func TestInitProcessedBlock_NoVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, th *vrfLogPollerListenerTH) { @@ -226,7 +225,7 @@ func TestInitProcessedBlock_NoVRFReqs(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, len(logs)) - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(6), lastProcessedBlock) } @@ -262,6 +261,7 @@ func TestLogPollerFilterRegistered(t *testing.T) { func TestInitProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -298,7 +298,7 @@ func TestInitProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { } // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 2 (VRF req/resp block) + 5 (EmitLog blocks) = 11 latestBlock := int64(2 + 2 + 2 + 5) @@ -308,17 +308,18 @@ func TestInitProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { // Then test if log poller is able to replay from finalizedBlockNumber (8 --> onwards) // since there are no pending VRF requests // Blocks: 1 2 3 4 [5;Request] [6;Fulfilment] 7 8 9 10 11 - require.NoError(t, th.LogPoller.Replay(th.Ctx, latestBlock)) + require.NoError(t, th.LogPoller.Replay(ctx, latestBlock)) // initializeLastProcessedBlock must return the finalizedBlockNumber (8) instead of // VRF request block number (5), since all VRF requests are fulfilled - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(8), lastProcessedBlock) } func TestInitProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -353,7 +354,7 @@ func TestInitProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { } // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 1 (VRF req block) + 5 (EmitLog blocks) = 10 latestBlock := int64(2 + 2 + 1 + 5) @@ -362,17 +363,18 @@ func TestInitProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { // Replay from block 10 (latest) onwards, so that log poller has a latest block // Then test if log poller is able to replay from earliestUnprocessedBlock (5 --> onwards) // Blocks: 1 2 3 4 [5;Request] 6 7 8 9 10 - require.NoError(t, th.LogPoller.Replay(th.Ctx, latestBlock)) + require.NoError(t, th.LogPoller.Replay(ctx, latestBlock)) // initializeLastProcessedBlock must return the unfulfilled VRF // request block number (5) instead of finalizedBlockNumber (8) - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(5), lastProcessedBlock) } func TestInitProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -413,7 +415,7 @@ func TestInitProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { } // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 3*5 (EmitLog + VRF req/resp blocks) = 19 latestBlock := int64(2 + 2 + 3*5) @@ -424,17 +426,18 @@ func TestInitProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { // Blocks: 1 2 3 4 5 [6;Request] [7;Request] 8 [9;Request] [10;Request] // 11 [12;Request] [13;Request] 14 [15;Request] [16;Request] // 17 [18;Request] [19;Request] - require.NoError(t, th.LogPoller.Replay(th.Ctx, latestBlock)) + require.NoError(t, th.LogPoller.Replay(ctx, latestBlock)) // initializeLastProcessedBlock must return the earliest unfulfilled VRF request block // number instead of finalizedBlockNumber - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(6), lastProcessedBlock) } func TestInitProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -480,7 +483,7 @@ func TestInitProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { } // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 3*5 (EmitLog + VRF req/resp blocks) = 19 latestBlock := int64(2 + 2 + 3*5) @@ -490,11 +493,11 @@ func TestInitProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { // Blocks: 1 2 3 4 5 [6;Request] [7;Request;6-Fulfilment] 8 [9;Request] [10;Request;9-Fulfilment] // 11 [12;Request] [13;Request;12-Fulfilment] 14 [15;Request] [16;Request;15-Fulfilment] // 17 [18;Request] [19;Request;18-Fulfilment] - require.NoError(t, th.LogPoller.Replay(th.Ctx, latestBlock)) + require.NoError(t, th.LogPoller.Replay(ctx, latestBlock)) // initializeLastProcessedBlock must return the earliest unfulfilled VRF request block // number instead of finalizedBlockNumber - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(7), lastProcessedBlock) } @@ -511,6 +514,7 @@ func TestInitProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { func TestUpdateLastProcessedBlock_NoVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -552,22 +556,23 @@ func TestUpdateLastProcessedBlock_NoVRFReqs(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 2 (VRF req blocks) + 5 (EmitLog blocks) = 11 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the finalizedBlockNumber as there are // no VRF requests, after currLastProcessedBlock (block 6). The VRF requests // made above are before the currLastProcessedBlock (7) passed in below - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 7) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 7) require.Nil(t, err) require.Equal(t, int64(8), lastProcessedBlock) } func TestUpdateLastProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -607,22 +612,23 @@ func TestUpdateLastProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 2 (VRF req/resp blocks) + 5 (EmitLog blocks) = 11 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the finalizedBlockNumber (8) though we have // a VRF req at block (5) after currLastProcessedBlock (4) passed below, because // the VRF request is fulfilled - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 4) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 4) require.Nil(t, err) require.Equal(t, int64(8), lastProcessedBlock) } func TestUpdateLastProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -658,22 +664,23 @@ func TestUpdateLastProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 1 (VRF req block) + 5 (EmitLog blocks) = 10 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the VRF req at block (5) instead of // finalizedBlockNumber (8) after currLastProcessedBlock (4) passed below, // because the VRF request is unfulfilled - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 4) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 4) require.Nil(t, err) require.Equal(t, int64(5), lastProcessedBlock) } func TestUpdateLastProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -715,22 +722,23 @@ func TestUpdateLastProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 3*5 (EmitLog + VRF req blocks) = 19 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the VRF req at block (6) instead of // finalizedBlockNumber (16) after currLastProcessedBlock (4) passed below, // as block 6 contains the earliest unfulfilled VRF request - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 4) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 4) require.Nil(t, err) require.Equal(t, int64(6), lastProcessedBlock) } func TestUpdateLastProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -776,17 +784,17 @@ func TestUpdateLastProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 3*5 (EmitLog + VRF req blocks) = 19 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the VRF req at block (7) instead of // finalizedBlockNumber (16) after currLastProcessedBlock (4) passed below, // as block 7 contains the earliest unfulfilled VRF request. VRF request // in block 6 has been fulfilled in block 7. - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 4) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 4) require.Nil(t, err) require.Equal(t, int64(7), lastProcessedBlock) } diff --git a/core/services/workflows/delegate_test.go b/core/services/workflows/delegate_test.go index 68abfa2f7a1..d87e6d68466 100644 --- a/core/services/workflows/delegate_test.go +++ b/core/services/workflows/delegate_test.go @@ -23,6 +23,7 @@ type = "workflow" schemaVersion = 1 workflowId = "15c631d295ef5e32deb99a10ee6804bc4af1385568f9b3363f6552ac6dbb2cef" workflowOwner = "00000000000000000000000000000000000000aa" +workflowName = "test" `, true, }, @@ -38,6 +39,16 @@ invalid syntax{{{{ ` type = "work flows" schemaVersion = 1 +`, + false, + }, + { + "missing name", + ` +type = "workflow" +schemaVersion = 1 +workflowId = "15c631d295ef5e32deb99a10ee6804bc4af1385568f9b3363f6552ac6dbb2cef" +workflowOwner = "00000000000000000000000000000000000000aa" `, false, }, diff --git a/core/store/migrate/migrations/0237_add_workflow_executions_on_delete.sql b/core/store/migrate/migrations/0237_add_workflow_executions_on_delete.sql new file mode 100644 index 00000000000..87670d0ab61 --- /dev/null +++ b/core/store/migrate/migrations/0237_add_workflow_executions_on_delete.sql @@ -0,0 +1,31 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE workflow_executions +DROP CONSTRAINT workflow_executions_workflow_id_fkey, +ADD CONSTRAINT workflow_executions_workflow_id_fkey + FOREIGN KEY (workflow_id) + REFERENCES workflow_specs(workflow_id) + ON DELETE CASCADE; + +ALTER TABLE workflow_steps +DROP CONSTRAINT workflow_steps_workflow_execution_id_fkey, +ADD CONSTRAINT workflow_steps_workflow_execution_id_fkey + FOREIGN KEY (workflow_execution_id) + REFERENCES workflow_executions(id) + ON DELETE CASCADE; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE workflow_executions +DROP CONSTRAINT workflow_executions_workflow_id_fkey, +ADD CONSTRAINT workflow_executions_workflow_id_fkey + FOREIGN KEY (workflow_id) + REFERENCES workflow_specs(workflow_id); + +ALTER TABLE workflow_steps +DROP CONSTRAINT workflow_steps_workflow_execution_id_fkey, +ADD CONSTRAINT workflow_steps_workflow_execution_id_fkey + FOREIGN KEY (workflow_execution_id) + REFERENCES workflow_executions(id); +-- +goose StatementEnd diff --git a/core/store/migrate/migrations/0238_workflow_spec_name.sql b/core/store/migrate/migrations/0238_workflow_spec_name.sql new file mode 100644 index 00000000000..8b9986b4da9 --- /dev/null +++ b/core/store/migrate/migrations/0238_workflow_spec_name.sql @@ -0,0 +1,22 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE workflow_specs ADD COLUMN workflow_name varchar(255); + +-- ensure that we can forward migrate to non-null name +UPDATE workflow_specs +SET + workflow_name = workflow_id +WHERE + workflow_name IS NULL; + +ALTER TABLE workflow_specs ALTER COLUMN workflow_name SET NOT NULL; + +-- unique constraint on workflow_owner and workflow_name +ALTER TABLE workflow_specs ADD CONSTRAINT unique_workflow_owner_name unique (workflow_owner, workflow_name); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE workflow_specs DROP CONSTRAINT unique_workflow_owner_name; +ALTER TABLE workflow_specs DROP COLUMN workflow_name; +-- +goose StatementEnd \ No newline at end of file diff --git a/core/testdata/testspecs/v2_specs.go b/core/testdata/testspecs/v2_specs.go index a0d8ea863e2..fb0e019d931 100644 --- a/core/testdata/testspecs/v2_specs.go +++ b/core/testdata/testspecs/v2_specs.go @@ -872,17 +872,18 @@ func (w WorkflowSpec) Toml() string { return w.toml } -func GenerateWorkflowSpec(id, owner, spec string) WorkflowSpec { +func GenerateWorkflowSpec(id, owner, name, spec string) WorkflowSpec { template := ` type = "workflow" schemaVersion = 1 name = "test-spec" workflowId = "%s" workflowOwner = "%s" +workflowName = "%s" workflow = """ %s """ ` - toml := fmt.Sprintf(template, id, owner, spec) + toml := fmt.Sprintf(template, id, owner, name, spec) return WorkflowSpec{toml: toml} } diff --git a/core/utils/config/validate.go b/core/utils/config/validate.go index 5fbae24ad53..f8508f27bf7 100644 --- a/core/utils/config/validate.go +++ b/core/utils/config/validate.go @@ -33,7 +33,9 @@ func validate(v reflect.Value, checkInterface bool) (err error) { if checkInterface { i := v.Interface() if vc, ok := i.(Validated); ok { - err = multierr.Append(err, vc.ValidateConfig()) + for _, e := range utils.UnwrapError(vc.ValidateConfig()) { + err = multierr.Append(err, e) + } } else if v.CanAddr() { i = v.Addr().Interface() if vc, ok := i.(Validated); ok { diff --git a/core/utils/thread_control_test.go b/core/utils/thread_control_test.go index dff3740eda7..51d5c00a578 100644 --- a/core/utils/thread_control_test.go +++ b/core/utils/thread_control_test.go @@ -49,7 +49,8 @@ func TestThreadControl_GoCtx(t *testing.T) { start := time.Now() wg.Wait() - require.True(t, time.Since(start) > timeout-1) - require.True(t, time.Since(start) < 2*timeout) + end := time.Since(start) + require.True(t, end > timeout-1) + require.True(t, end < 2*timeout) require.Equal(t, int32(1), finished.Load()) } diff --git a/core/web/jobs_controller_test.go b/core/web/jobs_controller_test.go index 8aaae0d5ba3..359f9ba8b1c 100644 --- a/core/web/jobs_controller_test.go +++ b/core/web/jobs_controller_test.go @@ -394,6 +394,7 @@ func TestJobController_Create_HappyPath(t *testing.T) { tomlTemplate: func(_ string) string { id := "15c631d295ef5e32deb99a10ee6804bc4af1385568f9b3363f6552ac6dbb2cef" owner := "00000000000000000000000000000000000000aa" + name := "my-test-workflow" workflow := ` triggers: - id: "mercury-trigger" @@ -441,14 +442,14 @@ targets: params: ["$(report)"] abi: "receive(report bytes)" ` - return testspecs.GenerateWorkflowSpec(id, owner, workflow).Toml() + return testspecs.GenerateWorkflowSpec(id, owner, name, workflow).Toml() }, assertion: func(t *testing.T, nameAndExternalJobID string, r *http.Response) { require.Equal(t, http.StatusOK, r.StatusCode) resp := cltest.ParseResponseBody(t, r) resource := presenters.JobResource{} err := web.ParseJSONAPIResponse(resp, &resource) - require.NoError(t, err) + require.NoError(t, err, "failed to parse response body: %s", resp) jb, err := jorm.FindJob(testutils.Context(t), mustInt32FromString(t, resource.ID)) require.NoError(t, err) @@ -457,6 +458,7 @@ targets: assert.Equal(t, jb.WorkflowSpec.Workflow, resource.WorkflowSpec.Workflow) assert.Equal(t, jb.WorkflowSpec.WorkflowID, resource.WorkflowSpec.WorkflowID) assert.Equal(t, jb.WorkflowSpec.WorkflowOwner, resource.WorkflowSpec.WorkflowOwner) + assert.Equal(t, jb.WorkflowSpec.WorkflowName, resource.WorkflowSpec.WorkflowName) }, }, } diff --git a/core/web/presenters/job.go b/core/web/presenters/job.go index 12b958a346d..ff59bc9bd11 100644 --- a/core/web/presenters/job.go +++ b/core/web/presenters/job.go @@ -433,6 +433,7 @@ type WorkflowSpec struct { Workflow string `json:"workflow"` WorkflowID string `json:"workflowId"` WorkflowOwner string `json:"workflowOwner"` + WorkflowName string `json:"workflowName"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } @@ -442,6 +443,7 @@ func NewWorkflowSpec(spec *job.WorkflowSpec) *WorkflowSpec { Workflow: spec.Workflow, WorkflowID: spec.WorkflowID, WorkflowOwner: spec.WorkflowOwner, + WorkflowName: spec.WorkflowName, CreatedAt: spec.CreatedAt, UpdatedAt: spec.UpdatedAt, } diff --git a/core/web/presenters/job_test.go b/core/web/presenters/job_test.go index 7d3c31465db..ba485d27789 100644 --- a/core/web/presenters/job_test.go +++ b/core/web/presenters/job_test.go @@ -861,6 +861,7 @@ func TestJob(t *testing.T) { WorkflowID: "", Workflow: ``, WorkflowOwner: "", + WorkflowName: "", }, PipelineSpec: &pipeline.Spec{ ID: 1, @@ -896,6 +897,7 @@ func TestJob(t *testing.T) { "workflow": "", "workflowId": "", "workflowOwner": "", + "workflowName": "", "createdAt":"0001-01-01T00:00:00Z", "updatedAt":"0001-01-01T00:00:00Z" }, diff --git a/core/web/resolver/api_token_test.go b/core/web/resolver/api_token_test.go index bb82e91ed9b..d35d5d6d9a8 100644 --- a/core/web/resolver/api_token_test.go +++ b/core/web/resolver/api_token_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" gqlerrors "github.com/graph-gophers/graphql-go/errors" @@ -52,8 +53,8 @@ func TestResolver_CreateAPIToken(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -85,8 +86,8 @@ func TestResolver_CreateAPIToken(t *testing.T) { { name: "input errors", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -110,8 +111,8 @@ func TestResolver_CreateAPIToken(t *testing.T) { { name: "failed to find user", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -138,8 +139,8 @@ func TestResolver_CreateAPIToken(t *testing.T) { { name: "failed to generate token", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -208,8 +209,8 @@ func TestResolver_DeleteAPIToken(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -239,8 +240,8 @@ func TestResolver_DeleteAPIToken(t *testing.T) { { name: "input errors", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -264,8 +265,8 @@ func TestResolver_DeleteAPIToken(t *testing.T) { { name: "failed to find user", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -292,8 +293,8 @@ func TestResolver_DeleteAPIToken(t *testing.T) { { name: "failed to delete token", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) diff --git a/core/web/resolver/bridge_test.go b/core/web/resolver/bridge_test.go index 2244ddf3dac..ae18fe59d2d 100644 --- a/core/web/resolver/bridge_test.go +++ b/core/web/resolver/bridge_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "encoding/json" "net/url" @@ -46,7 +47,7 @@ func Test_Bridges(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("BridgeTypes", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return([]bridges.BridgeType{ { @@ -116,7 +117,7 @@ func Test_Bridge(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{ Name: name, @@ -143,7 +144,7 @@ func Test_Bridge(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, sql.ErrNoRows) }, @@ -198,7 +199,7 @@ func Test_CreateBridge(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, sql.ErrNoRows) f.Mocks.bridgeORM.On("CreateBridgeType", mock.Anything, mock.IsType(&bridges.BridgeType{})). @@ -286,7 +287,7 @@ func Test_UpdateBridge(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { // Initialize the existing bridge bridge := bridges.BridgeType{ Name: name, @@ -340,7 +341,7 @@ func Test_UpdateBridge(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, sql.ErrNoRows) }, @@ -407,7 +408,7 @@ func Test_DeleteBridgeMutation(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { bridge := bridges.BridgeType{ Name: name, URL: models.WebURL(*bridgeURL), @@ -460,7 +461,7 @@ func Test_DeleteBridgeMutation(t *testing.T) { variables: map[string]interface{}{ "id": "bridge1", }, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, sql.ErrNoRows) f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) }, @@ -479,7 +480,7 @@ func Test_DeleteBridgeMutation(t *testing.T) { variables: map[string]interface{}{ "id": "bridge1", }, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, nil) f.Mocks.jobORM.On("FindJobIDsWithBridge", mock.Anything, name.String()).Return([]int32{1}, nil) f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) diff --git a/core/web/resolver/chain_test.go b/core/web/resolver/chain_test.go index 5e51356d928..75d7e36a5b5 100644 --- a/core/web/resolver/chain_test.go +++ b/core/web/resolver/chain_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "fmt" "testing" @@ -76,7 +77,7 @@ ResendAfterThreshold = '1h0m0s' { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { chainConf := evmtoml.EVMConfig{ ChainID: &chainID, Enabled: chain.Enabled, @@ -113,7 +114,7 @@ ResendAfterThreshold = '1h0m0s' { name: "no chains", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetRelayers").Return(&chainlinkmocks.FakeRelayerChainInteroperators{Relayers: []loop.Relayer{}}) }, query: query, @@ -192,7 +193,7 @@ ResendAfterThreshold = '1h0m0s' { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("EVMORM").Return(f.Mocks.evmORM) f.Mocks.evmORM.PutChains(evmtoml.EVMConfig{ ChainID: &chainID, @@ -212,7 +213,7 @@ ResendAfterThreshold = '1h0m0s' { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("EVMORM").Return(f.Mocks.evmORM) }, query: query, diff --git a/core/web/resolver/config_test.go b/core/web/resolver/config_test.go index a04b3fa2484..f380e4db55a 100644 --- a/core/web/resolver/config_test.go +++ b/core/web/resolver/config_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" _ "embed" "encoding/json" "fmt" @@ -38,7 +39,7 @@ func TestResolver_ConfigV2(t *testing.T) { { name: "empty", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { opts := chainlink.GeneralConfigOpts{} cfg, err := opts.New() require.NoError(t, err) @@ -50,7 +51,7 @@ func TestResolver_ConfigV2(t *testing.T) { { name: "full", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { opts := chainlink.GeneralConfigOpts{ ConfigStrings: []string{configFull}, SecretsStrings: []string{}, @@ -65,7 +66,7 @@ func TestResolver_ConfigV2(t *testing.T) { { name: "partial", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { opts := chainlink.GeneralConfigOpts{ ConfigStrings: []string{configMulti}, SecretsStrings: []string{}, diff --git a/core/web/resolver/csa_keys_test.go b/core/web/resolver/csa_keys_test.go index 1048d9aa4bc..94513b53e45 100644 --- a/core/web/resolver/csa_keys_test.go +++ b/core/web/resolver/csa_keys_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "fmt" "testing" @@ -58,7 +59,7 @@ func Test_CSAKeysQuery(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.csa.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -109,7 +110,7 @@ func Test_CreateCSAKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.csa.On("Create", mock.Anything).Return(fakeKey, nil) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -120,7 +121,7 @@ func Test_CreateCSAKey(t *testing.T) { { name: "csa key exists error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.csa.On("Create", mock.Anything).Return(csakey.KeyV2{}, keystore.ErrCSAKeyExists) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -178,7 +179,7 @@ func Test_DeleteCSAKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetKeyStore").Return(f.Mocks.keystore) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.Mocks.csa.On("Delete", mock.Anything, fakeKey.ID()).Return(fakeKey, nil) @@ -190,7 +191,7 @@ func Test_DeleteCSAKey(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetKeyStore").Return(f.Mocks.keystore) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.Mocks.csa. diff --git a/core/web/resolver/eth_key_test.go b/core/web/resolver/eth_key_test.go index 40a60263f06..55cdc230bd2 100644 --- a/core/web/resolver/eth_key_test.go +++ b/core/web/resolver/eth_key_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "fmt" "testing" @@ -80,7 +81,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "success on prod", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -147,7 +148,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "success with no chains", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -202,7 +203,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "generic error on GetAll()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ethKs.On("GetAll", mock.Anything).Return(nil, gError) f.Mocks.keystore.On("Eth").Return(f.Mocks.ethKs) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -221,7 +222,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "generic error on GetStatesForKeys()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ethKs.On("GetAll", mock.Anything).Return(keys, nil) f.Mocks.ethKs.On("GetStatesForKeys", mock.Anything, keys).Return(nil, gError) f.Mocks.keystore.On("Eth").Return(f.Mocks.ethKs) @@ -241,7 +242,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "generic error on Get()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -273,7 +274,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "Empty set on legacy evm chains", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -304,7 +305,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "generic error on GetLINKBalance()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -366,7 +367,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "success with no eth balance", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.EIP55AddressFromAddress(address), diff --git a/core/web/resolver/eth_transaction_test.go b/core/web/resolver/eth_transaction_test.go index 5568a6664cb..f288450edcc 100644 --- a/core/web/resolver/eth_transaction_test.go +++ b/core/web/resolver/eth_transaction_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "errors" "math/big" @@ -68,7 +69,7 @@ func TestResolver_EthTransaction(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("FindTxByHash", mock.Anything, hash).Return(&txmgr.Tx{ ID: 1, ToAddress: common.HexToAddress("0x5431F5F973781809D18643b87B44921b11355d81"), @@ -130,7 +131,7 @@ func TestResolver_EthTransaction(t *testing.T) { { name: "success without nil values", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { num := int64(2) nonce := evmtypes.Nonce(num) @@ -195,7 +196,7 @@ func TestResolver_EthTransaction(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("FindTxByHash", mock.Anything, hash).Return(nil, sql.ErrNoRows) f.App.On("TxmStorageService").Return(f.Mocks.txmStore) }, @@ -212,7 +213,7 @@ func TestResolver_EthTransaction(t *testing.T) { { name: "generic error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("FindTxByHash", mock.Anything, hash).Return(nil, gError) f.App.On("TxmStorageService").Return(f.Mocks.txmStore) }, @@ -266,7 +267,7 @@ func TestResolver_EthTransactions(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { num := int64(2) f.Mocks.txmStore.On("Transactions", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return([]txmgr.Tx{ @@ -319,7 +320,7 @@ func TestResolver_EthTransactions(t *testing.T) { { name: "generic error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("Transactions", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return(nil, 0, gError) f.App.On("TxmStorageService").Return(f.Mocks.txmStore) }, @@ -364,7 +365,7 @@ func TestResolver_EthTransactionsAttempts(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { num := int64(2) f.Mocks.txmStore.On("TxAttempts", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return([]txmgr.TxAttempt{ @@ -397,7 +398,7 @@ func TestResolver_EthTransactionsAttempts(t *testing.T) { { name: "success with nil values", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("TxAttempts", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return([]txmgr.TxAttempt{ { Hash: hash, @@ -427,7 +428,7 @@ func TestResolver_EthTransactionsAttempts(t *testing.T) { { name: "generic error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("TxAttempts", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return(nil, 0, gError) f.App.On("TxmStorageService").Return(f.Mocks.txmStore) }, diff --git a/core/web/resolver/features_test.go b/core/web/resolver/features_test.go index f14f71abc90..76394f038b0 100644 --- a/core/web/resolver/features_test.go +++ b/core/web/resolver/features_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" @@ -23,7 +24,7 @@ func Test_ToFeatures(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetConfig").Return(configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { t, f := true, false c.Feature.UICSAKeys = &f diff --git a/core/web/resolver/feeds_manager_chain_config_test.go b/core/web/resolver/feeds_manager_chain_config_test.go index 31208aa0581..c5dd77c14a1 100644 --- a/core/web/resolver/feeds_manager_chain_config_test.go +++ b/core/web/resolver/feeds_manager_chain_config_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "testing" @@ -71,7 +72,7 @@ func Test_CreateFeedsManagerChainConfig(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CreateChainConfig", mock.Anything, feeds.ChainConfig{ FeedsManagerID: mgrID, @@ -144,7 +145,7 @@ func Test_CreateFeedsManagerChainConfig(t *testing.T) { { name: "create call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CreateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(int64(0), sql.ErrNoRows) }, @@ -161,7 +162,7 @@ func Test_CreateFeedsManagerChainConfig(t *testing.T) { { name: "get call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CreateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(cfgID, nil) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(nil, sql.ErrNoRows) @@ -209,7 +210,7 @@ func Test_DeleteFeedsManagerChainConfig(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(&feeds.ChainConfig{ ID: cfgID, @@ -230,7 +231,7 @@ func Test_DeleteFeedsManagerChainConfig(t *testing.T) { { name: "get call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(nil, sql.ErrNoRows) }, @@ -247,7 +248,7 @@ func Test_DeleteFeedsManagerChainConfig(t *testing.T) { { name: "delete call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(&feeds.ChainConfig{ ID: cfgID, @@ -324,7 +325,7 @@ func Test_UpdateFeedsManagerChainConfig(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateChainConfig", mock.Anything, feeds.ChainConfig{ ID: cfgID, @@ -393,7 +394,7 @@ func Test_UpdateFeedsManagerChainConfig(t *testing.T) { { name: "update call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(int64(0), sql.ErrNoRows) }, @@ -410,7 +411,7 @@ func Test_UpdateFeedsManagerChainConfig(t *testing.T) { { name: "get call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(cfgID, nil) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(nil, sql.ErrNoRows) diff --git a/core/web/resolver/feeds_manager_test.go b/core/web/resolver/feeds_manager_test.go index a3ea80a6443..bafb50ab0d5 100644 --- a/core/web/resolver/feeds_manager_test.go +++ b/core/web/resolver/feeds_manager_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "testing" @@ -40,7 +41,7 @@ func Test_FeedsManagers(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ListJobProposalsByManagersIDs", mock.Anything, []int64{1}).Return([]feeds.JobProposal{ { @@ -113,7 +114,7 @@ func Test_FeedsManager(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(&feeds.FeedsManager{ ID: mgrID, @@ -140,7 +141,7 @@ func Test_FeedsManager(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(nil, sql.ErrNoRows) }, @@ -212,7 +213,7 @@ func Test_CreateFeedsManager(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RegisterManager", mock.Anything, feeds.RegisterManagerParams{ Name: name, @@ -247,7 +248,7 @@ func Test_CreateFeedsManager(t *testing.T) { { name: "single feeds manager error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc. On("RegisterManager", mock.Anything, mock.IsType(feeds.RegisterManagerParams{})). @@ -266,7 +267,7 @@ func Test_CreateFeedsManager(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RegisterManager", mock.Anything, mock.IsType(feeds.RegisterManagerParams{})).Return(mgrID, nil) f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(nil, sql.ErrNoRows) @@ -358,7 +359,7 @@ func Test_UpdateFeedsManager(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateManager", mock.Anything, feeds.FeedsManager{ ID: mgrID, @@ -394,7 +395,7 @@ func Test_UpdateFeedsManager(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateManager", mock.Anything, mock.IsType(feeds.FeedsManager{})).Return(nil) f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(nil, sql.ErrNoRows) diff --git a/core/web/resolver/job_error_test.go b/core/web/resolver/job_error_test.go index 69899a3ec47..e3af3230e27 100644 --- a/core/web/resolver/job_error_test.go +++ b/core/web/resolver/job_error_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "encoding/json" "testing" @@ -27,7 +28,7 @@ func TestResolver_JobErrors(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: int32(1), @@ -123,7 +124,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{ ID: id, Occurrences: 5, @@ -140,7 +141,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "not found on FindSpecError()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -158,7 +159,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "not found on DismissError()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, nil) f.Mocks.jobORM.On("DismissError", mock.Anything, id).Return(sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) @@ -177,7 +178,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "generic error on FindSpecError()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -196,7 +197,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "generic error on DismissError()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, nil) f.Mocks.jobORM.On("DismissError", mock.Anything, id).Return(gError) f.App.On("JobORM").Return(f.Mocks.jobORM) diff --git a/core/web/resolver/job_proposal_spec_test.go b/core/web/resolver/job_proposal_spec_test.go index 5875a5acb69..46364998c9c 100644 --- a/core/web/resolver/job_proposal_spec_test.go +++ b/core/web/resolver/job_proposal_spec_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "testing" "time" @@ -50,7 +51,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ @@ -64,7 +65,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { { name: "not found error on approval", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(sql.ErrNoRows) }, @@ -81,7 +82,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { { name: "not found error on fetch", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) @@ -99,7 +100,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { { name: "unprocessable error on approval if job already exists", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(feeds.ErrJobAlreadyExists) }, @@ -154,7 +155,7 @@ func TestResolver_CancelJobProposalSpec(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CancelSpec", mock.Anything, specID).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ @@ -168,7 +169,7 @@ func TestResolver_CancelJobProposalSpec(t *testing.T) { { name: "not found error on cancel", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CancelSpec", mock.Anything, specID).Return(sql.ErrNoRows) }, @@ -185,7 +186,7 @@ func TestResolver_CancelJobProposalSpec(t *testing.T) { { name: "not found error on fetch", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CancelSpec", mock.Anything, specID).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) @@ -241,7 +242,7 @@ func TestResolver_RejectJobProposalSpec(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RejectSpec", mock.Anything, specID).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ @@ -255,7 +256,7 @@ func TestResolver_RejectJobProposalSpec(t *testing.T) { { name: "not found error on reject", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RejectSpec", mock.Anything, specID).Return(sql.ErrNoRows) }, @@ -272,7 +273,7 @@ func TestResolver_RejectJobProposalSpec(t *testing.T) { { name: "not found error on fetch", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RejectSpec", mock.Anything, specID).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) @@ -331,7 +332,7 @@ func TestResolver_UpdateJobProposalSpecDefinition(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateSpecDefinition", mock.Anything, specID, "").Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ @@ -345,7 +346,7 @@ func TestResolver_UpdateJobProposalSpecDefinition(t *testing.T) { { name: "not found error on update", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateSpecDefinition", mock.Anything, specID, "").Return(sql.ErrNoRows) }, @@ -362,7 +363,7 @@ func TestResolver_UpdateJobProposalSpecDefinition(t *testing.T) { { name: "not found error on fetch", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateSpecDefinition", mock.Anything, specID, "").Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) @@ -443,7 +444,7 @@ func TestResolver_GetJobProposal_Spec(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.feedsSvc.On("GetJobProposal", mock.Anything, jpID).Return(&feeds.JobProposal{ ID: jpID, Status: feeds.JobProposalStatusApproved, diff --git a/core/web/resolver/job_proposal_test.go b/core/web/resolver/job_proposal_test.go index 5544b39c936..3c09435e56e 100644 --- a/core/web/resolver/job_proposal_test.go +++ b/core/web/resolver/job_proposal_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "fmt" "testing" @@ -64,7 +65,7 @@ func TestResolver_GetJobProposal(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.feedsSvc.On("ListManagersByIDs", mock.Anything, []int64{1}).Return([]feeds.FeedsManager{ { ID: 1, @@ -89,7 +90,7 @@ func TestResolver_GetJobProposal(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.feedsSvc.On("GetJobProposal", mock.Anything, jpID).Return(nil, sql.ErrNoRows) f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) }, diff --git a/core/web/resolver/job_run_test.go b/core/web/resolver/job_run_test.go index 51631864e8c..3029710bcc4 100644 --- a/core/web/resolver/job_run_test.go +++ b/core/web/resolver/job_run_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "testing" @@ -39,7 +40,7 @@ func TestQuery_PaginatedJobRuns(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("PipelineRuns", mock.Anything, (*int32)(nil), PageDefaultOffset, PageDefaultLimit).Return([]pipeline.Run{ { ID: int64(200), @@ -63,7 +64,7 @@ func TestQuery_PaginatedJobRuns(t *testing.T) { { name: "generic error on PipelineRuns()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("PipelineRuns", mock.Anything, (*int32)(nil), PageDefaultOffset, PageDefaultLimit).Return(nil, 0, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -130,7 +131,7 @@ func TestResolver_JobRun(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindPipelineRunByID", mock.Anything, int64(2)).Return(pipeline.Run{ ID: 2, PipelineSpecID: 5, @@ -179,7 +180,7 @@ func TestResolver_JobRun(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindPipelineRunByID", mock.Anything, int64(2)).Return(pipeline.Run{}, sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -196,7 +197,7 @@ func TestResolver_JobRun(t *testing.T) { { name: "generic error on FindPipelineRunByID()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindPipelineRunByID", mock.Anything, int64(2)).Return(pipeline.Run{}, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -284,7 +285,7 @@ func TestResolver_RunJob(t *testing.T) { { name: "success without body", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), nil) f.Mocks.pipelineORM.On("FindRun", mock.Anything, int64(25)).Return(pipeline.Run{ ID: 2, @@ -337,7 +338,7 @@ func TestResolver_RunJob(t *testing.T) { { name: "not found job error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), webhook.ErrJobNotExists) }, query: mutation, @@ -355,7 +356,7 @@ func TestResolver_RunJob(t *testing.T) { { name: "generic error on RunJobV2", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), gError) }, query: mutation, @@ -375,7 +376,7 @@ func TestResolver_RunJob(t *testing.T) { { name: "generic error on FindRun", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), nil) f.Mocks.pipelineORM.On("FindRun", mock.Anything, int64(25)).Return(pipeline.Run{}, gError) f.App.On("PipelineORM").Return(f.Mocks.pipelineORM) diff --git a/core/web/resolver/job_test.go b/core/web/resolver/job_test.go index 0615e47a621..e00c4604bca 100644 --- a/core/web/resolver/job_test.go +++ b/core/web/resolver/job_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "encoding/json" "fmt" @@ -68,7 +69,7 @@ func TestResolver_Jobs(t *testing.T) { { name: "get jobs success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { plnSpecID := int32(12) f.App.On("JobORM").Return(f.Mocks.jobORM) @@ -206,7 +207,7 @@ func TestResolver_Job(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: 1, @@ -238,7 +239,7 @@ func TestResolver_Job(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, sql.ErrNoRows) }, @@ -255,7 +256,7 @@ func TestResolver_Job(t *testing.T) { { name: "show job when chainID is disabled", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: 1, @@ -351,7 +352,7 @@ func TestResolver_CreateJob(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetConfig").Return(f.Mocks.cfg) f.App.On("AddJobV2", mock.Anything, &jb).Return(nil) }, @@ -378,7 +379,7 @@ func TestResolver_CreateJob(t *testing.T) { { name: "generic error when adding the job", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetConfig").Return(f.Mocks.cfg) f.App.On("AddJobV2", mock.Anything, &jb).Return(gError) }, @@ -452,7 +453,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: id, Name: null.StringFrom("test-job"), @@ -470,7 +471,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "not found on FindJob()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -488,7 +489,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "not found on DeleteJob()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, nil) f.App.On("JobORM").Return(f.Mocks.jobORM) f.App.On("DeleteJob", mock.Anything, id).Return(sql.ErrNoRows) @@ -507,7 +508,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "generic error on FindJob()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -526,7 +527,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "generic error on DeleteJob()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, nil) f.App.On("JobORM").Return(f.Mocks.jobORM) f.App.On("DeleteJob", mock.Anything, id).Return(gError) diff --git a/core/web/resolver/log_test.go b/core/web/resolver/log_test.go index 8b1b941da5a..cf5620845b2 100644 --- a/core/web/resolver/log_test.go +++ b/core/web/resolver/log_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" gqlerrors "github.com/graph-gophers/graphql-go/errors" @@ -35,7 +36,7 @@ func TestResolver_SetSQLLogging(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.cfg.On("SetLogSQL", true).Return(nil) f.App.On("GetConfig").Return(f.Mocks.cfg) }, @@ -79,7 +80,7 @@ func TestResolver_SQLLogging(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.cfg.On("Database").Return(&databaseConfig{logSQL: false}) f.App.On("GetConfig").Return(f.Mocks.cfg) }, @@ -126,7 +127,7 @@ func TestResolver_GlobalLogLevel(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.cfg.On("Log").Return(&log{level: warnLvl}) f.App.On("GetConfig").Return(f.Mocks.cfg) }, @@ -178,7 +179,7 @@ func TestResolver_SetGlobalLogLevel(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("SetLogLevel", errorLvl).Return(nil) }, query: mutation, @@ -195,7 +196,7 @@ func TestResolver_SetGlobalLogLevel(t *testing.T) { { name: "generic error on SetLogLevel", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("SetLogLevel", errorLvl).Return(gError) }, query: mutation, diff --git a/core/web/resolver/node_test.go b/core/web/resolver/node_test.go index e103a470097..870f694990f 100644 --- a/core/web/resolver/node_test.go +++ b/core/web/resolver/node_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" gqlerrors "github.com/graph-gophers/graphql-go/errors" @@ -39,7 +40,7 @@ func TestResolver_Nodes(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetRelayers").Return(&chainlinkmocks.FakeRelayerChainInteroperators{ Nodes: []types.NodeStatus{ { @@ -78,7 +79,7 @@ func TestResolver_Nodes(t *testing.T) { { name: "generic error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.relayerChainInterops.NodesErr = gError f.App.On("GetRelayers").Return(f.Mocks.relayerChainInterops) }, @@ -122,7 +123,7 @@ func Test_NodeQuery(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetRelayers").Return(&chainlinkmocks.FakeRelayerChainInteroperators{Relayers: []loop.Relayer{ testutils.MockRelayer{NodeStatuses: []types.NodeStatus{ { @@ -146,7 +147,7 @@ func Test_NodeQuery(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetRelayers").Return(&chainlinkmocks.FakeRelayerChainInteroperators{Relayers: []loop.Relayer{}}) }, query: query, diff --git a/core/web/resolver/ocr2_keys_test.go b/core/web/resolver/ocr2_keys_test.go index fc82d070dd9..2269149bc37 100644 --- a/core/web/resolver/ocr2_keys_test.go +++ b/core/web/resolver/ocr2_keys_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/hex" "encoding/json" "fmt" @@ -69,7 +70,7 @@ func TestResolver_GetOCR2KeyBundles(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -80,7 +81,7 @@ func TestResolver_GetOCR2KeyBundles(t *testing.T) { { name: "generic error on GetAll()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("GetAll").Return(nil, gError) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -150,7 +151,7 @@ func TestResolver_CreateOCR2KeyBundle(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Create", mock.Anything, chaintype.ChainType("evm")).Return(fakeKey, nil) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -162,7 +163,7 @@ func TestResolver_CreateOCR2KeyBundle(t *testing.T) { { name: "generic error on Create()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Create", mock.Anything, chaintype.ChainType("evm")).Return(nil, gError) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -238,7 +239,7 @@ func TestResolver_DeleteOCR2KeyBundle(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Delete", mock.Anything, fakeKey.ID()).Return(nil) f.Mocks.ocr2.On("Get", fakeKey.ID()).Return(fakeKey, nil) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) @@ -251,7 +252,7 @@ func TestResolver_DeleteOCR2KeyBundle(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Get", fakeKey.ID()).Return(fakeKey, gError) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -268,7 +269,7 @@ func TestResolver_DeleteOCR2KeyBundle(t *testing.T) { { name: "generic error on Delete()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Delete", mock.Anything, fakeKey.ID()).Return(gError) f.Mocks.ocr2.On("Get", fakeKey.ID()).Return(fakeKey, nil) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) diff --git a/core/web/resolver/ocr_test.go b/core/web/resolver/ocr_test.go index 5ca56c4bd04..e4b74e1d66b 100644 --- a/core/web/resolver/ocr_test.go +++ b/core/web/resolver/ocr_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "math/big" "testing" @@ -54,7 +55,7 @@ func TestResolver_GetOCRKeyBundles(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("OCR").Return(f.Mocks.ocr) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -105,7 +106,7 @@ func TestResolver_OCRCreateBundle(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr.On("Create", mock.Anything).Return(fakeKey, nil) f.Mocks.keystore.On("OCR").Return(f.Mocks.ocr) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -163,7 +164,7 @@ func TestResolver_OCRDeleteBundle(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr.On("Delete", mock.Anything, fakeKey.ID()).Return(fakeKey, nil) f.Mocks.keystore.On("OCR").Return(f.Mocks.ocr) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -175,7 +176,7 @@ func TestResolver_OCRDeleteBundle(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr. On("Delete", mock.Anything, fakeKey.ID()). Return(ocrkey.KeyV2{}, keystore.KeyNotFoundError{ID: "helloWorld", KeyType: "OCR"}) diff --git a/core/web/resolver/p2p_test.go b/core/web/resolver/p2p_test.go index 6502ffc821a..941787aba96 100644 --- a/core/web/resolver/p2p_test.go +++ b/core/web/resolver/p2p_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "fmt" "math/big" @@ -53,7 +54,7 @@ func TestResolver_GetP2PKeys(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.p2p.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("P2P").Return(f.Mocks.p2p) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -102,7 +103,7 @@ func TestResolver_CreateP2PKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.p2p.On("Create", mock.Anything).Return(fakeKey, nil) f.Mocks.keystore.On("P2P").Return(f.Mocks.p2p) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -163,7 +164,7 @@ func TestResolver_DeleteP2PKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.p2p.On("Delete", mock.Anything, peerID).Return(fakeKey, nil) f.Mocks.keystore.On("P2P").Return(f.Mocks.p2p) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -175,7 +176,7 @@ func TestResolver_DeleteP2PKey(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.p2p. On("Delete", mock.Anything, peerID). Return( diff --git a/core/web/resolver/resolver_test.go b/core/web/resolver/resolver_test.go index 56b7f076eca..9f3445e1cee 100644 --- a/core/web/resolver/resolver_test.go +++ b/core/web/resolver/resolver_test.go @@ -72,9 +72,6 @@ type gqlTestFramework struct { // The root GQL schema RootSchema *graphql.Schema - // Contains the context with an injected dataloader - Ctx context.Context - Mocks *mocks } @@ -88,7 +85,6 @@ func setupFramework(t *testing.T) *gqlTestFramework { schema.MustGetRootSchema(), &Resolver{App: app}, ) - ctx = loader.InjectDataloader(testutils.Context(t), app) ) // Setup mocks @@ -128,7 +124,6 @@ func setupFramework(t *testing.T) *gqlTestFramework { t: t, App: app, RootSchema: rootSchema, - Ctx: ctx, Mocks: m, } @@ -146,20 +141,18 @@ func (f *gqlTestFramework) Timestamp() time.Time { return time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) } -// injectAuthenticatedUser injects a session into the request context -func (f *gqlTestFramework) injectAuthenticatedUser() { - f.t.Helper() - +// withAuthenticatedUser injects a session into the request context +func (f *gqlTestFramework) withAuthenticatedUser(ctx context.Context) context.Context { user := clsessions.User{Email: "gqltester@chain.link", Role: clsessions.UserRoleAdmin} - f.Ctx = auth.WithGQLAuthenticatedSession(f.Ctx, user, "gqltesterSession") + return auth.WithGQLAuthenticatedSession(ctx, user, "gqltesterSession") } // GQLTestCase represents a single GQL request test. type GQLTestCase struct { name string authenticated bool - before func(*gqlTestFramework) + before func(context.Context, *gqlTestFramework) query string variables map[string]interface{} result string @@ -175,16 +168,15 @@ func RunGQLTests(t *testing.T, testCases []GQLTestCase) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - var ( - f = setupFramework(t) - ) + f := setupFramework(t) + ctx := loader.InjectDataloader(testutils.Context(t), f.App) if tc.authenticated { - f.injectAuthenticatedUser() + ctx = f.withAuthenticatedUser(ctx) } if tc.before != nil { - tc.before(f) + tc.before(ctx, f) } // This does not print out the correct stack trace as the `RunTest` @@ -193,7 +185,7 @@ func RunGQLTests(t *testing.T, testCases []GQLTestCase) { // // This would need to be fixed upstream. gqltesting.RunTest(t, &gqltesting.Test{ - Context: f.Ctx, + Context: ctx, Schema: f.RootSchema, Query: tc.query, Variables: tc.variables, diff --git a/core/web/resolver/solana_key_test.go b/core/web/resolver/solana_key_test.go index 5472f12081d..e788e9e7ce2 100644 --- a/core/web/resolver/solana_key_test.go +++ b/core/web/resolver/solana_key_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "errors" "fmt" "testing" @@ -40,7 +41,7 @@ func TestResolver_SolanaKeys(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.solana.On("GetAll").Return([]solkey.Key{k}, nil) f.Mocks.keystore.On("Solana").Return(f.Mocks.solana) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -51,7 +52,7 @@ func TestResolver_SolanaKeys(t *testing.T) { { name: "generic error on GetAll", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.solana.On("GetAll").Return([]solkey.Key{}, gError) f.Mocks.keystore.On("Solana").Return(f.Mocks.solana) f.App.On("GetKeyStore").Return(f.Mocks.keystore) diff --git a/core/web/resolver/spec_test.go b/core/web/resolver/spec_test.go index 43682c14ead..63002e566f1 100644 --- a/core/web/resolver/spec_test.go +++ b/core/web/resolver/spec_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" "time" @@ -34,7 +35,7 @@ func TestResolver_CronSpec(t *testing.T) { { name: "cron spec success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Cron, @@ -88,7 +89,7 @@ func TestResolver_DirectRequestSpec(t *testing.T) { { name: "direct request spec success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.DirectRequest, @@ -153,7 +154,7 @@ func TestResolver_FluxMonitorSpec(t *testing.T) { { name: "flux monitor spec with standard timers", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.FluxMonitor, @@ -220,7 +221,7 @@ func TestResolver_FluxMonitorSpec(t *testing.T) { { name: "flux monitor spec with drumbeat", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.FluxMonitor, @@ -303,7 +304,7 @@ func TestResolver_KeeperSpec(t *testing.T) { { name: "keeper spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Keeper, @@ -367,7 +368,7 @@ func TestResolver_OCRSpec(t *testing.T) { { name: "OCR spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.OffchainReporting, @@ -472,7 +473,7 @@ func TestResolver_OCR2Spec(t *testing.T) { { name: "OCR 2 spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.OffchainReporting2, @@ -574,7 +575,7 @@ func TestResolver_VRFSpec(t *testing.T) { { name: "vrf spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.VRF, @@ -670,7 +671,7 @@ func TestResolver_WebhookSpec(t *testing.T) { { name: "webhook spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Webhook, @@ -739,7 +740,7 @@ func TestResolver_BlockhashStoreSpec(t *testing.T) { { name: "blockhash store spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.BlockhashStore, @@ -843,7 +844,7 @@ func TestResolver_BlockHeaderFeederSpec(t *testing.T) { { name: "block header feeder spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.BlockHeaderFeeder, @@ -930,7 +931,7 @@ func TestResolver_BootstrapSpec(t *testing.T) { { name: "Bootstrap spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Bootstrap, @@ -1002,7 +1003,7 @@ func TestResolver_WorkflowSpec(t *testing.T) { { name: "Workflow spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Workflow, @@ -1060,7 +1061,7 @@ func TestResolver_GatewaySpec(t *testing.T) { { name: "Gateway spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Gateway, diff --git a/core/web/resolver/user_test.go b/core/web/resolver/user_test.go index 2662b1f5040..8d37af2e379 100644 --- a/core/web/resolver/user_test.go +++ b/core/web/resolver/user_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" gqlerrors "github.com/graph-gophers/graphql-go/errors" @@ -44,8 +45,8 @@ func TestResolver_UpdateUserPassword(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := auth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := auth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -73,8 +74,8 @@ func TestResolver_UpdateUserPassword(t *testing.T) { { name: "update password match error", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := auth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := auth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -99,8 +100,8 @@ func TestResolver_UpdateUserPassword(t *testing.T) { { name: "failed to clear session error", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := auth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := auth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -130,8 +131,8 @@ func TestResolver_UpdateUserPassword(t *testing.T) { { name: "failed to update current user password error", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := auth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := auth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) diff --git a/core/web/resolver/vrf_test.go b/core/web/resolver/vrf_test.go index 5101bc5937b..c77f7b73ff0 100644 --- a/core/web/resolver/vrf_test.go +++ b/core/web/resolver/vrf_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "fmt" "math/big" @@ -64,7 +65,7 @@ func TestResolver_GetVRFKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf.On("Get", fakeKey.PublicKey.String()).Return(fakeKey, nil) f.Mocks.keystore.On("VRF").Return(f.Mocks.vrf) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -76,7 +77,7 @@ func TestResolver_GetVRFKey(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf. On("Get", fakeKey.PublicKey.String()). Return(vrfkey.KeyV2{}, errors.Wrapf( @@ -146,7 +147,7 @@ func TestResolver_GetVRFKeys(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("VRF").Return(f.Mocks.vrf) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -198,7 +199,7 @@ func TestResolver_CreateVRFKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf.On("Create", mock.Anything).Return(fakeKey, nil) f.Mocks.keystore.On("VRF").Return(f.Mocks.vrf) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -261,7 +262,7 @@ func TestResolver_DeleteVRFKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf.On("Delete", mock.Anything, fakeKey.PublicKey.String()).Return(fakeKey, nil) f.Mocks.keystore.On("VRF").Return(f.Mocks.vrf) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -273,7 +274,7 @@ func TestResolver_DeleteVRFKey(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf. On("Delete", mock.Anything, fakeKey.PublicKey.String()). Return(vrfkey.KeyV2{}, errors.Wrapf( diff --git a/core/web/solana_chains_controller_test.go b/core/web/solana_chains_controller_test.go index 1377cb65aba..a2ac904b783 100644 --- a/core/web/solana_chains_controller_test.go +++ b/core/web/solana_chains_controller_test.go @@ -15,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" @@ -80,7 +79,7 @@ Nodes = [] t.Run(tc.name, func(t *testing.T) { t.Parallel() - controller := setupSolanaChainsControllerTestV2(t, &solana.TOMLConfig{ + controller := setupSolanaChainsControllerTestV2(t, &config.TOMLConfig{ ChainID: ptr(validId), Chain: config.Chain{ SkipPreflight: ptr(false), @@ -111,13 +110,13 @@ Nodes = [] func Test_SolanaChainsController_Index(t *testing.T) { t.Parallel() - chainA := &solana.TOMLConfig{ + chainA := &config.TOMLConfig{ ChainID: ptr(fmt.Sprintf("ChainlinktestA-%d", rand.Int31n(999999))), Chain: config.Chain{ TxTimeout: commoncfg.MustNewDuration(time.Hour), }, } - chainB := &solana.TOMLConfig{ + chainB := &config.TOMLConfig{ ChainID: ptr(fmt.Sprintf("ChainlinktestB-%d", rand.Int31n(999999))), Chain: config.Chain{ SkipPreflight: ptr(false), @@ -175,7 +174,7 @@ type TestSolanaChainsController struct { client cltest.HTTPClientCleaner } -func setupSolanaChainsControllerTestV2(t *testing.T, cfgs ...*solana.TOMLConfig) *TestSolanaChainsController { +func setupSolanaChainsControllerTestV2(t *testing.T, cfgs ...*config.TOMLConfig) *TestSolanaChainsController { for i := range cfgs { cfgs[i].SetDefaults() } diff --git a/crib/devspace.yaml b/crib/devspace.yaml index f22e710f943..229c0829d02 100644 --- a/crib/devspace.yaml +++ b/crib/devspace.yaml @@ -100,7 +100,7 @@ deployments: releaseName: "app" chart: name: ${CHAINLINK_CLUSTER_HELM_CHART_URI} - version: 0.6.0 + version: "1.1.0" # for simplicity, we define all the values here # they can be defined the same way in values.yml # devspace merges these "values" with the "values.yaml" before deploy @@ -158,7 +158,7 @@ deployments: # extraEnvVars: # "CL_MEDIAN_CMD": "chainlink-feeds" nodes: - - name: node-1 + node1: image: ${runtime.images.app} # default resources are 300m/1Gi # first node need more resources to build faster inside container @@ -209,13 +209,13 @@ deployments: # CollectorTarget = 'app-opentelemetry-collector:4317' # TLSCertPath = '' # Mode = 'unencrypted' - - name: node-2 + node2: image: ${runtime.images.app} - - name: node-3 + node3: image: ${runtime.images.app} - - name: node-4 + node4: image: ${runtime.images.app} - - name: node-5 + node5: image: ${runtime.images.app} # each CL node have a dedicated PostgreSQL 11.15 @@ -307,7 +307,7 @@ deployments: - path: / backend: service: - name: app-node-1 + name: app-node1 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-node2.${DEVSPACE_INGRESS_BASE_DOMAIN} @@ -316,7 +316,7 @@ deployments: - path: / backend: service: - name: app-node-2 + name: app-node2 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-node3.${DEVSPACE_INGRESS_BASE_DOMAIN} @@ -325,7 +325,7 @@ deployments: - path: / backend: service: - name: app-node-3 + name: app-node3 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-node4.${DEVSPACE_INGRESS_BASE_DOMAIN} @@ -334,7 +334,7 @@ deployments: - path: / backend: service: - name: app-node-4 + name: app-node4 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-node5.${DEVSPACE_INGRESS_BASE_DOMAIN} @@ -343,7 +343,7 @@ deployments: - path: / backend: service: - name: app-node-5 + name: app-node5 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-geth-1337-http.${DEVSPACE_INGRESS_BASE_DOMAIN} diff --git a/go.mod b/go.mod index 9b72abb4db0..9ce8d98b91d 100644 --- a/go.mod +++ b/go.mod @@ -76,7 +76,7 @@ require ( github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240508101745-af1ed7bc8a69 github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab - github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 + github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 github.com/smartcontractkit/chainlink-vrf v0.0.0-20231120191722-fef03814f868 github.com/smartcontractkit/libocr v0.0.0-20240419185742-fd3cab206b2c diff --git a/go.sum b/go.sum index e452df784af..a41f19e2cb1 100644 --- a/go.sum +++ b/go.sum @@ -1179,8 +1179,8 @@ github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab/go.mod h1:RPUY7r8GxgzXxS1ijtU1P/fpJomOXztXgUbEziNmbCA= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 h1:f3W82k9V/XA6ZP/VQVJcGMVR6CrL3pQrPJSwyQWVFys= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83/go.mod h1:RdAtOeBUWq2zByw2kEbwPlXaPIb7YlaDOmnn+nVUBJI= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc h1:ZqgatXFWsJR/hkvm2mKAta6ivXZqTw7542Iz9ucoOq0= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc/go.mod h1:sR0dMjjpvvEpX3qH8DPRANauPkbO9jgUUGYK95xjLRU= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 h1:ssh/w3oXWu+C6bE88GuFRC1+0Bx/4ihsbc80XMLrl2k= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69/go.mod h1:VsfjhvWgjxqWja4q+FlXEtX5lu8BSxn10xRo6gi948g= github.com/smartcontractkit/chainlink-vrf v0.0.0-20231120191722-fef03814f868 h1:FFdvEzlYwcuVHkdZ8YnZR/XomeMGbz5E2F2HZI3I3w8= diff --git a/integration-tests/actions/vrf/common/errors.go b/integration-tests/actions/vrf/common/errors.go index 62164d0b274..78b7457e29f 100644 --- a/integration-tests/actions/vrf/common/errors.go +++ b/integration-tests/actions/vrf/common/errors.go @@ -25,6 +25,7 @@ const ( ErrLoadingCoordinator = "error loading coordinator contract" ErrCreatingVRFKey = "error creating VRF key" - ErrWaitRandomWordsRequestedEvent = "error waiting for RandomWordsRequested event" - ErrWaitRandomWordsFulfilledEvent = "error waiting for RandomWordsFulfilled event" + ErrWaitRandomWordsRequestedEvent = "error waiting for RandomWordsRequested event" + ErrWaitRandomWordsFulfilledEvent = "error waiting for RandomWordsFulfilled event" + ErrFilterRandomWordsFulfilledEvent = "error filtering RandomWordsFulfilled event" ) diff --git a/integration-tests/actions/vrf/vrfv2/contract_steps.go b/integration-tests/actions/vrf/vrfv2/contract_steps.go index 4f0ac9a5b6c..92da7b9f86b 100644 --- a/integration-tests/actions/vrf/vrfv2/contract_steps.go +++ b/integration-tests/actions/vrf/vrfv2/contract_steps.go @@ -6,6 +6,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/rs/zerolog" "github.com/shopspring/decimal" @@ -392,6 +393,7 @@ func DirectFundingRequestRandomnessAndWaitForFulfillment( fulfillmentEvents, err := WaitRandomWordsFulfilledEvent( coordinator, randomWordsRequestedEvent.RequestId, + randomWordsRequestedEvent.Raw.BlockNumber, randomWordsFulfilledEventTimeout, l, ) @@ -431,6 +433,7 @@ func RequestRandomnessAndWaitForFulfillment( randomWordsFulfilledEvent, err := WaitRandomWordsFulfilledEvent( coordinator, randomWordsRequestedEvent.RequestId, + randomWordsRequestedEvent.Raw.BlockNumber, randomWordsFulfilledEventTimeout, l, ) @@ -582,6 +585,7 @@ func RequestRandomnessWithForceFulfillAndWaitForFulfillment( func WaitRandomWordsFulfilledEvent( coordinator contracts.Coordinator, requestId *big.Int, + randomWordsRequestedEventBlockNumber uint64, randomWordsFulfilledEventTimeout time.Duration, l zerolog.Logger, ) (*contracts.CoordinatorRandomWordsFulfilled, error) { @@ -592,7 +596,18 @@ func WaitRandomWordsFulfilledEvent( }, ) if err != nil { - return nil, fmt.Errorf("%s, err %w", vrfcommon.ErrWaitRandomWordsFulfilledEvent, err) + l.Warn(). + Str("requestID", requestId.String()). + Err(err).Msg("Error waiting for random words fulfilled event, trying to filter for the event") + randomWordsFulfilledEvent, err = coordinator.FilterRandomWordsFulfilledEvent( + &bind.FilterOpts{ + Start: randomWordsRequestedEventBlockNumber, + }, + requestId, + ) + if err != nil { + return nil, fmt.Errorf(vrfcommon.ErrGenericFormat, vrfcommon.ErrFilterRandomWordsFulfilledEvent, err) + } } vrfcommon.LogRandomWordsFulfilledEvent(l, coordinator, randomWordsFulfilledEvent, false, 0) return randomWordsFulfilledEvent, err diff --git a/integration-tests/actions/vrf/vrfv2plus/contract_steps.go b/integration-tests/actions/vrf/vrfv2plus/contract_steps.go index 6151931d566..ed6139c4f17 100644 --- a/integration-tests/actions/vrf/vrfv2plus/contract_steps.go +++ b/integration-tests/actions/vrf/vrfv2plus/contract_steps.go @@ -6,6 +6,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/rs/zerolog" "github.com/shopspring/decimal" @@ -356,6 +357,7 @@ func RequestRandomnessAndWaitForFulfillment( coordinator, randomWordsRequestedEvent.RequestId, subID, + randomWordsRequestedEvent.Raw.BlockNumber, isNativeBilling, config.RandomWordsFulfilledEventTimeout.Duration, l, @@ -449,6 +451,7 @@ func DirectFundingRequestRandomnessAndWaitForFulfillment( coordinator, randomWordsRequestedEvent.RequestId, subID, + randomWordsRequestedEvent.Raw.BlockNumber, isNativeBilling, config.RandomWordsFulfilledEventTimeout.Duration, l, @@ -460,6 +463,7 @@ func WaitRandomWordsFulfilledEvent( coordinator contracts.Coordinator, requestId *big.Int, subID *big.Int, + randomWordsRequestedEventBlockNumber uint64, isNativeBilling bool, randomWordsFulfilledEventTimeout time.Duration, l zerolog.Logger, @@ -473,7 +477,18 @@ func WaitRandomWordsFulfilledEvent( }, ) if err != nil { - return nil, fmt.Errorf(vrfcommon.ErrGenericFormat, vrfcommon.ErrWaitRandomWordsFulfilledEvent, err) + l.Warn(). + Str("requestID", requestId.String()). + Err(err).Msg("Error waiting for random words fulfilled event, trying to filter for the event") + randomWordsFulfilledEvent, err = coordinator.FilterRandomWordsFulfilledEvent( + &bind.FilterOpts{ + Start: randomWordsRequestedEventBlockNumber, + }, + requestId, + ) + if err != nil { + return nil, fmt.Errorf(vrfcommon.ErrGenericFormat, vrfcommon.ErrFilterRandomWordsFulfilledEvent, err) + } } vrfcommon.LogRandomWordsFulfilledEvent(l, coordinator, randomWordsFulfilledEvent, isNativeBilling, keyNum) diff --git a/integration-tests/contracts/contract_vrf_models.go b/integration-tests/contracts/contract_vrf_models.go index 9ed08048998..c30eadd3d3d 100644 --- a/integration-tests/contracts/contract_vrf_models.go +++ b/integration-tests/contracts/contract_vrf_models.go @@ -5,6 +5,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -72,6 +73,7 @@ type VRFCoordinatorV2 interface { ParseRandomWordsFulfilled(log types.Log) (*CoordinatorRandomWordsFulfilled, error) ParseLog(log types.Log) (generated.AbigenLog, error) FindSubscriptionID(subID uint64) (uint64, error) + FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) WaitForConfigSetEvent(timeout time.Duration) (*CoordinatorConfigSet, error) OracleWithdraw(recipient common.Address, amount *big.Int) error @@ -115,6 +117,7 @@ type VRFCoordinatorV2_5 interface { GetNativeTokenTotalBalance(ctx context.Context) (*big.Int, error) GetLinkTotalBalance(ctx context.Context) (*big.Int, error) FindSubscriptionID(subID *big.Int) (*big.Int, error) + FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) ParseRandomWordsRequested(log types.Log) (*CoordinatorRandomWordsRequested, error) ParseRandomWordsFulfilled(log types.Log) (*CoordinatorRandomWordsFulfilled, error) @@ -154,6 +157,7 @@ type VRFCoordinatorV2PlusUpgradedVersion interface { GetSubscription(ctx context.Context, subID *big.Int) (vrf_v2plus_upgraded_version.GetSubscription, error) GetActiveSubscriptionIds(ctx context.Context, startIndex *big.Int, maxCount *big.Int) ([]*big.Int, error) FindSubscriptionID() (*big.Int, error) + FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) ParseRandomWordsRequested(log types.Log) (*CoordinatorRandomWordsRequested, error) ParseRandomWordsFulfilled(log types.Log) (*CoordinatorRandomWordsFulfilled, error) diff --git a/integration-tests/contracts/ethereum_vrf_common.go b/integration-tests/contracts/ethereum_vrf_common.go index f0498b6efe6..62a1809cfa6 100644 --- a/integration-tests/contracts/ethereum_vrf_common.go +++ b/integration-tests/contracts/ethereum_vrf_common.go @@ -5,6 +5,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -18,6 +19,7 @@ type Coordinator interface { Address() string WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) WaitForConfigSetEvent(timeout time.Duration) (*CoordinatorConfigSet, error) + FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) } type Subscription struct { diff --git a/integration-tests/contracts/ethereum_vrfv2_contracts.go b/integration-tests/contracts/ethereum_vrfv2_contracts.go index 9a8ab1c5b5b..be588ea3e3a 100644 --- a/integration-tests/contracts/ethereum_vrfv2_contracts.go +++ b/integration-tests/contracts/ethereum_vrfv2_contracts.go @@ -629,11 +629,9 @@ func (v *EthereumVRFCoordinatorV2) FindSubscriptionID(subID uint64) (uint64, err if err != nil { return 0, err } - if !subscriptionIterator.Next() { return 0, fmt.Errorf("expected at least 1 subID for the given owner %s", owner) } - return subscriptionIterator.Event.SubId, nil } @@ -649,6 +647,26 @@ func (v *EthereumVRFCoordinatorV2) GetBlockHashStoreAddress(ctx context.Context) return blockHashStoreAddress, nil } +func (v *EthereumVRFCoordinatorV2) FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) { + iterator, err := v.coordinator.FilterRandomWordsFulfilled( + opts, + []*big.Int{requestId}, + ) + if err != nil { + return nil, err + } + if !iterator.Next() { + return nil, fmt.Errorf("expected at least 1 RandomWordsFulfilled event for request Id: %s", requestId.String()) + } + return &CoordinatorRandomWordsFulfilled{ + RequestId: iterator.Event.RequestId, + OutputSeed: iterator.Event.OutputSeed, + Payment: iterator.Event.Payment, + Success: iterator.Event.Success, + Raw: iterator.Event.Raw, + }, nil +} + func (v *EthereumVRFCoordinatorV2) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) { randomWordsFulfilledEventsChannel := make(chan *vrf_coordinator_v2.VRFCoordinatorV2RandomWordsFulfilled) subscription, err := v.coordinator.WatchRandomWordsFulfilled(nil, randomWordsFulfilledEventsChannel, filter.RequestIds) diff --git a/integration-tests/contracts/ethereum_vrfv2plus_contracts.go b/integration-tests/contracts/ethereum_vrfv2plus_contracts.go index ba7234fbbf3..882baafcd19 100644 --- a/integration-tests/contracts/ethereum_vrfv2plus_contracts.go +++ b/integration-tests/contracts/ethereum_vrfv2plus_contracts.go @@ -480,24 +480,28 @@ func (v *EthereumVRFCoordinatorV2_5) FindSubscriptionID(subID *big.Int) (*big.In return subscriptionIterator.Event.SubId, nil } -func (v *EthereumVRFCoordinatorV2_5) WaitForSubscriptionCanceledEvent(subID *big.Int, timeout time.Duration) (*vrf_coordinator_v2_5.VRFCoordinatorV25SubscriptionCanceled, error) { - eventsChannel := make(chan *vrf_coordinator_v2_5.VRFCoordinatorV25SubscriptionCanceled) - subscription, err := v.coordinator.WatchSubscriptionCanceled(nil, eventsChannel, []*big.Int{subID}) +func (v *EthereumVRFCoordinatorV2_5) FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) { + iterator, err := v.coordinator.FilterRandomWordsFulfilled( + opts, + []*big.Int{requestId}, + nil, + ) if err != nil { return nil, err } - defer subscription.Unsubscribe() - - for { - select { - case err := <-subscription.Err(): - return nil, err - case <-time.After(timeout): - return nil, fmt.Errorf("timeout waiting for SubscriptionCanceled event") - case sub := <-eventsChannel: - return sub, nil - } + if !iterator.Next() { + return nil, fmt.Errorf("expected at least 1 RandomWordsFulfilled event for request Id: %s", requestId.String()) } + return &CoordinatorRandomWordsFulfilled{ + RequestId: iterator.Event.RequestId, + OutputSeed: iterator.Event.OutputSeed, + SubId: iterator.Event.SubId.String(), + Payment: iterator.Event.Payment, + NativePayment: iterator.Event.NativePayment, + Success: iterator.Event.Success, + OnlyPremium: iterator.Event.OnlyPremium, + Raw: iterator.Event.Raw, + }, nil } func (v *EthereumVRFCoordinatorV2_5) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) { @@ -902,14 +906,36 @@ func (v *EthereumVRFCoordinatorV2PlusUpgradedVersion) FindSubscriptionID() (*big if err != nil { return nil, err } - if !subscriptionIterator.Next() { return nil, fmt.Errorf("expected at least 1 subID for the given owner %s", owner) } - return subscriptionIterator.Event.SubId, nil } +func (v *EthereumVRFCoordinatorV2PlusUpgradedVersion) FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) { + iterator, err := v.coordinator.FilterRandomWordsFulfilled( + opts, + []*big.Int{requestId}, + nil, + ) + if err != nil { + return nil, err + } + if !iterator.Next() { + return nil, fmt.Errorf("expected at least 1 RandomWordsFulfilled event for request Id: %s", requestId.String()) + } + return &CoordinatorRandomWordsFulfilled{ + RequestId: iterator.Event.RequestId, + OutputSeed: iterator.Event.OutputSeed, + SubId: iterator.Event.SubId.String(), + Payment: iterator.Event.Payment, + NativePayment: iterator.Event.NativePayment, + Success: iterator.Event.Success, + OnlyPremium: iterator.Event.OnlyPremium, + Raw: iterator.Event.Raw, + }, nil +} + func (v *EthereumVRFCoordinatorV2PlusUpgradedVersion) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) { randomWordsFulfilledEventsChannel := make(chan *vrf_v2plus_upgraded_version.VRFCoordinatorV2PlusUpgradedVersionRandomWordsFulfilled) subscription, err := v.coordinator.WatchRandomWordsFulfilled(nil, randomWordsFulfilledEventsChannel, filter.RequestIds, filter.SubIDs) diff --git a/integration-tests/docker/test_env/test_env_builder.go b/integration-tests/docker/test_env/test_env_builder.go index 852918cc7d4..f5a5e558572 100644 --- a/integration-tests/docker/test_env/test_env_builder.go +++ b/integration-tests/docker/test_env/test_env_builder.go @@ -279,7 +279,7 @@ func (b *CLTestEnvBuilder) Build() (*CLClusterTestEnv, error) { } // this clean up has to be added as the FIRST one, because cleanup functions are executed in reverse order (LIFO) - if b.t != nil && b.cleanUpType == CleanUpTypeStandard { + if b.t != nil && b.cleanUpType != CleanUpTypeNone { b.t.Cleanup(func() { b.l.Info().Msg("Shutting down LogStream") logPath, err := osutil.GetAbsoluteFolderPath("logs") @@ -306,21 +306,24 @@ func (b *CLTestEnvBuilder) Build() (*CLClusterTestEnv, error) { // we cannot do parallel processing here, because ProcessContainerLogs() locks a mutex that controls whether // new logs can be added to the log stream, so parallel processing would get stuck on waiting for it to be unlocked + LogScanningLoop: for i := 0; i < b.clNodesCount; i++ { // ignore count return, because we are only interested in the error _, err := logProcessor.ProcessContainerLogs(b.te.ClCluster.Nodes[i].ContainerName, processFn) if err != nil && !strings.Contains(err.Error(), testreporters.MultipleLogsAtLogLevelErr) && !strings.Contains(err.Error(), testreporters.OneLogAtLogLevelErr) { - b.l.Error().Err(err).Msg("Error processing logs") - return + b.l.Error().Err(err).Msg("Error processing CL node logs") + continue } else if err != nil && (strings.Contains(err.Error(), testreporters.MultipleLogsAtLogLevelErr) || strings.Contains(err.Error(), testreporters.OneLogAtLogLevelErr)) { flushLogStream = true - b.t.Fatalf("Found a concerning log in Chainklink Node logs: %v", err) + b.t.Errorf("Found a concerning log in Chainklink Node logs: %v", err) + break LogScanningLoop } } b.l.Info().Msg("Finished scanning Chainlink Node logs for concerning errors") } if flushLogStream { + b.l.Info().Msg("Flushing LogStream logs") // we can't do much if this fails, so we just log the error in LogStream if err := b.te.LogStream.FlushAndShutdown(); err != nil { b.l.Error().Err(err).Msg("Error flushing and shutting down LogStream") @@ -328,9 +331,10 @@ func (b *CLTestEnvBuilder) Build() (*CLClusterTestEnv, error) { b.te.LogStream.PrintLogTargetsLocations() b.te.LogStream.SaveLogLocationInTestSummary() } + b.l.Info().Msg("Finished shutting down LogStream") }) } else { - b.l.Warn().Msg("LogStream won't be cleaned up, because test instance is not set or cleanup type is not standard") + b.l.Warn().Msg("LogStream won't be cleaned up, because either test instance is not set or cleanup type is set to none") } } diff --git a/integration-tests/go.mod b/integration-tests/go.mod index 144b1f62643..9442426eeb3 100644 --- a/integration-tests/go.mod +++ b/integration-tests/go.mod @@ -376,7 +376,7 @@ require ( github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240508101745-af1ed7bc8a69 // indirect github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 // indirect github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab // indirect - github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 // indirect + github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc // indirect github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 // indirect github.com/smartcontractkit/chainlink-testing-framework/grafana v0.0.0-20240328204215-ac91f55f1449 // indirect github.com/smartcontractkit/tdh2/go/ocr2/decryptionplugin v0.0.0-20230906073235-9e478e5e19f1 // indirect diff --git a/integration-tests/go.sum b/integration-tests/go.sum index 98d30f02d35..477edf4ab53 100644 --- a/integration-tests/go.sum +++ b/integration-tests/go.sum @@ -1520,8 +1520,8 @@ github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab/go.mod h1:RPUY7r8GxgzXxS1ijtU1P/fpJomOXztXgUbEziNmbCA= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 h1:f3W82k9V/XA6ZP/VQVJcGMVR6CrL3pQrPJSwyQWVFys= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83/go.mod h1:RdAtOeBUWq2zByw2kEbwPlXaPIb7YlaDOmnn+nVUBJI= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc h1:ZqgatXFWsJR/hkvm2mKAta6ivXZqTw7542Iz9ucoOq0= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc/go.mod h1:sR0dMjjpvvEpX3qH8DPRANauPkbO9jgUUGYK95xjLRU= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 h1:ssh/w3oXWu+C6bE88GuFRC1+0Bx/4ihsbc80XMLrl2k= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69/go.mod h1:VsfjhvWgjxqWja4q+FlXEtX5lu8BSxn10xRo6gi948g= github.com/smartcontractkit/chainlink-testing-framework v1.28.15 h1:mga7N6jtXQ3UOCt43IdsEnCMBh9xjOWPaE9BiM6kr6Q= diff --git a/integration-tests/load/go.mod b/integration-tests/load/go.mod index 162e506196b..395199f599b 100644 --- a/integration-tests/load/go.mod +++ b/integration-tests/load/go.mod @@ -369,7 +369,7 @@ require ( github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240508101745-af1ed7bc8a69 // indirect github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 // indirect github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab // indirect - github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 // indirect + github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc // indirect github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 // indirect github.com/smartcontractkit/chainlink-testing-framework/grafana v0.0.0-20240328204215-ac91f55f1449 // indirect github.com/smartcontractkit/chainlink-vrf v0.0.0-20231120191722-fef03814f868 // indirect diff --git a/integration-tests/load/go.sum b/integration-tests/load/go.sum index 35c37e1ecf9..95624177f68 100644 --- a/integration-tests/load/go.sum +++ b/integration-tests/load/go.sum @@ -1510,8 +1510,8 @@ github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab/go.mod h1:RPUY7r8GxgzXxS1ijtU1P/fpJomOXztXgUbEziNmbCA= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 h1:f3W82k9V/XA6ZP/VQVJcGMVR6CrL3pQrPJSwyQWVFys= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83/go.mod h1:RdAtOeBUWq2zByw2kEbwPlXaPIb7YlaDOmnn+nVUBJI= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc h1:ZqgatXFWsJR/hkvm2mKAta6ivXZqTw7542Iz9ucoOq0= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc/go.mod h1:sR0dMjjpvvEpX3qH8DPRANauPkbO9jgUUGYK95xjLRU= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 h1:ssh/w3oXWu+C6bE88GuFRC1+0Bx/4ihsbc80XMLrl2k= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69/go.mod h1:VsfjhvWgjxqWja4q+FlXEtX5lu8BSxn10xRo6gi948g= github.com/smartcontractkit/chainlink-testing-framework v1.28.15 h1:mga7N6jtXQ3UOCt43IdsEnCMBh9xjOWPaE9BiM6kr6Q= diff --git a/integration-tests/smoke/vrfv2_test.go b/integration-tests/smoke/vrfv2_test.go index cc37ad983ce..18a017110c7 100644 --- a/integration-tests/smoke/vrfv2_test.go +++ b/integration-tests/smoke/vrfv2_test.go @@ -861,10 +861,11 @@ func TestVRFV2WithBHS(t *testing.T) { CleanupFn: cleanupFn, } newEnvConfig := vrfcommon.NewEnvConfig{ - NodesToCreate: []vrfcommon.VRFNodeType{vrfcommon.VRF, vrfcommon.BHS}, - NumberOfTxKeysToCreate: 0, - UseVRFOwner: false, - UseTestCoordinator: false, + NodesToCreate: []vrfcommon.VRFNodeType{vrfcommon.VRF, vrfcommon.BHS}, + NumberOfTxKeysToCreate: 0, + UseVRFOwner: false, + UseTestCoordinator: false, + ChainlinkNodeLogScannerSettings: test_env.DefaultChainlinkNodeLogScannerSettings, } testEnv, vrfContracts, vrfKey, nodeTypeToNodeMap, err = vrfv2.SetupVRFV2Universe(testcontext.Get(t), t, vrfEnvConfig, newEnvConfig, l) require.NoError(t, err, "Error setting up VRFV2 universe") @@ -1055,8 +1056,11 @@ func TestVRFV2NodeReorg(t *testing.T) { config, err := tc.GetConfig("Smoke", tc.VRFv2) require.NoError(t, err, "Error getting config") vrfv2Config := config.VRFv2 - chainID := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0].ChainID - + network := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0] + if !network.Simulated { + t.Skip("Skipped since Reorg test could only be run on Simulated chain.") + } + chainID := network.ChainID cleanupFn := func() { sethClient, err := env.GetSethClient(chainID) require.NoError(t, err, "Getting Seth client shouldn't fail") @@ -1165,6 +1169,7 @@ func TestVRFV2NodeReorg(t *testing.T) { _, err = vrfv2.WaitRandomWordsFulfilledEvent( vrfContracts.CoordinatorV2, randomWordsRequestedEvent.RequestId, + randomWordsRequestedEvent.Raw.BlockNumber, configCopy.VRFv2.General.RandomWordsFulfilledEventTimeout.Duration, l, ) @@ -1235,8 +1240,8 @@ func TestVRFv2BatchFulfillmentEnabledDisabled(t *testing.T) { config, err := tc.GetConfig("Smoke", tc.VRFv2) require.NoError(t, err, "Error getting config") vrfv2Config := config.VRFv2 - chainID := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0].ChainID - + network := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0] + chainID := network.ChainID cleanupFn := func() { sethClient, err := env.GetSethClient(chainID) require.NoError(t, err, "Getting Seth client shouldn't fail") @@ -1262,10 +1267,11 @@ func TestVRFv2BatchFulfillmentEnabledDisabled(t *testing.T) { CleanupFn: cleanupFn, } newEnvConfig := vrfcommon.NewEnvConfig{ - NodesToCreate: []vrfcommon.VRFNodeType{vrfcommon.VRF}, - NumberOfTxKeysToCreate: 0, - UseVRFOwner: false, - UseTestCoordinator: false, + NodesToCreate: []vrfcommon.VRFNodeType{vrfcommon.VRF}, + NumberOfTxKeysToCreate: 0, + UseVRFOwner: false, + UseTestCoordinator: false, + ChainlinkNodeLogScannerSettings: test_env.DefaultChainlinkNodeLogScannerSettings, } env, vrfContracts, vrfKey, nodeTypeToNodeMap, err = vrfv2.SetupVRFV2Universe(testcontext.Get(t), t, vrfEnvConfig, newEnvConfig, l) require.NoError(t, err, "Error setting up VRFv2 universe") @@ -1394,14 +1400,16 @@ func TestVRFv2BatchFulfillmentEnabledDisabled(t *testing.T) { // verify that VRF node sends fulfillments via BatchCoordinator contract require.Equal(t, vrfContracts.BatchCoordinatorV2.Address(), fulfillmentTXToAddress, "Fulfillment Tx To Address should be the BatchCoordinatorV2 Address when batch fulfillment is enabled") - fulfillmentTxReceipt, err := sethClient.Client.TransactionReceipt(testcontext.Get(t), fulfillmentTx.Hash()) - require.NoError(t, err) - - randomWordsFulfilledLogs, err := contracts.ParseRandomWordsFulfilledLogs(vrfContracts.CoordinatorV2, fulfillmentTxReceipt.Logs) - require.NoError(t, err) - // verify that all fulfillments should be inside one tx - require.Equal(t, int(randRequestCount), len(randomWordsFulfilledLogs)) + // This check is disabled for live testnets since each testnet has different gas usage for similar tx + if network.Simulated { + fulfillmentTxReceipt, err := sethClient.Client.TransactionReceipt(testcontext.Get(t), fulfillmentTx.Hash()) + require.NoError(t, err) + randomWordsFulfilledLogs, err := contracts.ParseRandomWordsFulfilledLogs(vrfContracts.CoordinatorV2, fulfillmentTxReceipt.Logs) + require.NoError(t, err) + require.Equal(t, 1, len(batchFulfillmentTxs)) + require.Equal(t, int(randRequestCount), len(randomWordsFulfilledLogs)) + } }) t.Run("Batch Fulfillment Disabled", func(t *testing.T) { configCopy := config.MustCopy().(tc.TestConfig) diff --git a/integration-tests/smoke/vrfv2plus_test.go b/integration-tests/smoke/vrfv2plus_test.go index 260f30d0e57..d1593373204 100644 --- a/integration-tests/smoke/vrfv2plus_test.go +++ b/integration-tests/smoke/vrfv2plus_test.go @@ -970,6 +970,7 @@ func TestVRFv2PlusMigration(t *testing.T) { BatchFulfillmentGasMultiplier: *configCopy.VRFv2Plus.General.VRFJobBatchFulfillmentGasMultiplier, PollPeriod: configCopy.VRFv2Plus.General.VRFJobPollPeriod.Duration, RequestTimeout: configCopy.VRFv2Plus.General.VRFJobRequestTimeout.Duration, + SimulationBlock: configCopy.VRFv2Plus.General.VRFJobSimulationBlock, } _, err = vrfv2plus.CreateVRFV2PlusJob( @@ -1141,6 +1142,7 @@ func TestVRFv2PlusMigration(t *testing.T) { BatchFulfillmentGasMultiplier: *configCopy.VRFv2Plus.General.VRFJobBatchFulfillmentGasMultiplier, PollPeriod: configCopy.VRFv2Plus.General.VRFJobPollPeriod.Duration, RequestTimeout: configCopy.VRFv2Plus.General.VRFJobRequestTimeout.Duration, + SimulationBlock: configCopy.VRFv2Plus.General.VRFJobSimulationBlock, } _, err = vrfv2plus.CreateVRFV2PlusJob( @@ -1736,42 +1738,12 @@ func TestVRFv2PlusReplayAfterTimeout(t *testing.T) { ) require.NoError(t, err, "error requesting randomness and waiting for requested event") - // 3. create new request in a subscription with balance and wait for fulfilment - fundingLinkAmt := big.NewFloat(*configCopy.VRFv2Plus.General.SubscriptionRefundingAmountLink) - fundingNativeAmt := big.NewFloat(*configCopy.VRFv2Plus.General.SubscriptionRefundingAmountNative) - l.Info(). - Str("Coordinator", vrfContracts.CoordinatorV2Plus.Address()). - Int("Number of Subs to create", 1). - Msg("Creating and funding subscriptions, adding consumers") - fundedSubIDs, err := vrfv2plus.CreateFundSubsAndAddConsumers( - testcontext.Get(t), - env, - chainID, - fundingLinkAmt, - fundingNativeAmt, - vrfContracts.LinkToken, - vrfContracts.CoordinatorV2Plus, - []contracts.VRFv2PlusLoadTestConsumer{consumers[1]}, - 1, - ) - require.NoError(t, err, "error creating funded sub in replay test") - _, randomWordsFulfilledEvent, err := vrfv2plus.RequestRandomnessAndWaitForFulfillment( - consumers[1], - vrfContracts.CoordinatorV2Plus, - vrfKey, - fundedSubIDs[0], - isNativeBilling, - configCopy.VRFv2Plus.General, - l, - 0, - ) - require.NoError(t, err, "error requesting randomness and waiting for fulfilment") - require.True(t, randomWordsFulfilledEvent.Success, "RandomWordsFulfilled Event's `Success` field should be true") - - // 4. wait for the request timeout (1s more) duration + // 3. wait for the request timeout (1s more) duration time.Sleep(timeout + 1*time.Second) - // 5. fund sub so that node can fulfill request + fundingLinkAmt := big.NewFloat(*configCopy.VRFv2Plus.General.SubscriptionRefundingAmountLink) + fundingNativeAmt := big.NewFloat(*configCopy.VRFv2Plus.General.SubscriptionRefundingAmountNative) + // 4. fund sub so that node can fulfill request err = vrfv2plus.FundSubscriptions( fundingLinkAmt, fundingNativeAmt, @@ -1781,12 +1753,12 @@ func TestVRFv2PlusReplayAfterTimeout(t *testing.T) { ) require.NoError(t, err, "error funding subs after request timeout") - // 6. no fulfilment should happen since timeout+1 seconds passed in the job + // 5. no fulfilment should happen since timeout+1 seconds passed in the job pendingReqExists, err := vrfContracts.CoordinatorV2Plus.PendingRequestsExist(testcontext.Get(t), subID) require.NoError(t, err, "error fetching PendingRequestsExist from coordinator") require.True(t, pendingReqExists, "pendingRequest must exist since subID was underfunded till request timeout") - // 7. remove job and add new job with requestTimeout = 1 hour + // 6. remove job and add new job with requestTimeout = 1 hour vrfNode, exists := nodeTypeToNodeMap[vrfcommon.VRF] require.True(t, exists, "VRF Node does not exist") resp, err := vrfNode.CLNode.API.DeleteJob(vrfNode.Job.Data.ID) @@ -1821,7 +1793,7 @@ func TestVRFv2PlusReplayAfterTimeout(t *testing.T) { vrfNode.Job = job }() - // 8. Check if initial req in underfunded sub is fulfilled now, since it has been topped up and timeout increased + // 7. Check if initial req in underfunded sub is fulfilled now, since it has been topped up and timeout increased l.Info().Str("reqID", initialReqRandomWordsRequestedEvent.RequestId.String()). Str("subID", subID.String()). Msg("Waiting for initalReqRandomWordsFulfilledEvent") @@ -1956,8 +1928,11 @@ func TestVRFv2PlusNodeReorg(t *testing.T) { config, err := tc.GetConfig("Smoke", tc.VRFv2Plus) require.NoError(t, err, "Error getting config") vrfv2PlusConfig := config.VRFv2Plus - chainID := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0].ChainID - + network := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0] + if !network.Simulated { + t.Skip("Skipped since Reorg test could only be run on Simulated chain.") + } + chainID := network.ChainID cleanupFn := func() { sethClient, err := env.GetSethClient(chainID) require.NoError(t, err, "Getting Seth client shouldn't fail") @@ -2065,6 +2040,7 @@ func TestVRFv2PlusNodeReorg(t *testing.T) { vrfContracts.CoordinatorV2Plus, randomWordsRequestedEvent.RequestId, subID, + randomWordsRequestedEvent.Raw.BlockNumber, isNativeBilling, configCopy.VRFv2Plus.General.RandomWordsFulfilledEventTimeout.Duration, l, @@ -2135,8 +2111,8 @@ func TestVRFv2PlusBatchFulfillmentEnabledDisabled(t *testing.T) { config, err := tc.GetConfig("Smoke", tc.VRFv2Plus) require.NoError(t, err, "Error getting config") vrfv2PlusConfig := config.VRFv2Plus - chainID := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0].ChainID - + network := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0] + chainID := network.ChainID cleanupFn := func() { sethClient, err := env.GetSethClient(chainID) require.NoError(t, err, "Getting Seth client shouldn't fail") @@ -2276,9 +2252,6 @@ func TestVRFv2PlusBatchFulfillmentEnabledDisabled(t *testing.T) { batchFulfillmentTxs = append(batchFulfillmentTxs, tx) } } - // verify that all fulfillments should be inside one tx - require.Equal(t, 1, len(batchFulfillmentTxs)) - fulfillmentTx, _, err := sethClient.Client.TransactionByHash(testcontext.Get(t), randomWordsFulfilledEvent.Raw.TxHash) require.NoError(t, err, "error getting tx from hash") @@ -2291,14 +2264,16 @@ func TestVRFv2PlusBatchFulfillmentEnabledDisabled(t *testing.T) { // verify that VRF node sends fulfillments via BatchCoordinator contract require.Equal(t, vrfContracts.BatchCoordinatorV2Plus.Address(), fulfillmentTXToAddress, "Fulfillment Tx To Address should be the BatchCoordinatorV2Plus Address when batch fulfillment is enabled") - fulfillmentTxReceipt, err := sethClient.Client.TransactionReceipt(testcontext.Get(t), fulfillmentTx.Hash()) - require.NoError(t, err) - - randomWordsFulfilledLogs, err := contracts.ParseRandomWordsFulfilledLogs(vrfContracts.CoordinatorV2Plus, fulfillmentTxReceipt.Logs) - require.NoError(t, err) - // verify that all fulfillments should be inside one tx - require.Equal(t, int(randRequestCount), len(randomWordsFulfilledLogs)) + // This check is disabled for live testnets since each testnet has different gas usage for similar tx + if network.Simulated { + fulfillmentTxReceipt, err := sethClient.Client.TransactionReceipt(testcontext.Get(t), fulfillmentTx.Hash()) + require.NoError(t, err) + randomWordsFulfilledLogs, err := contracts.ParseRandomWordsFulfilledLogs(vrfContracts.CoordinatorV2Plus, fulfillmentTxReceipt.Logs) + require.NoError(t, err) + require.Equal(t, 1, len(batchFulfillmentTxs)) + require.Equal(t, int(randRequestCount), len(randomWordsFulfilledLogs)) + } }) t.Run("Batch Fulfillment Disabled", func(t *testing.T) { configCopy := config.MustCopy().(tc.TestConfig) diff --git a/integration-tests/testconfig/default.toml b/integration-tests/testconfig/default.toml index f5faf23d4e8..0c8d411c04a 100644 --- a/integration-tests/testconfig/default.toml +++ b/integration-tests/testconfig/default.toml @@ -380,7 +380,7 @@ gas_fee_cap = 30_000_000_000 gas_tip_cap = 1_000_000_000 # how many last blocks to use, when estimating gas for a transaction -gas_price_estimation_blocks = 50 +gas_price_estimation_blocks = 30 # priority of the transaction, can be "fast", "standard" or "slow" (the higher the priority, the higher adjustment factor will be used for gas estimation) [default: "standard"] gas_price_estimation_tx_priority = "standard" diff --git a/integration-tests/testconfig/vrfv2/vrfv2.toml b/integration-tests/testconfig/vrfv2/vrfv2.toml index 3b447a082bf..4ff48a3181a 100644 --- a/integration-tests/testconfig/vrfv2/vrfv2.toml +++ b/integration-tests/testconfig/vrfv2/vrfv2.toml @@ -57,6 +57,7 @@ subscription_refunding_amount_link = 5.0 cl_node_max_gas_price_gwei = 10 link_native_feed_response = 1000000000000000000 +#todo - need to have separate minimum_confirmations config for Coordinator, CL Node and Consumer request minimum_confirmations = 3 number_of_words = 3 diff --git a/integration-tests/testconfig/vrfv2plus/vrfv2plus.toml b/integration-tests/testconfig/vrfv2plus/vrfv2plus.toml index 859945bad9a..717a62e997f 100644 --- a/integration-tests/testconfig/vrfv2plus/vrfv2plus.toml +++ b/integration-tests/testconfig/vrfv2plus/vrfv2plus.toml @@ -63,35 +63,40 @@ chainlink_node_funding = 0.5 [VRFv2Plus.General] cancel_subs_after_test_run = true use_existing_env = false -subscription_funding_amount_link = 5.0 -subscription_funding_amount_native=1 - -subscription_refunding_amount_link = 5.0 -subscription_refunding_amount_native = 1 - -cl_node_max_gas_price_gwei = 10 -link_native_feed_response = 1000000000000000000 +#todo - need to have separate minimum_confirmations config for Coordinator, CL Node and Consumer request minimum_confirmations = 3 +# Can be "LINK", "NATIVE" or "LINK_AND_NATIVE" subscription_billing_type = "LINK_AND_NATIVE" +#CL Node config +cl_node_max_gas_price_gwei = 10 +number_of_sending_keys_to_create = 0 + +# Randomness Request Config +number_of_sub_to_create = 1 number_of_words = 3 callback_gas_limit = 1000000 -max_gas_limit_coordinator_config = 2500000 -fallback_wei_per_unit_link = "60000000000000000" -staleness_seconds = 86400 -gas_after_payment_calculation = 33825 -number_of_sub_to_create = 1 -number_of_sending_keys_to_create = 0 +subscription_funding_amount_link = 5.0 +subscription_funding_amount_native=1 +subscription_refunding_amount_link = 5.0 +subscription_refunding_amount_native = 1 randomness_request_count_per_request = 1 randomness_request_count_per_request_deviation = 0 random_words_fulfilled_event_timeout = "2m" wait_for_256_blocks_timeout = "280s" +# Coordinator config +link_native_feed_response = 1000000000000000000 +max_gas_limit_coordinator_config = 2500000 +fallback_wei_per_unit_link = "60000000000000000" +staleness_seconds = 86400 +gas_after_payment_calculation = 33825 fulfillment_flat_fee_native_ppm=0 fulfillment_flat_fee_link_discount_ppm=0 native_premium_percentage=24 link_premium_percentage=20 +# Wrapper config wrapped_gas_overhead = 50000 coordinator_gas_overhead_native = 52000 coordinator_gas_overhead_link = 74000 @@ -109,6 +114,8 @@ vrf_job_batch_fulfillment_enabled = true vrf_job_batch_fulfillment_gas_multiplier = 1.15 vrf_job_poll_period = "1s" vrf_job_request_timeout = "24h" +# should be "latest" if minimum_confirmations>0, "pending" if minimum_confirmations=0 +vrf_job_simulation_block="latest" # BHS Job config bhs_job_wait_blocks = 30 diff --git a/integration-tests/types/config/node/core.go b/integration-tests/types/config/node/core.go index 1a55e3e38f4..290d3e57dfb 100644 --- a/integration-tests/types/config/node/core.go +++ b/integration-tests/types/config/node/core.go @@ -110,12 +110,6 @@ func WithPrivateEVMs(networks []blockchain.EVMNetwork, commonChainConfig *evmcfg evmConfig.Chain = overriddenChainCfg } } - if evmConfig.Chain.FinalityDepth == nil && network.FinalityDepth > 0 { - evmConfig.Chain.FinalityDepth = ptr.Ptr(uint32(network.FinalityDepth)) - } - if evmConfig.Chain.FinalityTagEnabled == nil && network.FinalityTag { - evmConfig.Chain.FinalityTagEnabled = ptr.Ptr(network.FinalityTag) - } evmConfigs = append(evmConfigs, evmConfig) } return func(c *chainlink.Config) { diff --git a/testdata/scripts/node/validate/warnings.txtar b/testdata/scripts/node/validate/warnings.txtar index cf121e959e1..54de3227a9e 100644 --- a/testdata/scripts/node/validate/warnings.txtar +++ b/testdata/scripts/node/validate/warnings.txtar @@ -9,6 +9,15 @@ CollectorTarget = 'otel-collector:4317' TLSCertPath = 'something' Mode = 'unencrypted' +[[EVM]] +ChainID = '10200' +ChainType = 'xdai' + +[[EVM.Nodes]] +Name = 'fake' +WSURL = 'wss://foo.bar/ws' +HTTPURL = 'https://foo.bar' + -- secrets.toml -- [Database] URL = 'postgresql://user:pass1234567890abcd@localhost:5432/dbname?sslmode=disable' @@ -32,6 +41,15 @@ CollectorTarget = 'otel-collector:4317' Mode = 'unencrypted' TLSCertPath = 'something' +[[EVM]] +ChainID = '10200' +ChainType = 'xdai' + +[[EVM.Nodes]] +Name = 'fake' +WSURL = 'wss://foo.bar/ws' +HTTPURL = 'https://foo.bar' + # Effective Configuration, with defaults applied: InsecureFastScrypt = false RootDir = '~/.chainlink' @@ -284,6 +302,94 @@ DeltaDial = '15s' DeltaReconcile = '1m0s' ListenAddresses = [] +[[EVM]] +ChainID = '10200' +AutoCreateKey = true +BlockBackfillDepth = 10 +BlockBackfillSkip = false +ChainType = 'xdai' +FinalityDepth = 100 +FinalityTagEnabled = false +LogBackfillBatchSize = 1000 +LogPollInterval = '5s' +LogKeepBlocksDepth = 100000 +LogPrunePageSize = 0 +BackupLogPollerBlockDelay = 100 +MinIncomingConfirmations = 3 +MinContractPayment = '0.00001 link' +NonceAutoSync = true +NoNewHeadsThreshold = '3m0s' +RPCDefaultBatchSize = 250 +RPCBlockQueryDelay = 1 + +[EVM.Transactions] +ForwardersEnabled = false +MaxInFlight = 16 +MaxQueued = 250 +ReaperInterval = '1h0m0s' +ReaperThreshold = '168h0m0s' +ResendAfterThreshold = '1m0s' + +[EVM.BalanceMonitor] +Enabled = true + +[EVM.GasEstimator] +Mode = 'BlockHistory' +PriceDefault = '20 gwei' +PriceMax = '500 gwei' +PriceMin = '1 gwei' +LimitDefault = 500000 +LimitMax = 500000 +LimitMultiplier = '1' +LimitTransfer = 21000 +BumpMin = '5 gwei' +BumpPercent = 20 +BumpThreshold = 3 +EIP1559DynamicFees = true +FeeCapDefault = '100 gwei' +TipCapDefault = '1 wei' +TipCapMin = '1 wei' + +[EVM.GasEstimator.BlockHistory] +BatchSize = 25 +BlockHistorySize = 8 +CheckInclusionBlocks = 12 +CheckInclusionPercentile = 90 +TransactionPercentile = 60 + +[EVM.HeadTracker] +HistoryDepth = 100 +MaxBufferSize = 3 +SamplingInterval = '1s' + +[EVM.NodePool] +PollFailureThreshold = 5 +PollInterval = '10s' +SelectionMode = 'HighestHead' +SyncThreshold = 5 +LeaseDuration = '0s' +NodeIsSyncingEnabled = false +FinalizedBlockPollInterval = '5s' + +[EVM.OCR] +ContractConfirmations = 4 +ContractTransmitterTransmitTimeout = '10s' +DatabaseTimeout = '10s' +DeltaCOverride = '168h0m0s' +DeltaCJitterOverride = '1h0m0s' +ObservationGracePeriod = '1s' + +[EVM.OCR2] +[EVM.OCR2.Automation] +GasLimit = 5400000 + +[[EVM.Nodes]] +Name = 'fake' +WSURL = 'wss://foo.bar/ws' +HTTPURL = 'https://foo.bar' + # Configuration warning: -Tracing.TLSCertPath: invalid value (something): must be empty when Tracing.Mode is 'unencrypted' -Valid configuration. +2 errors: + - EVM.ChainType: invalid value (xdai): deprecated and will be removed in v2.13.0, use 'gnosis' instead + - Tracing.TLSCertPath: invalid value (something): must be empty when Tracing.Mode is 'unencrypted' +Valid configuration. \ No newline at end of file