From 6ff0ecdef2ceb9864fefd423e4e7f9611f0f8795 Mon Sep 17 00:00:00 2001 From: Ilja Pavlovs Date: Wed, 22 May 2024 19:42:04 +0300 Subject: [PATCH 01/16] =?UTF-8?q?VRF-1054:=20VRF=20E2E=20tests=20-=20add?= =?UTF-8?q?=20additional=20way=20to=20catch=20rand=20fulfilment=E2=80=A6?= =?UTF-8?q?=20(#13166)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * VRF-1054: VRF E2E tests - add additional way to catch rand fulfilment event in case if chain is very fast * TT-1102: fixing fund return * VRF-1054: resolving conflicts --- .../actions/vrf/common/errors.go | 5 +- .../actions/vrf/vrfv2/contract_steps.go | 17 +++++- .../actions/vrf/vrfv2plus/contract_steps.go | 17 +++++- .../contracts/contract_vrf_models.go | 4 ++ .../contracts/ethereum_vrf_common.go | 2 + .../contracts/ethereum_vrfv2_contracts.go | 22 ++++++- .../contracts/ethereum_vrfv2plus_contracts.go | 58 ++++++++++++++----- integration-tests/smoke/vrfv2_test.go | 1 + integration-tests/smoke/vrfv2plus_test.go | 1 + 9 files changed, 105 insertions(+), 22 deletions(-) diff --git a/integration-tests/actions/vrf/common/errors.go b/integration-tests/actions/vrf/common/errors.go index 62164d0b274..78b7457e29f 100644 --- a/integration-tests/actions/vrf/common/errors.go +++ b/integration-tests/actions/vrf/common/errors.go @@ -25,6 +25,7 @@ const ( ErrLoadingCoordinator = "error loading coordinator contract" ErrCreatingVRFKey = "error creating VRF key" - ErrWaitRandomWordsRequestedEvent = "error waiting for RandomWordsRequested event" - ErrWaitRandomWordsFulfilledEvent = "error waiting for RandomWordsFulfilled event" + ErrWaitRandomWordsRequestedEvent = "error waiting for RandomWordsRequested event" + ErrWaitRandomWordsFulfilledEvent = "error waiting for RandomWordsFulfilled event" + ErrFilterRandomWordsFulfilledEvent = "error filtering RandomWordsFulfilled event" ) diff --git a/integration-tests/actions/vrf/vrfv2/contract_steps.go b/integration-tests/actions/vrf/vrfv2/contract_steps.go index 4f0ac9a5b6c..92da7b9f86b 100644 --- a/integration-tests/actions/vrf/vrfv2/contract_steps.go +++ b/integration-tests/actions/vrf/vrfv2/contract_steps.go @@ -6,6 +6,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/rs/zerolog" "github.com/shopspring/decimal" @@ -392,6 +393,7 @@ func DirectFundingRequestRandomnessAndWaitForFulfillment( fulfillmentEvents, err := WaitRandomWordsFulfilledEvent( coordinator, randomWordsRequestedEvent.RequestId, + randomWordsRequestedEvent.Raw.BlockNumber, randomWordsFulfilledEventTimeout, l, ) @@ -431,6 +433,7 @@ func RequestRandomnessAndWaitForFulfillment( randomWordsFulfilledEvent, err := WaitRandomWordsFulfilledEvent( coordinator, randomWordsRequestedEvent.RequestId, + randomWordsRequestedEvent.Raw.BlockNumber, randomWordsFulfilledEventTimeout, l, ) @@ -582,6 +585,7 @@ func RequestRandomnessWithForceFulfillAndWaitForFulfillment( func WaitRandomWordsFulfilledEvent( coordinator contracts.Coordinator, requestId *big.Int, + randomWordsRequestedEventBlockNumber uint64, randomWordsFulfilledEventTimeout time.Duration, l zerolog.Logger, ) (*contracts.CoordinatorRandomWordsFulfilled, error) { @@ -592,7 +596,18 @@ func WaitRandomWordsFulfilledEvent( }, ) if err != nil { - return nil, fmt.Errorf("%s, err %w", vrfcommon.ErrWaitRandomWordsFulfilledEvent, err) + l.Warn(). + Str("requestID", requestId.String()). + Err(err).Msg("Error waiting for random words fulfilled event, trying to filter for the event") + randomWordsFulfilledEvent, err = coordinator.FilterRandomWordsFulfilledEvent( + &bind.FilterOpts{ + Start: randomWordsRequestedEventBlockNumber, + }, + requestId, + ) + if err != nil { + return nil, fmt.Errorf(vrfcommon.ErrGenericFormat, vrfcommon.ErrFilterRandomWordsFulfilledEvent, err) + } } vrfcommon.LogRandomWordsFulfilledEvent(l, coordinator, randomWordsFulfilledEvent, false, 0) return randomWordsFulfilledEvent, err diff --git a/integration-tests/actions/vrf/vrfv2plus/contract_steps.go b/integration-tests/actions/vrf/vrfv2plus/contract_steps.go index 6151931d566..ed6139c4f17 100644 --- a/integration-tests/actions/vrf/vrfv2plus/contract_steps.go +++ b/integration-tests/actions/vrf/vrfv2plus/contract_steps.go @@ -6,6 +6,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/rs/zerolog" "github.com/shopspring/decimal" @@ -356,6 +357,7 @@ func RequestRandomnessAndWaitForFulfillment( coordinator, randomWordsRequestedEvent.RequestId, subID, + randomWordsRequestedEvent.Raw.BlockNumber, isNativeBilling, config.RandomWordsFulfilledEventTimeout.Duration, l, @@ -449,6 +451,7 @@ func DirectFundingRequestRandomnessAndWaitForFulfillment( coordinator, randomWordsRequestedEvent.RequestId, subID, + randomWordsRequestedEvent.Raw.BlockNumber, isNativeBilling, config.RandomWordsFulfilledEventTimeout.Duration, l, @@ -460,6 +463,7 @@ func WaitRandomWordsFulfilledEvent( coordinator contracts.Coordinator, requestId *big.Int, subID *big.Int, + randomWordsRequestedEventBlockNumber uint64, isNativeBilling bool, randomWordsFulfilledEventTimeout time.Duration, l zerolog.Logger, @@ -473,7 +477,18 @@ func WaitRandomWordsFulfilledEvent( }, ) if err != nil { - return nil, fmt.Errorf(vrfcommon.ErrGenericFormat, vrfcommon.ErrWaitRandomWordsFulfilledEvent, err) + l.Warn(). + Str("requestID", requestId.String()). + Err(err).Msg("Error waiting for random words fulfilled event, trying to filter for the event") + randomWordsFulfilledEvent, err = coordinator.FilterRandomWordsFulfilledEvent( + &bind.FilterOpts{ + Start: randomWordsRequestedEventBlockNumber, + }, + requestId, + ) + if err != nil { + return nil, fmt.Errorf(vrfcommon.ErrGenericFormat, vrfcommon.ErrFilterRandomWordsFulfilledEvent, err) + } } vrfcommon.LogRandomWordsFulfilledEvent(l, coordinator, randomWordsFulfilledEvent, isNativeBilling, keyNum) diff --git a/integration-tests/contracts/contract_vrf_models.go b/integration-tests/contracts/contract_vrf_models.go index 9ed08048998..c30eadd3d3d 100644 --- a/integration-tests/contracts/contract_vrf_models.go +++ b/integration-tests/contracts/contract_vrf_models.go @@ -5,6 +5,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -72,6 +73,7 @@ type VRFCoordinatorV2 interface { ParseRandomWordsFulfilled(log types.Log) (*CoordinatorRandomWordsFulfilled, error) ParseLog(log types.Log) (generated.AbigenLog, error) FindSubscriptionID(subID uint64) (uint64, error) + FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) WaitForConfigSetEvent(timeout time.Duration) (*CoordinatorConfigSet, error) OracleWithdraw(recipient common.Address, amount *big.Int) error @@ -115,6 +117,7 @@ type VRFCoordinatorV2_5 interface { GetNativeTokenTotalBalance(ctx context.Context) (*big.Int, error) GetLinkTotalBalance(ctx context.Context) (*big.Int, error) FindSubscriptionID(subID *big.Int) (*big.Int, error) + FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) ParseRandomWordsRequested(log types.Log) (*CoordinatorRandomWordsRequested, error) ParseRandomWordsFulfilled(log types.Log) (*CoordinatorRandomWordsFulfilled, error) @@ -154,6 +157,7 @@ type VRFCoordinatorV2PlusUpgradedVersion interface { GetSubscription(ctx context.Context, subID *big.Int) (vrf_v2plus_upgraded_version.GetSubscription, error) GetActiveSubscriptionIds(ctx context.Context, startIndex *big.Int, maxCount *big.Int) ([]*big.Int, error) FindSubscriptionID() (*big.Int, error) + FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) ParseRandomWordsRequested(log types.Log) (*CoordinatorRandomWordsRequested, error) ParseRandomWordsFulfilled(log types.Log) (*CoordinatorRandomWordsFulfilled, error) diff --git a/integration-tests/contracts/ethereum_vrf_common.go b/integration-tests/contracts/ethereum_vrf_common.go index f0498b6efe6..62a1809cfa6 100644 --- a/integration-tests/contracts/ethereum_vrf_common.go +++ b/integration-tests/contracts/ethereum_vrf_common.go @@ -5,6 +5,7 @@ import ( "math/big" "time" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -18,6 +19,7 @@ type Coordinator interface { Address() string WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) WaitForConfigSetEvent(timeout time.Duration) (*CoordinatorConfigSet, error) + FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) } type Subscription struct { diff --git a/integration-tests/contracts/ethereum_vrfv2_contracts.go b/integration-tests/contracts/ethereum_vrfv2_contracts.go index 9a8ab1c5b5b..be588ea3e3a 100644 --- a/integration-tests/contracts/ethereum_vrfv2_contracts.go +++ b/integration-tests/contracts/ethereum_vrfv2_contracts.go @@ -629,11 +629,9 @@ func (v *EthereumVRFCoordinatorV2) FindSubscriptionID(subID uint64) (uint64, err if err != nil { return 0, err } - if !subscriptionIterator.Next() { return 0, fmt.Errorf("expected at least 1 subID for the given owner %s", owner) } - return subscriptionIterator.Event.SubId, nil } @@ -649,6 +647,26 @@ func (v *EthereumVRFCoordinatorV2) GetBlockHashStoreAddress(ctx context.Context) return blockHashStoreAddress, nil } +func (v *EthereumVRFCoordinatorV2) FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) { + iterator, err := v.coordinator.FilterRandomWordsFulfilled( + opts, + []*big.Int{requestId}, + ) + if err != nil { + return nil, err + } + if !iterator.Next() { + return nil, fmt.Errorf("expected at least 1 RandomWordsFulfilled event for request Id: %s", requestId.String()) + } + return &CoordinatorRandomWordsFulfilled{ + RequestId: iterator.Event.RequestId, + OutputSeed: iterator.Event.OutputSeed, + Payment: iterator.Event.Payment, + Success: iterator.Event.Success, + Raw: iterator.Event.Raw, + }, nil +} + func (v *EthereumVRFCoordinatorV2) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) { randomWordsFulfilledEventsChannel := make(chan *vrf_coordinator_v2.VRFCoordinatorV2RandomWordsFulfilled) subscription, err := v.coordinator.WatchRandomWordsFulfilled(nil, randomWordsFulfilledEventsChannel, filter.RequestIds) diff --git a/integration-tests/contracts/ethereum_vrfv2plus_contracts.go b/integration-tests/contracts/ethereum_vrfv2plus_contracts.go index ba7234fbbf3..882baafcd19 100644 --- a/integration-tests/contracts/ethereum_vrfv2plus_contracts.go +++ b/integration-tests/contracts/ethereum_vrfv2plus_contracts.go @@ -480,24 +480,28 @@ func (v *EthereumVRFCoordinatorV2_5) FindSubscriptionID(subID *big.Int) (*big.In return subscriptionIterator.Event.SubId, nil } -func (v *EthereumVRFCoordinatorV2_5) WaitForSubscriptionCanceledEvent(subID *big.Int, timeout time.Duration) (*vrf_coordinator_v2_5.VRFCoordinatorV25SubscriptionCanceled, error) { - eventsChannel := make(chan *vrf_coordinator_v2_5.VRFCoordinatorV25SubscriptionCanceled) - subscription, err := v.coordinator.WatchSubscriptionCanceled(nil, eventsChannel, []*big.Int{subID}) +func (v *EthereumVRFCoordinatorV2_5) FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) { + iterator, err := v.coordinator.FilterRandomWordsFulfilled( + opts, + []*big.Int{requestId}, + nil, + ) if err != nil { return nil, err } - defer subscription.Unsubscribe() - - for { - select { - case err := <-subscription.Err(): - return nil, err - case <-time.After(timeout): - return nil, fmt.Errorf("timeout waiting for SubscriptionCanceled event") - case sub := <-eventsChannel: - return sub, nil - } + if !iterator.Next() { + return nil, fmt.Errorf("expected at least 1 RandomWordsFulfilled event for request Id: %s", requestId.String()) } + return &CoordinatorRandomWordsFulfilled{ + RequestId: iterator.Event.RequestId, + OutputSeed: iterator.Event.OutputSeed, + SubId: iterator.Event.SubId.String(), + Payment: iterator.Event.Payment, + NativePayment: iterator.Event.NativePayment, + Success: iterator.Event.Success, + OnlyPremium: iterator.Event.OnlyPremium, + Raw: iterator.Event.Raw, + }, nil } func (v *EthereumVRFCoordinatorV2_5) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) { @@ -902,14 +906,36 @@ func (v *EthereumVRFCoordinatorV2PlusUpgradedVersion) FindSubscriptionID() (*big if err != nil { return nil, err } - if !subscriptionIterator.Next() { return nil, fmt.Errorf("expected at least 1 subID for the given owner %s", owner) } - return subscriptionIterator.Event.SubId, nil } +func (v *EthereumVRFCoordinatorV2PlusUpgradedVersion) FilterRandomWordsFulfilledEvent(opts *bind.FilterOpts, requestId *big.Int) (*CoordinatorRandomWordsFulfilled, error) { + iterator, err := v.coordinator.FilterRandomWordsFulfilled( + opts, + []*big.Int{requestId}, + nil, + ) + if err != nil { + return nil, err + } + if !iterator.Next() { + return nil, fmt.Errorf("expected at least 1 RandomWordsFulfilled event for request Id: %s", requestId.String()) + } + return &CoordinatorRandomWordsFulfilled{ + RequestId: iterator.Event.RequestId, + OutputSeed: iterator.Event.OutputSeed, + SubId: iterator.Event.SubId.String(), + Payment: iterator.Event.Payment, + NativePayment: iterator.Event.NativePayment, + Success: iterator.Event.Success, + OnlyPremium: iterator.Event.OnlyPremium, + Raw: iterator.Event.Raw, + }, nil +} + func (v *EthereumVRFCoordinatorV2PlusUpgradedVersion) WaitForRandomWordsFulfilledEvent(filter RandomWordsFulfilledEventFilter) (*CoordinatorRandomWordsFulfilled, error) { randomWordsFulfilledEventsChannel := make(chan *vrf_v2plus_upgraded_version.VRFCoordinatorV2PlusUpgradedVersionRandomWordsFulfilled) subscription, err := v.coordinator.WatchRandomWordsFulfilled(nil, randomWordsFulfilledEventsChannel, filter.RequestIds, filter.SubIDs) diff --git a/integration-tests/smoke/vrfv2_test.go b/integration-tests/smoke/vrfv2_test.go index cc37ad983ce..88bf2ac110a 100644 --- a/integration-tests/smoke/vrfv2_test.go +++ b/integration-tests/smoke/vrfv2_test.go @@ -1165,6 +1165,7 @@ func TestVRFV2NodeReorg(t *testing.T) { _, err = vrfv2.WaitRandomWordsFulfilledEvent( vrfContracts.CoordinatorV2, randomWordsRequestedEvent.RequestId, + randomWordsRequestedEvent.Raw.BlockNumber, configCopy.VRFv2.General.RandomWordsFulfilledEventTimeout.Duration, l, ) diff --git a/integration-tests/smoke/vrfv2plus_test.go b/integration-tests/smoke/vrfv2plus_test.go index 260f30d0e57..c17a09dcf7c 100644 --- a/integration-tests/smoke/vrfv2plus_test.go +++ b/integration-tests/smoke/vrfv2plus_test.go @@ -2065,6 +2065,7 @@ func TestVRFv2PlusNodeReorg(t *testing.T) { vrfContracts.CoordinatorV2Plus, randomWordsRequestedEvent.RequestId, subID, + randomWordsRequestedEvent.Raw.BlockNumber, isNativeBilling, configCopy.VRFv2Plus.General.RandomWordsFulfilledEventTimeout.Duration, l, From 6625266cc60b0a3f2e60697eb01c5be3f4a09c5f Mon Sep 17 00:00:00 2001 From: Bartek Tofel Date: Wed, 22 May 2024 18:48:28 +0200 Subject: [PATCH 02/16] do not set programmatically finality tag (#13277) Co-authored-by: Ilja Pavlovs --- integration-tests/types/config/node/core.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/integration-tests/types/config/node/core.go b/integration-tests/types/config/node/core.go index 1a55e3e38f4..290d3e57dfb 100644 --- a/integration-tests/types/config/node/core.go +++ b/integration-tests/types/config/node/core.go @@ -110,12 +110,6 @@ func WithPrivateEVMs(networks []blockchain.EVMNetwork, commonChainConfig *evmcfg evmConfig.Chain = overriddenChainCfg } } - if evmConfig.Chain.FinalityDepth == nil && network.FinalityDepth > 0 { - evmConfig.Chain.FinalityDepth = ptr.Ptr(uint32(network.FinalityDepth)) - } - if evmConfig.Chain.FinalityTagEnabled == nil && network.FinalityTag { - evmConfig.Chain.FinalityTagEnabled = ptr.Ptr(network.FinalityTag) - } evmConfigs = append(evmConfigs, evmConfig) } return func(c *chainlink.Config) { From 677abe19d1fc15e6fa248271eab0aacb2ae4c0b7 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Wed, 22 May 2024 12:12:25 -0500 Subject: [PATCH 03/16] golangci-lint: enable containedctx (#13171) --- .golangci.yml | 1 + common/client/node.go | 9 +- common/client/node_lifecycle.go | 55 ++++--- common/txmgr/confirmer.go | 21 +-- common/txmgr/mocks/tx_manager.go | 36 ++--- common/txmgr/resender.go | 31 ++-- common/txmgr/test_helpers.go | 8 +- common/txmgr/txmgr.go | 16 +- common/txmgr/types/forwarder_manager.go | 7 +- common/txmgr/types/mocks/forwarder_manager.go | 36 ++--- core/chains/evm/client/node.go | 10 +- core/chains/evm/client/node_lifecycle.go | 45 +++--- .../evm/forwarders/forwarder_manager.go | 45 +++--- .../evm/forwarders/forwarder_manager_test.go | 11 +- .../chains/evm/gas/block_history_estimator.go | 41 +++-- core/chains/evm/logpoller/log_poller.go | 55 +++---- .../evm/logpoller/log_poller_internal_test.go | 39 +++-- core/chains/evm/txmgr/evm_tx_store.go | 142 ++++++++---------- core/services/blockhashstore/delegate.go | 23 +-- core/services/blockheaderfeeder/delegate.go | 19 +-- core/services/feeds/connection_manager.go | 26 ++-- core/services/functions/listener.go | 23 +-- core/services/keeper/delegate.go | 2 +- core/services/keeper/integration_test.go | 2 +- core/services/ocr/config_overrider.go | 15 +- core/services/ocr/delegate.go | 2 +- core/services/ocr2/delegate.go | 8 +- core/services/ocr2/delegate_test.go | 16 +- .../ocr2keeper/evmregistry/v20/registry.go | 71 +++++---- .../evmregistry/v20/registry_test.go | 3 +- .../ocr2keeper/evmregistry/v21/registry.go | 74 ++++----- .../v21/registry_check_pipeline.go | 16 +- .../v21/registry_check_pipeline_test.go | 4 +- .../evmregistry/v21/registry_test.go | 8 +- .../plugins/ocr2keeper/integration_test.go | 5 +- core/services/pipeline/helpers_test.go | 3 +- core/services/pipeline/orm.go | 37 ++--- core/services/pipeline/orm_test.go | 6 +- core/services/pipeline/task.eth_tx.go | 2 +- core/services/relay/evm/evm.go | 13 +- .../relay/evm/request_round_tracker.go | 9 +- .../vrf/v2/listener_v2_log_listener_test.go | 70 +++++---- core/utils/thread_control_test.go | 5 +- core/web/resolver/api_token_test.go | 33 ++-- core/web/resolver/bridge_test.go | 19 +-- core/web/resolver/chain_test.go | 9 +- core/web/resolver/config_test.go | 7 +- core/web/resolver/csa_keys_test.go | 11 +- core/web/resolver/eth_key_test.go | 17 ++- core/web/resolver/eth_transaction_test.go | 19 +-- core/web/resolver/features_test.go | 3 +- .../feeds_manager_chain_config_test.go | 19 +-- core/web/resolver/feeds_manager_test.go | 17 ++- core/web/resolver/job_error_test.go | 13 +- core/web/resolver/job_proposal_spec_test.go | 29 ++-- core/web/resolver/job_proposal_test.go | 5 +- core/web/resolver/job_run_test.go | 19 +-- core/web/resolver/job_test.go | 23 +-- core/web/resolver/log_test.go | 11 +- core/web/resolver/node_test.go | 9 +- core/web/resolver/ocr2_keys_test.go | 15 +- core/web/resolver/ocr_test.go | 9 +- core/web/resolver/p2p_test.go | 9 +- core/web/resolver/resolver_test.go | 26 ++-- core/web/resolver/solana_key_test.go | 5 +- core/web/resolver/spec_test.go | 29 ++-- core/web/resolver/user_test.go | 17 ++- core/web/resolver/vrf_test.go | 13 +- 68 files changed, 762 insertions(+), 694 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 96a7de282e0..3834400ba67 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -15,6 +15,7 @@ linters: - noctx - depguard - whitespace + - containedctx linters-settings: exhaustive: default-signifies-exhaustive: true diff --git a/common/client/node.go b/common/client/node.go index 6450b086f10..bee1f3da3b8 100644 --- a/common/client/node.go +++ b/common/client/node.go @@ -106,10 +106,7 @@ type node[ stateLatestTotalDifficulty *big.Int stateLatestFinalizedBlockNumber int64 - // nodeCtx is the node lifetime's context - nodeCtx context.Context - // cancelNodeCtx cancels nodeCtx when stopping the node - cancelNodeCtx context.CancelFunc + stopCh services.StopChan // wg waits for subsidiary goroutines wg sync.WaitGroup @@ -148,7 +145,7 @@ func NewNode[ if httpuri != nil { n.http = httpuri } - n.nodeCtx, n.cancelNodeCtx = context.WithCancel(context.Background()) + n.stopCh = make(services.StopChan) lggr = logger.Named(lggr, "Node") lggr = logger.With(lggr, "nodeTier", Primary.String(), @@ -205,7 +202,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) close() error { n.stateMu.Lock() defer n.stateMu.Unlock() - n.cancelNodeCtx() + close(n.stopCh) n.state = nodeStateClosed return nil } diff --git a/common/client/node_lifecycle.go b/common/client/node_lifecycle.go index fa6397580c8..5947774e202 100644 --- a/common/client/node_lifecycle.go +++ b/common/client/node_lifecycle.go @@ -79,6 +79,8 @@ const rpcSubscriptionMethodNewHeads = "newHeads" // Should only be run ONCE per node, after a successful Dial func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -100,7 +102,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { lggr.Tracew("Alive loop starting", "nodeState", n.State()) headsC := make(chan HEAD) - sub, err := n.rpc.Subscribe(n.nodeCtx, headsC, rpcSubscriptionMethodNewHeads) + sub, err := n.rpc.Subscribe(ctx, headsC, rpcSubscriptionMethodNewHeads) if err != nil { lggr.Errorw("Initial subscribe for heads failed", "nodeState", n.State()) n.declareUnreachable() @@ -151,15 +153,16 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-pollCh: - var version string promPoolRPCNodePolls.WithLabelValues(n.chainID.String(), n.name).Inc() lggr.Tracew("Polling for version", "nodeState", n.State(), "pollFailures", pollFailures) - ctx, cancel := context.WithTimeout(n.nodeCtx, pollInterval) - version, err := n.RPC().ClientVersion(ctx) - cancel() + version, err := func(ctx context.Context) (string, error) { + ctx, cancel := context.WithTimeout(ctx, pollInterval) + defer cancel() + return n.RPC().ClientVersion(ctx) + }(ctx) if err != nil { // prevent overflow if pollFailures < math.MaxUint32 { @@ -240,9 +243,11 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { n.declareOutOfSync(func(num int64, td *big.Int) bool { return num < highestReceivedBlockNumber }) return case <-pollFinalizedHeadCh: - ctx, cancel := context.WithTimeout(n.nodeCtx, n.nodePoolCfg.FinalizedBlockPollInterval()) - latestFinalized, err := n.RPC().LatestFinalizedBlock(ctx) - cancel() + latestFinalized, err := func(ctx context.Context) (HEAD, error) { + ctx, cancel := context.WithTimeout(ctx, n.nodePoolCfg.FinalizedBlockPollInterval()) + defer cancel() + return n.RPC().LatestFinalizedBlock(ctx) + }(ctx) if err != nil { lggr.Warnw("Failed to fetch latest finalized block", "err", err) continue @@ -300,6 +305,8 @@ const ( // outOfSyncLoop takes an OutOfSync node and waits until isOutOfSync returns false to go back to live status func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -319,7 +326,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td lggr.Debugw("Trying to revive out-of-sync RPC node", "nodeState", n.State()) // Need to redial since out-of-sync nodes are automatically disconnected - state := n.createVerifiedConn(n.nodeCtx, lggr) + state := n.createVerifiedConn(ctx, lggr) if state != nodeStateAlive { n.declareState(state) return @@ -328,7 +335,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td lggr.Tracew("Successfully subscribed to heads feed on out-of-sync RPC node", "nodeState", n.State()) ch := make(chan HEAD) - sub, err := n.rpc.Subscribe(n.nodeCtx, ch, rpcSubscriptionMethodNewHeads) + sub, err := n.rpc.Subscribe(ctx, ch, rpcSubscriptionMethodNewHeads) if err != nil { lggr.Errorw("Failed to subscribe heads on out-of-sync RPC node", "nodeState", n.State(), "err", err) n.declareUnreachable() @@ -338,7 +345,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case head, open := <-ch: if !open { @@ -372,6 +379,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -394,12 +403,12 @@ func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(dialRetryBackoff.Duration()): lggr.Tracew("Trying to re-dial RPC node", "nodeState", n.State()) - err := n.rpc.Dial(n.nodeCtx) + err := n.rpc.Dial(ctx) if err != nil { lggr.Errorw(fmt.Sprintf("Failed to redial RPC node; still unreachable: %v", err), "err", err, "nodeState", n.State()) continue @@ -407,7 +416,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { n.setState(nodeStateDialed) - state := n.verifyConn(n.nodeCtx, lggr) + state := n.verifyConn(ctx, lggr) switch state { case nodeStateUnreachable: n.setState(nodeStateUnreachable) @@ -425,6 +434,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -443,7 +454,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { lggr := logger.Named(n.lfcLog, "InvalidChainID") // Need to redial since invalid chain ID nodes are automatically disconnected - state := n.createVerifiedConn(n.nodeCtx, lggr) + state := n.createVerifiedConn(ctx, lggr) if state != nodeStateInvalidChainID { n.declareState(state) return @@ -455,10 +466,10 @@ func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(chainIDRecheckBackoff.Duration()): - state := n.verifyConn(n.nodeCtx, lggr) + state := n.verifyConn(ctx, lggr) switch state { case nodeStateInvalidChainID: continue @@ -475,6 +486,8 @@ func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { func (n *node[CHAIN_ID, HEAD, RPC]) syncingLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -493,7 +506,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) syncingLoop() { lggr := logger.Sugared(logger.Named(n.lfcLog, "Syncing")) lggr.Debugw(fmt.Sprintf("Periodically re-checking RPC node %s with syncing status", n.String()), "nodeState", n.State()) // Need to redial since syncing nodes are automatically disconnected - state := n.createVerifiedConn(n.nodeCtx, lggr) + state := n.createVerifiedConn(ctx, lggr) if state != nodeStateSyncing { n.declareState(state) return @@ -503,11 +516,11 @@ func (n *node[CHAIN_ID, HEAD, RPC]) syncingLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(recheckBackoff.Duration()): lggr.Tracew("Trying to recheck if the node is still syncing", "nodeState", n.State()) - isSyncing, err := n.rpc.IsSyncing(n.nodeCtx) + isSyncing, err := n.rpc.IsSyncing(ctx) if err != nil { lggr.Errorw("Unexpected error while verifying RPC node synchronization status", "err", err, "nodeState", n.State()) n.declareUnreachable() diff --git a/common/txmgr/confirmer.go b/common/txmgr/confirmer.go index dd98df0a8fe..30fbbc48987 100644 --- a/common/txmgr/confirmer.go +++ b/common/txmgr/confirmer.go @@ -133,8 +133,7 @@ type Confirmer[ enabledAddresses []ADDR mb *mailbox.Mailbox[HEAD] - ctx context.Context - ctxCancel context.CancelFunc + stopCh services.StopChan wg sync.WaitGroup initSync sync.Mutex isStarted bool @@ -207,7 +206,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) sta return fmt.Errorf("Confirmer: failed to load EnabledAddressesForChain: %w", err) } - ec.ctx, ec.ctxCancel = context.WithCancel(context.Background()) + ec.stopCh = make(chan struct{}) ec.wg = sync.WaitGroup{} ec.wg.Add(1) go ec.runLoop() @@ -228,7 +227,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) clo if !ec.isStarted { return fmt.Errorf("Confirmer is not started: %w", services.ErrAlreadyStopped) } - ec.ctxCancel() + close(ec.stopCh) ec.wg.Wait() ec.isStarted = false return nil @@ -248,23 +247,25 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Hea func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() { defer ec.wg.Done() + ctx, cancel := ec.stopCh.NewCtx() + defer cancel() for { select { case <-ec.mb.Notify(): for { - if ec.ctx.Err() != nil { + if ctx.Err() != nil { return } head, exists := ec.mb.Retrieve() if !exists { break } - if err := ec.ProcessHead(ec.ctx, head); err != nil { + if err := ec.ProcessHead(ctx, head); err != nil { ec.lggr.Errorw("Error processing head", "err", err) continue } } - case <-ec.ctx.Done(): + case <-ctx.Done(): return } } @@ -940,7 +941,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Ens for _, etx := range etxs { if !hasReceiptInLongestChain(*etx, head) { - if err := ec.markForRebroadcast(*etx, head); err != nil { + if err := ec.markForRebroadcast(ctx, *etx, head); err != nil { return fmt.Errorf("markForRebroadcast failed for etx %v: %w", etx.ID, err) } } @@ -992,7 +993,7 @@ func hasReceiptInLongestChain[ } } -func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) markForRebroadcast(etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], head types.Head[BLOCK_HASH]) error { +func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) markForRebroadcast(ctx context.Context, etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], head types.Head[BLOCK_HASH]) error { if len(etx.TxAttempts) == 0 { return fmt.Errorf("invariant violation: expected tx %v to have at least one attempt", etx.ID) } @@ -1027,7 +1028,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) mar ec.lggr.Infow(fmt.Sprintf("Re-org detected. Rebroadcasting transaction %s which may have been re-org'd out of the main chain", attempt.Hash.String()), logValues...) // Put it back in progress and delete all receipts (they do not apply to the new chain) - if err := ec.txStore.UpdateTxForRebroadcast(ec.ctx, etx, attempt); err != nil { + if err := ec.txStore.UpdateTxForRebroadcast(ctx, etx, attempt); err != nil { return fmt.Errorf("markForRebroadcast failed: %w", err) } diff --git a/common/txmgr/mocks/tx_manager.go b/common/txmgr/mocks/tx_manager.go index a3e8c489314..974fd455903 100644 --- a/common/txmgr/mocks/tx_manager.go +++ b/common/txmgr/mocks/tx_manager.go @@ -273,9 +273,9 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTx return r0, r1 } -// GetForwarderForEOA provides a mock function with given fields: eoa -func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(eoa ADDR) (ADDR, error) { - ret := _m.Called(eoa) +// GetForwarderForEOA provides a mock function with given fields: ctx, eoa +func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(ctx context.Context, eoa ADDR) (ADDR, error) { + ret := _m.Called(ctx, eoa) if len(ret) == 0 { panic("no return value specified for GetForwarderForEOA") @@ -283,17 +283,17 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR) (ADDR, error)); ok { - return rf(eoa) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) (ADDR, error)); ok { + return rf(ctx, eoa) } - if rf, ok := ret.Get(0).(func(ADDR) ADDR); ok { - r0 = rf(eoa) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) ADDR); ok { + r0 = rf(ctx, eoa) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR) error); ok { - r1 = rf(eoa) + if rf, ok := ret.Get(1).(func(context.Context, ADDR) error); ok { + r1 = rf(ctx, eoa) } else { r1 = ret.Error(1) } @@ -301,9 +301,9 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor return r0, r1 } -// GetForwarderForEOAOCR2Feeds provides a mock function with given fields: eoa, ocr2AggregatorID -func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(eoa ADDR, ocr2AggregatorID ADDR) (ADDR, error) { - ret := _m.Called(eoa, ocr2AggregatorID) +// GetForwarderForEOAOCR2Feeds provides a mock function with given fields: ctx, eoa, ocr2AggregatorID +func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa ADDR, ocr2AggregatorID ADDR) (ADDR, error) { + ret := _m.Called(ctx, eoa, ocr2AggregatorID) if len(ret) == 0 { panic("no return value specified for GetForwarderForEOAOCR2Feeds") @@ -311,17 +311,17 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR, ADDR) (ADDR, error)); ok { - return rf(eoa, ocr2AggregatorID) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) (ADDR, error)); ok { + return rf(ctx, eoa, ocr2AggregatorID) } - if rf, ok := ret.Get(0).(func(ADDR, ADDR) ADDR); ok { - r0 = rf(eoa, ocr2AggregatorID) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) ADDR); ok { + r0 = rf(ctx, eoa, ocr2AggregatorID) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR, ADDR) error); ok { - r1 = rf(eoa, ocr2AggregatorID) + if rf, ok := ret.Get(1).(func(context.Context, ADDR, ADDR) error); ok { + r1 = rf(ctx, eoa, ocr2AggregatorID) } else { r1 = ret.Error(1) } diff --git a/common/txmgr/resender.go b/common/txmgr/resender.go index b752ec63f13..8483b7a0264 100644 --- a/common/txmgr/resender.go +++ b/common/txmgr/resender.go @@ -8,6 +8,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/chains/label" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink/v2/common/client" @@ -56,8 +57,7 @@ type Resender[ logger logger.SugaredLogger lastAlertTimestamps map[string]time.Time - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan chDone chan struct{} } @@ -83,7 +83,6 @@ func NewResender[ panic("Resender requires a non-zero threshold") } // todo: add context to txStore https://smartcontract-it.atlassian.net/browse/BCI-1585 - ctx, cancel := context.WithCancel(context.Background()) return &Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]{ txStore, client, @@ -95,8 +94,7 @@ func NewResender[ txConfig, logger.Sugared(logger.Named(lggr, "Resender")), make(map[string]time.Time), - ctx, - cancel, + make(chan struct{}), make(chan struct{}), } } @@ -109,14 +107,16 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(ctx // Stop is a comment which satisfies the linter func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Stop() { - er.cancel() + close(er.stopCh) <-er.chDone } func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() { defer close(er.chDone) + ctx, cancel := er.stopCh.NewCtx() + defer cancel() - if err := er.resendUnconfirmed(er.ctx); err != nil { + if err := er.resendUnconfirmed(ctx); err != nil { er.logger.Warnw("Failed to resend unconfirmed transactions", "err", err) } @@ -124,10 +124,10 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() defer ticker.Stop() for { select { - case <-er.ctx.Done(): + case <-ctx.Done(): return case <-ticker.C: - if err := er.resendUnconfirmed(er.ctx); err != nil { + if err := er.resendUnconfirmed(ctx); err != nil { er.logger.Warnw("Failed to resend unconfirmed transactions", "err", err) } } @@ -135,6 +135,9 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() } func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) resendUnconfirmed(ctx context.Context) error { + var cancel func() + ctx, cancel = er.stopCh.Ctx(ctx) + defer cancel() resendAddresses, err := er.ks.EnabledAddressesForChain(ctx, er.chainID) if err != nil { return fmt.Errorf("Resender failed getting enabled keys for chain %s: %w", er.chainID.String(), err) @@ -147,7 +150,7 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) resendUnco for _, k := range resendAddresses { var attempts []txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] - attempts, err = er.txStore.FindTxAttemptsRequiringResend(er.ctx, olderThan, maxInFlightTransactions, er.chainID, k) + attempts, err = er.txStore.FindTxAttemptsRequiringResend(ctx, olderThan, maxInFlightTransactions, er.chainID, k) if err != nil { return fmt.Errorf("failed to FindTxAttemptsRequiringResend: %w", err) } @@ -165,13 +168,13 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) resendUnco er.logger.Infow(fmt.Sprintf("Re-sending %d unconfirmed transactions that were last sent over %s ago. These transactions are taking longer than usual to be mined. %s", len(allAttempts), ageThreshold, label.NodeConnectivityProblemWarning), "n", len(allAttempts)) batchSize := int(er.config.RPCDefaultBatchSize()) - ctx, cancel := context.WithTimeout(er.ctx, batchSendTransactionTimeout) - defer cancel() - txErrTypes, _, broadcastTime, txIDs, err := er.client.BatchSendTransactions(ctx, allAttempts, batchSize, er.logger) + batchCtx, batchCancel := context.WithTimeout(ctx, batchSendTransactionTimeout) + defer batchCancel() + txErrTypes, _, broadcastTime, txIDs, err := er.client.BatchSendTransactions(batchCtx, allAttempts, batchSize, er.logger) // update broadcast times before checking additional errors if len(txIDs) > 0 { - if updateErr := er.txStore.UpdateBroadcastAts(er.ctx, broadcastTime, txIDs); updateErr != nil { + if updateErr := er.txStore.UpdateBroadcastAts(ctx, broadcastTime, txIDs); updateErr != nil { err = errors.Join(err, fmt.Errorf("failed to update broadcast time: %w", updateErr)) } } diff --git a/common/txmgr/test_helpers.go b/common/txmgr/test_helpers.go index dbc07861ffe..3051e0985d8 100644 --- a/common/txmgr/test_helpers.go +++ b/common/txmgr/test_helpers.go @@ -35,7 +35,9 @@ func (eb *Broadcaster[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) XXXT } func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestStartInternal() error { - return ec.startInternal(ec.ctx) + ctx, cancel := ec.stopCh.NewCtx() + defer cancel() + return ec.startInternal(ctx) } func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestCloseInternal() error { @@ -43,7 +45,9 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXX } func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestResendUnconfirmed() error { - return er.resendUnconfirmed(er.ctx) + ctx, cancel := er.stopCh.NewCtx() + defer cancel() + return er.resendUnconfirmed(ctx) } func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestAbandon(addr ADDR) (err error) { diff --git a/common/txmgr/txmgr.go b/common/txmgr/txmgr.go index 1c8b59a55cc..44b518fdaab 100644 --- a/common/txmgr/txmgr.go +++ b/common/txmgr/txmgr.go @@ -46,8 +46,8 @@ type TxManager[ services.Service Trigger(addr ADDR) CreateTransaction(ctx context.Context, txRequest txmgrtypes.TxRequest[ADDR, TX_HASH]) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) - GetForwarderForEOA(eoa ADDR) (forwarder ADDR, err error) - GetForwarderForEOAOCR2Feeds(eoa, ocr2AggregatorID ADDR) (forwarder ADDR, err error) + GetForwarderForEOA(ctx context.Context, eoa ADDR) (forwarder ADDR, err error) + GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa, ocr2AggregatorID ADDR) (forwarder ADDR, err error) RegisterResumeCallback(fn ResumeCallback) SendNativeToken(ctx context.Context, chainID CHAIN_ID, from, to ADDR, value big.Int, gasLimit uint64) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) Reset(addr ADDR, abandon bool) error @@ -546,20 +546,20 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) CreateTran } // Calls forwarderMgr to get a proper forwarder for a given EOA. -func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOA(eoa ADDR) (forwarder ADDR, err error) { +func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOA(ctx context.Context, eoa ADDR) (forwarder ADDR, err error) { if !b.txConfig.ForwardersEnabled() { return forwarder, fmt.Errorf("forwarding is not enabled, to enable set Transactions.ForwardersEnabled =true") } - forwarder, err = b.fwdMgr.ForwarderFor(eoa) + forwarder, err = b.fwdMgr.ForwarderFor(ctx, eoa) return } // GetForwarderForEOAOCR2Feeds calls forwarderMgr to get a proper forwarder for a given EOA and checks if its set as a transmitter on the OCR2Aggregator contract. -func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) { +func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) { if !b.txConfig.ForwardersEnabled() { return forwarder, fmt.Errorf("forwarding is not enabled, to enable set Transactions.ForwardersEnabled =true") } - forwarder, err = b.fwdMgr.ForwarderForOCR2Feeds(eoa, ocr2Aggregator) + forwarder, err = b.fwdMgr.ForwarderForOCR2Feeds(ctx, eoa, ocr2Aggregator) return } @@ -656,10 +656,10 @@ func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) Tri func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) CreateTransaction(ctx context.Context, txRequest txmgrtypes.TxRequest[ADDR, TX_HASH]) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) { return etx, errors.New(n.ErrMsg) } -func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(addr ADDR) (fwdr ADDR, err error) { +func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(ctx context.Context, addr ADDR) (fwdr ADDR, err error) { return fwdr, err } -func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(_, _ ADDR) (fwdr ADDR, err error) { +func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(ctx context.Context, _, _ ADDR) (fwdr ADDR, err error) { return fwdr, err } diff --git a/common/txmgr/types/forwarder_manager.go b/common/txmgr/types/forwarder_manager.go index 3e51ffb1524..6acb491a1fb 100644 --- a/common/txmgr/types/forwarder_manager.go +++ b/common/txmgr/types/forwarder_manager.go @@ -1,15 +1,18 @@ package types import ( + "context" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink/v2/common/types" ) //go:generate mockery --quiet --name ForwarderManager --output ./mocks/ --case=underscore type ForwarderManager[ADDR types.Hashable] interface { services.Service - ForwarderFor(addr ADDR) (forwarder ADDR, err error) - ForwarderForOCR2Feeds(eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) + ForwarderFor(ctx context.Context, addr ADDR) (forwarder ADDR, err error) + ForwarderForOCR2Feeds(ctx context.Context, eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) // Converts payload to be forwarder-friendly ConvertPayload(dest ADDR, origPayload []byte) ([]byte, error) } diff --git a/common/txmgr/types/mocks/forwarder_manager.go b/common/txmgr/types/mocks/forwarder_manager.go index 1021e776e9d..b2cf9bc9d35 100644 --- a/common/txmgr/types/mocks/forwarder_manager.go +++ b/common/txmgr/types/mocks/forwarder_manager.go @@ -63,9 +63,9 @@ func (_m *ForwarderManager[ADDR]) ConvertPayload(dest ADDR, origPayload []byte) return r0, r1 } -// ForwarderFor provides a mock function with given fields: addr -func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { - ret := _m.Called(addr) +// ForwarderFor provides a mock function with given fields: ctx, addr +func (_m *ForwarderManager[ADDR]) ForwarderFor(ctx context.Context, addr ADDR) (ADDR, error) { + ret := _m.Called(ctx, addr) if len(ret) == 0 { panic("no return value specified for ForwarderFor") @@ -73,17 +73,17 @@ func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR) (ADDR, error)); ok { - return rf(addr) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) (ADDR, error)); ok { + return rf(ctx, addr) } - if rf, ok := ret.Get(0).(func(ADDR) ADDR); ok { - r0 = rf(addr) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) ADDR); ok { + r0 = rf(ctx, addr) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR) error); ok { - r1 = rf(addr) + if rf, ok := ret.Get(1).(func(context.Context, ADDR) error); ok { + r1 = rf(ctx, addr) } else { r1 = ret.Error(1) } @@ -91,9 +91,9 @@ func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { return r0, r1 } -// ForwarderForOCR2Feeds provides a mock function with given fields: eoa, ocr2Aggregator -func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(eoa ADDR, ocr2Aggregator ADDR) (ADDR, error) { - ret := _m.Called(eoa, ocr2Aggregator) +// ForwarderForOCR2Feeds provides a mock function with given fields: ctx, eoa, ocr2Aggregator +func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(ctx context.Context, eoa ADDR, ocr2Aggregator ADDR) (ADDR, error) { + ret := _m.Called(ctx, eoa, ocr2Aggregator) if len(ret) == 0 { panic("no return value specified for ForwarderForOCR2Feeds") @@ -101,17 +101,17 @@ func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(eoa ADDR, ocr2Aggregator var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR, ADDR) (ADDR, error)); ok { - return rf(eoa, ocr2Aggregator) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) (ADDR, error)); ok { + return rf(ctx, eoa, ocr2Aggregator) } - if rf, ok := ret.Get(0).(func(ADDR, ADDR) ADDR); ok { - r0 = rf(eoa, ocr2Aggregator) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) ADDR); ok { + r0 = rf(ctx, eoa, ocr2Aggregator) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR, ADDR) error); ok { - r1 = rf(eoa, ocr2Aggregator) + if rf, ok := ret.Get(1).(func(context.Context, ADDR, ADDR) error); ok { + r1 = rf(ctx, eoa, ocr2Aggregator) } else { r1 = ret.Error(1) } diff --git a/core/chains/evm/client/node.go b/core/chains/evm/client/node.go index 92b7a8301e5..968cb34b9fe 100644 --- a/core/chains/evm/client/node.go +++ b/core/chains/evm/client/node.go @@ -166,10 +166,8 @@ type node struct { // this node. Closing and replacing should be serialized through // stateMu since it can happen on state transitions as well as node Close. chStopInFlight chan struct{} - // nodeCtx is the node lifetime's context - nodeCtx context.Context - // cancelNodeCtx cancels nodeCtx when stopping the node - cancelNodeCtx context.CancelFunc + + stopCh services.StopChan // wg waits for subsidiary goroutines wg sync.WaitGroup @@ -196,7 +194,7 @@ func NewNode(nodeCfg config.NodePool, noNewHeadsThreshold time.Duration, lggr lo n.http = &rawclient{uri: *httpuri} } n.chStopInFlight = make(chan struct{}) - n.nodeCtx, n.cancelNodeCtx = context.WithCancel(context.Background()) + n.stopCh = make(chan struct{}) lggr = logger.Named(lggr, "Node") lggr = logger.With(lggr, "nodeTier", "primary", @@ -367,7 +365,7 @@ func (n *node) Close() error { n.stateMu.Lock() defer n.stateMu.Unlock() - n.cancelNodeCtx() + close(n.stopCh) n.cancelInflightRequests() n.state = NodeStateClosed return nil diff --git a/core/chains/evm/client/node_lifecycle.go b/core/chains/evm/client/node_lifecycle.go index 41add532222..c18c8032009 100644 --- a/core/chains/evm/client/node_lifecycle.go +++ b/core/chains/evm/client/node_lifecycle.go @@ -75,6 +75,8 @@ const ( // Should only be run ONCE per node, after a successful Dial func (n *node) aliveLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -96,7 +98,7 @@ func (n *node) aliveLoop() { lggr.Tracew("Alive loop starting", "nodeState", n.State()) headsC := make(chan *evmtypes.Head) - sub, err := n.EthSubscribe(n.nodeCtx, headsC, "newHeads") + sub, err := n.EthSubscribe(ctx, headsC, "newHeads") if err != nil { lggr.Errorw("Initial subscribe for heads failed", "nodeState", n.State()) n.declareUnreachable() @@ -137,18 +139,19 @@ func (n *node) aliveLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-pollCh: - var version string promEVMPoolRPCNodePolls.WithLabelValues(n.chainID.String(), n.name).Inc() lggr.Tracew("Polling for version", "nodeState", n.State(), "pollFailures", pollFailures) - ctx, cancel := context.WithTimeout(n.nodeCtx, pollInterval) - ctx, cancel2 := n.makeQueryCtx(ctx) - err := n.CallContext(ctx, &version, "web3_clientVersion") - cancel2() - cancel() - if err != nil { + var version string + if err := func(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, pollInterval) + defer cancel() + ctx, cancel2 := n.makeQueryCtx(ctx) + defer cancel2() + return n.CallContext(ctx, &version, "web3_clientVersion") + }(ctx); err != nil { // prevent overflow if pollFailures < math.MaxUint32 { promEVMPoolRPCNodePollsFailed.WithLabelValues(n.chainID.String(), n.name).Inc() @@ -262,6 +265,8 @@ const ( // outOfSyncLoop takes an OutOfSync node and waits until isOutOfSync returns false to go back to live status func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -281,14 +286,14 @@ func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { lggr.Debugw("Trying to revive out-of-sync RPC node", "nodeState", n.State()) // Need to redial since out-of-sync nodes are automatically disconnected - if err := n.dial(n.nodeCtx); err != nil { + if err := n.dial(ctx); err != nil { lggr.Errorw("Failed to dial out-of-sync RPC node", "nodeState", n.State()) n.declareUnreachable() return } // Manually re-verify since out-of-sync nodes are automatically disconnected - if err := n.verify(n.nodeCtx); err != nil { + if err := n.verify(ctx); err != nil { lggr.Errorw(fmt.Sprintf("Failed to verify out-of-sync RPC node: %v", err), "err", err) n.declareInvalidChainID() return @@ -297,7 +302,7 @@ func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { lggr.Tracew("Successfully subscribed to heads feed on out-of-sync RPC node", "nodeState", n.State()) ch := make(chan *evmtypes.Head) - subCtx, cancel := n.makeQueryCtx(n.nodeCtx) + subCtx, cancel := n.makeQueryCtx(ctx) // raw call here to bypass node state checking sub, err := n.ws.rpc.EthSubscribe(subCtx, ch, "newHeads") cancel() @@ -310,7 +315,7 @@ func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case head, open := <-ch: if !open { @@ -344,6 +349,8 @@ func (n *node) outOfSyncLoop(isOutOfSync func(num int64, td *big.Int) bool) { func (n *node) unreachableLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -366,12 +373,12 @@ func (n *node) unreachableLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(dialRetryBackoff.Duration()): lggr.Tracew("Trying to re-dial RPC node", "nodeState", n.State()) - err := n.dial(n.nodeCtx) + err := n.dial(ctx) if err != nil { lggr.Errorw(fmt.Sprintf("Failed to redial RPC node; still unreachable: %v", err), "err", err, "nodeState", n.State()) continue @@ -379,7 +386,7 @@ func (n *node) unreachableLoop() { n.setState(NodeStateDialed) - err = n.verify(n.nodeCtx) + err = n.verify(ctx) if pkgerrors.Is(err, errInvalidChainID) { lggr.Errorw("Failed to redial RPC node; remote endpoint returned the wrong chain ID", "err", err) @@ -400,6 +407,8 @@ func (n *node) unreachableLoop() { func (n *node) invalidChainIDLoop() { defer n.wg.Done() + ctx, cancel := n.stopCh.NewCtx() + defer cancel() { // sanity check @@ -422,10 +431,10 @@ func (n *node) invalidChainIDLoop() { for { select { - case <-n.nodeCtx.Done(): + case <-ctx.Done(): return case <-time.After(chainIDRecheckBackoff.Duration()): - err := n.verify(n.nodeCtx) + err := n.verify(ctx) if pkgerrors.Is(err, errInvalidChainID) { lggr.Errorw("Failed to verify RPC node; remote endpoint returned the wrong chain ID", "err", err) continue diff --git a/core/chains/evm/forwarders/forwarder_manager.go b/core/chains/evm/forwarders/forwarder_manager.go index 9505cdfbbbf..fca35274708 100644 --- a/core/chains/evm/forwarders/forwarder_manager.go +++ b/core/chains/evm/forwarders/forwarder_manager.go @@ -49,8 +49,7 @@ type FwdMgr struct { authRcvr authorized_receiver.AuthorizedReceiverInterface offchainAgg offchain_aggregator_wrapper.OffchainAggregatorInterface - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan cacheMu sync.RWMutex wg sync.WaitGroup @@ -66,7 +65,7 @@ func NewFwdMgr(ds sqlutil.DataSource, client evmclient.Client, logpoller evmlogp logpoller: logpoller, sendersCache: make(map[common.Address][]common.Address), } - fwdMgr.ctx, fwdMgr.cancel = context.WithCancel(context.Background()) + fwdMgr.stopCh = make(chan struct{}) return &fwdMgr } @@ -86,7 +85,7 @@ func (f *FwdMgr) Start(ctx context.Context) error { } if len(fwdrs) != 0 { f.initForwardersCache(ctx, fwdrs) - if err = f.subscribeForwardersLogs(fwdrs); err != nil { + if err = f.subscribeForwardersLogs(ctx, fwdrs); err != nil { return err } } @@ -111,15 +110,15 @@ func FilterName(addr common.Address) string { return evmlogpoller.FilterName("ForwarderManager AuthorizedSendersChanged", addr.String()) } -func (f *FwdMgr) ForwarderFor(addr common.Address) (forwarder common.Address, err error) { +func (f *FwdMgr) ForwarderFor(ctx context.Context, addr common.Address) (forwarder common.Address, err error) { // Gets forwarders for current chain. - fwdrs, err := f.ORM.FindForwardersByChain(f.ctx, big.Big(*f.evmClient.ConfiguredChainID())) + fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*f.evmClient.ConfiguredChainID())) if err != nil { return common.Address{}, err } for _, fwdr := range fwdrs { - eoas, err := f.getContractSenders(fwdr.Address) + eoas, err := f.getContractSenders(ctx, fwdr.Address) if err != nil { f.logger.Errorw("Failed to get forwarder senders", "forwarder", fwdr.Address, "err", err) continue @@ -133,8 +132,8 @@ func (f *FwdMgr) ForwarderFor(addr common.Address) (forwarder common.Address, er return common.Address{}, pkgerrors.Errorf("Cannot find forwarder for given EOA") } -func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forwarder common.Address, err error) { - fwdrs, err := f.ORM.FindForwardersByChain(f.ctx, big.Big(*f.evmClient.ConfiguredChainID())) +func (f *FwdMgr) ForwarderForOCR2Feeds(ctx context.Context, eoa, ocr2Aggregator common.Address) (forwarder common.Address, err error) { + fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*f.evmClient.ConfiguredChainID())) if err != nil { return common.Address{}, err } @@ -144,7 +143,7 @@ func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forw return common.Address{}, err } - transmitters, err := offchainAggregator.GetTransmitters(&bind.CallOpts{Context: f.ctx}) + transmitters, err := offchainAggregator.GetTransmitters(&bind.CallOpts{Context: ctx}) if err != nil { return common.Address{}, pkgerrors.Errorf("failed to get ocr2 aggregator transmitters: %s", err.Error()) } @@ -155,7 +154,7 @@ func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forw continue } - eoas, err := f.getContractSenders(fwdr.Address) + eoas, err := f.getContractSenders(ctx, fwdr.Address) if err != nil { f.logger.Errorw("Failed to get forwarder senders", "forwarder", fwdr.Address, "err", err) continue @@ -191,16 +190,16 @@ func (f *FwdMgr) getForwardedPayload(dest common.Address, origPayload []byte) ([ return dataBytes, nil } -func (f *FwdMgr) getContractSenders(addr common.Address) ([]common.Address, error) { +func (f *FwdMgr) getContractSenders(ctx context.Context, addr common.Address) ([]common.Address, error) { if senders, ok := f.getCachedSenders(addr); ok { return senders, nil } - senders, err := f.getAuthorizedSenders(f.ctx, addr) + senders, err := f.getAuthorizedSenders(ctx, addr) if err != nil { return nil, pkgerrors.Wrapf(err, "Failed to call getAuthorizedSenders on %s", addr) } f.setCachedSenders(addr, senders) - if err = f.subscribeSendersChangedLogs(addr); err != nil { + if err = f.subscribeSendersChangedLogs(ctx, addr); err != nil { return nil, err } return senders, nil @@ -230,23 +229,23 @@ func (f *FwdMgr) initForwardersCache(ctx context.Context, fwdrs []Forwarder) { } } -func (f *FwdMgr) subscribeForwardersLogs(fwdrs []Forwarder) error { +func (f *FwdMgr) subscribeForwardersLogs(ctx context.Context, fwdrs []Forwarder) error { for _, fwdr := range fwdrs { - if err := f.subscribeSendersChangedLogs(fwdr.Address); err != nil { + if err := f.subscribeSendersChangedLogs(ctx, fwdr.Address); err != nil { return err } } return nil } -func (f *FwdMgr) subscribeSendersChangedLogs(addr common.Address) error { +func (f *FwdMgr) subscribeSendersChangedLogs(ctx context.Context, addr common.Address) error { if err := f.logpoller.Ready(); err != nil { f.logger.Warnw("Unable to subscribe to AuthorizedSendersChanged logs", "forwarder", addr, "err", err) return nil } err := f.logpoller.RegisterFilter( - f.ctx, + ctx, evmlogpoller.Filter{ Name: FilterName(addr), EventSigs: []common.Hash{authChangedTopic}, @@ -270,8 +269,10 @@ func (f *FwdMgr) getCachedSenders(addr common.Address) ([]common.Address, bool) func (f *FwdMgr) runLoop() { defer f.wg.Done() - tick := time.After(0) + ctx, cancel := f.stopCh.NewCtx() + defer cancel() + tick := time.After(0) for ; ; tick = time.After(utils.WithJitter(time.Minute)) { select { case <-tick: @@ -287,7 +288,7 @@ func (f *FwdMgr) runLoop() { } logs, err := f.logpoller.LatestLogEventSigsAddrsWithConfs( - f.ctx, + ctx, f.latestBlock, []common.Hash{authChangedTopic}, addrs, @@ -308,7 +309,7 @@ func (f *FwdMgr) runLoop() { } } - case <-f.ctx.Done(): + case <-ctx.Done(): return } } @@ -352,7 +353,7 @@ func (f *FwdMgr) collectAddresses() (addrs []common.Address) { // Stop cancels all outgoings calls and stops internal ticker loop. func (f *FwdMgr) Close() error { return f.StopOnce("EVMForwarderManager", func() (err error) { - f.cancel() + close(f.stopCh) f.wg.Wait() return nil }) diff --git a/core/chains/evm/forwarders/forwarder_manager_test.go b/core/chains/evm/forwarders/forwarder_manager_test.go index 993efacac4a..020446aa547 100644 --- a/core/chains/evm/forwarders/forwarder_manager_test.go +++ b/core/chains/evm/forwarders/forwarder_manager_test.go @@ -18,6 +18,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/testhelpers" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" @@ -86,7 +87,7 @@ func TestFwdMgr_MaybeForwardTransaction(t *testing.T) { require.Equal(t, lst[0].Address, forwarderAddr) require.NoError(t, fwdMgr.Start(testutils.Context(t))) - addr, err := fwdMgr.ForwarderFor(owner.From) + addr, err := fwdMgr.ForwarderFor(ctx, owner.From) require.NoError(t, err) require.Equal(t, addr.String(), forwarderAddr.String()) err = fwdMgr.Close() @@ -148,7 +149,7 @@ func TestFwdMgr_AccountUnauthorizedToForward_SkipsForwarding(t *testing.T) { err = fwdMgr.Start(testutils.Context(t)) require.NoError(t, err) - addr, err := fwdMgr.ForwarderFor(owner.From) + addr, err := fwdMgr.ForwarderFor(ctx, owner.From) require.ErrorContains(t, err, "Cannot find forwarder for given EOA") require.True(t, utils.IsZero(addr)) err = fwdMgr.Close() @@ -214,7 +215,7 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { fwdMgr = forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM()) require.NoError(t, fwdMgr.Start(testutils.Context(t))) // cannot find forwarder because it isn't authorized nor added as a transmitter - addr, err := fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + addr, err := fwdMgr.ForwarderForOCR2Feeds(ctx, owner.From, ocr2Address) require.ErrorContains(t, err, "Cannot find forwarder for given EOA") require.True(t, utils.IsZero(addr)) @@ -227,7 +228,7 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { require.Equal(t, owner.From, authorizedSenders[0]) // cannot find forwarder because it isn't added as a transmitter - addr, err = fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + addr, err = fwdMgr.ForwarderForOCR2Feeds(ctx, owner.From, ocr2Address) require.ErrorContains(t, err, "Cannot find forwarder for given EOA") require.True(t, utils.IsZero(addr)) @@ -251,7 +252,7 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { // create new fwd to have an empty cache that has to fetch authorized forwarders from log poller fwdMgr = forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM()) require.NoError(t, fwdMgr.Start(testutils.Context(t))) - addr, err = fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + addr, err = fwdMgr.ForwarderForOCR2Feeds(ctx, owner.From, ocr2Address) require.NoError(t, err, "forwarder should be valid and found because it is both authorized and set as a transmitter") require.Equal(t, forwarderAddr, addr) require.NoError(t, fwdMgr.Close()) diff --git a/core/chains/evm/gas/block_history_estimator.go b/core/chains/evm/gas/block_history_estimator.go index 82f1c46fff8..f6d15f7aff7 100644 --- a/core/chains/evm/gas/block_history_estimator.go +++ b/core/chains/evm/gas/block_history_estimator.go @@ -104,13 +104,12 @@ type BlockHistoryEstimator struct { bhConfig BlockHistoryConfig // NOTE: it is assumed that blocks will be kept sorted by // block number ascending - blocks []evmtypes.Block - blocksMu sync.RWMutex - size int64 - mb *mailbox.Mailbox[*evmtypes.Head] - wg *sync.WaitGroup - ctx context.Context - ctxCancel context.CancelFunc + blocks []evmtypes.Block + blocksMu sync.RWMutex + size int64 + mb *mailbox.Mailbox[*evmtypes.Head] + wg *sync.WaitGroup + stopCh services.StopChan gasPrice *assets.Wei tipCap *assets.Wei @@ -128,9 +127,7 @@ type BlockHistoryEstimator struct { // for new heads and updates the base gas price dynamically based on the // configured percentile of gas prices in that block func NewBlockHistoryEstimator(lggr logger.Logger, ethClient feeEstimatorClient, cfg chainConfig, eCfg estimatorGasEstimatorConfig, bhCfg BlockHistoryConfig, chainID *big.Int, l1Oracle rollups.L1Oracle) EvmEstimator { - ctx, cancel := context.WithCancel(context.Background()) - - b := &BlockHistoryEstimator{ + return &BlockHistoryEstimator{ ethClient: ethClient, chainID: chainID, config: cfg, @@ -138,16 +135,13 @@ func NewBlockHistoryEstimator(lggr logger.Logger, ethClient feeEstimatorClient, bhConfig: bhCfg, blocks: make([]evmtypes.Block, 0), // Must have enough blocks for both estimator and connectivity checker - size: int64(mathutil.Max(bhCfg.BlockHistorySize(), bhCfg.CheckInclusionBlocks())), - mb: mailbox.NewSingle[*evmtypes.Head](), - wg: new(sync.WaitGroup), - ctx: ctx, - ctxCancel: cancel, - logger: logger.Sugared(logger.Named(lggr, "BlockHistoryEstimator")), - l1Oracle: l1Oracle, + size: int64(mathutil.Max(bhCfg.BlockHistorySize(), bhCfg.CheckInclusionBlocks())), + mb: mailbox.NewSingle[*evmtypes.Head](), + wg: new(sync.WaitGroup), + stopCh: make(chan struct{}), + logger: logger.Sugared(logger.Named(lggr, "BlockHistoryEstimator")), + l1Oracle: l1Oracle, } - - return b } // OnNewLongestChain recalculates and sets global gas price if a sampled new head comes @@ -240,7 +234,7 @@ func (b *BlockHistoryEstimator) L1Oracle() rollups.L1Oracle { func (b *BlockHistoryEstimator) Close() error { return b.StopOnce("BlockHistoryEstimator", func() error { - b.ctxCancel() + close(b.stopCh) b.wg.Wait() return nil }) @@ -482,9 +476,12 @@ func (b *BlockHistoryEstimator) BumpDynamicFee(_ context.Context, originalFee Dy func (b *BlockHistoryEstimator) runLoop() { defer b.wg.Done() + ctx, cancel := b.stopCh.NewCtx() + defer cancel() + for { select { - case <-b.ctx.Done(): + case <-ctx.Done(): return case <-b.mb.Notify(): head, exists := b.mb.Retrieve() @@ -492,7 +489,7 @@ func (b *BlockHistoryEstimator) runLoop() { b.logger.Debug("No head to retrieve") continue } - b.FetchBlocksAndRecalculate(b.ctx, head) + b.FetchBlocksAndRecalculate(ctx, head) } } } diff --git a/core/chains/evm/logpoller/log_poller.go b/core/chains/evm/logpoller/log_poller.go index e964b86221c..26978b18d48 100644 --- a/core/chains/evm/logpoller/log_poller.go +++ b/core/chains/evm/logpoller/log_poller.go @@ -119,8 +119,7 @@ type logPoller struct { replayStart chan int64 replayComplete chan error - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan wg sync.WaitGroup // This flag is raised whenever the log poller detects that the chain's finality has been violated. // It can happen when reorg is deeper than the latest finalized block that LogPoller saw in a previous PollAndSave tick. @@ -152,10 +151,8 @@ type Opts struct { // How fast that can be done depends largely on network speed and DB, but even for the fastest // support chain, polygon, which has 2s block times, we need RPCs roughly with <= 500ms latency func NewLogPoller(orm ORM, ec Client, lggr logger.Logger, opts Opts) *logPoller { - ctx, cancel := context.WithCancel(context.Background()) return &logPoller{ - ctx: ctx, - cancel: cancel, + stopCh: make(chan struct{}), ec: ec, orm: orm, lggr: logger.Sugared(logger.Named(lggr, "LogPoller")), @@ -465,21 +462,23 @@ func (lp *logPoller) savedFinalizedBlockNumber(ctx context.Context) (int64, erro } func (lp *logPoller) recvReplayComplete() { + defer lp.wg.Done() err := <-lp.replayComplete if err != nil { lp.lggr.Error(err) } - lp.wg.Done() } // Asynchronous wrapper for Replay() func (lp *logPoller) ReplayAsync(fromBlock int64) { lp.wg.Add(1) go func() { - if err := lp.Replay(lp.ctx, fromBlock); err != nil { + defer lp.wg.Done() + ctx, cancel := lp.stopCh.NewCtx() + defer cancel() + if err := lp.Replay(ctx, fromBlock); err != nil { lp.lggr.Error(err) } - lp.wg.Done() }() } @@ -498,7 +497,7 @@ func (lp *logPoller) Close() error { case lp.replayComplete <- ErrLogPollerShutdown: default: } - lp.cancel() + close(lp.stopCh) lp.wg.Wait() return nil }) @@ -535,10 +534,10 @@ func (lp *logPoller) GetReplayFromBlock(ctx context.Context, requested int64) (i return mathutil.Min(requested, lastProcessed.BlockNumber), nil } -func (lp *logPoller) loadFilters() error { +func (lp *logPoller) loadFilters(ctx context.Context) error { lp.filterMu.Lock() defer lp.filterMu.Unlock() - filters, err := lp.orm.LoadFilters(lp.ctx) + filters, err := lp.orm.LoadFilters(ctx) if err != nil { return pkgerrors.Wrapf(err, "Failed to load initial filters from db, retrying") @@ -551,6 +550,8 @@ func (lp *logPoller) loadFilters() error { func (lp *logPoller) run() { defer lp.wg.Done() + ctx, cancel := lp.stopCh.NewCtx() + defer cancel() logPollTick := time.After(0) // stagger these somewhat, so they don't all run back-to-back backupLogPollTick := time.After(100 * time.Millisecond) @@ -558,14 +559,14 @@ func (lp *logPoller) run() { for { select { - case <-lp.ctx.Done(): + case <-ctx.Done(): return case fromBlockReq := <-lp.replayStart: - lp.handleReplayRequest(fromBlockReq, filtersLoaded) + lp.handleReplayRequest(ctx, fromBlockReq, filtersLoaded) case <-logPollTick: logPollTick = time.After(utils.WithJitter(lp.pollPeriod)) if !filtersLoaded { - if err := lp.loadFilters(); err != nil { + if err := lp.loadFilters(ctx); err != nil { lp.lggr.Errorw("Failed loading filters in main logpoller loop, retrying later", "err", err) continue } @@ -574,7 +575,7 @@ func (lp *logPoller) run() { // Always start from the latest block in the db. var start int64 - lastProcessed, err := lp.orm.SelectLatestBlock(lp.ctx) + lastProcessed, err := lp.orm.SelectLatestBlock(ctx) if err != nil { if !pkgerrors.Is(err, sql.ErrNoRows) { // Assume transient db reading issue, retry forever. @@ -583,7 +584,7 @@ func (lp *logPoller) run() { } // Otherwise this is the first poll _ever_ on a new chain. // Only safe thing to do is to start at the first finalized block. - latestBlock, latestFinalizedBlockNumber, err := lp.latestBlocks(lp.ctx) + latestBlock, latestFinalizedBlockNumber, err := lp.latestBlocks(ctx) if err != nil { lp.lggr.Warnw("Unable to get latest for first poll", "err", err) continue @@ -600,7 +601,7 @@ func (lp *logPoller) run() { } else { start = lastProcessed.BlockNumber + 1 } - lp.PollAndSaveLogs(lp.ctx, start) + lp.PollAndSaveLogs(ctx, start) case <-backupLogPollTick: if lp.backupPollerBlockDelay == 0 { continue // backup poller is disabled @@ -618,13 +619,15 @@ func (lp *logPoller) run() { lp.lggr.Warnw("Backup log poller ran before filters loaded, skipping") continue } - lp.BackupPollAndSaveLogs(lp.ctx) + lp.BackupPollAndSaveLogs(ctx) } } } func (lp *logPoller) backgroundWorkerRun() { defer lp.wg.Done() + ctx, cancel := lp.stopCh.NewCtx() + defer cancel() // Avoid putting too much pressure on the database by staggering the pruning of old blocks and logs. // Usually, node after restart will have some work to boot the plugins and other services. @@ -634,11 +637,11 @@ func (lp *logPoller) backgroundWorkerRun() { for { select { - case <-lp.ctx.Done(): + case <-ctx.Done(): return case <-blockPruneTick: blockPruneTick = time.After(utils.WithJitter(lp.pollPeriod * 1000)) - if allRemoved, err := lp.PruneOldBlocks(lp.ctx); err != nil { + if allRemoved, err := lp.PruneOldBlocks(ctx); err != nil { lp.lggr.Errorw("Unable to prune old blocks", "err", err) } else if !allRemoved { // Tick faster when cleanup can't keep up with the pace of new blocks @@ -646,7 +649,7 @@ func (lp *logPoller) backgroundWorkerRun() { } case <-logPruneTick: logPruneTick = time.After(utils.WithJitter(lp.pollPeriod * 2401)) // = 7^5 avoids common factors with 1000 - if allRemoved, err := lp.PruneExpiredLogs(lp.ctx); err != nil { + if allRemoved, err := lp.PruneExpiredLogs(ctx); err != nil { lp.lggr.Errorw("Unable to prune expired logs", "err", err) } else if !allRemoved { // Tick faster when cleanup can't keep up with the pace of new logs @@ -656,26 +659,26 @@ func (lp *logPoller) backgroundWorkerRun() { } } -func (lp *logPoller) handleReplayRequest(fromBlockReq int64, filtersLoaded bool) { - fromBlock, err := lp.GetReplayFromBlock(lp.ctx, fromBlockReq) +func (lp *logPoller) handleReplayRequest(ctx context.Context, fromBlockReq int64, filtersLoaded bool) { + fromBlock, err := lp.GetReplayFromBlock(ctx, fromBlockReq) if err == nil { if !filtersLoaded { lp.lggr.Warnw("Received replayReq before filters loaded", "fromBlock", fromBlock, "requested", fromBlockReq) - if err = lp.loadFilters(); err != nil { + if err = lp.loadFilters(ctx); err != nil { lp.lggr.Errorw("Failed loading filters during Replay", "err", err, "fromBlock", fromBlock) } } if err == nil { // Serially process replay requests. lp.lggr.Infow("Executing replay", "fromBlock", fromBlock, "requested", fromBlockReq) - lp.PollAndSaveLogs(lp.ctx, fromBlock) + lp.PollAndSaveLogs(ctx, fromBlock) lp.lggr.Infow("Executing replay finished", "fromBlock", fromBlock, "requested", fromBlockReq) } } else { lp.lggr.Errorw("Error executing replay, could not get fromBlock", "err", err) } select { - case <-lp.ctx.Done(): + case <-ctx.Done(): // We're shutting down, notify client and exit select { case lp.replayComplete <- ErrReplayRequestAborted: diff --git a/core/chains/evm/logpoller/log_poller_internal_test.go b/core/chains/evm/logpoller/log_poller_internal_test.go index b7dbb074568..bc295105874 100644 --- a/core/chains/evm/logpoller/log_poller_internal_test.go +++ b/core/chains/evm/logpoller/log_poller_internal_test.go @@ -280,7 +280,6 @@ func TestLogPoller_Replay(t *testing.T) { chainID := testutils.FixtureChainID db := pgtest.NewSqlxDB(t) orm := NewORM(chainID, db, lggr) - ctx := testutils.Context(t) head := evmtypes.Head{Number: 4} events := []common.Hash{EmitterABI.Events["Log1"].ID} @@ -312,18 +311,21 @@ func TestLogPoller_Replay(t *testing.T) { } lp := NewLogPoller(orm, ec, lggr, lpOpts) - // process 1 log in block 3 - lp.PollAndSaveLogs(ctx, 4) - latest, err := lp.LatestBlock(ctx) - require.NoError(t, err) - require.Equal(t, int64(4), latest.BlockNumber) - require.Equal(t, int64(1), latest.FinalizedBlockNumber) + { + ctx := testutils.Context(t) + // process 1 log in block 3 + lp.PollAndSaveLogs(ctx, 4) + latest, err := lp.LatestBlock(ctx) + require.NoError(t, err) + require.Equal(t, int64(4), latest.BlockNumber) + require.Equal(t, int64(1), latest.FinalizedBlockNumber) + } t.Run("abort before replayStart received", func(t *testing.T) { // Replay() should abort immediately if caller's context is cancelled before request signal is read cancelCtx, cancel := context.WithCancel(testutils.Context(t)) cancel() - err = lp.Replay(cancelCtx, 3) + err := lp.Replay(cancelCtx, 3) assert.ErrorIs(t, err, ErrReplayRequestAborted) }) @@ -338,6 +340,7 @@ func TestLogPoller_Replay(t *testing.T) { // Replay() should return error code received from replayComplete t.Run("returns error code on replay complete", func(t *testing.T) { + ctx := testutils.Context(t) ec.On("FilterLogs", mock.Anything, mock.Anything).Return([]types.Log{log1}, nil).Once() mockBatchCallContext(t, ec) anyErr := pkgerrors.New("any error") @@ -368,6 +371,7 @@ func TestLogPoller_Replay(t *testing.T) { // Main lp.run() loop shouldn't get stuck if client aborts t.Run("client abort doesnt hang run loop", func(t *testing.T) { + ctx := testutils.Context(t) lp.backupPollerNextBlock = 0 pass := make(chan struct{}) @@ -420,6 +424,7 @@ func TestLogPoller_Replay(t *testing.T) { // run() should abort if log poller shuts down while replay is in progress t.Run("shutdown during replay", func(t *testing.T) { + ctx := testutils.Context(t) lp.backupPollerNextBlock = 0 pass := make(chan struct{}) @@ -438,8 +443,15 @@ func TestLogPoller_Replay(t *testing.T) { }() }) ec.On("FilterLogs", mock.Anything, mock.Anything).Once().Return([]types.Log{log1}, nil).Run(func(args mock.Arguments) { - lp.cancel() - close(pass) + go func() { + assert.NoError(t, lp.Close()) + + // prevent double close + lp.reset() + assert.NoError(t, lp.Start(ctx)) + + close(pass) + }() }) ec.On("FilterLogs", mock.Anything, mock.Anything).Return([]types.Log{log1}, nil) @@ -468,6 +480,7 @@ func TestLogPoller_Replay(t *testing.T) { }) t.Run("ReplayAsync error", func(t *testing.T) { + ctx := testutils.Context(t) t.Cleanup(lp.reset) servicetest.Run(t, lp) head = evmtypes.Head{Number: 4} @@ -481,7 +494,7 @@ func TestLogPoller_Replay(t *testing.T) { select { case lp.replayComplete <- anyErr: time.Sleep(2 * time.Second) - case <-lp.ctx.Done(): + case <-ctx.Done(): t.Error("timed out waiting to send replaceComplete") } require.Equal(t, 1, observedLogs.Len()) @@ -489,6 +502,7 @@ func TestLogPoller_Replay(t *testing.T) { }) t.Run("run regular replay when there are not blocks in db", func(t *testing.T) { + ctx := testutils.Context(t) err := lp.orm.DeleteLogsAndBlocksAfter(ctx, 0) require.NoError(t, err) @@ -497,6 +511,7 @@ func TestLogPoller_Replay(t *testing.T) { }) t.Run("run only backfill when everything is finalized", func(t *testing.T) { + ctx := testutils.Context(t) err := lp.orm.DeleteLogsAndBlocksAfter(ctx, 0) require.NoError(t, err) @@ -513,7 +528,7 @@ func TestLogPoller_Replay(t *testing.T) { func (lp *logPoller) reset() { lp.StateMachine = services.StateMachine{} - lp.ctx, lp.cancel = context.WithCancel(context.Background()) + lp.stopCh = make(chan struct{}) } func Test_latestBlockAndFinalityDepth(t *testing.T) { diff --git a/core/chains/evm/txmgr/evm_tx_store.go b/core/chains/evm/txmgr/evm_tx_store.go index 22b9b6678fa..505938d3026 100644 --- a/core/chains/evm/txmgr/evm_tx_store.go +++ b/core/chains/evm/txmgr/evm_tx_store.go @@ -19,6 +19,7 @@ import ( nullv4 "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/null" @@ -76,10 +77,9 @@ type TestEvmTxStore interface { } type evmTxStore struct { - q sqlutil.DataSource - logger logger.SugaredLogger - ctx context.Context - ctxCancel context.CancelFunc + q sqlutil.DataSource + logger logger.SugaredLogger + stopCh services.StopChan } var _ EvmTxStore = (*evmTxStore)(nil) @@ -345,12 +345,10 @@ func NewTxStore( lggr logger.Logger, ) *evmTxStore { namedLogger := logger.Named(lggr, "TxmStore") - ctx, cancel := context.WithCancel(context.Background()) return &evmTxStore{ - q: db, - logger: logger.Sugared(namedLogger), - ctx: ctx, - ctxCancel: cancel, + q: db, + logger: logger.Sugared(namedLogger), + stopCh: make(chan struct{}), } } @@ -361,7 +359,7 @@ RETURNING *; ` func (o *evmTxStore) Close() { - o.ctxCancel() + close(o.stopCh) } func (o *evmTxStore) preloadTxAttempts(ctx context.Context, txs []Tx) error { @@ -398,7 +396,7 @@ func (o *evmTxStore) preloadTxAttempts(ctx context.Context, txs []Tx) error { func (o *evmTxStore) PreloadTxes(ctx context.Context, attempts []TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() return o.preloadTxesAtomic(ctx, attempts) } @@ -573,7 +571,7 @@ func (o *evmTxStore) InsertReceipt(ctx context.Context, receipt *evmtypes.Receip func (o *evmTxStore) GetFatalTransactions(ctx context.Context) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { stmt := `SELECT * FROM evm.txes WHERE state = 'fatal_error'` @@ -651,7 +649,7 @@ func (o *evmTxStore) LoadTxesAttempts(ctx context.Context, etxs []*Tx) error { func (o *evmTxStore) LoadTxAttempts(ctx context.Context, etx *Tx) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() return o.loadTxAttemptsAtomic(ctx, etx) } @@ -716,7 +714,7 @@ func loadConfirmedAttemptsReceipts(ctx context.Context, q sqlutil.DataSource, at // eth_tx that was last sent before or at the given time (up to limit) func (o *evmTxStore) FindTxAttemptsRequiringResend(ctx context.Context, olderThan time.Time, maxInFlightTransactions uint32, chainID *big.Int, address common.Address) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var limit null.Uint32 if maxInFlightTransactions > 0 { @@ -740,7 +738,7 @@ LIMIT $4 func (o *evmTxStore) UpdateBroadcastAts(ctx context.Context, now time.Time, etxIDs []int64) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() // Deliberately do nothing on NULL broadcast_at because that indicates the // tx has been moved into a state where broadcast_at is not relevant, e.g. @@ -758,7 +756,7 @@ func (o *evmTxStore) UpdateBroadcastAts(ctx context.Context, now time.Time, etxI // the attempt is already broadcast it _must_ have been before this head. func (o *evmTxStore) SetBroadcastBeforeBlockNum(ctx context.Context, blockNum int64, chainID *big.Int) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() _, err := o.q.ExecContext(ctx, `UPDATE evm.tx_attempts @@ -773,7 +771,7 @@ AND evm.txes.id = evm.tx_attempts.eth_tx_id AND evm.txes.evm_chain_id = $2`, func (o *evmTxStore) FindTxAttemptsConfirmedMissingReceipt(ctx context.Context, chainID *big.Int) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbAttempts []DbEthTxAttempt err = o.q.SelectContext(ctx, &dbAttempts, @@ -792,7 +790,7 @@ func (o *evmTxStore) FindTxAttemptsConfirmedMissingReceipt(ctx context.Context, func (o *evmTxStore) UpdateTxsUnconfirmed(ctx context.Context, ids []int64) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() _, err := o.q.ExecContext(ctx, `UPDATE evm.txes SET state='unconfirmed' WHERE id = ANY($1)`, pq.Array(ids)) @@ -804,7 +802,7 @@ func (o *evmTxStore) UpdateTxsUnconfirmed(ctx context.Context, ids []int64) erro func (o *evmTxStore) FindTxAttemptsRequiringReceiptFetch(ctx context.Context, chainID *big.Int) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbAttempts []DbEthTxAttempt @@ -826,7 +824,7 @@ ORDER BY evm.txes.nonce ASC, evm.tx_attempts.gas_price DESC, evm.tx_attempts.gas func (o *evmTxStore) SaveFetchedReceipts(ctx context.Context, r []*evmtypes.Receipt, chainID *big.Int) (err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() receipts := toOnchainReceipt(r) if len(receipts) == 0 { @@ -930,7 +928,7 @@ func (o *evmTxStore) SaveFetchedReceipts(ctx context.Context, r []*evmtypes.Rece // attempts are below the finality depth from current head. func (o *evmTxStore) MarkAllConfirmedMissingReceipt(ctx context.Context, chainID *big.Int) (err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() res, err := o.q.ExecContext(ctx, ` UPDATE evm.txes @@ -961,7 +959,7 @@ WHERE state = 'unconfirmed' func (o *evmTxStore) GetInProgressTxAttempts(ctx context.Context, address common.Address, chainID *big.Int) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbAttempts []DbEthTxAttempt @@ -985,7 +983,7 @@ func (o *evmTxStore) FindTxesPendingCallback(ctx context.Context, blockNum int64 var rs []dbReceiptPlus var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.q.SelectContext(ctx, &rs, ` SELECT evm.txes.pipeline_task_run_id, evm.receipts.receipt, COALESCE((evm.txes.meta->>'FailOnRevert')::boolean, false) "FailOnRevert" FROM evm.txes @@ -1004,7 +1002,7 @@ func (o *evmTxStore) FindTxesPendingCallback(ctx context.Context, blockNum int64 // Update tx to mark that its callback has been signaled func (o *evmTxStore) UpdateTxCallbackCompleted(ctx context.Context, pipelineTaskRunId uuid.UUID, chainId *big.Int) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() _, err := o.q.ExecContext(ctx, `UPDATE evm.txes SET callback_completed = TRUE WHERE pipeline_task_run_id = $1 AND evm_chain_id = $2`, pipelineTaskRunId, chainId.String()) if err != nil { @@ -1015,7 +1013,7 @@ func (o *evmTxStore) UpdateTxCallbackCompleted(ctx context.Context, pipelineTask func (o *evmTxStore) FindLatestSequence(ctx context.Context, fromAddress common.Address, chainId *big.Int) (nonce evmtypes.Nonce, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() sql := `SELECT nonce FROM evm.txes WHERE from_address = $1 AND evm_chain_id = $2 AND nonce IS NOT NULL ORDER BY nonce DESC LIMIT 1` err = o.q.GetContext(ctx, &nonce, sql, fromAddress, chainId.String()) @@ -1025,7 +1023,7 @@ func (o *evmTxStore) FindLatestSequence(ctx context.Context, fromAddress common. // FindTxWithIdempotencyKey returns any broadcast ethtx with the given idempotencyKey and chainID func (o *evmTxStore) FindTxWithIdempotencyKey(ctx context.Context, idempotencyKey string, chainID *big.Int) (etx *Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtx DbEthTx err = o.q.GetContext(ctx, &dbEtx, `SELECT * FROM evm.txes WHERE idempotency_key = $1 and evm_chain_id = $2`, idempotencyKey, chainID.String()) @@ -1043,7 +1041,7 @@ func (o *evmTxStore) FindTxWithIdempotencyKey(ctx context.Context, idempotencyKe // FindTxWithSequence returns any broadcast ethtx with the given nonce func (o *evmTxStore) FindTxWithSequence(ctx context.Context, fromAddress common.Address, nonce evmtypes.Nonce) (etx *Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() etx = new(Tx) err = o.Transact(ctx, true, func(orm *evmTxStore) error { @@ -1092,7 +1090,7 @@ AND evm.tx_attempts.eth_tx_id = $1 func (o *evmTxStore) UpdateTxForRebroadcast(ctx context.Context, etx Tx, etxAttempt TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() return o.Transact(ctx, false, func(orm *evmTxStore) error { if err := deleteEthReceipts(ctx, orm, etx.ID); err != nil { @@ -1107,7 +1105,7 @@ func (o *evmTxStore) UpdateTxForRebroadcast(ctx context.Context, etx Tx, etxAtte func (o *evmTxStore) FindTransactionsConfirmedInBlockRange(ctx context.Context, highBlockNumber, lowBlockNumber int64, chainID *big.Int) (etxs []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtxs []DbEthTx @@ -1134,7 +1132,7 @@ ORDER BY nonce ASC func (o *evmTxStore) FindEarliestUnconfirmedBroadcastTime(ctx context.Context, chainID *big.Int) (broadcastAt nullv4.Time, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { if err = orm.q.QueryRowxContext(ctx, `SELECT min(initial_broadcast_at) FROM evm.txes WHERE state = 'unconfirmed' AND evm_chain_id = $1`, chainID.String()).Scan(&broadcastAt); err != nil { @@ -1147,7 +1145,7 @@ func (o *evmTxStore) FindEarliestUnconfirmedBroadcastTime(ctx context.Context, c func (o *evmTxStore) FindEarliestUnconfirmedTxAttemptBlock(ctx context.Context, chainID *big.Int) (earliestUnconfirmedTxBlock nullv4.Int, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { err = orm.q.QueryRowxContext(ctx, ` @@ -1165,7 +1163,7 @@ AND evm_chain_id = $1`, chainID.String()).Scan(&earliestUnconfirmedTxBlock) func (o *evmTxStore) IsTxFinalized(ctx context.Context, blockHeight int64, txID int64, chainID *big.Int) (finalized bool, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var count int32 @@ -1198,7 +1196,7 @@ func (o *evmTxStore) saveAttemptWithNewState(ctx context.Context, attempt TxAtte func (o *evmTxStore) SaveInsufficientFundsAttempt(ctx context.Context, timeout time.Duration, attempt *TxAttempt, broadcastAt time.Time) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if !(attempt.State == txmgrtypes.TxAttemptInProgress || attempt.State == txmgrtypes.TxAttemptInsufficientFunds) { return errors.New("expected state to be either in_progress or insufficient_eth") @@ -1221,14 +1219,14 @@ func (o *evmTxStore) saveSentAttempt(ctx context.Context, timeout time.Duration, func (o *evmTxStore) SaveSentAttempt(ctx context.Context, timeout time.Duration, attempt *TxAttempt, broadcastAt time.Time) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() return o.saveSentAttempt(ctx, timeout, attempt, broadcastAt) } func (o *evmTxStore) SaveConfirmedMissingReceiptAttempt(ctx context.Context, timeout time.Duration, attempt *TxAttempt, broadcastAt time.Time) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err := o.Transact(ctx, false, func(orm *evmTxStore) error { if err := orm.saveSentAttempt(ctx, timeout, attempt, broadcastAt); err != nil { @@ -1244,7 +1242,7 @@ func (o *evmTxStore) SaveConfirmedMissingReceiptAttempt(ctx context.Context, tim func (o *evmTxStore) DeleteInProgressAttempt(ctx context.Context, attempt TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if attempt.State != txmgrtypes.TxAttemptInProgress { return errors.New("DeleteInProgressAttempt: expected attempt state to be in_progress") @@ -1259,7 +1257,7 @@ func (o *evmTxStore) DeleteInProgressAttempt(ctx context.Context, attempt TxAtte // SaveInProgressAttempt inserts or updates an attempt func (o *evmTxStore) SaveInProgressAttempt(ctx context.Context, attempt *TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if attempt.State != txmgrtypes.TxAttemptInProgress { return errors.New("SaveInProgressAttempt failed: attempt state must be in_progress") @@ -1293,7 +1291,7 @@ func (o *evmTxStore) SaveInProgressAttempt(ctx context.Context, attempt *TxAttem func (o *evmTxStore) GetAbandonedTransactionsByBatch(ctx context.Context, chainID *big.Int, enabledAddrs []common.Address, offset, limit uint) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var enabledAddrsBytea [][]byte @@ -1317,7 +1315,7 @@ func (o *evmTxStore) GetAbandonedTransactionsByBatch(ctx context.Context, chainI func (o *evmTxStore) GetTxByID(ctx context.Context, id int64) (txe *Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { @@ -1352,7 +1350,7 @@ func (o *evmTxStore) FindTxsRequiringGasBump(ctx context.Context, address common return } var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { stmt := ` @@ -1379,7 +1377,7 @@ ORDER BY nonce ASC // block func (o *evmTxStore) FindTxsRequiringResubmissionDueToInsufficientFunds(ctx context.Context, address common.Address, chainID *big.Int) (etxs []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtxs []DbEthTx @@ -1409,7 +1407,7 @@ ORDER BY nonce ASC // receipt and thus cannot pass on any transaction hash func (o *evmTxStore) MarkOldTxesMissingReceiptAsErrored(ctx context.Context, blockNum int64, finalityDepth uint32, chainID *big.Int) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() // cutoffBlockNum is a block height // Any 'confirmed_missing_receipt' eth_tx with all attempts older than this block height will be marked as errored @@ -1502,7 +1500,7 @@ GROUP BY e.id func (o *evmTxStore) SaveReplacementInProgressAttempt(ctx context.Context, oldAttempt TxAttempt, replacementAttempt *TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if oldAttempt.State != txmgrtypes.TxAttemptInProgress || replacementAttempt.State != txmgrtypes.TxAttemptInProgress { return errors.New("expected attempts to be in_progress") @@ -1529,7 +1527,7 @@ func (o *evmTxStore) SaveReplacementInProgressAttempt(ctx context.Context, oldAt // Finds earliest saved transaction that has yet to be broadcast from the given address func (o *evmTxStore) FindNextUnstartedTransactionFromAddress(ctx context.Context, fromAddress common.Address, chainID *big.Int) (*Tx, error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtx DbEthTx err := o.q.GetContext(ctx, &dbEtx, `SELECT * FROM evm.txes WHERE from_address = $1 AND state = 'unstarted' AND evm_chain_id = $2 ORDER BY value ASC, created_at ASC, id ASC`, fromAddress, chainID.String()) @@ -1544,7 +1542,7 @@ func (o *evmTxStore) FindNextUnstartedTransactionFromAddress(ctx context.Context func (o *evmTxStore) UpdateTxFatalError(ctx context.Context, etx *Tx) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if etx.State != txmgr.TxInProgress && etx.State != txmgr.TxUnstarted { return pkgerrors.Errorf("can only transition to fatal_error from in_progress or unstarted, transaction is currently %s", etx.State) @@ -1571,7 +1569,7 @@ func (o *evmTxStore) UpdateTxFatalError(ctx context.Context, etx *Tx) error { // Updates eth attempt from in_progress to broadcast. Also updates the eth tx to unconfirmed. func (o *evmTxStore) UpdateTxAttemptInProgressToBroadcast(ctx context.Context, etx *Tx, attempt TxAttempt, NewAttemptState txmgrtypes.TxAttemptState) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if etx.BroadcastAt == nil { return errors.New("unconfirmed transaction must have broadcast_at time") @@ -1609,7 +1607,7 @@ func (o *evmTxStore) UpdateTxAttemptInProgressToBroadcast(ctx context.Context, e // Updates eth tx from unstarted to in_progress and inserts in_progress eth attempt func (o *evmTxStore) UpdateTxUnstartedToInProgress(ctx context.Context, etx *Tx, attempt *TxAttempt) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if etx.Sequence == nil { return errors.New("in_progress transaction must have nonce") @@ -1681,7 +1679,7 @@ func (o *evmTxStore) UpdateTxUnstartedToInProgress(ctx context.Context, etx *Tx, // It may or may not have been broadcast to an eth node. func (o *evmTxStore) GetTxInProgress(ctx context.Context, fromAddress common.Address) (etx *Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() etx = new(Tx) if err != nil { @@ -1712,7 +1710,7 @@ func (o *evmTxStore) GetTxInProgress(ctx context.Context, fromAddress common.Add func (o *evmTxStore) HasInProgressTransaction(ctx context.Context, account common.Address, chainID *big.Int) (exists bool, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.q.GetContext(ctx, &exists, `SELECT EXISTS(SELECT 1 FROM evm.txes WHERE state = 'in_progress' AND from_address = $1 AND evm_chain_id = $2)`, account, chainID.String()) return exists, pkgerrors.Wrap(err, "hasInProgressTransaction failed") @@ -1720,7 +1718,7 @@ func (o *evmTxStore) HasInProgressTransaction(ctx context.Context, account commo func (o *evmTxStore) countTransactionsWithState(ctx context.Context, fromAddress common.Address, state txmgrtypes.TxState, chainID *big.Int) (count uint32, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.q.GetContext(ctx, &count, `SELECT count(*) FROM evm.txes WHERE from_address = $1 AND state = $2 AND evm_chain_id = $3`, fromAddress, state, chainID.String()) @@ -1735,7 +1733,7 @@ func (o *evmTxStore) CountUnconfirmedTransactions(ctx context.Context, fromAddre // CountTransactionsByState returns the number of transactions with any fromAddress in the given state func (o *evmTxStore) CountTransactionsByState(ctx context.Context, state txmgrtypes.TxState, chainID *big.Int) (count uint32, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.q.GetContext(ctx, &count, `SELECT count(*) FROM evm.txes WHERE state = $1 AND evm_chain_id = $2`, state, chainID.String()) @@ -1752,7 +1750,7 @@ func (o *evmTxStore) CountUnstartedTransactions(ctx context.Context, fromAddress func (o *evmTxStore) CheckTxQueueCapacity(ctx context.Context, fromAddress common.Address, maxQueuedTransactions uint64, chainID *big.Int) (err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() if maxQueuedTransactions == 0 { return nil @@ -1772,7 +1770,7 @@ func (o *evmTxStore) CheckTxQueueCapacity(ctx context.Context, fromAddress commo func (o *evmTxStore) CreateTransaction(ctx context.Context, txRequest TxRequest, chainID *big.Int) (tx Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtx DbEthTx err = o.Transact(ctx, false, func(orm *evmTxStore) error { @@ -1806,7 +1804,7 @@ RETURNING "txes".* func (o *evmTxStore) PruneUnstartedTxQueue(ctx context.Context, queueSize uint32, subject uuid.UUID) (ids []int64, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, false, func(orm *evmTxStore) error { err := orm.q.SelectContext(ctx, &ids, ` @@ -1834,7 +1832,7 @@ id < ( func (o *evmTxStore) ReapTxHistory(ctx context.Context, minBlockNumberToKeep int64, timeThreshold time.Time, chainID *big.Int) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() // Delete old confirmed evm.txes @@ -1893,7 +1891,7 @@ AND evm_chain_id = $2`, timeThreshold, chainID.String()) func (o *evmTxStore) Abandon(ctx context.Context, chainID *big.Int, addr common.Address) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() _, err := o.q.ExecContext(ctx, `UPDATE evm.txes SET state='fatal_error', nonce = NULL, error = 'abandoned' WHERE state IN ('unconfirmed', 'in_progress', 'unstarted') AND evm_chain_id = $1 AND from_address = $2`, chainID.String(), addr) return err @@ -1902,7 +1900,7 @@ func (o *evmTxStore) Abandon(ctx context.Context, chainID *big.Int, addr common. // Find transactions by a field in the TxMeta blob and transaction states func (o *evmTxStore) FindTxesByMetaFieldAndStates(ctx context.Context, metaField string, metaValue string, states []txmgrtypes.TxState, chainID *big.Int) ([]*Tx, error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtxs []DbEthTx sql := fmt.Sprintf("SELECT * FROM evm.txes WHERE evm_chain_id = $1 AND meta->>'%s' = $2 AND state = ANY($3)", metaField) @@ -1915,7 +1913,7 @@ func (o *evmTxStore) FindTxesByMetaFieldAndStates(ctx context.Context, metaField // Find transactions with a non-null TxMeta field that was provided by transaction states func (o *evmTxStore) FindTxesWithMetaFieldByStates(ctx context.Context, metaField string, states []txmgrtypes.TxState, chainID *big.Int) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtxs []DbEthTx sql := fmt.Sprintf("SELECT * FROM evm.txes WHERE meta->'%s' IS NOT NULL AND state = ANY($1) AND evm_chain_id = $2", metaField) @@ -1928,7 +1926,7 @@ func (o *evmTxStore) FindTxesWithMetaFieldByStates(ctx context.Context, metaFiel // Find transactions with a non-null TxMeta field that was provided and a receipt block number greater than or equal to the one provided func (o *evmTxStore) FindTxesWithMetaFieldByReceiptBlockNum(ctx context.Context, metaField string, blockNum int64, chainID *big.Int) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtxs []DbEthTx sql := fmt.Sprintf("SELECT et.* FROM evm.txes et JOIN evm.tx_attempts eta on et.id = eta.eth_tx_id JOIN evm.receipts er on eta.hash = er.tx_hash WHERE et.meta->'%s' IS NOT NULL AND er.block_number >= $1 AND et.evm_chain_id = $2", metaField) @@ -1941,7 +1939,7 @@ func (o *evmTxStore) FindTxesWithMetaFieldByReceiptBlockNum(ctx context.Context, // Find transactions loaded with transaction attempts and receipts by transaction IDs and states func (o *evmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []txmgrtypes.TxState, chainID *big.Int) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtxs []DbEthTx @@ -1964,7 +1962,7 @@ func (o *evmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Co // For testing only, get all txes in the DB func (o *evmTxStore) GetAllTxes(ctx context.Context) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbEtxs []DbEthTx sql := "SELECT * FROM evm.txes" @@ -1977,7 +1975,7 @@ func (o *evmTxStore) GetAllTxes(ctx context.Context) (txes []*Tx, err error) { // For testing only, get all tx attempts in the DB func (o *evmTxStore) GetAllTxAttempts(ctx context.Context) (attempts []TxAttempt, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() var dbAttempts []DbEthTxAttempt sql := "SELECT * FROM evm.tx_attempts" @@ -1988,7 +1986,7 @@ func (o *evmTxStore) GetAllTxAttempts(ctx context.Context) (attempts []TxAttempt func (o *evmTxStore) CountTxesByStateAndSubject(ctx context.Context, state txmgrtypes.TxState, subject uuid.UUID) (count int, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() sql := "SELECT COUNT(*) FROM evm.txes WHERE state = $1 AND subject = $2" err = o.q.GetContext(ctx, &count, sql, state, subject) @@ -1997,7 +1995,7 @@ func (o *evmTxStore) CountTxesByStateAndSubject(ctx context.Context, state txmgr func (o *evmTxStore) FindTxesByFromAddressAndState(ctx context.Context, fromAddress common.Address, state string) (txes []*Tx, err error) { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() sql := "SELECT * FROM evm.txes WHERE from_address = $1 AND state = $2" var dbEtxs []DbEthTx @@ -2009,23 +2007,9 @@ func (o *evmTxStore) FindTxesByFromAddressAndState(ctx context.Context, fromAddr func (o *evmTxStore) UpdateTxAttemptBroadcastBeforeBlockNum(ctx context.Context, id int64, blockNum uint) error { var cancel context.CancelFunc - ctx, cancel = o.mergeContexts(ctx) + ctx, cancel = o.stopCh.Ctx(ctx) defer cancel() sql := "UPDATE evm.tx_attempts SET broadcast_before_block_num = $1 WHERE eth_tx_id = $2" _, err := o.q.ExecContext(ctx, sql, blockNum, id) return err } - -// Returns a context that contains the values of the provided context, -// and which is canceled when either the provided context or TxStore parent context is canceled. -func (o *evmTxStore) mergeContexts(ctx context.Context) (context.Context, context.CancelFunc) { - var cancel context.CancelCauseFunc - ctx, cancel = context.WithCancelCause(ctx) - stop := context.AfterFunc(o.ctx, func() { - cancel(context.Cause(o.ctx)) - }) - return ctx, func() { - stop() - cancel(context.Canceled) - } -} diff --git a/core/services/blockhashstore/delegate.go b/core/services/blockhashstore/delegate.go index 2181084aeec..172dbafc4a4 100644 --- a/core/services/blockhashstore/delegate.go +++ b/core/services/blockhashstore/delegate.go @@ -215,29 +215,32 @@ type service struct { pollPeriod time.Duration runTimeout time.Duration logger logger.Logger - parentCtx context.Context - cancel context.CancelFunc + stopCh services.StopChan } // Start the BHS feeder service, satisfying the job.Service interface. func (s *service) Start(context.Context) error { return s.StartOnce("BHS Feeder Service", func() error { s.logger.Infow("Starting BHS feeder") - ticker := time.NewTicker(utils.WithJitter(s.pollPeriod)) - s.parentCtx, s.cancel = context.WithCancel(context.Background()) + s.stopCh = make(chan struct{}) s.wg.Add(2) go func() { defer s.wg.Done() - s.feeder.StartHeartbeats(s.parentCtx, &realTimer{}) + ctx, cancel := s.stopCh.NewCtx() + defer cancel() + s.feeder.StartHeartbeats(ctx, &realTimer{}) }() go func() { defer s.wg.Done() + ctx, cancel := s.stopCh.NewCtx() + defer cancel() + ticker := time.NewTicker(utils.WithJitter(s.pollPeriod)) defer ticker.Stop() for { select { case <-ticker.C: - s.runFeeder() - case <-s.parentCtx.Done(): + s.runFeeder(ctx) + case <-ctx.Done(): return } } @@ -250,15 +253,15 @@ func (s *service) Start(context.Context) error { func (s *service) Close() error { return s.StopOnce("BHS Feeder Service", func() error { s.logger.Infow("Stopping BHS feeder") - s.cancel() + close(s.stopCh) s.wg.Wait() return nil }) } -func (s *service) runFeeder() { +func (s *service) runFeeder(ctx context.Context) { s.logger.Debugw("Running BHS feeder") - ctx, cancel := context.WithTimeout(s.parentCtx, s.runTimeout) + ctx, cancel := context.WithTimeout(ctx, s.runTimeout) defer cancel() err := s.feeder.Run(ctx) if err == nil { diff --git a/core/services/blockheaderfeeder/delegate.go b/core/services/blockheaderfeeder/delegate.go index 36d1d1cf895..830c2e23377 100644 --- a/core/services/blockheaderfeeder/delegate.go +++ b/core/services/blockheaderfeeder/delegate.go @@ -229,24 +229,25 @@ type service struct { pollPeriod time.Duration runTimeout time.Duration logger logger.Logger - parentCtx context.Context - cancel context.CancelFunc + stopCh services.StopChan } // Start the BHS feeder service, satisfying the job.Service interface. func (s *service) Start(context.Context) error { return s.StartOnce("Block Header Feeder Service", func() error { s.logger.Infow("Starting BlockHeaderFeeder") - ticker := time.NewTicker(utils.WithJitter(s.pollPeriod)) - s.parentCtx, s.cancel = context.WithCancel(context.Background()) + s.stopCh = make(chan struct{}) go func() { defer close(s.done) + ctx, cancel := s.stopCh.NewCtx() + defer cancel() + ticker := time.NewTicker(utils.WithJitter(s.pollPeriod)) defer ticker.Stop() for { select { case <-ticker.C: - s.runFeeder() - case <-s.parentCtx.Done(): + s.runFeeder(ctx) + case <-ctx.Done(): return } } @@ -259,15 +260,15 @@ func (s *service) Start(context.Context) error { func (s *service) Close() error { return s.StopOnce("Block Header Feeder Service", func() error { s.logger.Infow("Stopping BlockHeaderFeeder") - s.cancel() + close(s.stopCh) <-s.done return nil }) } -func (s *service) runFeeder() { +func (s *service) runFeeder(ctx context.Context) { s.logger.Debugw("Running BlockHeaderFeeder") - ctx, cancel := context.WithTimeout(s.parentCtx, s.runTimeout) + ctx, cancel := context.WithTimeout(ctx, s.runTimeout) defer cancel() err := s.feeder.Run(ctx) if err == nil { diff --git a/core/services/feeds/connection_manager.go b/core/services/feeds/connection_manager.go index 6339339ab7c..d388bc0899f 100644 --- a/core/services/feeds/connection_manager.go +++ b/core/services/feeds/connection_manager.go @@ -1,7 +1,6 @@ package feeds import ( - "context" "crypto/ed25519" "sync" @@ -10,6 +9,7 @@ import ( "github.com/smartcontractkit/wsrpc" "github.com/smartcontractkit/wsrpc/connectivity" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/recovery" pb "github.com/smartcontractkit/chainlink/v2/core/services/feeds/proto" @@ -35,10 +35,7 @@ type connectionsManager struct { } type connection struct { - // ctx allows us to cancel any connections which are currently blocking - // while waiting to establish a connection to FMS. - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan connected bool client pb.FeedsManagerClient @@ -81,11 +78,8 @@ type ConnectOpts struct { // Eventually when FMS does come back up, wsrpc will establish the connection // without any interaction on behalf of the node operator. func (mgr *connectionsManager) Connect(opts ConnectOpts) { - ctx, cancel := context.WithCancel(context.Background()) - conn := &connection{ - ctx: ctx, - cancel: cancel, + stopCh: make(chan struct{}), connected: false, } @@ -96,11 +90,13 @@ func (mgr *connectionsManager) Connect(opts ConnectOpts) { mgr.mu.Unlock() go recovery.WrapRecover(mgr.lggr, func() { + ctx, cancel := conn.stopCh.NewCtx() + defer cancel() defer mgr.wgClosed.Done() mgr.lggr.Infow("Connecting to Feeds Manager...", "feedsManagerID", opts.FeedsManagerID) - clientConn, err := wsrpc.DialWithContext(conn.ctx, opts.URI, + clientConn, err := wsrpc.DialWithContext(ctx, opts.URI, wsrpc.WithTransportCreds(opts.Privkey, ed25519.PublicKey(opts.Pubkey)), wsrpc.WithBlock(), wsrpc.WithLogger(mgr.lggr), @@ -108,7 +104,7 @@ func (mgr *connectionsManager) Connect(opts ConnectOpts) { if err != nil { // We only want to log if there was an error that did not occur // from a context cancel. - if conn.ctx.Err() == nil { + if ctx.Err() == nil { mgr.lggr.Warnf("Error connecting to Feeds Manager server: %v", err) } else { mgr.lggr.Infof("Closing wsrpc websocket connection: %v", err) @@ -139,7 +135,7 @@ func (mgr *connectionsManager) Connect(opts ConnectOpts) { for { s := clientConn.GetState() - clientConn.WaitForStateChange(conn.ctx, s) + clientConn.WaitForStateChange(ctx, s) s = clientConn.GetState() @@ -155,7 +151,7 @@ func (mgr *connectionsManager) Connect(opts ConnectOpts) { }() // Wait for close - <-conn.ctx.Done() + <-ctx.Done() }) } @@ -169,7 +165,7 @@ func (mgr *connectionsManager) Disconnect(id int64) error { return errors.New("feeds manager is not connected") } - conn.cancel() + close(conn.stopCh) delete(mgr.connections, id) mgr.lggr.Infow("Disconnected Feeds Manager", "feedsManagerID", id) @@ -181,7 +177,7 @@ func (mgr *connectionsManager) Disconnect(id int64) error { func (mgr *connectionsManager) Close() { mgr.mu.Lock() for _, conn := range mgr.connections { - conn.cancel() + close(conn.stopCh) } mgr.mu.Unlock() diff --git a/core/services/functions/listener.go b/core/services/functions/listener.go index d2033ff74de..bfcbf10f692 100644 --- a/core/services/functions/listener.go +++ b/core/services/functions/listener.go @@ -130,9 +130,7 @@ type functionsListener struct { job job.Job bridgeAccessor BridgeAccessor shutdownWaitGroup sync.WaitGroup - serviceContext context.Context - serviceCancel context.CancelFunc - chStop chan struct{} + chStop services.StopChan pluginORM ORM pluginConfig config.PluginConfig s4Storage s4.Storage @@ -186,12 +184,10 @@ func NewFunctionsListener( // Start complies with job.Service func (l *functionsListener) Start(context.Context) error { return l.StartOnce("FunctionsListener", func() error { - l.serviceContext, l.serviceCancel = context.WithCancel(context.Background()) - switch l.pluginConfig.ContractVersion { case 1: l.shutdownWaitGroup.Add(1) - go l.processOracleEventsV1(l.serviceContext) + go l.processOracleEventsV1() default: return fmt.Errorf("unsupported contract version: %d", l.pluginConfig.ContractVersion) } @@ -213,15 +209,16 @@ func (l *functionsListener) Start(context.Context) error { // Close complies with job.Service func (l *functionsListener) Close() error { return l.StopOnce("FunctionsListener", func() error { - l.serviceCancel() close(l.chStop) l.shutdownWaitGroup.Wait() return nil }) } -func (l *functionsListener) processOracleEventsV1(ctx context.Context) { +func (l *functionsListener) processOracleEventsV1() { defer l.shutdownWaitGroup.Done() + ctx, cancel := l.chStop.NewCtx() + defer cancel() freqMillis := l.pluginConfig.ListenerEventsCheckFrequencyMillis if freqMillis == 0 { l.logger.Errorw("ListenerEventsCheckFrequencyMillis must set to more than 0 in PluginConfig") @@ -255,11 +252,17 @@ func (l *functionsListener) processOracleEventsV1(ctx context.Context) { } func (l *functionsListener) getNewHandlerContext() (context.Context, context.CancelFunc) { + ctx, cancel := l.chStop.NewCtx() timeoutSec := l.pluginConfig.ListenerEventHandlerTimeoutSec if timeoutSec == 0 { - return context.WithCancel(l.serviceContext) + return ctx, cancel + } + var cancel2 func() + ctx, cancel2 = context.WithTimeout(ctx, time.Duration(timeoutSec)*time.Second) + return ctx, func() { + cancel2() + cancel() } - return context.WithTimeout(l.serviceContext, time.Duration(timeoutSec)*time.Second) } func (l *functionsListener) setError(ctx context.Context, requestId RequestID, errType ErrType, errBytes []byte) { diff --git a/core/services/keeper/delegate.go b/core/services/keeper/delegate.go index 71a0c5c43a9..c9d189b30c5 100644 --- a/core/services/keeper/delegate.go +++ b/core/services/keeper/delegate.go @@ -93,7 +93,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services // In the case of forwarding, the keeper address is the forwarder contract deployed onchain between EOA and Registry. effectiveKeeperAddress := spec.KeeperSpec.FromAddress.Address() if spec.ForwardingAllowed { - fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(spec.KeeperSpec.FromAddress.Address()) + fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(ctx, spec.KeeperSpec.FromAddress.Address()) if fwderr == nil { effectiveKeeperAddress = fwdrAddress } else { diff --git a/core/services/keeper/integration_test.go b/core/services/keeper/integration_test.go index 9e4cf5f9041..cbbe89b3f21 100644 --- a/core/services/keeper/integration_test.go +++ b/core/services/keeper/integration_test.go @@ -417,7 +417,7 @@ func TestKeeperForwarderEthIntegration(t *testing.T) { _, err = forwarderORM.CreateForwarder(ctx, fwdrAddress, chainID) require.NoError(t, err) - addr, err := app.GetRelayers().LegacyEVMChains().Slice()[0].TxManager().GetForwarderForEOA(nodeAddress) + addr, err := app.GetRelayers().LegacyEVMChains().Slice()[0].TxManager().GetForwarderForEOA(ctx, nodeAddress) require.NoError(t, err) require.Equal(t, addr, fwdrAddress) diff --git a/core/services/ocr/config_overrider.go b/core/services/ocr/config_overrider.go index 435efa437c7..067c06f58ce 100644 --- a/core/services/ocr/config_overrider.go +++ b/core/services/ocr/config_overrider.go @@ -31,9 +31,8 @@ type ConfigOverriderImpl struct { DeltaCFromAddress time.Duration // Start/Stop lifecycle - ctx context.Context - ctxCancel context.CancelFunc - chDone chan struct{} + chStop services.StopChan + chDone chan struct{} mu sync.RWMutex } @@ -63,7 +62,6 @@ func NewConfigOverriderImpl( addressSeconds := addressBig.Mod(addressBig, big.NewInt(jitterSeconds)).Uint64() deltaC := cfg.DeltaCOverride() + time.Duration(addressSeconds)*time.Second - ctx, cancel := context.WithCancel(context.Background()) co := ConfigOverriderImpl{ services.StateMachine{}, logger, @@ -73,8 +71,7 @@ func NewConfigOverriderImpl( time.Now(), InitialHibernationStatus, deltaC, - ctx, - cancel, + make(chan struct{}), make(chan struct{}), sync.RWMutex{}, } @@ -96,7 +93,7 @@ func (c *ConfigOverriderImpl) Start(context.Context) error { func (c *ConfigOverriderImpl) Close() error { return c.StopOnce("OCRContractTracker", func() error { - c.ctxCancel() + close(c.chStop) <-c.chDone return nil }) @@ -104,11 +101,13 @@ func (c *ConfigOverriderImpl) Close() error { func (c *ConfigOverriderImpl) eventLoop() { defer close(c.chDone) + ctx, cancel := c.chStop.NewCtx() + defer cancel() c.pollTicker.Resume() defer c.pollTicker.Destroy() for { select { - case <-c.ctx.Done(): + case <-ctx.Done(): return case <-c.pollTicker.Ticks(): if err := c.updateFlagsStatus(); err != nil { diff --git a/core/services/ocr/delegate.go b/core/services/ocr/delegate.go index e748823ad71..a47e7ec9e7d 100644 --- a/core/services/ocr/delegate.go +++ b/core/services/ocr/delegate.go @@ -216,7 +216,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] // In the case of forwarding, the transmitter address is the forwarder contract deployed onchain between EOA and OCR contract. effectiveTransmitterAddress := concreteSpec.TransmitterAddress.Address() if jb.ForwardingAllowed { - fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(effectiveTransmitterAddress) + fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(ctx, effectiveTransmitterAddress) if fwderr == nil { effectiveTransmitterAddress = fwdrAddress } else { diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index 8ea43582126..350cbc8d593 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -382,7 +382,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi if err2 != nil { return nil, fmt.Errorf("ServicesForSpec: could not get EVM chain %s: %w", rid.ChainID, err2) } - effectiveTransmitterID, err2 = GetEVMEffectiveTransmitterID(&jb, chain, lggr) + effectiveTransmitterID, err2 = GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) if err2 != nil { return nil, fmt.Errorf("ServicesForSpec failed to get evm transmitterID: %w", err2) } @@ -470,7 +470,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi } } -func GetEVMEffectiveTransmitterID(jb *job.Job, chain legacyevm.Chain, lggr logger.SugaredLogger) (string, error) { +func GetEVMEffectiveTransmitterID(ctx context.Context, jb *job.Job, chain legacyevm.Chain, lggr logger.SugaredLogger) (string, error) { spec := jb.OCR2OracleSpec if spec.PluginType == types.Mercury || spec.PluginType == types.LLO { return spec.TransmitterID.String, nil @@ -501,9 +501,9 @@ func GetEVMEffectiveTransmitterID(jb *job.Job, chain legacyevm.Chain, lggr logge var effectiveTransmitterID common.Address // Median forwarders need special handling because of OCR2Aggregator transmitters whitelist. if spec.PluginType == types.Median { - effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOAOCR2Feeds(common.HexToAddress(spec.TransmitterID.String), common.HexToAddress(spec.ContractID)) + effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOAOCR2Feeds(ctx, common.HexToAddress(spec.TransmitterID.String), common.HexToAddress(spec.ContractID)) } else { - effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOA(common.HexToAddress(spec.TransmitterID.String)) + effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOA(ctx, common.HexToAddress(spec.TransmitterID.String)) } if err == nil { diff --git a/core/services/ocr2/delegate_test.go b/core/services/ocr2/delegate_test.go index bc5c2df2bbe..1e4be66c7d1 100644 --- a/core/services/ocr2/delegate_test.go +++ b/core/services/ocr2/delegate_test.go @@ -5,10 +5,12 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink-common/pkg/types" + evmcfg "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" txmmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr/mocks" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" @@ -27,7 +29,6 @@ import ( ) func TestGetEVMEffectiveTransmitterID(t *testing.T) { - ctx := testutils.Context(t) customChainID := big.New(testutils.NewRandomEVMChainID()) config := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -41,7 +42,7 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { }) db := pgtest.NewSqlxDB(t) keyStore := cltest.NewKeyStore(t, db) - require.NoError(t, keyStore.OCR2().Add(ctx, cltest.DefaultOCR2Key)) + require.NoError(t, keyStore.OCR2().Add(testutils.Context(t), cltest.DefaultOCR2Key)) lggr := logger.TestLogger(t) txManager := txmmocks.NewMockEvmTxManager(t) @@ -67,7 +68,7 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = tc.sendingKeys jb.ForwardingAllowed = tc.forwardingEnabled - args := []interface{}{tc.getForwarderForEOAArg} + args := []interface{}{mock.Anything, tc.getForwarderForEOAArg} getForwarderMethodName := "GetForwarderForEOA" if tc.pluginType == types.Median { getForwarderMethodName = "GetForwarderForEOAOCR2Feeds" @@ -144,13 +145,14 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { } t.Run("when sending keys are not defined, the first one should be set to transmitterID", func(t *testing.T) { + ctx := testutils.Context(t) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) jb.OCR2OracleSpec.TransmitterID = null.StringFrom("some transmitterID string") jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = nil chain, err := legacyChains.Get(customChainID.String()) require.NoError(t, err) - effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(&jb, chain, lggr) + effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) require.NoError(t, err) require.Equal(t, "some transmitterID string", effectiveTransmitterID) require.Equal(t, []string{"some transmitterID string"}, jb.OCR2OracleSpec.RelayConfig["sendingKeys"].([]string)) @@ -158,13 +160,14 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + ctx := testutils.Context(t) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) setTestCase(&jb, tc, txManager) chain, err := legacyChains.Get(customChainID.String()) require.NoError(t, err) - effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(&jb, chain, lggr) + effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) if tc.expectedError { require.Error(t, err) } else { @@ -180,13 +183,14 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { } t.Run("when forwarders are enabled and chain retrieval fails, error should be handled", func(t *testing.T) { + ctx := testutils.Context(t) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) jb.ForwardingAllowed = true jb.OCR2OracleSpec.TransmitterID = null.StringFrom("0x7e57000000000000000000000000000000000001") chain, err := legacyChains.Get("not an id") require.Error(t, err) - _, err = ocr2.GetEVMEffectiveTransmitterID(&jb, chain, lggr) + _, err = ocr2.GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) require.Error(t, err) }) } diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry.go index da4dd17d96f..56c200f9b13 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry.go @@ -102,10 +102,12 @@ func NewEVMRegistryService(addr common.Address, client legacyevm.Chain, lggr log enc: EVMAutomationEncoder20{}, } - r.ctx, r.cancel = context.WithCancel(context.Background()) + r.stopCh = make(chan struct{}) r.reInit = time.NewTimer(reInitializationDelay) - if err := r.registerEvents(client.ID().Uint64(), addr); err != nil { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() + if err := r.registerEvents(ctx, client.ID().Uint64(), addr); err != nil { return nil, fmt.Errorf("logPoller error while registering automation events: %w", err) } @@ -152,8 +154,7 @@ type EvmRegistry struct { mu sync.RWMutex txHashes map[string]bool lastPollBlock int64 - ctx context.Context - cancel context.CancelFunc + stopCh services.StopChan active map[string]activeUpkeep headFunc func(ocr2keepers.BlockKey) runState int @@ -209,8 +210,10 @@ func (r *EvmRegistry) Start(_ context.Context) error { defer r.mu.Unlock() // initialize the upkeep keys; if the reInit timer returns, do it again { - go func(cx context.Context, tmr *time.Timer, lggr logger.Logger, f func() error) { - err := f() + go func(tmr *time.Timer, lggr logger.Logger, f func(context.Context) error) { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() + err := f(ctx) if err != nil { lggr.Errorf("failed to initialize upkeeps", err) } @@ -218,53 +221,57 @@ func (r *EvmRegistry) Start(_ context.Context) error { for { select { case <-tmr.C: - err = f() + err = f(ctx) if err != nil { lggr.Errorf("failed to re-initialize upkeeps", err) } tmr.Reset(reInitializationDelay) - case <-cx.Done(): + case <-ctx.Done(): return } } - }(r.ctx, r.reInit, r.lggr, r.initialize) + }(r.reInit, r.lggr, r.initialize) } // start polling logs on an interval { - go func(cx context.Context, lggr logger.Logger, f func() error) { + go func(lggr logger.Logger, f func(context.Context) error) { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() ticker := time.NewTicker(time.Second) for { select { case <-ticker.C: - err := f() + err := f(ctx) if err != nil { lggr.Errorf("failed to poll logs for upkeeps", err) } - case <-cx.Done(): + case <-ctx.Done(): ticker.Stop() return } } - }(r.ctx, r.lggr, r.pollLogs) + }(r.lggr, r.pollLogs) } // run process to process logs from log channel { - go func(cx context.Context, ch chan logpoller.Log, lggr logger.Logger, f func(logpoller.Log) error) { + go func(ch chan logpoller.Log, lggr logger.Logger, f func(context.Context, logpoller.Log) error) { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() for { select { case l := <-ch: - err := f(l) + err := f(ctx, l) if err != nil { lggr.Errorf("failed to process log for upkeep", err) } - case <-cx.Done(): + case <-ctx.Done(): return } } - }(r.ctx, r.chLog, r.lggr, r.processUpkeepStateLog) + }(r.chLog, r.lggr, r.processUpkeepStateLog) } r.runState = 1 @@ -276,7 +283,7 @@ func (r *EvmRegistry) Close() error { return r.sync.StopOnce("AutomationRegistry", func() error { r.mu.Lock() defer r.mu.Unlock() - r.cancel() + close(r.stopCh) r.runState = 0 r.runError = nil return nil @@ -303,8 +310,8 @@ func (r *EvmRegistry) HealthReport() map[string]error { return map[string]error{r.Name(): r.sync.Healthy()} } -func (r *EvmRegistry) initialize() error { - startupCtx, cancel := context.WithTimeout(r.ctx, reInitializationDelay) +func (r *EvmRegistry) initialize(ctx context.Context) error { + startupCtx, cancel := context.WithTimeout(ctx, reInitializationDelay) defer cancel() idMap := make(map[string]activeUpkeep) @@ -345,12 +352,12 @@ func (r *EvmRegistry) initialize() error { return nil } -func (r *EvmRegistry) pollLogs() error { +func (r *EvmRegistry) pollLogs(ctx context.Context) error { var latest int64 var end logpoller.LogPollerBlock var err error - if end, err = r.poller.LatestBlock(r.ctx); err != nil { + if end, err = r.poller.LatestBlock(ctx); err != nil { return fmt.Errorf("%w: %s", ErrHeadNotAvailable, err) } @@ -367,7 +374,7 @@ func (r *EvmRegistry) pollLogs() error { { var logs []logpoller.Log if logs, err = r.poller.LogsWithSigs( - r.ctx, + ctx, end.BlockNumber-logEventLookback, end.BlockNumber, upkeepStateEvents, @@ -388,17 +395,17 @@ func UpkeepFilterName(addr common.Address) string { return logpoller.FilterName("EvmRegistry - Upkeep events for", addr.String()) } -func (r *EvmRegistry) registerEvents(chainID uint64, addr common.Address) error { +func (r *EvmRegistry) registerEvents(ctx context.Context, chainID uint64, addr common.Address) error { // Add log filters for the log poller so that it can poll and find the logs that // we need - return r.poller.RegisterFilter(r.ctx, logpoller.Filter{ + return r.poller.RegisterFilter(ctx, logpoller.Filter{ Name: UpkeepFilterName(addr), EventSigs: append(upkeepStateEvents, upkeepActiveEvents...), Addresses: []common.Address{addr}, }) } -func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { +func (r *EvmRegistry) processUpkeepStateLog(ctx context.Context, l logpoller.Log) error { hash := l.TxHash.String() if _, ok := r.txHashes[hash]; ok { return nil @@ -414,22 +421,22 @@ func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { switch l := abilog.(type) { case *keeper_registry_wrapper2_0.KeeperRegistryUpkeepRegistered: r.lggr.Debugf("KeeperRegistryUpkeepRegistered log detected for upkeep ID %s in transaction %s", l.Id.String(), hash) - r.addToActive(l.Id, false) + r.addToActive(ctx, l.Id, false) case *keeper_registry_wrapper2_0.KeeperRegistryUpkeepReceived: r.lggr.Debugf("KeeperRegistryUpkeepReceived log detected for upkeep ID %s in transaction %s", l.Id.String(), hash) - r.addToActive(l.Id, false) + r.addToActive(ctx, l.Id, false) case *keeper_registry_wrapper2_0.KeeperRegistryUpkeepUnpaused: r.lggr.Debugf("KeeperRegistryUpkeepUnpaused log detected for upkeep ID %s in transaction %s", l.Id.String(), hash) - r.addToActive(l.Id, false) + r.addToActive(ctx, l.Id, false) case *keeper_registry_wrapper2_0.KeeperRegistryUpkeepGasLimitSet: r.lggr.Debugf("KeeperRegistryUpkeepGasLimitSet log detected for upkeep ID %s in transaction %s", l.Id.String(), hash) - r.addToActive(l.Id, true) + r.addToActive(ctx, l.Id, true) } return nil } -func (r *EvmRegistry) addToActive(id *big.Int, force bool) { +func (r *EvmRegistry) addToActive(ctx context.Context, id *big.Int, force bool) { r.mu.Lock() defer r.mu.Unlock() @@ -438,7 +445,7 @@ func (r *EvmRegistry) addToActive(id *big.Int, force bool) { } if _, ok := r.active[id.String()]; !ok || force { - actives, err := r.getUpkeepConfigs(r.ctx, []*big.Int{id}) + actives, err := r.getUpkeepConfigs(ctx, []*big.Int{id}) if err != nil { r.lggr.Errorf("failed to get upkeep configs during adding active upkeep: %w", err) return diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry_test.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry_test.go index 592563f0b04..8100980dd6b 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry_test.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20/registry_test.go @@ -186,6 +186,7 @@ func TestPollLogs(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { + ctx := testutils.Context(t) mp := new(mocks.LogPoller) if test.LatestBlock != nil { @@ -205,7 +206,7 @@ func TestPollLogs(t *testing.T) { chLog: make(chan logpoller.Log, 10), } - err := rg.pollLogs() + err := rg.pollLogs(ctx) assert.Equal(t, test.ExpectedLastPoll, rg.lastPollBlock) if test.ExpectedErr != nil { diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry.go index 082318518a5..6bab073b9b3 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry.go @@ -98,7 +98,7 @@ func NewEvmRegistry( hc := http.DefaultClient return &EvmRegistry{ - ctx: context.Background(), + stopCh: make(chan struct{}), threadCtrl: utils.NewThreadControl(), lggr: lggr.Named(RegistryServiceName), poller: client.LogPoller(), @@ -190,7 +190,7 @@ type EvmRegistry struct { logProcessed map[string]bool active ActiveUpkeepList lastPollBlock int64 - ctx context.Context + stopCh services.StopChan headFunc func(ocr2keepers.BlockKey) mercury *MercuryConfig hc HttpClient @@ -207,13 +207,13 @@ func (r *EvmRegistry) Name() string { func (r *EvmRegistry) Start(ctx context.Context) error { return r.StartOnce(RegistryServiceName, func() error { - if err := r.registerEvents(r.chainID, r.addr); err != nil { + if err := r.registerEvents(ctx, r.chainID, r.addr); err != nil { return fmt.Errorf("logPoller error while registering automation events: %w", err) } r.threadCtrl.Go(func(ctx context.Context) { lggr := r.lggr.With("where", "upkeeps_referesh") - err := r.refreshActiveUpkeeps() + err := r.refreshActiveUpkeeps(ctx) if err != nil { lggr.Errorf("failed to initialize upkeeps", err) } @@ -224,7 +224,7 @@ func (r *EvmRegistry) Start(ctx context.Context) error { for { select { case <-ticker.C: - err = r.refreshActiveUpkeeps() + err = r.refreshActiveUpkeeps(ctx) if err != nil { lggr.Errorf("failed to refresh upkeeps", err) } @@ -242,7 +242,7 @@ func (r *EvmRegistry) Start(ctx context.Context) error { for { select { case <-ticker.C: - err := r.pollUpkeepStateLogs() + err := r.pollUpkeepStateLogs(ctx) if err != nil { lggr.Errorf("failed to poll logs for upkeeps", err) } @@ -259,7 +259,7 @@ func (r *EvmRegistry) Start(ctx context.Context) error { for { select { case l := <-ch: - err := r.processUpkeepStateLog(l) + err := r.processUpkeepStateLog(ctx, l) if err != nil { lggr.Errorf("failed to process log for upkeep", err) } @@ -284,9 +284,9 @@ func (r *EvmRegistry) HealthReport() map[string]error { return map[string]error{RegistryServiceName: r.Healthy()} } -func (r *EvmRegistry) refreshActiveUpkeeps() error { +func (r *EvmRegistry) refreshActiveUpkeeps(ctx context.Context) error { // Allow for max timeout of refreshInterval - ctx, cancel := context.WithTimeout(r.ctx, refreshInterval) + ctx, cancel := context.WithTimeout(ctx, refreshInterval) defer cancel() r.lggr.Debugf("Refreshing active upkeeps list") @@ -311,17 +311,17 @@ func (r *EvmRegistry) refreshActiveUpkeeps() error { } } - _, err = r.logEventProvider.RefreshActiveUpkeeps(r.ctx, logTriggerIDs...) + _, err = r.logEventProvider.RefreshActiveUpkeeps(ctx, logTriggerIDs...) if err != nil { return fmt.Errorf("failed to refresh active upkeep ids in log event provider: %w", err) } // Try to refersh log trigger config for all log upkeeps - return r.refreshLogTriggerUpkeeps(logTriggerIDs) + return r.refreshLogTriggerUpkeeps(ctx, logTriggerIDs) } // refreshLogTriggerUpkeeps refreshes the active upkeep ids for log trigger upkeeps -func (r *EvmRegistry) refreshLogTriggerUpkeeps(ids []*big.Int) error { +func (r *EvmRegistry) refreshLogTriggerUpkeeps(ctx context.Context, ids []*big.Int) error { var err error for i := 0; i < len(ids); i += logTriggerRefreshBatchSize { end := i + logTriggerRefreshBatchSize @@ -330,7 +330,7 @@ func (r *EvmRegistry) refreshLogTriggerUpkeeps(ids []*big.Int) error { } idBatch := ids[i:end] - if batchErr := r.refreshLogTriggerUpkeepsBatch(idBatch); batchErr != nil { + if batchErr := r.refreshLogTriggerUpkeepsBatch(ctx, idBatch); batchErr != nil { multierr.AppendInto(&err, batchErr) } @@ -340,17 +340,17 @@ func (r *EvmRegistry) refreshLogTriggerUpkeeps(ids []*big.Int) error { return err } -func (r *EvmRegistry) refreshLogTriggerUpkeepsBatch(logTriggerIDs []*big.Int) error { +func (r *EvmRegistry) refreshLogTriggerUpkeepsBatch(ctx context.Context, logTriggerIDs []*big.Int) error { var logTriggerHashes []common.Hash for _, id := range logTriggerIDs { logTriggerHashes = append(logTriggerHashes, common.BigToHash(id)) } - unpausedLogs, err := r.poller.IndexedLogs(r.ctx, ac.IAutomationV21PlusCommonUpkeepUnpaused{}.Topic(), r.addr, 1, logTriggerHashes, evmtypes.Confirmations(r.finalityDepth)) + unpausedLogs, err := r.poller.IndexedLogs(ctx, ac.IAutomationV21PlusCommonUpkeepUnpaused{}.Topic(), r.addr, 1, logTriggerHashes, evmtypes.Confirmations(r.finalityDepth)) if err != nil { return err } - configSetLogs, err := r.poller.IndexedLogs(r.ctx, ac.IAutomationV21PlusCommonUpkeepTriggerConfigSet{}.Topic(), r.addr, 1, logTriggerHashes, evmtypes.Confirmations(r.finalityDepth)) + configSetLogs, err := r.poller.IndexedLogs(ctx, ac.IAutomationV21PlusCommonUpkeepTriggerConfigSet{}.Topic(), r.addr, 1, logTriggerHashes, evmtypes.Confirmations(r.finalityDepth)) if err != nil { return err } @@ -400,7 +400,7 @@ func (r *EvmRegistry) refreshLogTriggerUpkeepsBatch(logTriggerIDs []*big.Int) er if unpausedBlockNumbers[id.String()] > logBlock { logBlock = unpausedBlockNumbers[id.String()] } - if err := r.updateTriggerConfig(id, config, logBlock); err != nil { + if err := r.updateTriggerConfig(ctx, id, config, logBlock); err != nil { merr = goerrors.Join(merr, fmt.Errorf("failed to update trigger config for upkeep id %s: %w", id.String(), err)) } } @@ -408,12 +408,12 @@ func (r *EvmRegistry) refreshLogTriggerUpkeepsBatch(logTriggerIDs []*big.Int) er return merr } -func (r *EvmRegistry) pollUpkeepStateLogs() error { +func (r *EvmRegistry) pollUpkeepStateLogs(ctx context.Context) error { var latest int64 var end logpoller.LogPollerBlock var err error - if end, err = r.poller.LatestBlock(r.ctx); err != nil { + if end, err = r.poller.LatestBlock(ctx); err != nil { return fmt.Errorf("%w: %s", ErrHeadNotAvailable, err) } @@ -429,7 +429,7 @@ func (r *EvmRegistry) pollUpkeepStateLogs() error { var logs []logpoller.Log if logs, err = r.poller.LogsWithSigs( - r.ctx, + ctx, end.BlockNumber-logEventLookback, end.BlockNumber, upkeepStateEvents, @@ -445,7 +445,7 @@ func (r *EvmRegistry) pollUpkeepStateLogs() error { return nil } -func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { +func (r *EvmRegistry) processUpkeepStateLog(ctx context.Context, l logpoller.Log) error { lid := fmt.Sprintf("%s%d", l.TxHash.String(), l.LogIndex) r.mu.Lock() if _, ok := r.logProcessed[lid]; ok { @@ -465,16 +465,16 @@ func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { switch l := abilog.(type) { case *ac.IAutomationV21PlusCommonUpkeepPaused: r.lggr.Debugf("KeeperRegistryUpkeepPaused log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) - r.removeFromActive(r.ctx, l.Id) + r.removeFromActive(ctx, l.Id) case *ac.IAutomationV21PlusCommonUpkeepCanceled: r.lggr.Debugf("KeeperRegistryUpkeepCanceled log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) - r.removeFromActive(r.ctx, l.Id) + r.removeFromActive(ctx, l.Id) case *ac.IAutomationV21PlusCommonUpkeepMigrated: r.lggr.Debugf("AutomationV2CommonUpkeepMigrated log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) - r.removeFromActive(r.ctx, l.Id) + r.removeFromActive(ctx, l.Id) case *ac.IAutomationV21PlusCommonUpkeepTriggerConfigSet: r.lggr.Debugf("KeeperRegistryUpkeepTriggerConfigSet log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) - if err := r.updateTriggerConfig(l.Id, l.TriggerConfig, rawLog.BlockNumber); err != nil { + if err := r.updateTriggerConfig(ctx, l.Id, l.TriggerConfig, rawLog.BlockNumber); err != nil { r.lggr.Warnf("failed to update trigger config upon AutomationV2CommonUpkeepTriggerConfigSet for upkeep ID %s: %s", l.Id.String(), err) } case *ac.IAutomationV21PlusCommonUpkeepRegistered: @@ -483,19 +483,19 @@ func (r *EvmRegistry) processUpkeepStateLog(l logpoller.Log) error { trigger := core.GetUpkeepType(*uid) r.lggr.Debugf("KeeperRegistryUpkeepRegistered log detected for upkeep ID %s (trigger=%d) in transaction %s", l.Id.String(), trigger, txHash) r.active.Add(l.Id) - if err := r.updateTriggerConfig(l.Id, nil, rawLog.BlockNumber); err != nil { + if err := r.updateTriggerConfig(ctx, l.Id, nil, rawLog.BlockNumber); err != nil { r.lggr.Warnf("failed to update trigger config upon AutomationV2CommonUpkeepRegistered for upkeep ID %s: %s", err) } case *ac.IAutomationV21PlusCommonUpkeepReceived: r.lggr.Debugf("KeeperRegistryUpkeepReceived log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) r.active.Add(l.Id) - if err := r.updateTriggerConfig(l.Id, nil, rawLog.BlockNumber); err != nil { + if err := r.updateTriggerConfig(ctx, l.Id, nil, rawLog.BlockNumber); err != nil { r.lggr.Warnf("failed to update trigger config upon AutomationV2CommonUpkeepReceived for upkeep ID %s: %s", err) } case *ac.IAutomationV21PlusCommonUpkeepUnpaused: r.lggr.Debugf("KeeperRegistryUpkeepUnpaused log detected for upkeep ID %s in transaction %s", l.Id.String(), txHash) r.active.Add(l.Id) - if err := r.updateTriggerConfig(l.Id, nil, rawLog.BlockNumber); err != nil { + if err := r.updateTriggerConfig(ctx, l.Id, nil, rawLog.BlockNumber); err != nil { r.lggr.Warnf("failed to update trigger config upon AutomationV2CommonUpkeepUnpaused for upkeep ID %s: %s", err) } default: @@ -510,9 +510,9 @@ func RegistryUpkeepFilterName(addr common.Address) string { } // registerEvents registers upkeep state events from keeper registry on log poller -func (r *EvmRegistry) registerEvents(_ uint64, addr common.Address) error { +func (r *EvmRegistry) registerEvents(ctx context.Context, _ uint64, addr common.Address) error { // Add log filters for the log poller so that it can poll and find the logs that we need - return r.poller.RegisterFilter(r.ctx, logpoller.Filter{ + return r.poller.RegisterFilter(ctx, logpoller.Filter{ Name: RegistryUpkeepFilterName(addr), EventSigs: upkeepStateEvents, Addresses: []common.Address{addr}, @@ -591,13 +591,13 @@ func (r *EvmRegistry) getLatestIDsFromContract(ctx context.Context) ([]*big.Int, } // updateTriggerConfig updates the trigger config for an upkeep. it will re-register a filter for this upkeep. -func (r *EvmRegistry) updateTriggerConfig(id *big.Int, cfg []byte, logBlock uint64) error { +func (r *EvmRegistry) updateTriggerConfig(ctx context.Context, id *big.Int, cfg []byte, logBlock uint64) error { uid := &ocr2keepers.UpkeepIdentifier{} uid.FromBigInt(id) switch core.GetUpkeepType(*uid) { case types2.LogTrigger: if len(cfg) == 0 { - fetched, err := r.fetchTriggerConfig(id) + fetched, err := r.fetchTriggerConfig(ctx, id) if err != nil { return errors.Wrap(err, "failed to fetch log upkeep config") } @@ -609,7 +609,7 @@ func (r *EvmRegistry) updateTriggerConfig(id *big.Int, cfg []byte, logBlock uint r.lggr.Warnw("failed to unpack log upkeep config", "upkeepID", id.String(), "err", err) return nil } - if err := r.logEventProvider.RegisterFilter(r.ctx, logprovider.FilterOptions{ + if err := r.logEventProvider.RegisterFilter(ctx, logprovider.FilterOptions{ TriggerConfig: logprovider.LogTriggerConfig(parsed), UpkeepID: id, UpdateBlock: logBlock, @@ -623,8 +623,8 @@ func (r *EvmRegistry) updateTriggerConfig(id *big.Int, cfg []byte, logBlock uint } // fetchTriggerConfig fetches trigger config in raw bytes for an upkeep. -func (r *EvmRegistry) fetchTriggerConfig(id *big.Int) ([]byte, error) { - opts := r.buildCallOpts(r.ctx, nil) +func (r *EvmRegistry) fetchTriggerConfig(ctx context.Context, id *big.Int) ([]byte, error) { + opts := r.buildCallOpts(ctx, nil) cfg, err := r.registry.GetUpkeepTriggerConfig(opts, id) if err != nil { r.lggr.Warnw("failed to get trigger config", "err", err) @@ -634,8 +634,8 @@ func (r *EvmRegistry) fetchTriggerConfig(id *big.Int) ([]byte, error) { } // fetchUpkeepOffchainConfig fetches upkeep offchain config in raw bytes for an upkeep. -func (r *EvmRegistry) fetchUpkeepOffchainConfig(id *big.Int) ([]byte, error) { - opts := r.buildCallOpts(r.ctx, nil) +func (r *EvmRegistry) fetchUpkeepOffchainConfig(ctx context.Context, id *big.Int) ([]byte, error) { + opts := r.buildCallOpts(ctx, nil) ui, err := r.registry.GetUpkeep(opts, id) if err != nil { return []byte{}, err diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline.go index 491099496cb..5294530140b 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline.go @@ -96,8 +96,8 @@ func (r *EvmRegistry) getBlockAndUpkeepId(upkeepID ocr2keepers.UpkeepIdentifier, return block, common.BytesToHash(trigger.BlockHash[:]), upkeepID.BigInt() } -func (r *EvmRegistry) getBlockHash(blockNumber *big.Int) (common.Hash, error) { - blocks, err := r.poller.GetBlocksRange(r.ctx, []uint64{blockNumber.Uint64()}) +func (r *EvmRegistry) getBlockHash(ctx context.Context, blockNumber *big.Int) (common.Hash, error) { + blocks, err := r.poller.GetBlocksRange(ctx, []uint64{blockNumber.Uint64()}) if err != nil { return [32]byte{}, err } @@ -109,7 +109,7 @@ func (r *EvmRegistry) getBlockHash(blockNumber *big.Int) (common.Hash, error) { } // verifyCheckBlock checks that the check block and hash are valid, returns the pipeline execution state and retryable -func (r *EvmRegistry) verifyCheckBlock(_ context.Context, checkBlock, upkeepId *big.Int, checkHash common.Hash) (state encoding.PipelineExecutionState, retryable bool) { +func (r *EvmRegistry) verifyCheckBlock(ctx context.Context, checkBlock, upkeepId *big.Int, checkHash common.Hash) (state encoding.PipelineExecutionState, retryable bool) { // verify check block number and hash are valid h, ok := r.bs.queryBlocksMap(checkBlock.Int64()) // if this block number/hash combo exists in block subscriber, this check block and hash still exist on chain and are valid @@ -119,7 +119,7 @@ func (r *EvmRegistry) verifyCheckBlock(_ context.Context, checkBlock, upkeepId * return encoding.NoPipelineError, false } r.lggr.Warnf("check block %s does not exist in block subscriber or hash does not match for upkeepId %s. this may be caused by block subscriber outdated due to re-org, querying eth client to confirm", checkBlock, upkeepId) - b, err := r.getBlockHash(checkBlock) + b, err := r.getBlockHash(ctx, checkBlock) if err != nil { r.lggr.Warnf("failed to query block %s: %s", checkBlock, err.Error()) return encoding.RpcFlakyFailure, true @@ -132,7 +132,7 @@ func (r *EvmRegistry) verifyCheckBlock(_ context.Context, checkBlock, upkeepId * } // verifyLogExists checks that the log still exists on chain, returns failure reason, pipeline error, and retryable -func (r *EvmRegistry) verifyLogExists(upkeepId *big.Int, p ocr2keepers.UpkeepPayload) (encoding.UpkeepFailureReason, encoding.PipelineExecutionState, bool) { +func (r *EvmRegistry) verifyLogExists(ctx context.Context, upkeepId *big.Int, p ocr2keepers.UpkeepPayload) (encoding.UpkeepFailureReason, encoding.PipelineExecutionState, bool) { logBlockNumber := int64(p.Trigger.LogTriggerExtension.BlockNumber) logBlockHash := common.BytesToHash(p.Trigger.LogTriggerExtension.BlockHash[:]) checkBlockHash := common.BytesToHash(p.Trigger.BlockHash[:]) @@ -158,7 +158,7 @@ func (r *EvmRegistry) verifyLogExists(upkeepId *big.Int, p ocr2keepers.UpkeepPay r.lggr.Debugf("log block not provided, querying eth client for tx hash %s for upkeepId %s", hexutil.Encode(p.Trigger.LogTriggerExtension.TxHash[:]), upkeepId) } // query eth client as a fallback - bn, bh, err := core.GetTxBlock(r.ctx, r.client, p.Trigger.LogTriggerExtension.TxHash) + bn, bh, err := core.GetTxBlock(ctx, r.client, p.Trigger.LogTriggerExtension.TxHash) if err != nil { // primitive way of checking errors if strings.Contains(err.Error(), "missing required field") || strings.Contains(err.Error(), "not found") { @@ -202,7 +202,7 @@ func (r *EvmRegistry) checkUpkeeps(ctx context.Context, payloads []ocr2keepers.U uid.FromBigInt(upkeepId) switch core.GetUpkeepType(*uid) { case types.LogTrigger: - reason, state, retryable := r.verifyLogExists(upkeepId, p) + reason, state, retryable := r.verifyLogExists(ctx, upkeepId, p) if reason != encoding.UpkeepFailureReasonNone || state != encoding.NoPipelineError { results[i] = encoding.GetIneligibleCheckResultWithoutPerformData(p, reason, state, retryable) continue @@ -306,7 +306,7 @@ func (r *EvmRegistry) simulatePerformUpkeeps(ctx context.Context, checkResults [ block, _, upkeepId := r.getBlockAndUpkeepId(cr.UpkeepID, cr.Trigger) - oc, err := r.fetchUpkeepOffchainConfig(upkeepId) + oc, err := r.fetchUpkeepOffchainConfig(ctx, upkeepId) if err != nil { // this is mostly caused by RPC flakiness r.lggr.Errorw("failed get offchain config, gas price check will be disabled", "err", err, "upkeepId", upkeepId, "block", block) diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline_test.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline_test.go index f1b9cc66ae4..6f8785fda78 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline_test.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_check_pipeline_test.go @@ -346,13 +346,13 @@ func TestRegistry_VerifyLogExists(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + ctx := testutils.Context(t) bs := &BlockSubscriber{ blocks: tc.blocks, } e := &EvmRegistry{ lggr: lggr, bs: bs, - ctx: testutils.Context(t), } if tc.makeEthCall { @@ -370,7 +370,7 @@ func TestRegistry_VerifyLogExists(t *testing.T) { e.client = client } - reason, state, retryable := e.verifyLogExists(tc.upkeepId, tc.payload) + reason, state, retryable := e.verifyLogExists(ctx, tc.upkeepId, tc.payload) assert.Equal(t, tc.reason, reason) assert.Equal(t, tc.state, state) assert.Equal(t, tc.retryable, retryable) diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_test.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_test.go index 34b3f822dbf..ab530f877ae 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_test.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/registry_test.go @@ -16,7 +16,9 @@ import ( types2 "github.com/smartcontractkit/chainlink-automation/pkg/v3/types" "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" types3 "github.com/smartcontractkit/chainlink/v2/core/chains/evm/headtracker/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" @@ -186,6 +188,7 @@ func TestPollLogs(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { + ctx := testutils.Context(t) mp := new(mocks.LogPoller) if test.LatestBlock != nil { @@ -205,7 +208,7 @@ func TestPollLogs(t *testing.T) { chLog: make(chan logpoller.Log, 10), } - err := rg.pollUpkeepStateLogs() + err := rg.pollUpkeepStateLogs(ctx) assert.Equal(t, test.ExpectedLastPoll, rg.lastPollBlock) if test.ExpectedErr != nil { @@ -542,6 +545,7 @@ func TestRegistry_refreshLogTriggerUpkeeps(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { + ctx := tests.Context(t) lggr := logger.TestLogger(t) var hb types3.HeadBroadcaster var lp logpoller.LogPoller @@ -559,7 +563,7 @@ func TestRegistry_refreshLogTriggerUpkeeps(t *testing.T) { lggr: lggr, } - err := registry.refreshLogTriggerUpkeeps(tc.ids) + err := registry.refreshLogTriggerUpkeeps(ctx, tc.ids) if tc.expectsErr { assert.Error(t, err) assert.Equal(t, err.Error(), tc.wantErr.Error()) diff --git a/core/services/ocr2/plugins/ocr2keeper/integration_test.go b/core/services/ocr2/plugins/ocr2keeper/integration_test.go index 29e56460c36..1c9c6025391 100644 --- a/core/services/ocr2/plugins/ocr2keeper/integration_test.go +++ b/core/services/ocr2/plugins/ocr2keeper/integration_test.go @@ -427,6 +427,7 @@ func setupForwarderForNode( backend *backends.SimulatedBackend, recipient common.Address, linkAddr common.Address) common.Address { + ctx := testutils.Context(t) faddr, _, authorizedForwarder, err := authorized_forwarder.DeployAuthorizedForwarder(caller, backend, linkAddr, caller.From, recipient, []byte{}) require.NoError(t, err) @@ -438,12 +439,12 @@ func setupForwarderForNode( // add forwarder address to be tracked in db forwarderORM := forwarders.NewORM(app.GetDB()) chainID := ubig.Big(*backend.Blockchain().Config().ChainID) - _, err = forwarderORM.CreateForwarder(testutils.Context(t), faddr, chainID) + _, err = forwarderORM.CreateForwarder(ctx, faddr, chainID) require.NoError(t, err) chain, err := app.GetRelayers().LegacyEVMChains().Get((*big.Int)(&chainID).String()) require.NoError(t, err) - fwdr, err := chain.TxManager().GetForwarderForEOA(recipient) + fwdr, err := chain.TxManager().GetForwarderForEOA(ctx, recipient) require.NoError(t, err) require.Equal(t, faddr, fwdr) diff --git a/core/services/pipeline/helpers_test.go b/core/services/pipeline/helpers_test.go index 97d81f56f74..7068209aa18 100644 --- a/core/services/pipeline/helpers_test.go +++ b/core/services/pipeline/helpers_test.go @@ -1,6 +1,7 @@ package pipeline import ( + "context" "net/http" "github.com/google/uuid" @@ -64,4 +65,4 @@ func (t *ETHTxTask) HelperSetDependencies(legacyChains legacyevm.LegacyChainCont t.jobType = jobType } -func (o *orm) Prune(pipelineSpecID int32) { o.prune(o.ds, pipelineSpecID) } +func (o *orm) Prune(ctx context.Context, pipelineSpecID int32) { o.prune(ctx, o.ds, pipelineSpecID) } diff --git a/core/services/pipeline/orm.go b/core/services/pipeline/orm.go index 06774e06e99..266b605ed42 100644 --- a/core/services/pipeline/orm.go +++ b/core/services/pipeline/orm.go @@ -108,22 +108,19 @@ type orm struct { lggr logger.Logger maxSuccessfulRuns uint64 // jobID => count - pm sync.Map - wg sync.WaitGroup - ctx context.Context - cncl context.CancelFunc + pm sync.Map + wg sync.WaitGroup + stopCh services.StopChan } var _ ORM = (*orm)(nil) func NewORM(ds sqlutil.DataSource, lggr logger.Logger, jobPipelineMaxSuccessfulRuns uint64) *orm { - ctx, cancel := context.WithCancel(context.Background()) return &orm{ ds: ds, lggr: lggr.Named("PipelineORM"), maxSuccessfulRuns: jobPipelineMaxSuccessfulRuns, - ctx: ctx, - cncl: cancel, + stopCh: make(chan struct{}), } } @@ -142,7 +139,7 @@ func (o *orm) Start(_ context.Context) error { func (o *orm) Close() error { return o.StopOnce("PipelineORM", func() error { - o.cncl() + close(o.stopCh) o.wg.Wait() return nil }) @@ -177,13 +174,11 @@ func (o *orm) DataSource() sqlutil.DataSource { return o.ds } func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { return o.withDataSource(ds) } func (o *orm) withDataSource(ds sqlutil.DataSource) *orm { - ctx, cancel := context.WithCancel(context.Background()) return &orm{ ds: ds, lggr: o.lggr, maxSuccessfulRuns: o.maxSuccessfulRuns, - ctx: ctx, - cncl: cancel, + stopCh: make(chan struct{}), } } @@ -231,7 +226,7 @@ func (o *orm) CreateRun(ctx context.Context, run *Run) (err error) { // InsertRun inserts a run into the database func (o *orm) InsertRun(ctx context.Context, run *Run) error { if run.Status() == RunStatusCompleted { - defer o.prune(o.ds, run.PruningKey) + defer o.prune(ctx, o.ds, run.PruningKey) } query, args, err := o.ds.BindNamed(`INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) VALUES (:pipeline_spec_id, :pruning_key, :meta, :all_errors, :fatal_errors, :inputs, :outputs, :created_at, :finished_at, :state) @@ -287,7 +282,7 @@ func (o *orm) StoreRun(ctx context.Context, run *Run) (restart bool, err error) return fmt.Errorf("failed to update pipeline run %d to %s: %w", run.ID, run.State, err) } } else { - defer o.prune(tx.ds, run.PruningKey) + defer o.prune(ctx, tx.ds, run.PruningKey) // Simply finish the run, no need to do any sort of locking if run.Outputs.Val == nil || len(run.FatalErrors)+len(run.AllErrors) == 0 { return fmt.Errorf("run must have both Outputs and Errors, got Outputs: %#v, FatalErrors: %#v, AllErrors: %#v", run.Outputs.Val, run.FatalErrors, run.AllErrors) @@ -401,7 +396,7 @@ RETURNING id defer func() { for pruningKey := range pruningKeysm { - o.prune(tx.ds, pruningKey) + o.prune(ctx, tx.ds, pruningKey) } }() @@ -510,7 +505,7 @@ func (o *orm) insertFinishedRun(ctx context.Context, run *Run, saveSuccessfulTas return nil } - defer o.prune(o.ds, run.PruningKey) + defer o.prune(ctx, o.ds, run.PruningKey) sql = ` INSERT INTO pipeline_task_runs (pipeline_run_id, id, type, index, output, error, dot_id, created_at, finished_at) VALUES (:pipeline_run_id, :id, :type, :index, :output, :error, :dot_id, :created_at, :finished_at);` @@ -709,13 +704,13 @@ const syncLimit = 1000 // // Note this does not guarantee the pipeline_runs table is kept to exactly the // max length, rather that it doesn't excessively larger than it. -func (o *orm) prune(tx sqlutil.DataSource, jobID int32) { +func (o *orm) prune(ctx context.Context, tx sqlutil.DataSource, jobID int32) { if jobID == 0 { o.lggr.Panic("expected a non-zero job ID") } // For small maxSuccessfulRuns its fast enough to prune every time if o.maxSuccessfulRuns < syncLimit { - o.withDataSource(tx).execPrune(o.ctx, jobID) + o.withDataSource(tx).execPrune(ctx, jobID) return } // for large maxSuccessfulRuns we do it async on a sampled basis @@ -725,15 +720,15 @@ func (o *orm) prune(tx sqlutil.DataSource, jobID int32) { if val%every == 0 { ok := o.IfStarted(func() { o.wg.Add(1) - go func() { + go func(ctx context.Context) { o.lggr.Debugw("Pruning runs", "jobID", jobID, "count", val, "every", every, "maxSuccessfulRuns", o.maxSuccessfulRuns) defer o.wg.Done() - ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(o.ctx), time.Minute) + ctx, cancel := o.stopCh.CtxCancel(context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute)) defer cancel() // Must not use tx here since it could be stale by the time we execute async. o.execPrune(ctx, jobID) - }() + }(context.WithoutCancel(ctx)) // don't propagate cancellation }) if !ok { o.lggr.Warnw("Cannot prune: ORM is not running", "jobID", jobID) @@ -743,7 +738,7 @@ func (o *orm) prune(tx sqlutil.DataSource, jobID int32) { } func (o *orm) execPrune(ctx context.Context, jobID int32) { - res, err := o.ds.ExecContext(o.ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( + res, err := o.ds.ExecContext(ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( SELECT id FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 ORDER BY id DESC diff --git a/core/services/pipeline/orm_test.go b/core/services/pipeline/orm_test.go index 8c99635c8d1..877aa9e4aa5 100644 --- a/core/services/pipeline/orm_test.go +++ b/core/services/pipeline/orm_test.go @@ -16,6 +16,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/hex" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" @@ -862,7 +863,8 @@ func Test_Prune(t *testing.T) { jobID := ps1.ID t.Run("when there are no runs to prune, does nothing", func(t *testing.T) { - porm.Prune(jobID) + ctx := tests.Context(t) + porm.Prune(ctx, jobID) // no error logs; it did nothing assert.Empty(t, observed.All()) @@ -898,7 +900,7 @@ func Test_Prune(t *testing.T) { cltest.MustInsertPipelineRunWithStatus(t, db, ps2.ID, pipeline.RunStatusSuspended, jobID2) } - porm.Prune(jobID2) + porm.Prune(tests.Context(t), jobID2) cnt := pgtest.MustCount(t, db, "SELECT count(*) FROM pipeline_runs WHERE pipeline_spec_id = $1 AND state = $2", ps1.ID, pipeline.RunStatusCompleted) assert.Equal(t, cnt, 20) diff --git a/core/services/pipeline/task.eth_tx.go b/core/services/pipeline/task.eth_tx.go index 354651acbb4..964591cacd2 100644 --- a/core/services/pipeline/task.eth_tx.go +++ b/core/services/pipeline/task.eth_tx.go @@ -140,7 +140,7 @@ func (t *ETHTxTask) Run(ctx context.Context, lggr logger.Logger, vars Vars, inpu var forwarderAddress common.Address if t.forwardingAllowed { var fwderr error - forwarderAddress, fwderr = chain.TxManager().GetForwarderForEOA(fromAddr) + forwarderAddress, fwderr = chain.TxManager().GetForwarderForEOA(ctx, fromAddr) if fwderr != nil { lggr.Warnw("Skipping forwarding for job, will fallback to default behavior", "err", fwderr) } diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index 97241044206..5a0ccffaf71 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -439,8 +439,7 @@ type configWatcher struct { chain legacyevm.Chain runReplay bool fromBlock uint64 - replayCtx context.Context - replayCancel context.CancelFunc + stopCh services.StopChan wg sync.WaitGroup } @@ -452,7 +451,6 @@ func newConfigWatcher(lggr logger.Logger, fromBlock uint64, runReplay bool, ) *configWatcher { - replayCtx, replayCancel := context.WithCancel(context.Background()) return &configWatcher{ lggr: lggr.Named("ConfigWatcher").Named(contractAddress.String()), contractAddress: contractAddress, @@ -461,8 +459,7 @@ func newConfigWatcher(lggr logger.Logger, chain: chain, runReplay: runReplay, fromBlock: fromBlock, - replayCtx: replayCtx, - replayCancel: replayCancel, + stopCh: make(chan struct{}), } } @@ -477,8 +474,10 @@ func (c *configWatcher) Start(ctx context.Context) error { c.wg.Add(1) go func() { defer c.wg.Done() + ctx, cancel := c.stopCh.NewCtx() + defer cancel() c.lggr.Infow("starting replay for config", "fromBlock", c.fromBlock) - if err := c.configPoller.Replay(c.replayCtx, int64(c.fromBlock)); err != nil { + if err := c.configPoller.Replay(ctx, int64(c.fromBlock)); err != nil { c.lggr.Errorf("error replaying for config", "err", err) } else { c.lggr.Infow("completed replaying for config", "fromBlock", c.fromBlock) @@ -492,7 +491,7 @@ func (c *configWatcher) Start(ctx context.Context) error { func (c *configWatcher) Close() error { return c.StopOnce(fmt.Sprintf("configWatcher %x", c.contractAddress), func() error { - c.replayCancel() + close(c.stopCh) c.wg.Wait() return c.configPoller.Close() }) diff --git a/core/services/relay/evm/request_round_tracker.go b/core/services/relay/evm/request_round_tracker.go index fe6b6826eb2..7cf13775693 100644 --- a/core/services/relay/evm/request_round_tracker.go +++ b/core/services/relay/evm/request_round_tracker.go @@ -37,8 +37,7 @@ type RequestRoundTracker struct { blockTranslator ocrcommon.BlockTranslator // Start/Stop lifecycle - ctx context.Context - ctxCancel context.CancelFunc + stopCh services.StopChan unsubscribeLogs func() // LatestRoundRequested @@ -58,7 +57,6 @@ func NewRequestRoundTracker( odb RequestRoundDB, chain ocrcommon.Config, ) (o *RequestRoundTracker) { - ctx, cancel := context.WithCancel(context.Background()) return &RequestRoundTracker{ ethClient: ethClient, contract: contract, @@ -69,8 +67,7 @@ func NewRequestRoundTracker( odb: odb, ds: ds, blockTranslator: ocrcommon.NewBlockTranslator(chain, ethClient, lggr), - ctx: ctx, - ctxCancel: cancel, + stopCh: make(chan struct{}), } } @@ -98,7 +95,7 @@ func (t *RequestRoundTracker) Start(ctx context.Context) error { // Close should be called after teardown of the OCR job relying on this tracker func (t *RequestRoundTracker) Close() error { return t.StopOnce("RequestRoundTracker", func() error { - t.ctxCancel() + close(t.stopCh) t.unsubscribeLogs() return nil }) diff --git a/core/services/vrf/v2/listener_v2_log_listener_test.go b/core/services/vrf/v2/listener_v2_log_listener_test.go index a393aec3ee3..5b827a5291d 100644 --- a/core/services/vrf/v2/listener_v2_log_listener_test.go +++ b/core/services/vrf/v2/listener_v2_log_listener_test.go @@ -1,7 +1,6 @@ package v2 import ( - "context" "fmt" "math/big" "strings" @@ -20,6 +19,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -56,7 +56,6 @@ type vrfLogPollerListenerTH struct { EthDB ethdb.Database Db *sqlx.DB Listener *listenerV2 - Ctx context.Context } func setupVRFLogPollerListenerTH(t *testing.T, @@ -173,7 +172,6 @@ func setupVRFLogPollerListenerTH(t *testing.T, EthDB: ethDB, Db: db, Listener: listener, - Ctx: ctx, } mockChainUpdateFn(chain, th) return th @@ -189,6 +187,7 @@ func setupVRFLogPollerListenerTH(t *testing.T, func TestInitProcessedBlock_NoVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, th *vrfLogPollerListenerTH) { @@ -226,7 +225,7 @@ func TestInitProcessedBlock_NoVRFReqs(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, len(logs)) - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(6), lastProcessedBlock) } @@ -262,6 +261,7 @@ func TestLogPollerFilterRegistered(t *testing.T) { func TestInitProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -298,7 +298,7 @@ func TestInitProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { } // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 2 (VRF req/resp block) + 5 (EmitLog blocks) = 11 latestBlock := int64(2 + 2 + 2 + 5) @@ -308,17 +308,18 @@ func TestInitProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { // Then test if log poller is able to replay from finalizedBlockNumber (8 --> onwards) // since there are no pending VRF requests // Blocks: 1 2 3 4 [5;Request] [6;Fulfilment] 7 8 9 10 11 - require.NoError(t, th.LogPoller.Replay(th.Ctx, latestBlock)) + require.NoError(t, th.LogPoller.Replay(ctx, latestBlock)) // initializeLastProcessedBlock must return the finalizedBlockNumber (8) instead of // VRF request block number (5), since all VRF requests are fulfilled - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(8), lastProcessedBlock) } func TestInitProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -353,7 +354,7 @@ func TestInitProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { } // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 1 (VRF req block) + 5 (EmitLog blocks) = 10 latestBlock := int64(2 + 2 + 1 + 5) @@ -362,17 +363,18 @@ func TestInitProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { // Replay from block 10 (latest) onwards, so that log poller has a latest block // Then test if log poller is able to replay from earliestUnprocessedBlock (5 --> onwards) // Blocks: 1 2 3 4 [5;Request] 6 7 8 9 10 - require.NoError(t, th.LogPoller.Replay(th.Ctx, latestBlock)) + require.NoError(t, th.LogPoller.Replay(ctx, latestBlock)) // initializeLastProcessedBlock must return the unfulfilled VRF // request block number (5) instead of finalizedBlockNumber (8) - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(5), lastProcessedBlock) } func TestInitProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -413,7 +415,7 @@ func TestInitProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { } // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 3*5 (EmitLog + VRF req/resp blocks) = 19 latestBlock := int64(2 + 2 + 3*5) @@ -424,17 +426,18 @@ func TestInitProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { // Blocks: 1 2 3 4 5 [6;Request] [7;Request] 8 [9;Request] [10;Request] // 11 [12;Request] [13;Request] 14 [15;Request] [16;Request] // 17 [18;Request] [19;Request] - require.NoError(t, th.LogPoller.Replay(th.Ctx, latestBlock)) + require.NoError(t, th.LogPoller.Replay(ctx, latestBlock)) // initializeLastProcessedBlock must return the earliest unfulfilled VRF request block // number instead of finalizedBlockNumber - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(6), lastProcessedBlock) } func TestInitProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -480,7 +483,7 @@ func TestInitProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { } // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 3*5 (EmitLog + VRF req/resp blocks) = 19 latestBlock := int64(2 + 2 + 3*5) @@ -490,11 +493,11 @@ func TestInitProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { // Blocks: 1 2 3 4 5 [6;Request] [7;Request;6-Fulfilment] 8 [9;Request] [10;Request;9-Fulfilment] // 11 [12;Request] [13;Request;12-Fulfilment] 14 [15;Request] [16;Request;15-Fulfilment] // 17 [18;Request] [19;Request;18-Fulfilment] - require.NoError(t, th.LogPoller.Replay(th.Ctx, latestBlock)) + require.NoError(t, th.LogPoller.Replay(ctx, latestBlock)) // initializeLastProcessedBlock must return the earliest unfulfilled VRF request block // number instead of finalizedBlockNumber - lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(th.Ctx) + lastProcessedBlock, err := th.Listener.initializeLastProcessedBlock(ctx) require.Nil(t, err) require.Equal(t, int64(7), lastProcessedBlock) } @@ -511,6 +514,7 @@ func TestInitProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { func TestUpdateLastProcessedBlock_NoVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -552,22 +556,23 @@ func TestUpdateLastProcessedBlock_NoVRFReqs(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 2 (VRF req blocks) + 5 (EmitLog blocks) = 11 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the finalizedBlockNumber as there are // no VRF requests, after currLastProcessedBlock (block 6). The VRF requests // made above are before the currLastProcessedBlock (7) passed in below - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 7) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 7) require.Nil(t, err) require.Equal(t, int64(8), lastProcessedBlock) } func TestUpdateLastProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -607,22 +612,23 @@ func TestUpdateLastProcessedBlock_NoUnfulfilledVRFReqs(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 2 (VRF req/resp blocks) + 5 (EmitLog blocks) = 11 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the finalizedBlockNumber (8) though we have // a VRF req at block (5) after currLastProcessedBlock (4) passed below, because // the VRF request is fulfilled - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 4) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 4) require.Nil(t, err) require.Equal(t, int64(8), lastProcessedBlock) } func TestUpdateLastProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -658,22 +664,23 @@ func TestUpdateLastProcessedBlock_OneUnfulfilledVRFReq(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 1 (VRF req block) + 5 (EmitLog blocks) = 10 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the VRF req at block (5) instead of // finalizedBlockNumber (8) after currLastProcessedBlock (4) passed below, // because the VRF request is unfulfilled - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 4) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 4) require.Nil(t, err) require.Equal(t, int64(5), lastProcessedBlock) } func TestUpdateLastProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -715,22 +722,23 @@ func TestUpdateLastProcessedBlock_SomeUnfulfilledVRFReqs(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 3*5 (EmitLog + VRF req blocks) = 19 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the VRF req at block (6) instead of // finalizedBlockNumber (16) after currLastProcessedBlock (4) passed below, // as block 6 contains the earliest unfulfilled VRF request - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 4) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 4) require.Nil(t, err) require.Equal(t, int64(6), lastProcessedBlock) } func TestUpdateLastProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { t.Parallel() + ctx := tests.Context(t) finalityDepth := int64(3) th := setupVRFLogPollerListenerTH(t, false, finalityDepth, 3, 2, 1000, func(mockChain *evmmocks.Chain, curTH *vrfLogPollerListenerTH) { @@ -776,17 +784,17 @@ func TestUpdateLastProcessedBlock_UnfulfilledNFulfilledVRFReqs(t *testing.T) { // Blocks till now: 2 (in SetupTH) + 2 (empty blocks) + 3*5 (EmitLog + VRF req blocks) = 19 // Calling Start() after RegisterFilter() simulates a node restart after job creation, should reload Filter from db. - require.NoError(t, th.LogPoller.Start(th.Ctx)) + require.NoError(t, th.LogPoller.Start(ctx)) // We've to replay from before VRF request log, since updateLastProcessedBlock // does not internally call LogPoller.Replay - require.NoError(t, th.LogPoller.Replay(th.Ctx, 4)) + require.NoError(t, th.LogPoller.Replay(ctx, 4)) // updateLastProcessedBlock must return the VRF req at block (7) instead of // finalizedBlockNumber (16) after currLastProcessedBlock (4) passed below, // as block 7 contains the earliest unfulfilled VRF request. VRF request // in block 6 has been fulfilled in block 7. - lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(th.Ctx, 4) + lastProcessedBlock, err := th.Listener.updateLastProcessedBlock(ctx, 4) require.Nil(t, err) require.Equal(t, int64(7), lastProcessedBlock) } diff --git a/core/utils/thread_control_test.go b/core/utils/thread_control_test.go index dff3740eda7..51d5c00a578 100644 --- a/core/utils/thread_control_test.go +++ b/core/utils/thread_control_test.go @@ -49,7 +49,8 @@ func TestThreadControl_GoCtx(t *testing.T) { start := time.Now() wg.Wait() - require.True(t, time.Since(start) > timeout-1) - require.True(t, time.Since(start) < 2*timeout) + end := time.Since(start) + require.True(t, end > timeout-1) + require.True(t, end < 2*timeout) require.Equal(t, int32(1), finished.Load()) } diff --git a/core/web/resolver/api_token_test.go b/core/web/resolver/api_token_test.go index bb82e91ed9b..d35d5d6d9a8 100644 --- a/core/web/resolver/api_token_test.go +++ b/core/web/resolver/api_token_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" gqlerrors "github.com/graph-gophers/graphql-go/errors" @@ -52,8 +53,8 @@ func TestResolver_CreateAPIToken(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -85,8 +86,8 @@ func TestResolver_CreateAPIToken(t *testing.T) { { name: "input errors", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -110,8 +111,8 @@ func TestResolver_CreateAPIToken(t *testing.T) { { name: "failed to find user", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -138,8 +139,8 @@ func TestResolver_CreateAPIToken(t *testing.T) { { name: "failed to generate token", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -208,8 +209,8 @@ func TestResolver_DeleteAPIToken(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -239,8 +240,8 @@ func TestResolver_DeleteAPIToken(t *testing.T) { { name: "input errors", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -264,8 +265,8 @@ func TestResolver_DeleteAPIToken(t *testing.T) { { name: "failed to find user", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -292,8 +293,8 @@ func TestResolver_DeleteAPIToken(t *testing.T) { { name: "failed to delete token", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := webauth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := webauth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) diff --git a/core/web/resolver/bridge_test.go b/core/web/resolver/bridge_test.go index 2244ddf3dac..ae18fe59d2d 100644 --- a/core/web/resolver/bridge_test.go +++ b/core/web/resolver/bridge_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "encoding/json" "net/url" @@ -46,7 +47,7 @@ func Test_Bridges(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("BridgeTypes", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return([]bridges.BridgeType{ { @@ -116,7 +117,7 @@ func Test_Bridge(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{ Name: name, @@ -143,7 +144,7 @@ func Test_Bridge(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, sql.ErrNoRows) }, @@ -198,7 +199,7 @@ func Test_CreateBridge(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, sql.ErrNoRows) f.Mocks.bridgeORM.On("CreateBridgeType", mock.Anything, mock.IsType(&bridges.BridgeType{})). @@ -286,7 +287,7 @@ func Test_UpdateBridge(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { // Initialize the existing bridge bridge := bridges.BridgeType{ Name: name, @@ -340,7 +341,7 @@ func Test_UpdateBridge(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, sql.ErrNoRows) }, @@ -407,7 +408,7 @@ func Test_DeleteBridgeMutation(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { bridge := bridges.BridgeType{ Name: name, URL: models.WebURL(*bridgeURL), @@ -460,7 +461,7 @@ func Test_DeleteBridgeMutation(t *testing.T) { variables: map[string]interface{}{ "id": "bridge1", }, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, sql.ErrNoRows) f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) }, @@ -479,7 +480,7 @@ func Test_DeleteBridgeMutation(t *testing.T) { variables: map[string]interface{}{ "id": "bridge1", }, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, nil) f.Mocks.jobORM.On("FindJobIDsWithBridge", mock.Anything, name.String()).Return([]int32{1}, nil) f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) diff --git a/core/web/resolver/chain_test.go b/core/web/resolver/chain_test.go index 5e51356d928..75d7e36a5b5 100644 --- a/core/web/resolver/chain_test.go +++ b/core/web/resolver/chain_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "fmt" "testing" @@ -76,7 +77,7 @@ ResendAfterThreshold = '1h0m0s' { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { chainConf := evmtoml.EVMConfig{ ChainID: &chainID, Enabled: chain.Enabled, @@ -113,7 +114,7 @@ ResendAfterThreshold = '1h0m0s' { name: "no chains", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetRelayers").Return(&chainlinkmocks.FakeRelayerChainInteroperators{Relayers: []loop.Relayer{}}) }, query: query, @@ -192,7 +193,7 @@ ResendAfterThreshold = '1h0m0s' { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("EVMORM").Return(f.Mocks.evmORM) f.Mocks.evmORM.PutChains(evmtoml.EVMConfig{ ChainID: &chainID, @@ -212,7 +213,7 @@ ResendAfterThreshold = '1h0m0s' { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("EVMORM").Return(f.Mocks.evmORM) }, query: query, diff --git a/core/web/resolver/config_test.go b/core/web/resolver/config_test.go index a04b3fa2484..f380e4db55a 100644 --- a/core/web/resolver/config_test.go +++ b/core/web/resolver/config_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" _ "embed" "encoding/json" "fmt" @@ -38,7 +39,7 @@ func TestResolver_ConfigV2(t *testing.T) { { name: "empty", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { opts := chainlink.GeneralConfigOpts{} cfg, err := opts.New() require.NoError(t, err) @@ -50,7 +51,7 @@ func TestResolver_ConfigV2(t *testing.T) { { name: "full", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { opts := chainlink.GeneralConfigOpts{ ConfigStrings: []string{configFull}, SecretsStrings: []string{}, @@ -65,7 +66,7 @@ func TestResolver_ConfigV2(t *testing.T) { { name: "partial", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { opts := chainlink.GeneralConfigOpts{ ConfigStrings: []string{configMulti}, SecretsStrings: []string{}, diff --git a/core/web/resolver/csa_keys_test.go b/core/web/resolver/csa_keys_test.go index 1048d9aa4bc..94513b53e45 100644 --- a/core/web/resolver/csa_keys_test.go +++ b/core/web/resolver/csa_keys_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "fmt" "testing" @@ -58,7 +59,7 @@ func Test_CSAKeysQuery(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.csa.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -109,7 +110,7 @@ func Test_CreateCSAKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.csa.On("Create", mock.Anything).Return(fakeKey, nil) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -120,7 +121,7 @@ func Test_CreateCSAKey(t *testing.T) { { name: "csa key exists error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.csa.On("Create", mock.Anything).Return(csakey.KeyV2{}, keystore.ErrCSAKeyExists) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -178,7 +179,7 @@ func Test_DeleteCSAKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetKeyStore").Return(f.Mocks.keystore) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.Mocks.csa.On("Delete", mock.Anything, fakeKey.ID()).Return(fakeKey, nil) @@ -190,7 +191,7 @@ func Test_DeleteCSAKey(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetKeyStore").Return(f.Mocks.keystore) f.Mocks.keystore.On("CSA").Return(f.Mocks.csa) f.Mocks.csa. diff --git a/core/web/resolver/eth_key_test.go b/core/web/resolver/eth_key_test.go index 40a60263f06..55cdc230bd2 100644 --- a/core/web/resolver/eth_key_test.go +++ b/core/web/resolver/eth_key_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "fmt" "testing" @@ -80,7 +81,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "success on prod", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -147,7 +148,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "success with no chains", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -202,7 +203,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "generic error on GetAll()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ethKs.On("GetAll", mock.Anything).Return(nil, gError) f.Mocks.keystore.On("Eth").Return(f.Mocks.ethKs) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -221,7 +222,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "generic error on GetStatesForKeys()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ethKs.On("GetAll", mock.Anything).Return(keys, nil) f.Mocks.ethKs.On("GetStatesForKeys", mock.Anything, keys).Return(nil, gError) f.Mocks.keystore.On("Eth").Return(f.Mocks.ethKs) @@ -241,7 +242,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "generic error on Get()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -273,7 +274,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "Empty set on legacy evm chains", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -304,7 +305,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "generic error on GetLINKBalance()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.MustEIP55Address(address.Hex()), @@ -366,7 +367,7 @@ func TestResolver_ETHKeys(t *testing.T) { { name: "success with no eth balance", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { states := []ethkey.State{ { Address: evmtypes.EIP55AddressFromAddress(address), diff --git a/core/web/resolver/eth_transaction_test.go b/core/web/resolver/eth_transaction_test.go index 5568a6664cb..f288450edcc 100644 --- a/core/web/resolver/eth_transaction_test.go +++ b/core/web/resolver/eth_transaction_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "errors" "math/big" @@ -68,7 +69,7 @@ func TestResolver_EthTransaction(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("FindTxByHash", mock.Anything, hash).Return(&txmgr.Tx{ ID: 1, ToAddress: common.HexToAddress("0x5431F5F973781809D18643b87B44921b11355d81"), @@ -130,7 +131,7 @@ func TestResolver_EthTransaction(t *testing.T) { { name: "success without nil values", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { num := int64(2) nonce := evmtypes.Nonce(num) @@ -195,7 +196,7 @@ func TestResolver_EthTransaction(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("FindTxByHash", mock.Anything, hash).Return(nil, sql.ErrNoRows) f.App.On("TxmStorageService").Return(f.Mocks.txmStore) }, @@ -212,7 +213,7 @@ func TestResolver_EthTransaction(t *testing.T) { { name: "generic error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("FindTxByHash", mock.Anything, hash).Return(nil, gError) f.App.On("TxmStorageService").Return(f.Mocks.txmStore) }, @@ -266,7 +267,7 @@ func TestResolver_EthTransactions(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { num := int64(2) f.Mocks.txmStore.On("Transactions", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return([]txmgr.Tx{ @@ -319,7 +320,7 @@ func TestResolver_EthTransactions(t *testing.T) { { name: "generic error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("Transactions", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return(nil, 0, gError) f.App.On("TxmStorageService").Return(f.Mocks.txmStore) }, @@ -364,7 +365,7 @@ func TestResolver_EthTransactionsAttempts(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { num := int64(2) f.Mocks.txmStore.On("TxAttempts", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return([]txmgr.TxAttempt{ @@ -397,7 +398,7 @@ func TestResolver_EthTransactionsAttempts(t *testing.T) { { name: "success with nil values", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("TxAttempts", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return([]txmgr.TxAttempt{ { Hash: hash, @@ -427,7 +428,7 @@ func TestResolver_EthTransactionsAttempts(t *testing.T) { { name: "generic error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.txmStore.On("TxAttempts", mock.Anything, PageDefaultOffset, PageDefaultLimit).Return(nil, 0, gError) f.App.On("TxmStorageService").Return(f.Mocks.txmStore) }, diff --git a/core/web/resolver/features_test.go b/core/web/resolver/features_test.go index f14f71abc90..76394f038b0 100644 --- a/core/web/resolver/features_test.go +++ b/core/web/resolver/features_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" @@ -23,7 +24,7 @@ func Test_ToFeatures(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetConfig").Return(configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { t, f := true, false c.Feature.UICSAKeys = &f diff --git a/core/web/resolver/feeds_manager_chain_config_test.go b/core/web/resolver/feeds_manager_chain_config_test.go index 31208aa0581..c5dd77c14a1 100644 --- a/core/web/resolver/feeds_manager_chain_config_test.go +++ b/core/web/resolver/feeds_manager_chain_config_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "testing" @@ -71,7 +72,7 @@ func Test_CreateFeedsManagerChainConfig(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CreateChainConfig", mock.Anything, feeds.ChainConfig{ FeedsManagerID: mgrID, @@ -144,7 +145,7 @@ func Test_CreateFeedsManagerChainConfig(t *testing.T) { { name: "create call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CreateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(int64(0), sql.ErrNoRows) }, @@ -161,7 +162,7 @@ func Test_CreateFeedsManagerChainConfig(t *testing.T) { { name: "get call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CreateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(cfgID, nil) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(nil, sql.ErrNoRows) @@ -209,7 +210,7 @@ func Test_DeleteFeedsManagerChainConfig(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(&feeds.ChainConfig{ ID: cfgID, @@ -230,7 +231,7 @@ func Test_DeleteFeedsManagerChainConfig(t *testing.T) { { name: "get call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(nil, sql.ErrNoRows) }, @@ -247,7 +248,7 @@ func Test_DeleteFeedsManagerChainConfig(t *testing.T) { { name: "delete call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(&feeds.ChainConfig{ ID: cfgID, @@ -324,7 +325,7 @@ func Test_UpdateFeedsManagerChainConfig(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateChainConfig", mock.Anything, feeds.ChainConfig{ ID: cfgID, @@ -393,7 +394,7 @@ func Test_UpdateFeedsManagerChainConfig(t *testing.T) { { name: "update call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(int64(0), sql.ErrNoRows) }, @@ -410,7 +411,7 @@ func Test_UpdateFeedsManagerChainConfig(t *testing.T) { { name: "get call not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(cfgID, nil) f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(nil, sql.ErrNoRows) diff --git a/core/web/resolver/feeds_manager_test.go b/core/web/resolver/feeds_manager_test.go index a3ea80a6443..bafb50ab0d5 100644 --- a/core/web/resolver/feeds_manager_test.go +++ b/core/web/resolver/feeds_manager_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "testing" @@ -40,7 +41,7 @@ func Test_FeedsManagers(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ListJobProposalsByManagersIDs", mock.Anything, []int64{1}).Return([]feeds.JobProposal{ { @@ -113,7 +114,7 @@ func Test_FeedsManager(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(&feeds.FeedsManager{ ID: mgrID, @@ -140,7 +141,7 @@ func Test_FeedsManager(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(nil, sql.ErrNoRows) }, @@ -212,7 +213,7 @@ func Test_CreateFeedsManager(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RegisterManager", mock.Anything, feeds.RegisterManagerParams{ Name: name, @@ -247,7 +248,7 @@ func Test_CreateFeedsManager(t *testing.T) { { name: "single feeds manager error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc. On("RegisterManager", mock.Anything, mock.IsType(feeds.RegisterManagerParams{})). @@ -266,7 +267,7 @@ func Test_CreateFeedsManager(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RegisterManager", mock.Anything, mock.IsType(feeds.RegisterManagerParams{})).Return(mgrID, nil) f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(nil, sql.ErrNoRows) @@ -358,7 +359,7 @@ func Test_UpdateFeedsManager(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateManager", mock.Anything, feeds.FeedsManager{ ID: mgrID, @@ -394,7 +395,7 @@ func Test_UpdateFeedsManager(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateManager", mock.Anything, mock.IsType(feeds.FeedsManager{})).Return(nil) f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(nil, sql.ErrNoRows) diff --git a/core/web/resolver/job_error_test.go b/core/web/resolver/job_error_test.go index 69899a3ec47..e3af3230e27 100644 --- a/core/web/resolver/job_error_test.go +++ b/core/web/resolver/job_error_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "encoding/json" "testing" @@ -27,7 +28,7 @@ func TestResolver_JobErrors(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: int32(1), @@ -123,7 +124,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{ ID: id, Occurrences: 5, @@ -140,7 +141,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "not found on FindSpecError()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -158,7 +159,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "not found on DismissError()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, nil) f.Mocks.jobORM.On("DismissError", mock.Anything, id).Return(sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) @@ -177,7 +178,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "generic error on FindSpecError()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -196,7 +197,7 @@ func TestResolver_DismissJobError(t *testing.T) { { name: "generic error on DismissError()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, nil) f.Mocks.jobORM.On("DismissError", mock.Anything, id).Return(gError) f.App.On("JobORM").Return(f.Mocks.jobORM) diff --git a/core/web/resolver/job_proposal_spec_test.go b/core/web/resolver/job_proposal_spec_test.go index 5875a5acb69..46364998c9c 100644 --- a/core/web/resolver/job_proposal_spec_test.go +++ b/core/web/resolver/job_proposal_spec_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "testing" "time" @@ -50,7 +51,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ @@ -64,7 +65,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { { name: "not found error on approval", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(sql.ErrNoRows) }, @@ -81,7 +82,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { { name: "not found error on fetch", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) @@ -99,7 +100,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { { name: "unprocessable error on approval if job already exists", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(feeds.ErrJobAlreadyExists) }, @@ -154,7 +155,7 @@ func TestResolver_CancelJobProposalSpec(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CancelSpec", mock.Anything, specID).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ @@ -168,7 +169,7 @@ func TestResolver_CancelJobProposalSpec(t *testing.T) { { name: "not found error on cancel", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CancelSpec", mock.Anything, specID).Return(sql.ErrNoRows) }, @@ -185,7 +186,7 @@ func TestResolver_CancelJobProposalSpec(t *testing.T) { { name: "not found error on fetch", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CancelSpec", mock.Anything, specID).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) @@ -241,7 +242,7 @@ func TestResolver_RejectJobProposalSpec(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RejectSpec", mock.Anything, specID).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ @@ -255,7 +256,7 @@ func TestResolver_RejectJobProposalSpec(t *testing.T) { { name: "not found error on reject", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RejectSpec", mock.Anything, specID).Return(sql.ErrNoRows) }, @@ -272,7 +273,7 @@ func TestResolver_RejectJobProposalSpec(t *testing.T) { { name: "not found error on fetch", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RejectSpec", mock.Anything, specID).Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) @@ -331,7 +332,7 @@ func TestResolver_UpdateJobProposalSpecDefinition(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateSpecDefinition", mock.Anything, specID, "").Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ @@ -345,7 +346,7 @@ func TestResolver_UpdateJobProposalSpecDefinition(t *testing.T) { { name: "not found error on update", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateSpecDefinition", mock.Anything, specID, "").Return(sql.ErrNoRows) }, @@ -362,7 +363,7 @@ func TestResolver_UpdateJobProposalSpecDefinition(t *testing.T) { { name: "not found error on fetch", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateSpecDefinition", mock.Anything, specID, "").Return(nil) f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) @@ -443,7 +444,7 @@ func TestResolver_GetJobProposal_Spec(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.feedsSvc.On("GetJobProposal", mock.Anything, jpID).Return(&feeds.JobProposal{ ID: jpID, Status: feeds.JobProposalStatusApproved, diff --git a/core/web/resolver/job_proposal_test.go b/core/web/resolver/job_proposal_test.go index 5544b39c936..3c09435e56e 100644 --- a/core/web/resolver/job_proposal_test.go +++ b/core/web/resolver/job_proposal_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "fmt" "testing" @@ -64,7 +65,7 @@ func TestResolver_GetJobProposal(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.feedsSvc.On("ListManagersByIDs", mock.Anything, []int64{1}).Return([]feeds.FeedsManager{ { ID: 1, @@ -89,7 +90,7 @@ func TestResolver_GetJobProposal(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.feedsSvc.On("GetJobProposal", mock.Anything, jpID).Return(nil, sql.ErrNoRows) f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) }, diff --git a/core/web/resolver/job_run_test.go b/core/web/resolver/job_run_test.go index 51631864e8c..3029710bcc4 100644 --- a/core/web/resolver/job_run_test.go +++ b/core/web/resolver/job_run_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "testing" @@ -39,7 +40,7 @@ func TestQuery_PaginatedJobRuns(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("PipelineRuns", mock.Anything, (*int32)(nil), PageDefaultOffset, PageDefaultLimit).Return([]pipeline.Run{ { ID: int64(200), @@ -63,7 +64,7 @@ func TestQuery_PaginatedJobRuns(t *testing.T) { { name: "generic error on PipelineRuns()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("PipelineRuns", mock.Anything, (*int32)(nil), PageDefaultOffset, PageDefaultLimit).Return(nil, 0, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -130,7 +131,7 @@ func TestResolver_JobRun(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindPipelineRunByID", mock.Anything, int64(2)).Return(pipeline.Run{ ID: 2, PipelineSpecID: 5, @@ -179,7 +180,7 @@ func TestResolver_JobRun(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindPipelineRunByID", mock.Anything, int64(2)).Return(pipeline.Run{}, sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -196,7 +197,7 @@ func TestResolver_JobRun(t *testing.T) { { name: "generic error on FindPipelineRunByID()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindPipelineRunByID", mock.Anything, int64(2)).Return(pipeline.Run{}, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -284,7 +285,7 @@ func TestResolver_RunJob(t *testing.T) { { name: "success without body", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), nil) f.Mocks.pipelineORM.On("FindRun", mock.Anything, int64(25)).Return(pipeline.Run{ ID: 2, @@ -337,7 +338,7 @@ func TestResolver_RunJob(t *testing.T) { { name: "not found job error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), webhook.ErrJobNotExists) }, query: mutation, @@ -355,7 +356,7 @@ func TestResolver_RunJob(t *testing.T) { { name: "generic error on RunJobV2", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), gError) }, query: mutation, @@ -375,7 +376,7 @@ func TestResolver_RunJob(t *testing.T) { { name: "generic error on FindRun", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), nil) f.Mocks.pipelineORM.On("FindRun", mock.Anything, int64(25)).Return(pipeline.Run{}, gError) f.App.On("PipelineORM").Return(f.Mocks.pipelineORM) diff --git a/core/web/resolver/job_test.go b/core/web/resolver/job_test.go index 0615e47a621..e00c4604bca 100644 --- a/core/web/resolver/job_test.go +++ b/core/web/resolver/job_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "database/sql" "encoding/json" "fmt" @@ -68,7 +69,7 @@ func TestResolver_Jobs(t *testing.T) { { name: "get jobs success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { plnSpecID := int32(12) f.App.On("JobORM").Return(f.Mocks.jobORM) @@ -206,7 +207,7 @@ func TestResolver_Job(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: 1, @@ -238,7 +239,7 @@ func TestResolver_Job(t *testing.T) { { name: "not found", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, sql.ErrNoRows) }, @@ -255,7 +256,7 @@ func TestResolver_Job(t *testing.T) { { name: "show job when chainID is disabled", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: 1, @@ -351,7 +352,7 @@ func TestResolver_CreateJob(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetConfig").Return(f.Mocks.cfg) f.App.On("AddJobV2", mock.Anything, &jb).Return(nil) }, @@ -378,7 +379,7 @@ func TestResolver_CreateJob(t *testing.T) { { name: "generic error when adding the job", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetConfig").Return(f.Mocks.cfg) f.App.On("AddJobV2", mock.Anything, &jb).Return(gError) }, @@ -452,7 +453,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: id, Name: null.StringFrom("test-job"), @@ -470,7 +471,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "not found on FindJob()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -488,7 +489,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "not found on DeleteJob()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, nil) f.App.On("JobORM").Return(f.Mocks.jobORM) f.App.On("DeleteJob", mock.Anything, id).Return(sql.ErrNoRows) @@ -507,7 +508,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "generic error on FindJob()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -526,7 +527,7 @@ func TestResolver_DeleteJob(t *testing.T) { { name: "generic error on DeleteJob()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, nil) f.App.On("JobORM").Return(f.Mocks.jobORM) f.App.On("DeleteJob", mock.Anything, id).Return(gError) diff --git a/core/web/resolver/log_test.go b/core/web/resolver/log_test.go index 8b1b941da5a..cf5620845b2 100644 --- a/core/web/resolver/log_test.go +++ b/core/web/resolver/log_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" gqlerrors "github.com/graph-gophers/graphql-go/errors" @@ -35,7 +36,7 @@ func TestResolver_SetSQLLogging(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.cfg.On("SetLogSQL", true).Return(nil) f.App.On("GetConfig").Return(f.Mocks.cfg) }, @@ -79,7 +80,7 @@ func TestResolver_SQLLogging(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.cfg.On("Database").Return(&databaseConfig{logSQL: false}) f.App.On("GetConfig").Return(f.Mocks.cfg) }, @@ -126,7 +127,7 @@ func TestResolver_GlobalLogLevel(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.cfg.On("Log").Return(&log{level: warnLvl}) f.App.On("GetConfig").Return(f.Mocks.cfg) }, @@ -178,7 +179,7 @@ func TestResolver_SetGlobalLogLevel(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("SetLogLevel", errorLvl).Return(nil) }, query: mutation, @@ -195,7 +196,7 @@ func TestResolver_SetGlobalLogLevel(t *testing.T) { { name: "generic error on SetLogLevel", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("SetLogLevel", errorLvl).Return(gError) }, query: mutation, diff --git a/core/web/resolver/node_test.go b/core/web/resolver/node_test.go index e103a470097..870f694990f 100644 --- a/core/web/resolver/node_test.go +++ b/core/web/resolver/node_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" gqlerrors "github.com/graph-gophers/graphql-go/errors" @@ -39,7 +40,7 @@ func TestResolver_Nodes(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetRelayers").Return(&chainlinkmocks.FakeRelayerChainInteroperators{ Nodes: []types.NodeStatus{ { @@ -78,7 +79,7 @@ func TestResolver_Nodes(t *testing.T) { { name: "generic error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.relayerChainInterops.NodesErr = gError f.App.On("GetRelayers").Return(f.Mocks.relayerChainInterops) }, @@ -122,7 +123,7 @@ func Test_NodeQuery(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetRelayers").Return(&chainlinkmocks.FakeRelayerChainInteroperators{Relayers: []loop.Relayer{ testutils.MockRelayer{NodeStatuses: []types.NodeStatus{ { @@ -146,7 +147,7 @@ func Test_NodeQuery(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("GetRelayers").Return(&chainlinkmocks.FakeRelayerChainInteroperators{Relayers: []loop.Relayer{}}) }, query: query, diff --git a/core/web/resolver/ocr2_keys_test.go b/core/web/resolver/ocr2_keys_test.go index fc82d070dd9..2269149bc37 100644 --- a/core/web/resolver/ocr2_keys_test.go +++ b/core/web/resolver/ocr2_keys_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/hex" "encoding/json" "fmt" @@ -69,7 +70,7 @@ func TestResolver_GetOCR2KeyBundles(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -80,7 +81,7 @@ func TestResolver_GetOCR2KeyBundles(t *testing.T) { { name: "generic error on GetAll()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("GetAll").Return(nil, gError) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -150,7 +151,7 @@ func TestResolver_CreateOCR2KeyBundle(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Create", mock.Anything, chaintype.ChainType("evm")).Return(fakeKey, nil) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -162,7 +163,7 @@ func TestResolver_CreateOCR2KeyBundle(t *testing.T) { { name: "generic error on Create()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Create", mock.Anything, chaintype.ChainType("evm")).Return(nil, gError) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -238,7 +239,7 @@ func TestResolver_DeleteOCR2KeyBundle(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Delete", mock.Anything, fakeKey.ID()).Return(nil) f.Mocks.ocr2.On("Get", fakeKey.ID()).Return(fakeKey, nil) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) @@ -251,7 +252,7 @@ func TestResolver_DeleteOCR2KeyBundle(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Get", fakeKey.ID()).Return(fakeKey, gError) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -268,7 +269,7 @@ func TestResolver_DeleteOCR2KeyBundle(t *testing.T) { { name: "generic error on Delete()", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr2.On("Delete", mock.Anything, fakeKey.ID()).Return(gError) f.Mocks.ocr2.On("Get", fakeKey.ID()).Return(fakeKey, nil) f.Mocks.keystore.On("OCR2").Return(f.Mocks.ocr2) diff --git a/core/web/resolver/ocr_test.go b/core/web/resolver/ocr_test.go index 5ca56c4bd04..e4b74e1d66b 100644 --- a/core/web/resolver/ocr_test.go +++ b/core/web/resolver/ocr_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "math/big" "testing" @@ -54,7 +55,7 @@ func TestResolver_GetOCRKeyBundles(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("OCR").Return(f.Mocks.ocr) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -105,7 +106,7 @@ func TestResolver_OCRCreateBundle(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr.On("Create", mock.Anything).Return(fakeKey, nil) f.Mocks.keystore.On("OCR").Return(f.Mocks.ocr) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -163,7 +164,7 @@ func TestResolver_OCRDeleteBundle(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr.On("Delete", mock.Anything, fakeKey.ID()).Return(fakeKey, nil) f.Mocks.keystore.On("OCR").Return(f.Mocks.ocr) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -175,7 +176,7 @@ func TestResolver_OCRDeleteBundle(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.ocr. On("Delete", mock.Anything, fakeKey.ID()). Return(ocrkey.KeyV2{}, keystore.KeyNotFoundError{ID: "helloWorld", KeyType: "OCR"}) diff --git a/core/web/resolver/p2p_test.go b/core/web/resolver/p2p_test.go index 6502ffc821a..941787aba96 100644 --- a/core/web/resolver/p2p_test.go +++ b/core/web/resolver/p2p_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "fmt" "math/big" @@ -53,7 +54,7 @@ func TestResolver_GetP2PKeys(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.p2p.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("P2P").Return(f.Mocks.p2p) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -102,7 +103,7 @@ func TestResolver_CreateP2PKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.p2p.On("Create", mock.Anything).Return(fakeKey, nil) f.Mocks.keystore.On("P2P").Return(f.Mocks.p2p) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -163,7 +164,7 @@ func TestResolver_DeleteP2PKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.p2p.On("Delete", mock.Anything, peerID).Return(fakeKey, nil) f.Mocks.keystore.On("P2P").Return(f.Mocks.p2p) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -175,7 +176,7 @@ func TestResolver_DeleteP2PKey(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.p2p. On("Delete", mock.Anything, peerID). Return( diff --git a/core/web/resolver/resolver_test.go b/core/web/resolver/resolver_test.go index 56b7f076eca..9f3445e1cee 100644 --- a/core/web/resolver/resolver_test.go +++ b/core/web/resolver/resolver_test.go @@ -72,9 +72,6 @@ type gqlTestFramework struct { // The root GQL schema RootSchema *graphql.Schema - // Contains the context with an injected dataloader - Ctx context.Context - Mocks *mocks } @@ -88,7 +85,6 @@ func setupFramework(t *testing.T) *gqlTestFramework { schema.MustGetRootSchema(), &Resolver{App: app}, ) - ctx = loader.InjectDataloader(testutils.Context(t), app) ) // Setup mocks @@ -128,7 +124,6 @@ func setupFramework(t *testing.T) *gqlTestFramework { t: t, App: app, RootSchema: rootSchema, - Ctx: ctx, Mocks: m, } @@ -146,20 +141,18 @@ func (f *gqlTestFramework) Timestamp() time.Time { return time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) } -// injectAuthenticatedUser injects a session into the request context -func (f *gqlTestFramework) injectAuthenticatedUser() { - f.t.Helper() - +// withAuthenticatedUser injects a session into the request context +func (f *gqlTestFramework) withAuthenticatedUser(ctx context.Context) context.Context { user := clsessions.User{Email: "gqltester@chain.link", Role: clsessions.UserRoleAdmin} - f.Ctx = auth.WithGQLAuthenticatedSession(f.Ctx, user, "gqltesterSession") + return auth.WithGQLAuthenticatedSession(ctx, user, "gqltesterSession") } // GQLTestCase represents a single GQL request test. type GQLTestCase struct { name string authenticated bool - before func(*gqlTestFramework) + before func(context.Context, *gqlTestFramework) query string variables map[string]interface{} result string @@ -175,16 +168,15 @@ func RunGQLTests(t *testing.T, testCases []GQLTestCase) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - var ( - f = setupFramework(t) - ) + f := setupFramework(t) + ctx := loader.InjectDataloader(testutils.Context(t), f.App) if tc.authenticated { - f.injectAuthenticatedUser() + ctx = f.withAuthenticatedUser(ctx) } if tc.before != nil { - tc.before(f) + tc.before(ctx, f) } // This does not print out the correct stack trace as the `RunTest` @@ -193,7 +185,7 @@ func RunGQLTests(t *testing.T, testCases []GQLTestCase) { // // This would need to be fixed upstream. gqltesting.RunTest(t, &gqltesting.Test{ - Context: f.Ctx, + Context: ctx, Schema: f.RootSchema, Query: tc.query, Variables: tc.variables, diff --git a/core/web/resolver/solana_key_test.go b/core/web/resolver/solana_key_test.go index 5472f12081d..e788e9e7ce2 100644 --- a/core/web/resolver/solana_key_test.go +++ b/core/web/resolver/solana_key_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "errors" "fmt" "testing" @@ -40,7 +41,7 @@ func TestResolver_SolanaKeys(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.solana.On("GetAll").Return([]solkey.Key{k}, nil) f.Mocks.keystore.On("Solana").Return(f.Mocks.solana) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -51,7 +52,7 @@ func TestResolver_SolanaKeys(t *testing.T) { { name: "generic error on GetAll", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.solana.On("GetAll").Return([]solkey.Key{}, gError) f.Mocks.keystore.On("Solana").Return(f.Mocks.solana) f.App.On("GetKeyStore").Return(f.Mocks.keystore) diff --git a/core/web/resolver/spec_test.go b/core/web/resolver/spec_test.go index 43682c14ead..63002e566f1 100644 --- a/core/web/resolver/spec_test.go +++ b/core/web/resolver/spec_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" "time" @@ -34,7 +35,7 @@ func TestResolver_CronSpec(t *testing.T) { { name: "cron spec success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Cron, @@ -88,7 +89,7 @@ func TestResolver_DirectRequestSpec(t *testing.T) { { name: "direct request spec success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.DirectRequest, @@ -153,7 +154,7 @@ func TestResolver_FluxMonitorSpec(t *testing.T) { { name: "flux monitor spec with standard timers", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.FluxMonitor, @@ -220,7 +221,7 @@ func TestResolver_FluxMonitorSpec(t *testing.T) { { name: "flux monitor spec with drumbeat", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.FluxMonitor, @@ -303,7 +304,7 @@ func TestResolver_KeeperSpec(t *testing.T) { { name: "keeper spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Keeper, @@ -367,7 +368,7 @@ func TestResolver_OCRSpec(t *testing.T) { { name: "OCR spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.OffchainReporting, @@ -472,7 +473,7 @@ func TestResolver_OCR2Spec(t *testing.T) { { name: "OCR 2 spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.OffchainReporting2, @@ -574,7 +575,7 @@ func TestResolver_VRFSpec(t *testing.T) { { name: "vrf spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.VRF, @@ -670,7 +671,7 @@ func TestResolver_WebhookSpec(t *testing.T) { { name: "webhook spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Webhook, @@ -739,7 +740,7 @@ func TestResolver_BlockhashStoreSpec(t *testing.T) { { name: "blockhash store spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.BlockhashStore, @@ -843,7 +844,7 @@ func TestResolver_BlockHeaderFeederSpec(t *testing.T) { { name: "block header feeder spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.BlockHeaderFeeder, @@ -930,7 +931,7 @@ func TestResolver_BootstrapSpec(t *testing.T) { { name: "Bootstrap spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Bootstrap, @@ -1002,7 +1003,7 @@ func TestResolver_WorkflowSpec(t *testing.T) { { name: "Workflow spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Workflow, @@ -1060,7 +1061,7 @@ func TestResolver_GatewaySpec(t *testing.T) { { name: "Gateway spec", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Gateway, diff --git a/core/web/resolver/user_test.go b/core/web/resolver/user_test.go index 2662b1f5040..8d37af2e379 100644 --- a/core/web/resolver/user_test.go +++ b/core/web/resolver/user_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "testing" gqlerrors "github.com/graph-gophers/graphql-go/errors" @@ -44,8 +45,8 @@ func TestResolver_UpdateUserPassword(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := auth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := auth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -73,8 +74,8 @@ func TestResolver_UpdateUserPassword(t *testing.T) { { name: "update password match error", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := auth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := auth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -99,8 +100,8 @@ func TestResolver_UpdateUserPassword(t *testing.T) { { name: "failed to clear session error", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := auth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := auth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) @@ -130,8 +131,8 @@ func TestResolver_UpdateUserPassword(t *testing.T) { { name: "failed to update current user password error", authenticated: true, - before: func(f *gqlTestFramework) { - session, ok := auth.GetGQLAuthenticatedSession(f.Ctx) + before: func(ctx context.Context, f *gqlTestFramework) { + session, ok := auth.GetGQLAuthenticatedSession(ctx) require.True(t, ok) require.NotNil(t, session) diff --git a/core/web/resolver/vrf_test.go b/core/web/resolver/vrf_test.go index 5101bc5937b..c77f7b73ff0 100644 --- a/core/web/resolver/vrf_test.go +++ b/core/web/resolver/vrf_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "encoding/json" "fmt" "math/big" @@ -64,7 +65,7 @@ func TestResolver_GetVRFKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf.On("Get", fakeKey.PublicKey.String()).Return(fakeKey, nil) f.Mocks.keystore.On("VRF").Return(f.Mocks.vrf) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -76,7 +77,7 @@ func TestResolver_GetVRFKey(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf. On("Get", fakeKey.PublicKey.String()). Return(vrfkey.KeyV2{}, errors.Wrapf( @@ -146,7 +147,7 @@ func TestResolver_GetVRFKeys(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf.On("GetAll").Return(fakeKeys, nil) f.Mocks.keystore.On("VRF").Return(f.Mocks.vrf) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -198,7 +199,7 @@ func TestResolver_CreateVRFKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf.On("Create", mock.Anything).Return(fakeKey, nil) f.Mocks.keystore.On("VRF").Return(f.Mocks.vrf) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -261,7 +262,7 @@ func TestResolver_DeleteVRFKey(t *testing.T) { { name: "success", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf.On("Delete", mock.Anything, fakeKey.PublicKey.String()).Return(fakeKey, nil) f.Mocks.keystore.On("VRF").Return(f.Mocks.vrf) f.App.On("GetKeyStore").Return(f.Mocks.keystore) @@ -273,7 +274,7 @@ func TestResolver_DeleteVRFKey(t *testing.T) { { name: "not found error", authenticated: true, - before: func(f *gqlTestFramework) { + before: func(ctx context.Context, f *gqlTestFramework) { f.Mocks.vrf. On("Delete", mock.Anything, fakeKey.PublicKey.String()). Return(vrfkey.KeyV2{}, errors.Wrapf( From eb6b50d31323c324aaa2bf8d1cf465f97a7893fd Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Wed, 22 May 2024 10:19:06 -0700 Subject: [PATCH 04/16] EVM encoder support for tuples (#13202) credit goes to @archseer --- .changeset/silent-jars-relax.md | 5 + core/chains/evm/abi/selector_parser.go | 148 ++++++------------- core/chains/evm/abi/selector_parser_test.go | 47 +++++- core/services/relay/evm/cap_encoder.go | 2 +- core/services/relay/evm/cap_encoder_test.go | 149 ++++++++++++++++++-- core/services/relay/evm/codec_test.go | 69 ++++++++- 6 files changed, 301 insertions(+), 119 deletions(-) create mode 100644 .changeset/silent-jars-relax.md diff --git a/.changeset/silent-jars-relax.md b/.changeset/silent-jars-relax.md new file mode 100644 index 00000000000..3b076da1226 --- /dev/null +++ b/.changeset/silent-jars-relax.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +#internal [Keystone] EVM encoder support for tuples diff --git a/core/chains/evm/abi/selector_parser.go b/core/chains/evm/abi/selector_parser.go index 30e687ba33a..329ed6eb181 100644 --- a/core/chains/evm/abi/selector_parser.go +++ b/core/chains/evm/abi/selector_parser.go @@ -1,5 +1,5 @@ -// Sourced from https://github.com/ethereum/go-ethereum/blob/fe91d476ba3e29316b6dc99b6efd4a571481d888/accounts/abi/selector_parser.go#L126 -// Modified assembleArgs to retain argument names +// Originally sourced from https://github.com/ethereum/go-ethereum/blob/fe91d476ba3e29316b6dc99b6efd4a571481d888/accounts/abi/selector_parser.go#L126 +// Modified to suppor parsing selectors with argument names // Copyright 2022 The go-ethereum Authors // This file is part of the go-ethereum library. @@ -83,55 +83,24 @@ func parseElementaryType(unescapedSelector string) (string, string, error) { return parsedType, rest, nil } -func parseCompositeType(unescapedSelector string) ([]interface{}, string, error) { +func parseCompositeType(unescapedSelector string) ([]abi.ArgumentMarshaling, string, error) { if len(unescapedSelector) == 0 || unescapedSelector[0] != '(' { return nil, "", fmt.Errorf("expected '(', got %c", unescapedSelector[0]) } - parsedType, rest, err := parseType(unescapedSelector[1:]) - if err != nil { - return nil, "", fmt.Errorf("failed to parse type: %v", err) - } - result := []interface{}{parsedType} - for len(rest) > 0 && rest[0] != ')' { - parsedType, rest, err = parseType(rest[1:]) - if err != nil { - return nil, "", fmt.Errorf("failed to parse type: %v", err) - } - result = append(result, parsedType) - } - if len(rest) == 0 || rest[0] != ')' { - return nil, "", fmt.Errorf("expected ')', got '%s'", rest) - } - if len(rest) >= 3 && rest[1] == '[' && rest[2] == ']' { - return append(result, "[]"), rest[3:], nil - } - return result, rest[1:], nil -} - -func parseType(unescapedSelector string) (interface{}, string, error) { - if len(unescapedSelector) == 0 { - return nil, "", errors.New("empty type") - } - if unescapedSelector[0] == '(' { - return parseCompositeType(unescapedSelector) - } - return parseElementaryType(unescapedSelector) -} - -func parseArgs(unescapedSelector string) ([]abi.ArgumentMarshaling, error) { - if len(unescapedSelector) == 0 || unescapedSelector[0] != '(' { - return nil, fmt.Errorf("expected '(', got %c", unescapedSelector[0]) - } + rest := unescapedSelector[1:] // skip over the opening `(` result := []abi.ArgumentMarshaling{} - rest := unescapedSelector[1:] var parsedType any var err error + i := 0 for len(rest) > 0 && rest[0] != ')' { - // parse method name - var name string - name, rest, err = parseIdentifier(rest[:]) + // skip any leading whitespace + for rest[0] == ' ' { + rest = rest[1:] + } + + parsedType, rest, err = parseType(rest[0:]) if err != nil { - return nil, fmt.Errorf("failed to parse name: %v", err) + return nil, "", fmt.Errorf("failed to parse type: %v", err) } // skip whitespace between name and identifier @@ -139,40 +108,58 @@ func parseArgs(unescapedSelector string) ([]abi.ArgumentMarshaling, error) { rest = rest[1:] } - // parse type - parsedType, rest, err = parseType(rest[:]) - if err != nil { - return nil, fmt.Errorf("failed to parse type: %v", err) + name := fmt.Sprintf("name%d", i) + // if we're at a delimiter the parameter is unnamed + if !(rest[0] == ',' || rest[0] == ')') { + // attempt to parse name + name, rest, err = parseIdentifier(rest[:]) + if err != nil { + return nil, "", fmt.Errorf("failed to parse name: %v", err) + } } arg, err := assembleArg(name, parsedType) if err != nil { - return nil, fmt.Errorf("failed to parse type: %v", err) + return nil, "", fmt.Errorf("failed to parse type: %v", err) } result = append(result, arg) + i++ + // skip trailing whitespace, consume comma for rest[0] == ' ' || rest[0] == ',' { rest = rest[1:] } } if len(rest) == 0 || rest[0] != ')' { - return nil, fmt.Errorf("expected ')', got '%s'", rest) + return nil, "", fmt.Errorf("expected ')', got '%s'", rest) } - if len(rest) > 1 { - return nil, fmt.Errorf("failed to parse selector '%s': unexpected string '%s'", unescapedSelector, rest) + if len(rest) >= 3 && rest[1] == '[' && rest[2] == ']' { + // emits a sentinel value that later gets removed when assembling + array, err := assembleArg("", "[]") + if err != nil { + panic("unreachable") + } + return append(result, array), rest[3:], nil } - return result, nil + return result, rest[1:], nil +} + +// type-name rule +func parseType(unescapedSelector string) (interface{}, string, error) { + if len(unescapedSelector) == 0 { + return nil, "", errors.New("empty type") + } + if unescapedSelector[0] == '(' { + return parseCompositeType(unescapedSelector) + } + return parseElementaryType(unescapedSelector) } func assembleArg(name string, arg any) (abi.ArgumentMarshaling, error) { if s, ok := arg.(string); ok { return abi.ArgumentMarshaling{Name: name, Type: s, InternalType: s, Components: nil, Indexed: false}, nil - } else if components, ok := arg.([]interface{}); ok { - subArgs, err := assembleArgs(components) - if err != nil { - return abi.ArgumentMarshaling{}, fmt.Errorf("failed to assemble components: %v", err) - } + } else if subArgs, ok := arg.([]abi.ArgumentMarshaling); ok { tupleType := "tuple" if len(subArgs) != 0 && subArgs[len(subArgs)-1].Type == "[]" { subArgs = subArgs[:len(subArgs)-1] @@ -183,20 +170,6 @@ func assembleArg(name string, arg any) (abi.ArgumentMarshaling, error) { return abi.ArgumentMarshaling{}, fmt.Errorf("failed to assemble args: unexpected type %T", arg) } -func assembleArgs(args []interface{}) ([]abi.ArgumentMarshaling, error) { - arguments := make([]abi.ArgumentMarshaling, 0) - for i, arg := range args { - // generate dummy name to avoid unmarshal issues - name := fmt.Sprintf("name%d", i) - arg, err := assembleArg(name, arg) - if err != nil { - return nil, err - } - arguments = append(arguments, arg) - } - return arguments, nil -} - // ParseSelector converts a method selector into a struct that can be JSON encoded // and consumed by other functions in this package. // Note, although uppercase letters are not part of the ABI spec, this function @@ -204,46 +177,19 @@ func assembleArgs(args []interface{}) ([]abi.ArgumentMarshaling, error) { func ParseSelector(unescapedSelector string) (abi.SelectorMarshaling, error) { name, rest, err := parseIdentifier(unescapedSelector) if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': %v", unescapedSelector, err) + return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector identifier '%s': %v", unescapedSelector, err) } - args := []interface{}{} + args := []abi.ArgumentMarshaling{} if len(rest) >= 2 && rest[0] == '(' && rest[1] == ')' { rest = rest[2:] } else { args, rest, err = parseCompositeType(rest) if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': %v", unescapedSelector, err) + return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector args '%s': %v", unescapedSelector, err) } } if len(rest) > 0 { return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': unexpected string '%s'", unescapedSelector, rest) } - - // Reassemble the fake ABI and construct the JSON - fakeArgs, err := assembleArgs(args) - if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector: %v", err) - } - - return abi.SelectorMarshaling{Name: name, Type: "function", Inputs: fakeArgs}, nil -} - -// ParseSelector converts a method selector into a struct that can be JSON encoded -// and consumed by other functions in this package. -// Note, although uppercase letters are not part of the ABI spec, this function -// still accepts it as the general format is valid. -func ParseSignature(unescapedSelector string) (abi.SelectorMarshaling, error) { - name, rest, err := parseIdentifier(unescapedSelector) - if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': %v", unescapedSelector, err) - } - args := []abi.ArgumentMarshaling{} - if len(rest) < 2 || rest[0] != '(' || rest[1] != ')' { - args, err = parseArgs(rest) - if err != nil { - return abi.SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': %v", unescapedSelector, err) - } - } - return abi.SelectorMarshaling{Name: name, Type: "function", Inputs: args}, nil } diff --git a/core/chains/evm/abi/selector_parser_test.go b/core/chains/evm/abi/selector_parser_test.go index caae3744678..8ef37a193f8 100644 --- a/core/chains/evm/abi/selector_parser_test.go +++ b/core/chains/evm/abi/selector_parser_test.go @@ -25,6 +25,8 @@ import ( "testing" "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParseSelector(t *testing.T) { @@ -83,7 +85,7 @@ func TestParseSelector(t *testing.T) { } } -func TestParseSignature(t *testing.T) { +func TestParseSelectorWithNames(t *testing.T) { t.Parallel() mkType := func(name string, typeOrComponents interface{}) abi.ArgumentMarshaling { if typeName, ok := typeOrComponents.(string); ok { @@ -102,13 +104,14 @@ func TestParseSignature(t *testing.T) { args []abi.ArgumentMarshaling }{ {"noargs()", "noargs", []abi.ArgumentMarshaling{}}, - {"simple(a uint256, b uint256, c uint256)", "simple", []abi.ArgumentMarshaling{mkType("a", "uint256"), mkType("b", "uint256"), mkType("c", "uint256")}}, - {"other(foo uint256, bar address)", "other", []abi.ArgumentMarshaling{mkType("foo", "uint256"), mkType("bar", "address")}}, - {"withArray(a uint256[], b address[2], c uint8[4][][5])", "withArray", []abi.ArgumentMarshaling{mkType("a", "uint256[]"), mkType("b", "address[2]"), mkType("c", "uint8[4][][5]")}}, - {"singleNest(d bytes32, e uint8, f (uint256,uint256), g address)", "singleNest", []abi.ArgumentMarshaling{mkType("d", "bytes32"), mkType("e", "uint8"), mkType("f", []abi.ArgumentMarshaling{mkType("name0", "uint256"), mkType("name1", "uint256")}), mkType("g", "address")}}, + {"simple(uint256 a , uint256 b, uint256 c)", "simple", []abi.ArgumentMarshaling{mkType("a", "uint256"), mkType("b", "uint256"), mkType("c", "uint256")}}, + {"other(uint256 foo, address bar )", "other", []abi.ArgumentMarshaling{mkType("foo", "uint256"), mkType("bar", "address")}}, + {"withArray(uint256[] a, address[2] b, uint8[4][][5] c)", "withArray", []abi.ArgumentMarshaling{mkType("a", "uint256[]"), mkType("b", "address[2]"), mkType("c", "uint8[4][][5]")}}, + {"singleNest(bytes32 d, uint8 e, (uint256,uint256) f, address g)", "singleNest", []abi.ArgumentMarshaling{mkType("d", "bytes32"), mkType("e", "uint8"), mkType("f", []abi.ArgumentMarshaling{mkType("name0", "uint256"), mkType("name1", "uint256")}), mkType("g", "address")}}, + {"singleNest(bytes32 d, uint8 e, (uint256 first, uint256 second ) f, address g)", "singleNest", []abi.ArgumentMarshaling{mkType("d", "bytes32"), mkType("e", "uint8"), mkType("f", []abi.ArgumentMarshaling{mkType("first", "uint256"), mkType("second", "uint256")}), mkType("g", "address")}}, } for i, tt := range tests { - selector, err := ParseSignature(tt.input) + selector, err := ParseSelector(tt.input) if err != nil { t.Errorf("test %d: failed to parse selector '%v': %v", i, tt.input, err) } @@ -124,3 +127,35 @@ func TestParseSignature(t *testing.T) { } } } + +func TestParseSelectorErrors(t *testing.T) { + type errorTestCases struct { + description string + input string + expectedError string + } + + for _, scenario := range []errorTestCases{ + { + description: "invalid name", + input: "123()", + expectedError: "failed to parse selector identifier '123()': invalid token start: 1", + }, + { + description: "missing closing parenthesis", + input: "noargs(", + expectedError: "failed to parse selector args 'noargs(': expected ')', got ''", + }, + { + description: "missing opening parenthesis", + input: "noargs)", + expectedError: "failed to parse selector args 'noargs)': expected '(', got )", + }, + } { + t.Run(scenario.description, func(t *testing.T) { + _, err := ParseSelector(scenario.input) + require.Error(t, err) + assert.Equal(t, scenario.expectedError, err.Error()) + }) + } +} diff --git a/core/services/relay/evm/cap_encoder.go b/core/services/relay/evm/cap_encoder.go index 00fa3bc8773..e0e3a2cf0f5 100644 --- a/core/services/relay/evm/cap_encoder.go +++ b/core/services/relay/evm/cap_encoder.go @@ -35,7 +35,7 @@ func NewEVMEncoder(config *values.Map) (consensustypes.Encoder, error) { if !ok { return nil, fmt.Errorf("expected %s to be a string", abiConfigFieldName) } - selector, err := abiutil.ParseSignature("inner(" + selectorStr + ")") + selector, err := abiutil.ParseSelector("inner(" + selectorStr + ")") if err != nil { return nil, err } diff --git a/core/services/relay/evm/cap_encoder_test.go b/core/services/relay/evm/cap_encoder_test.go index 10a19fd962b..8c56fb9075a 100644 --- a/core/services/relay/evm/cap_encoder_test.go +++ b/core/services/relay/evm/cap_encoder_test.go @@ -2,6 +2,7 @@ package evm_test import ( "encoding/hex" + "math/big" "testing" "github.com/stretchr/testify/assert" @@ -27,9 +28,9 @@ var ( wrongLength = "8d4e66" ) -func TestEVMEncoder(t *testing.T) { +func TestEVMEncoder_SingleField(t *testing.T) { config := map[string]any{ - "abi": "mercury_reports bytes[]", + "abi": "bytes[] Full_reports", } wrapped, err := values.NewMap(config) require.NoError(t, err) @@ -38,7 +39,7 @@ func TestEVMEncoder(t *testing.T) { // output of a DF2.0 aggregator + metadata fields appended by OCR input := map[string]any{ - "mercury_reports": []any{reportA, reportB}, + "Full_reports": []any{reportA, reportB}, consensustypes.WorkflowIDFieldName: workflowID, consensustypes.ExecutionIDFieldName: executionID, } @@ -48,28 +49,156 @@ func TestEVMEncoder(t *testing.T) { require.NoError(t, err) expected := - // start of the outer tuple ((user_fields), workflow_id, workflow_execution_id) + // start of the outer tuple workflowID + donID + executionID + workflowOwnerID + // start of the inner tuple (user_fields) - "0000000000000000000000000000000000000000000000000000000000000020" + // offset of mercury_reports array - "0000000000000000000000000000000000000000000000000000000000000002" + // length of mercury_reports array + "0000000000000000000000000000000000000000000000000000000000000020" + // offset of Full_reports array + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Full_reports array "0000000000000000000000000000000000000000000000000000000000000040" + // offset of reportA "0000000000000000000000000000000000000000000000000000000000000080" + // offset of reportB "0000000000000000000000000000000000000000000000000000000000000003" + // length of reportA "0102030000000000000000000000000000000000000000000000000000000000" + // reportA "0000000000000000000000000000000000000000000000000000000000000004" + // length of reportB "aabbccdd00000000000000000000000000000000000000000000000000000000" // reportB - // end of the inner tuple (user_fields) + + require.Equal(t, expected, hex.EncodeToString(encoded)) +} + +func TestEVMEncoder_TwoFields(t *testing.T) { + config := map[string]any{ + "abi": "uint256[] Prices, uint32[] Timestamps", + } + wrapped, err := values.NewMap(config) + require.NoError(t, err) + enc, err := evm.NewEVMEncoder(wrapped) + require.NoError(t, err) + + // output of a DF2.0 aggregator + metadata fields appended by OCR + input := map[string]any{ + "Prices": []any{big.NewInt(234), big.NewInt(456)}, + "Timestamps": []any{int64(111), int64(222)}, + consensustypes.WorkflowIDFieldName: workflowID, + consensustypes.ExecutionIDFieldName: executionID, + } + wrapped, err = values.NewMap(input) + require.NoError(t, err) + encoded, err := enc.Encode(testutils.Context(t), *wrapped) + require.NoError(t, err) + + expected := + // start of the outer tuple + workflowID + + donID + + executionID + + workflowOwnerID + + // start of the inner tuple (user_fields) + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of Prices array + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset of Timestamps array + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Prices array + "00000000000000000000000000000000000000000000000000000000000000ea" + // Prices[0] + "00000000000000000000000000000000000000000000000000000000000001c8" + // Prices[1] + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Timestamps array + "000000000000000000000000000000000000000000000000000000000000006f" + // Timestamps[0] + "00000000000000000000000000000000000000000000000000000000000000de" // Timestamps[1] + + require.Equal(t, expected, hex.EncodeToString(encoded)) +} + +func TestEVMEncoder_Tuple(t *testing.T) { + config := map[string]any{ + "abi": "(uint256[] Prices, uint32[] Timestamps) Elem", + } + wrapped, err := values.NewMap(config) + require.NoError(t, err) + enc, err := evm.NewEVMEncoder(wrapped) + require.NoError(t, err) + + // output of a DF2.0 aggregator + metadata fields appended by OCR + input := map[string]any{ + "Elem": map[string]any{ + "Prices": []any{big.NewInt(234), big.NewInt(456)}, + "Timestamps": []any{int64(111), int64(222)}, + }, + consensustypes.WorkflowIDFieldName: workflowID, + consensustypes.ExecutionIDFieldName: executionID, + } + wrapped, err = values.NewMap(input) + require.NoError(t, err) + encoded, err := enc.Encode(testutils.Context(t), *wrapped) + require.NoError(t, err) + + expected := + // start of the outer tuple + workflowID + + donID + + executionID + + workflowOwnerID + + // start of the inner tuple (user_fields) + "0000000000000000000000000000000000000000000000000000000000000020" + // offset of Elem tuple + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of Prices array + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset of Timestamps array + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Prices array + "00000000000000000000000000000000000000000000000000000000000000ea" + // Prices[0] = 234 + "00000000000000000000000000000000000000000000000000000000000001c8" + // Prices[1] = 456 + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Timestamps array + "000000000000000000000000000000000000000000000000000000000000006f" + // Timestamps[0] = 111 + "00000000000000000000000000000000000000000000000000000000000000de" // Timestamps[1] = 222 + + require.Equal(t, expected, hex.EncodeToString(encoded)) +} + +func TestEVMEncoder_ListOfTuples(t *testing.T) { + config := map[string]any{ + "abi": "(uint256 Price, uint32 Timestamp)[] Elems", + } + wrapped, err := values.NewMap(config) + require.NoError(t, err) + enc, err := evm.NewEVMEncoder(wrapped) + require.NoError(t, err) + + // output of a DF2.0 aggregator + metadata fields appended by OCR + input := map[string]any{ + "Elems": []any{ + map[string]any{ + "Price": big.NewInt(234), + "Timestamp": int64(111), + }, + map[string]any{ + "Price": big.NewInt(456), + "Timestamp": int64(222), + }, + }, + consensustypes.WorkflowIDFieldName: workflowID, + consensustypes.ExecutionIDFieldName: executionID, + } + wrapped, err = values.NewMap(input) + require.NoError(t, err) + encoded, err := enc.Encode(testutils.Context(t), *wrapped) + require.NoError(t, err) + + expected := + // start of the outer tuple + workflowID + + donID + + executionID + + workflowOwnerID + + // start of the inner tuple (user_fields) + "0000000000000000000000000000000000000000000000000000000000000020" + // offset of Elem list + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Elem list + "00000000000000000000000000000000000000000000000000000000000000ea" + // Elem[0].Price = 234 + "000000000000000000000000000000000000000000000000000000000000006f" + // Elem[0].Timestamp = 111 + "00000000000000000000000000000000000000000000000000000000000001c8" + // Elem[1].Price = 456 + "00000000000000000000000000000000000000000000000000000000000000de" // Elem[1].Timestamp = 222 require.Equal(t, expected, hex.EncodeToString(encoded)) } func TestEVMEncoder_InvalidIDs(t *testing.T) { config := map[string]any{ - "abi": "mercury_reports bytes[]", + "abi": "bytes[] Full_reports", } wrapped, err := values.NewMap(config) require.NoError(t, err) @@ -79,7 +208,7 @@ func TestEVMEncoder_InvalidIDs(t *testing.T) { // output of a DF2.0 aggregator + metadata fields appended by OCR // using an invalid ID input := map[string]any{ - "mercury_reports": []any{reportA, reportB}, + "Full_reports": []any{reportA, reportB}, consensustypes.WorkflowIDFieldName: invalidID, consensustypes.ExecutionIDFieldName: executionID, } @@ -90,7 +219,7 @@ func TestEVMEncoder_InvalidIDs(t *testing.T) { // using valid hex string of wrong length input = map[string]any{ - "mercury_reports": []any{reportA, reportB}, + "full_reports": []any{reportA, reportB}, consensustypes.WorkflowIDFieldName: wrongLength, consensustypes.ExecutionIDFieldName: executionID, } diff --git a/core/services/relay/evm/codec_test.go b/core/services/relay/evm/codec_test.go index 0597560aaec..0773b274107 100644 --- a/core/services/relay/evm/codec_test.go +++ b/core/services/relay/evm/codec_test.go @@ -2,6 +2,7 @@ package evm_test import ( "encoding/json" + "math/big" "testing" "github.com/ethereum/go-ethereum/accounts/abi" @@ -68,13 +69,79 @@ func TestCodec_SimpleEncode(t *testing.T) { require.NoError(t, err) expected := "0000000000000000000000000000000000000000000000000000000000000006" + // int32(6) - "0000000000000000000000000000000000000000000000000000000000000040" + // total bytes occupied by the string (64) + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of the beginning of second value (64 bytes) "0000000000000000000000000000000000000000000000000000000000000007" + // length of the string (7 chars) "6162636465666700000000000000000000000000000000000000000000000000" // actual string require.Equal(t, expected, hexutil.Encode(result)[2:]) } +func TestCodec_EncodeTuple(t *testing.T) { + codecName := "my_codec" + input := map[string]any{ + "Report": int32(6), + "Nested": map[string]any{ + "Meta": "abcdefg", + "Count": int32(14), + "Other": "12334", + }, + } + evmEncoderConfig := `[{"Name":"Report","Type":"int32"},{"Name":"Nested","Type":"tuple","Components":[{"Name":"Other","Type":"string"},{"Name":"Count","Type":"int32"},{"Name":"Meta","Type":"string"}]}]` + + codecConfig := types.CodecConfig{Configs: map[string]types.ChainCodecConfig{ + codecName: {TypeABI: evmEncoderConfig}, + }} + c, err := evm.NewCodec(codecConfig) + require.NoError(t, err) + + result, err := c.Encode(testutils.Context(t), input, codecName) + require.NoError(t, err) + expected := + "0000000000000000000000000000000000000000000000000000000000000006" + // Report integer (=6) + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of the first dynamic value (tuple, 64 bytes) + "0000000000000000000000000000000000000000000000000000000000000060" + // offset of the first nested dynamic value (string, 96 bytes) + "000000000000000000000000000000000000000000000000000000000000000e" + // "Count" integer (=14) + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset of the second nested dynamic value (string, 160 bytes) + "0000000000000000000000000000000000000000000000000000000000000005" + // length of the "Meta" string (5 chars) + "3132333334000000000000000000000000000000000000000000000000000000" + // "Other" string (="12334") + "0000000000000000000000000000000000000000000000000000000000000007" + // length of the "Other" string (7 chars) + "6162636465666700000000000000000000000000000000000000000000000000" // "Meta" string (="abcdefg") + + require.Equal(t, expected, hexutil.Encode(result)[2:]) +} + +func TestCodec_EncodeTupleWithLists(t *testing.T) { + codecName := "my_codec" + input := map[string]any{ + "Elem": map[string]any{ + "Prices": []any{big.NewInt(234), big.NewInt(456)}, + "Timestamps": []any{int64(111), int64(222)}, + }, + } + evmEncoderConfig := `[{"Name":"Elem","Type":"tuple","InternalType":"tuple","Components":[{"Name":"Prices","Type":"uint256[]","InternalType":"uint256[]","Components":null,"Indexed":false},{"Name":"Timestamps","Type":"uint32[]","InternalType":"uint32[]","Components":null,"Indexed":false}],"Indexed":false}]` + + codecConfig := types.CodecConfig{Configs: map[string]types.ChainCodecConfig{ + codecName: {TypeABI: evmEncoderConfig}, + }} + c, err := evm.NewCodec(codecConfig) + require.NoError(t, err) + + result, err := c.Encode(testutils.Context(t), input, codecName) + require.NoError(t, err) + expected := + "0000000000000000000000000000000000000000000000000000000000000020" + // offset of Elem tuple + "0000000000000000000000000000000000000000000000000000000000000040" + // offset of Prices array + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset of Timestamps array + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Prices array + "00000000000000000000000000000000000000000000000000000000000000ea" + // Prices[0] = 234 + "00000000000000000000000000000000000000000000000000000000000001c8" + // Prices[1] = 456 + "0000000000000000000000000000000000000000000000000000000000000002" + // length of Timestamps array + "000000000000000000000000000000000000000000000000000000000000006f" + // Timestamps[0] = 111 + "00000000000000000000000000000000000000000000000000000000000000de" // Timestamps[1] = 222 + + require.Equal(t, expected, hexutil.Encode(result)[2:]) +} + type codecInterfaceTester struct{} func (it *codecInterfaceTester) Setup(_ *testing.T) {} From 087e2a744f59d12330531ff0d8e1aa99fca2b290 Mon Sep 17 00:00:00 2001 From: Ilja Pavlovs Date: Wed, 22 May 2024 20:37:13 +0300 Subject: [PATCH 05/16] =?UTF-8?q?VRF-1102:=20disable=20tests=20and=20check?= =?UTF-8?q?s=20depending=20on=20network=20type=20(Simulat=E2=80=A6=20(#132?= =?UTF-8?q?90)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * VRF-1102: disable tests and checks depending on network type (Simulated or live testnet) * VRF-1102: fixing tests --- integration-tests/smoke/vrfv2_test.go | 27 +++++++++++--------- integration-tests/smoke/vrfv2plus_test.go | 30 ++++++++++++----------- integration-tests/testconfig/default.toml | 2 +- 3 files changed, 33 insertions(+), 26 deletions(-) diff --git a/integration-tests/smoke/vrfv2_test.go b/integration-tests/smoke/vrfv2_test.go index 88bf2ac110a..c3e516b093a 100644 --- a/integration-tests/smoke/vrfv2_test.go +++ b/integration-tests/smoke/vrfv2_test.go @@ -1055,8 +1055,11 @@ func TestVRFV2NodeReorg(t *testing.T) { config, err := tc.GetConfig("Smoke", tc.VRFv2) require.NoError(t, err, "Error getting config") vrfv2Config := config.VRFv2 - chainID := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0].ChainID - + network := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0] + if !network.Simulated { + t.Skip("Skipped since Reorg test could only be run on Simulated chain.") + } + chainID := network.ChainID cleanupFn := func() { sethClient, err := env.GetSethClient(chainID) require.NoError(t, err, "Getting Seth client shouldn't fail") @@ -1236,8 +1239,8 @@ func TestVRFv2BatchFulfillmentEnabledDisabled(t *testing.T) { config, err := tc.GetConfig("Smoke", tc.VRFv2) require.NoError(t, err, "Error getting config") vrfv2Config := config.VRFv2 - chainID := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0].ChainID - + network := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0] + chainID := network.ChainID cleanupFn := func() { sethClient, err := env.GetSethClient(chainID) require.NoError(t, err, "Getting Seth client shouldn't fail") @@ -1395,14 +1398,16 @@ func TestVRFv2BatchFulfillmentEnabledDisabled(t *testing.T) { // verify that VRF node sends fulfillments via BatchCoordinator contract require.Equal(t, vrfContracts.BatchCoordinatorV2.Address(), fulfillmentTXToAddress, "Fulfillment Tx To Address should be the BatchCoordinatorV2 Address when batch fulfillment is enabled") - fulfillmentTxReceipt, err := sethClient.Client.TransactionReceipt(testcontext.Get(t), fulfillmentTx.Hash()) - require.NoError(t, err) - - randomWordsFulfilledLogs, err := contracts.ParseRandomWordsFulfilledLogs(vrfContracts.CoordinatorV2, fulfillmentTxReceipt.Logs) - require.NoError(t, err) - // verify that all fulfillments should be inside one tx - require.Equal(t, int(randRequestCount), len(randomWordsFulfilledLogs)) + // This check is disabled for live testnets since each testnet has different gas usage for similar tx + if network.Simulated { + fulfillmentTxReceipt, err := sethClient.Client.TransactionReceipt(testcontext.Get(t), fulfillmentTx.Hash()) + require.NoError(t, err) + randomWordsFulfilledLogs, err := contracts.ParseRandomWordsFulfilledLogs(vrfContracts.CoordinatorV2, fulfillmentTxReceipt.Logs) + require.NoError(t, err) + require.Equal(t, 1, len(batchFulfillmentTxs)) + require.Equal(t, int(randRequestCount), len(randomWordsFulfilledLogs)) + } }) t.Run("Batch Fulfillment Disabled", func(t *testing.T) { configCopy := config.MustCopy().(tc.TestConfig) diff --git a/integration-tests/smoke/vrfv2plus_test.go b/integration-tests/smoke/vrfv2plus_test.go index c17a09dcf7c..a4b542004d2 100644 --- a/integration-tests/smoke/vrfv2plus_test.go +++ b/integration-tests/smoke/vrfv2plus_test.go @@ -1956,8 +1956,11 @@ func TestVRFv2PlusNodeReorg(t *testing.T) { config, err := tc.GetConfig("Smoke", tc.VRFv2Plus) require.NoError(t, err, "Error getting config") vrfv2PlusConfig := config.VRFv2Plus - chainID := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0].ChainID - + network := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0] + if !network.Simulated { + t.Skip("Skipped since Reorg test could only be run on Simulated chain.") + } + chainID := network.ChainID cleanupFn := func() { sethClient, err := env.GetSethClient(chainID) require.NoError(t, err, "Getting Seth client shouldn't fail") @@ -2136,8 +2139,8 @@ func TestVRFv2PlusBatchFulfillmentEnabledDisabled(t *testing.T) { config, err := tc.GetConfig("Smoke", tc.VRFv2Plus) require.NoError(t, err, "Error getting config") vrfv2PlusConfig := config.VRFv2Plus - chainID := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0].ChainID - + network := networks.MustGetSelectedNetworkConfig(config.GetNetworkConfig())[0] + chainID := network.ChainID cleanupFn := func() { sethClient, err := env.GetSethClient(chainID) require.NoError(t, err, "Getting Seth client shouldn't fail") @@ -2277,9 +2280,6 @@ func TestVRFv2PlusBatchFulfillmentEnabledDisabled(t *testing.T) { batchFulfillmentTxs = append(batchFulfillmentTxs, tx) } } - // verify that all fulfillments should be inside one tx - require.Equal(t, 1, len(batchFulfillmentTxs)) - fulfillmentTx, _, err := sethClient.Client.TransactionByHash(testcontext.Get(t), randomWordsFulfilledEvent.Raw.TxHash) require.NoError(t, err, "error getting tx from hash") @@ -2292,14 +2292,16 @@ func TestVRFv2PlusBatchFulfillmentEnabledDisabled(t *testing.T) { // verify that VRF node sends fulfillments via BatchCoordinator contract require.Equal(t, vrfContracts.BatchCoordinatorV2Plus.Address(), fulfillmentTXToAddress, "Fulfillment Tx To Address should be the BatchCoordinatorV2Plus Address when batch fulfillment is enabled") - fulfillmentTxReceipt, err := sethClient.Client.TransactionReceipt(testcontext.Get(t), fulfillmentTx.Hash()) - require.NoError(t, err) - - randomWordsFulfilledLogs, err := contracts.ParseRandomWordsFulfilledLogs(vrfContracts.CoordinatorV2Plus, fulfillmentTxReceipt.Logs) - require.NoError(t, err) - // verify that all fulfillments should be inside one tx - require.Equal(t, int(randRequestCount), len(randomWordsFulfilledLogs)) + // This check is disabled for live testnets since each testnet has different gas usage for similar tx + if network.Simulated { + fulfillmentTxReceipt, err := sethClient.Client.TransactionReceipt(testcontext.Get(t), fulfillmentTx.Hash()) + require.NoError(t, err) + randomWordsFulfilledLogs, err := contracts.ParseRandomWordsFulfilledLogs(vrfContracts.CoordinatorV2Plus, fulfillmentTxReceipt.Logs) + require.NoError(t, err) + require.Equal(t, 1, len(batchFulfillmentTxs)) + require.Equal(t, int(randRequestCount), len(randomWordsFulfilledLogs)) + } }) t.Run("Batch Fulfillment Disabled", func(t *testing.T) { configCopy := config.MustCopy().(tc.TestConfig) diff --git a/integration-tests/testconfig/default.toml b/integration-tests/testconfig/default.toml index f5faf23d4e8..0c8d411c04a 100644 --- a/integration-tests/testconfig/default.toml +++ b/integration-tests/testconfig/default.toml @@ -380,7 +380,7 @@ gas_fee_cap = 30_000_000_000 gas_tip_cap = 1_000_000_000 # how many last blocks to use, when estimating gas for a transaction -gas_price_estimation_blocks = 50 +gas_price_estimation_blocks = 30 # priority of the transaction, can be "fast", "standard" or "slow" (the higher the priority, the higher adjustment factor will be used for gas estimation) [default: "standard"] gas_price_estimation_tx_priority = "standard" From 61391260340ba74f3510e6ded4fdace6829630b7 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Wed, 22 May 2024 12:54:24 -0500 Subject: [PATCH 06/16] Pipeline Data Corruption (#13286) * Pipeline Data Corruption The unit test `TestDivide_Example` was acting flakey in the CI pipeline which suggested a flaw in the divide and multiply operations. When running the test, the expected result would be one of the input values or the division result in failure cases. This implied that results were either received out of order or were being sorted incorrectly. The pipeline runner does a final sort on the results, so that ruled out the received out of order possibility. On inspection of the sorting index on each task, every index was the zero value. This resulted in occasional correct and incorrect sorting, causing the test flake. To correct the problem, the test was updated such that the expected result has an index of `1`, leaving all other tasks with a `0` index. * fix test * updated changeset --- .changeset/blue-camels-begin.md | 5 +++ core/services/ocr/validate_test.go | 52 ++++++++++++++++++++++ core/services/pipeline/common.go | 19 ++++++++ core/services/pipeline/graph.go | 11 +++++ core/services/pipeline/graph_test.go | 12 ++--- core/services/pipeline/task.divide_test.go | 19 ++++---- 6 files changed, 103 insertions(+), 15 deletions(-) create mode 100644 .changeset/blue-camels-begin.md diff --git a/.changeset/blue-camels-begin.md b/.changeset/blue-camels-begin.md new file mode 100644 index 00000000000..3ad57286e91 --- /dev/null +++ b/.changeset/blue-camels-begin.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +enforce proper result indexing on pipeline results #breaking_change diff --git a/core/services/ocr/validate_test.go b/core/services/ocr/validate_test.go index 6e68559d09d..59213c7168e 100644 --- a/core/services/ocr/validate_test.go +++ b/core/services/ocr/validate_test.go @@ -27,6 +27,58 @@ func TestValidateOracleSpec(t *testing.T) { overrides func(c *chainlink.Config, s *chainlink.Secrets) assertion func(t *testing.T, os job.Job, err error) }{ + { + name: "invalid result sorting index", + toml: ` +ds1 [type=memo value=10000.1234]; +ds2 [type=memo value=100]; + +div_by_ds2 [type=divide divisor="$(ds2)"]; + +ds1 -> div_by_ds2 -> answer1; + +answer1 [type=multiply times=10000 index=-1]; +`, + assertion: func(t *testing.T, os job.Job, err error) { + require.Error(t, err) + }, + }, + { + name: "duplicate sorting indexes not allowed", + toml: ` +ds1 [type=memo value=10000.1234]; +ds2 [type=memo value=100]; + +div_by_ds2 [type=divide divisor="$(ds2)"]; + +ds1 -> div_by_ds2 -> answer1; +ds1 -> div_by_ds2 -> answer2; + +answer1 [type=multiply times=10000 index=0]; +answer2 [type=multiply times=10000 index=0]; +`, + assertion: func(t *testing.T, os job.Job, err error) { + require.Error(t, err) + }, + }, + { + name: "invalid result sorting index", + toml: ` +type = "offchainreporting" +schemaVersion = 1 +contractAddress = "0x613a38AC1659769640aaE063C651F48E0250454C" +isBootstrapPeer = false +observationSource = """ +ds1 [type=bridge name=voter_turnout]; +ds1_parse [type=jsonparse path="one,two"]; +ds1_multiply [type=multiply times=1.23]; +ds1 -> ds1_parse -> ds1_multiply -> answer1; +answer1 [type=median index=-1]; +"""`, + assertion: func(t *testing.T, os job.Job, err error) { + require.Error(t, err) + }, + }, { name: "minimal non-bootstrap oracle spec", toml: ` diff --git a/core/services/pipeline/common.go b/core/services/pipeline/common.go index a0fc28c6862..5d843b8b918 100644 --- a/core/services/pipeline/common.go +++ b/core/services/pipeline/common.go @@ -415,9 +415,11 @@ func UnmarshalTaskFromMap(taskType TaskType, taskMap interface{}, ID int, dotID return nil, pkgerrors.Errorf(`unknown task type: "%v"`, taskType) } + metadata := mapstructure.Metadata{} decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ Result: task, WeaklyTypedInput: true, + Metadata: &metadata, DecodeHook: mapstructure.ComposeDecodeHookFunc( mapstructure.StringToTimeDurationHookFunc(), func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { @@ -441,6 +443,23 @@ func UnmarshalTaskFromMap(taskType TaskType, taskMap interface{}, ID int, dotID if err != nil { return nil, err } + + // valid explicit index values are 0-based + for _, key := range metadata.Keys { + if key == "index" { + if task.OutputIndex() < 0 { + return nil, errors.New("result sorting indexes should start with 0") + } + } + } + + // the 'unset' value should be -1 to allow explicit indexes to be 0-based + for _, key := range metadata.Unset { + if key == "index" { + task.Base().Index = -1 + } + } + return task, nil } diff --git a/core/services/pipeline/graph.go b/core/services/pipeline/graph.go index c3914e698c6..12bec2bc8b9 100644 --- a/core/services/pipeline/graph.go +++ b/core/services/pipeline/graph.go @@ -235,6 +235,8 @@ func Parse(text string) (*Pipeline, error) { // we need a temporary mapping of graph.IDs to positional ids after toposort ids := make(map[int64]int) + resultIdxs := make(map[int32]struct{}) + // use the new ordering as the id so that we can easily reproduce the original toposort for id, node := range nodes { node, is := node.(*GraphNode) @@ -251,6 +253,15 @@ func Parse(text string) (*Pipeline, error) { return nil, err } + if task.OutputIndex() > 0 { + _, exists := resultIdxs[task.OutputIndex()] + if exists { + return nil, errors.New("duplicate sorting indexes detected") + } + + resultIdxs[task.OutputIndex()] = struct{}{} + } + // re-link the edges for inputs := g.To(node.ID()); inputs.Next(); { isImplicitEdge := g.IsImplicitEdge(inputs.Node().ID(), node.ID()) diff --git a/core/services/pipeline/graph_test.go b/core/services/pipeline/graph_test.go index b3960bb1f46..c6248a38e24 100644 --- a/core/services/pipeline/graph_test.go +++ b/core/services/pipeline/graph_test.go @@ -171,27 +171,27 @@ func TestGraph_TasksInDependencyOrder(t *testing.T) { "ds1_multiply", []pipeline.TaskDependency{{PropagateResult: true, InputTask: pipeline.Task(ds1_parse)}}, []pipeline.Task{answer1}, - 0) + -1) ds2_multiply.BaseTask = pipeline.NewBaseTask( 5, "ds2_multiply", []pipeline.TaskDependency{{PropagateResult: true, InputTask: pipeline.Task(ds2_parse)}}, []pipeline.Task{answer1}, - 0) + -1) ds1_parse.BaseTask = pipeline.NewBaseTask( 1, "ds1_parse", []pipeline.TaskDependency{{PropagateResult: true, InputTask: pipeline.Task(ds1)}}, []pipeline.Task{ds1_multiply}, - 0) + -1) ds2_parse.BaseTask = pipeline.NewBaseTask( 4, "ds2_parse", []pipeline.TaskDependency{{PropagateResult: true, InputTask: pipeline.Task(ds2)}}, []pipeline.Task{ds2_multiply}, - 0) - ds1.BaseTask = pipeline.NewBaseTask(0, "ds1", nil, []pipeline.Task{ds1_parse}, 0) - ds2.BaseTask = pipeline.NewBaseTask(3, "ds2", nil, []pipeline.Task{ds2_parse}, 0) + -1) + ds1.BaseTask = pipeline.NewBaseTask(0, "ds1", nil, []pipeline.Task{ds1_parse}, -1) + ds2.BaseTask = pipeline.NewBaseTask(3, "ds2", nil, []pipeline.Task{ds2_parse}, -1) for i, task := range p.Tasks { // Make sure inputs appear before the task, and outputs don't diff --git a/core/services/pipeline/task.divide_test.go b/core/services/pipeline/task.divide_test.go index 3c2f57c7a07..8eb8e4de063 100644 --- a/core/services/pipeline/task.divide_test.go +++ b/core/services/pipeline/task.divide_test.go @@ -3,6 +3,7 @@ package pipeline_test import ( "fmt" "math" + "reflect" "testing" "github.com/pkg/errors" @@ -198,19 +199,17 @@ func TestDivideTask_Overflow(t *testing.T) { } func TestDivide_Example(t *testing.T) { - testutils.SkipFlakey(t, "BCF-3236") t.Parallel() dag := ` -ds1 [type=memo value=10000.1234] +ds1 [type=memo value=10000.1234]; +ds2 [type=memo value=100]; -ds2 [type=memo value=100] +div_by_ds2 [type=divide divisor="$(ds2)"]; +multiply [type=multiply times=10000 index=0]; -div_by_ds2 [type=divide divisor="$(ds2)"] +ds1 -> div_by_ds2 -> multiply; -multiply [type=multiply times=10000 index=0] - -ds1->div_by_ds2->multiply; ` db := pgtest.NewSqlxDB(t) @@ -223,12 +222,14 @@ ds1->div_by_ds2->multiply; lggr := logger.TestLogger(t) _, trrs, err := r.ExecuteRun(testutils.Context(t), spec, vars, lggr) - require.NoError(t, err) + require.NoError(t, err) require.Len(t, trrs, 4) finalResult := trrs[3] - assert.Nil(t, finalResult.Result.Error) + require.NoError(t, finalResult.Result.Error) + require.Equal(t, reflect.TypeOf(decimal.Decimal{}), reflect.TypeOf(finalResult.Result.Value)) + assert.Equal(t, "1000012.34", finalResult.Result.Value.(decimal.Decimal).String()) } From 5a87f4a59e3c6c92b08ebefc5090017693785729 Mon Sep 17 00:00:00 2001 From: Dylan Tinianov Date: Wed, 22 May 2024 14:22:15 -0400 Subject: [PATCH 07/16] Remove ClientErrors interface from common (#13279) * Move ClientErrors interface to common * Update blue-camels-promise.md * Remove client errors from common * Update blue-camels-promise.md * Delete clienterrors.go --- .changeset/blue-camels-promise.md | 5 +++++ common/client/node.go | 3 --- common/client/node_test.go | 6 ------ core/chains/evm/client/chain_client.go | 4 +--- core/chains/evm/client/errors.go | 3 ++- core/chains/evm/config/chain_scoped_node_pool.go | 4 +++- core/chains/evm/txmgr/client.go | 6 +++--- 7 files changed, 14 insertions(+), 17 deletions(-) create mode 100644 .changeset/blue-camels-promise.md diff --git a/.changeset/blue-camels-promise.md b/.changeset/blue-camels-promise.md new file mode 100644 index 00000000000..48d7fd47565 --- /dev/null +++ b/.changeset/blue-camels-promise.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +#changed Remove ClientErrors interface from common diff --git a/common/client/node.go b/common/client/node.go index bee1f3da3b8..1d0a799321b 100644 --- a/common/client/node.go +++ b/common/client/node.go @@ -9,8 +9,6 @@ import ( "sync" "time" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" - "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -47,7 +45,6 @@ type NodeConfig interface { SyncThreshold() uint32 NodeIsSyncingEnabled() bool FinalizedBlockPollInterval() time.Duration - Errors() config.ClientErrors } type ChainConfig interface { diff --git a/common/client/node_test.go b/common/client/node_test.go index 85c96145740..a97f26555a9 100644 --- a/common/client/node_test.go +++ b/common/client/node_test.go @@ -9,7 +9,6 @@ import ( clientMocks "github.com/smartcontractkit/chainlink/v2/common/client/mocks" "github.com/smartcontractkit/chainlink/v2/common/types" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" ) type testNodeConfig struct { @@ -19,7 +18,6 @@ type testNodeConfig struct { syncThreshold uint32 nodeIsSyncingEnabled bool finalizedBlockPollInterval time.Duration - errors config.ClientErrors } func (n testNodeConfig) PollFailureThreshold() uint32 { @@ -46,10 +44,6 @@ func (n testNodeConfig) FinalizedBlockPollInterval() time.Duration { return n.finalizedBlockPollInterval } -func (n testNodeConfig) Errors() config.ClientErrors { - return n.errors -} - type testNode struct { *node[types.ID, Head, NodeClient[types.ID, Head]] } diff --git a/core/chains/evm/client/chain_client.go b/core/chains/evm/client/chain_client.go index 3ee10a600da..04b1ff29387 100644 --- a/core/chains/evm/client/chain_client.go +++ b/core/chains/evm/client/chain_client.go @@ -5,8 +5,6 @@ import ( "math/big" "time" - evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" - "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -14,10 +12,10 @@ import ( commonassets "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/logger" - commonclient "github.com/smartcontractkit/chainlink/v2/common/client" "github.com/smartcontractkit/chainlink/v2/common/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" + evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) diff --git a/core/chains/evm/client/errors.go b/core/chains/evm/client/errors.go index b7e0d9317eb..bcc8ff961f0 100644 --- a/core/chains/evm/client/errors.go +++ b/core/chains/evm/client/errors.go @@ -6,6 +6,8 @@ import ( "fmt" "regexp" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" pkgerrors "github.com/pkg/errors" @@ -13,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" commonclient "github.com/smartcontractkit/chainlink/v2/common/client" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/label" ) diff --git a/core/chains/evm/config/chain_scoped_node_pool.go b/core/chains/evm/config/chain_scoped_node_pool.go index 7f071e5506f..50269366829 100644 --- a/core/chains/evm/config/chain_scoped_node_pool.go +++ b/core/chains/evm/config/chain_scoped_node_pool.go @@ -38,4 +38,6 @@ func (n *NodePoolConfig) FinalizedBlockPollInterval() time.Duration { return n.C.FinalizedBlockPollInterval.Duration() } -func (n *NodePoolConfig) Errors() ClientErrors { return &clientErrorsConfig{c: n.C.Errors} } +func (n *NodePoolConfig) Errors() ClientErrors { + return &clientErrorsConfig{c: n.C.Errors} +} diff --git a/core/chains/evm/txmgr/client.go b/core/chains/evm/txmgr/client.go index 8ba3c841289..661a180af50 100644 --- a/core/chains/evm/txmgr/client.go +++ b/core/chains/evm/txmgr/client.go @@ -18,7 +18,7 @@ import ( commonclient "github.com/smartcontractkit/chainlink/v2/common/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" - evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) @@ -27,10 +27,10 @@ var _ TxmClient = (*evmTxmClient)(nil) type evmTxmClient struct { client client.Client - clientErrors evmconfig.ClientErrors + clientErrors config.ClientErrors } -func NewEvmTxmClient(c client.Client, clientErrors evmconfig.ClientErrors) *evmTxmClient { +func NewEvmTxmClient(c client.Client, clientErrors config.ClientErrors) *evmTxmClient { return &evmTxmClient{client: c, clientErrors: clientErrors} } From 143741012c4d0b148ada9d5aa237ff932cd3005b Mon Sep 17 00:00:00 2001 From: Cedric Date: Thu, 23 May 2024 02:37:00 +0100 Subject: [PATCH 08/16] Add ON DELETE CASCADE to workflow tables (#13165) --- .changeset/tame-mice-give.md | 5 +++ ...0237_add_workflow_executions_on_delete.sql | 31 +++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 .changeset/tame-mice-give.md create mode 100644 core/store/migrate/migrations/0237_add_workflow_executions_on_delete.sql diff --git a/.changeset/tame-mice-give.md b/.changeset/tame-mice-give.md new file mode 100644 index 00000000000..7cd59b154ad --- /dev/null +++ b/.changeset/tame-mice-give.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +#db_update Add ON DELETE CASCADE to workflow tables diff --git a/core/store/migrate/migrations/0237_add_workflow_executions_on_delete.sql b/core/store/migrate/migrations/0237_add_workflow_executions_on_delete.sql new file mode 100644 index 00000000000..87670d0ab61 --- /dev/null +++ b/core/store/migrate/migrations/0237_add_workflow_executions_on_delete.sql @@ -0,0 +1,31 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE workflow_executions +DROP CONSTRAINT workflow_executions_workflow_id_fkey, +ADD CONSTRAINT workflow_executions_workflow_id_fkey + FOREIGN KEY (workflow_id) + REFERENCES workflow_specs(workflow_id) + ON DELETE CASCADE; + +ALTER TABLE workflow_steps +DROP CONSTRAINT workflow_steps_workflow_execution_id_fkey, +ADD CONSTRAINT workflow_steps_workflow_execution_id_fkey + FOREIGN KEY (workflow_execution_id) + REFERENCES workflow_executions(id) + ON DELETE CASCADE; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE workflow_executions +DROP CONSTRAINT workflow_executions_workflow_id_fkey, +ADD CONSTRAINT workflow_executions_workflow_id_fkey + FOREIGN KEY (workflow_id) + REFERENCES workflow_specs(workflow_id); + +ALTER TABLE workflow_steps +DROP CONSTRAINT workflow_steps_workflow_execution_id_fkey, +ADD CONSTRAINT workflow_steps_workflow_execution_id_fkey + FOREIGN KEY (workflow_execution_id) + REFERENCES workflow_executions(id); +-- +goose StatementEnd From 5db47b63b3f2d0addf521904940d780caf9f57eb Mon Sep 17 00:00:00 2001 From: krehermann Date: Wed, 22 May 2024 23:22:43 -0600 Subject: [PATCH 09/16] KS-205: add workflow name to spec (#13265) * KS-205: add workflow name to spec * fix test * fix sql and test * fix tests * remove empty wf owner,name check * fix bad merge of main * rename migration --------- Co-authored-by: Bolek <1416262+bolekk@users.noreply.github.com> --- .changeset/wild-berries-cry.md | 5 + core/services/feeds/service_test.go | 3 +- core/services/job/job_orm_test.go | 183 ++++++++++++++++++ core/services/job/mocks/orm.go | 28 +++ core/services/job/models.go | 5 + core/services/job/orm.go | 23 ++- core/services/workflows/delegate_test.go | 11 ++ .../migrations/0238_workflow_spec_name.sql | 22 +++ core/testdata/testspecs/v2_specs.go | 5 +- core/web/jobs_controller_test.go | 6 +- core/web/presenters/job.go | 2 + core/web/presenters/job_test.go | 2 + 12 files changed, 288 insertions(+), 7 deletions(-) create mode 100644 .changeset/wild-berries-cry.md create mode 100644 core/store/migrate/migrations/0238_workflow_spec_name.sql diff --git a/.changeset/wild-berries-cry.md b/.changeset/wild-berries-cry.md new file mode 100644 index 00000000000..196de1a124e --- /dev/null +++ b/.changeset/wild-berries-cry.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +#db_update Add name to workflow spec. Add unique constraint to (owner,name) for workflow spec diff --git a/core/services/feeds/service_test.go b/core/services/feeds/service_test.go index 43d75f712a0..b8cd590a402 100644 --- a/core/services/feeds/service_test.go +++ b/core/services/feeds/service_test.go @@ -643,6 +643,7 @@ func Test_Service_ProposeJob(t *testing.T) { // variables for workflow spec wfID = "15c631d295ef5e32deb99a10ee6804bc4af1385568f9b3363f6552ac6dbb2cef" wfOwner = "00000000000000000000000000000000000000aa" + wfName = "my-workflow" specYaml = ` triggers: - id: "a-trigger" @@ -666,7 +667,7 @@ targets: inputs: consensus_output: $(a-consensus.outputs) ` - wfSpec = testspecs.GenerateWorkflowSpec(wfID, wfOwner, specYaml).Toml() + wfSpec = testspecs.GenerateWorkflowSpec(wfID, wfOwner, wfName, specYaml).Toml() proposalIDWF = int64(11) remoteUUIDWF = uuid.New() argsWF = &feeds.ProposeJobArgs{ diff --git a/core/services/job/job_orm_test.go b/core/services/job/job_orm_test.go index 3c7d5a7afa5..c13cf9da4b1 100644 --- a/core/services/job/job_orm_test.go +++ b/core/services/job/job_orm_test.go @@ -10,10 +10,12 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/google/uuid" "github.com/lib/pq" + "github.com/pelletier/go-toml/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/guregu/null.v4" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" @@ -1803,6 +1805,187 @@ func Test_CountPipelineRunsByJobID(t *testing.T) { }) } +func Test_ORM_FindJobByWorkflow(t *testing.T) { + type fields struct { + ds sqlutil.DataSource + } + type args struct { + spec job.WorkflowSpec + before func(t *testing.T, o job.ORM, s job.WorkflowSpec) int32 + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "wf not job found", + fields: fields{ + ds: pgtest.NewSqlxDB(t), + }, + args: args{ + // before is nil, so no job is inserted + spec: job.WorkflowSpec{ + ID: 1, + WorkflowID: "workflow 1", + Workflow: "abcd", + WorkflowOwner: "me", + WorkflowName: "myworkflow", + }, + }, + wantErr: true, + }, + { + name: "wf job found", + fields: fields{ + ds: pgtest.NewSqlxDB(t), + }, + args: args{ + spec: job.WorkflowSpec{ + ID: 1, + WorkflowID: "workflow 2", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "myworkflow", + }, + before: mustInsertWFJob, + }, + wantErr: false, + }, + + { + name: "wf wrong name", + fields: fields{ + ds: pgtest.NewSqlxDB(t), + }, + args: args{ + spec: job.WorkflowSpec{ + ID: 1, + WorkflowID: "workflow 3", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "wf3", + }, + before: func(t *testing.T, o job.ORM, s job.WorkflowSpec) int32 { + s.WorkflowName = "notmyworkflow" + return mustInsertWFJob(t, o, s) + }, + }, + wantErr: true, + }, + { + name: "wf wrong owner", + fields: fields{ + ds: pgtest.NewSqlxDB(t), + }, + args: args{ + spec: job.WorkflowSpec{ + ID: 1, + WorkflowID: "workflow 4", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "wf4", + }, + before: func(t *testing.T, o job.ORM, s job.WorkflowSpec) int32 { + s.WorkflowOwner = "not me" + return mustInsertWFJob(t, o, s) + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ks := cltest.NewKeyStore(t, tt.fields.ds) + pipelineORM := pipeline.NewORM(tt.fields.ds, logger.TestLogger(t), configtest.NewTestGeneralConfig(t).JobPipeline().MaxSuccessfulRuns()) + bridgesORM := bridges.NewORM(tt.fields.ds) + o := NewTestORM(t, tt.fields.ds, pipelineORM, bridgesORM, ks) + var wantJobID int32 + if tt.args.before != nil { + wantJobID = tt.args.before(t, o, tt.args.spec) + } + ctx := testutils.Context(t) + gotJ, err := o.FindJobIDByWorkflow(ctx, tt.args.spec) + if (err != nil) != tt.wantErr { + t.Errorf("orm.FindJobByWorkflow() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + assert.Equal(t, wantJobID, gotJ, "mismatch job id") + } + }) + } + + t.Run("multiple jobs", func(t *testing.T) { + db := pgtest.NewSqlxDB(t) + o := NewTestORM(t, + db, + pipeline.NewORM(db, + logger.TestLogger(t), + configtest.NewTestGeneralConfig(t).JobPipeline().MaxSuccessfulRuns()), + bridges.NewORM(db), + cltest.NewKeyStore(t, db)) + ctx := testutils.Context(t) + s1 := job.WorkflowSpec{ + WorkflowID: "workflowid", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "a_common_name", + } + wantJobID1 := mustInsertWFJob(t, o, s1) + + s2 := job.WorkflowSpec{ + WorkflowID: "another workflowid", + Workflow: "anything", + WorkflowOwner: "me", + WorkflowName: "another workflow name", + } + wantJobID2 := mustInsertWFJob(t, o, s2) + + s3 := job.WorkflowSpec{ + WorkflowID: "xworkflowid", + Workflow: "anything", + WorkflowOwner: "someone else", + WorkflowName: "a_common_name", + } + wantJobID3 := mustInsertWFJob(t, o, s3) + + expectedIDs := []int32{wantJobID1, wantJobID2, wantJobID3} + for i, s := range []job.WorkflowSpec{s1, s2, s3} { + gotJ, err := o.FindJobIDByWorkflow(ctx, s) + require.NoError(t, err) + assert.Equal(t, expectedIDs[i], gotJ, "mismatch job id case %d, spec %v", i, s) + j, err := o.FindJob(ctx, expectedIDs[i]) + require.NoError(t, err) + assert.NotNil(t, j) + t.Logf("found job %v", j) + assert.EqualValues(t, j.WorkflowSpec.Workflow, s.Workflow) + assert.EqualValues(t, j.WorkflowSpec.WorkflowID, s.WorkflowID) + assert.EqualValues(t, j.WorkflowSpec.WorkflowOwner, s.WorkflowOwner) + assert.EqualValues(t, j.WorkflowSpec.WorkflowName, s.WorkflowName) + } + }) +} + +func mustInsertWFJob(t *testing.T, orm job.ORM, s job.WorkflowSpec) int32 { + t.Helper() + ctx := testutils.Context(t) + _, err := toml.Marshal(s.Workflow) + require.NoError(t, err) + j := job.Job{ + Type: job.Workflow, + WorkflowSpec: &s, + ExternalJobID: uuid.New(), + Name: null.StringFrom(s.WorkflowOwner + "_" + s.WorkflowName), + SchemaVersion: 1, + } + err = orm.CreateJob(ctx, &j) + require.NoError(t, err) + return j.ID +} + func mustInsertPipelineRun(t *testing.T, orm pipeline.ORM, j job.Job) pipeline.Run { t.Helper() ctx := testutils.Context(t) diff --git a/core/services/job/mocks/orm.go b/core/services/job/mocks/orm.go index ec60137de93..e8911b25af3 100644 --- a/core/services/job/mocks/orm.go +++ b/core/services/job/mocks/orm.go @@ -248,6 +248,34 @@ func (_m *ORM) FindJobIDByAddress(ctx context.Context, address types.EIP55Addres return r0, r1 } +// FindJobIDByWorkflow provides a mock function with given fields: ctx, spec +func (_m *ORM) FindJobIDByWorkflow(ctx context.Context, spec job.WorkflowSpec) (int32, error) { + ret := _m.Called(ctx, spec) + + if len(ret) == 0 { + panic("no return value specified for FindJobIDByWorkflow") + } + + var r0 int32 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, job.WorkflowSpec) (int32, error)); ok { + return rf(ctx, spec) + } + if rf, ok := ret.Get(0).(func(context.Context, job.WorkflowSpec) int32); ok { + r0 = rf(ctx, spec) + } else { + r0 = ret.Get(0).(int32) + } + + if rf, ok := ret.Get(1).(func(context.Context, job.WorkflowSpec) error); ok { + r1 = rf(ctx, spec) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindJobIDsWithBridge provides a mock function with given fields: ctx, name func (_m *ORM) FindJobIDsWithBridge(ctx context.Context, name string) ([]int32, error) { ret := _m.Called(ctx, name) diff --git a/core/services/job/models.go b/core/services/job/models.go index 578e9e079b8..9601df2e02d 100644 --- a/core/services/job/models.go +++ b/core/services/job/models.go @@ -845,6 +845,7 @@ type WorkflowSpec struct { WorkflowID string `toml:"workflowId"` Workflow string `toml:"workflow"` WorkflowOwner string `toml:"workflowOwner"` + WorkflowName string `toml:"workflowName"` CreatedAt time.Time `toml:"-"` UpdatedAt time.Time `toml:"-"` } @@ -863,5 +864,9 @@ func (w *WorkflowSpec) Validate() error { return fmt.Errorf("incorrect length for owner %s: expected %d, got %d", w.WorkflowOwner, workflowOwnerLen, len(w.WorkflowOwner)) } + if w.WorkflowName == "" { + return fmt.Errorf("workflow name is required") + } + return nil } diff --git a/core/services/job/orm.go b/core/services/job/orm.go index d54d6fba522..71a4ebebb1e 100644 --- a/core/services/job/orm.go +++ b/core/services/job/orm.go @@ -78,6 +78,8 @@ type ORM interface { DataSource() sqlutil.DataSource WithDataSource(source sqlutil.DataSource) ORM + + FindJobIDByWorkflow(ctx context.Context, spec WorkflowSpec) (int32, error) } type ORMConfig interface { @@ -395,8 +397,8 @@ func (o *orm) CreateJob(ctx context.Context, jb *Job) error { case Stream: // 'stream' type has no associated spec, nothing to do here case Workflow: - sql := `INSERT INTO workflow_specs (workflow, workflow_id, workflow_owner, created_at, updated_at) - VALUES (:workflow, :workflow_id, :workflow_owner, NOW(), NOW()) + sql := `INSERT INTO workflow_specs (workflow, workflow_id, workflow_owner, workflow_name, created_at, updated_at) + VALUES (:workflow, :workflow_id, :workflow_owner, :workflow_name, NOW(), NOW()) RETURNING id;` specID, err := tx.prepareQuerySpecID(ctx, sql, jb.WorkflowSpec) if err != nil { @@ -1043,6 +1045,23 @@ func (o *orm) FindJobIDsWithBridge(ctx context.Context, name string) (jids []int return } +func (o *orm) FindJobIDByWorkflow(ctx context.Context, spec WorkflowSpec) (jobID int32, err error) { + stmt := ` +SELECT jobs.id FROM jobs +INNER JOIN workflow_specs ws on jobs.workflow_spec_id = ws.id AND ws.workflow_owner = $1 AND ws.workflow_name = $2 +` + err = o.ds.GetContext(ctx, &jobID, stmt, spec.WorkflowOwner, spec.WorkflowName) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + err = fmt.Errorf("error searching for job by workflow (owner,name) ('%s','%s'): %w", spec.WorkflowOwner, spec.WorkflowName, err) + } + err = fmt.Errorf("FindJobIDByWorkflow failed: %w", err) + return + } + + return +} + // PipelineRunsByJobsIDs returns pipeline runs for multiple jobs, not preloading data func (o *orm) PipelineRunsByJobsIDs(ctx context.Context, ids []int32) (runs []pipeline.Run, err error) { err = o.transact(ctx, false, func(tx *orm) error { diff --git a/core/services/workflows/delegate_test.go b/core/services/workflows/delegate_test.go index 68abfa2f7a1..d87e6d68466 100644 --- a/core/services/workflows/delegate_test.go +++ b/core/services/workflows/delegate_test.go @@ -23,6 +23,7 @@ type = "workflow" schemaVersion = 1 workflowId = "15c631d295ef5e32deb99a10ee6804bc4af1385568f9b3363f6552ac6dbb2cef" workflowOwner = "00000000000000000000000000000000000000aa" +workflowName = "test" `, true, }, @@ -38,6 +39,16 @@ invalid syntax{{{{ ` type = "work flows" schemaVersion = 1 +`, + false, + }, + { + "missing name", + ` +type = "workflow" +schemaVersion = 1 +workflowId = "15c631d295ef5e32deb99a10ee6804bc4af1385568f9b3363f6552ac6dbb2cef" +workflowOwner = "00000000000000000000000000000000000000aa" `, false, }, diff --git a/core/store/migrate/migrations/0238_workflow_spec_name.sql b/core/store/migrate/migrations/0238_workflow_spec_name.sql new file mode 100644 index 00000000000..8b9986b4da9 --- /dev/null +++ b/core/store/migrate/migrations/0238_workflow_spec_name.sql @@ -0,0 +1,22 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE workflow_specs ADD COLUMN workflow_name varchar(255); + +-- ensure that we can forward migrate to non-null name +UPDATE workflow_specs +SET + workflow_name = workflow_id +WHERE + workflow_name IS NULL; + +ALTER TABLE workflow_specs ALTER COLUMN workflow_name SET NOT NULL; + +-- unique constraint on workflow_owner and workflow_name +ALTER TABLE workflow_specs ADD CONSTRAINT unique_workflow_owner_name unique (workflow_owner, workflow_name); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE workflow_specs DROP CONSTRAINT unique_workflow_owner_name; +ALTER TABLE workflow_specs DROP COLUMN workflow_name; +-- +goose StatementEnd \ No newline at end of file diff --git a/core/testdata/testspecs/v2_specs.go b/core/testdata/testspecs/v2_specs.go index a0d8ea863e2..fb0e019d931 100644 --- a/core/testdata/testspecs/v2_specs.go +++ b/core/testdata/testspecs/v2_specs.go @@ -872,17 +872,18 @@ func (w WorkflowSpec) Toml() string { return w.toml } -func GenerateWorkflowSpec(id, owner, spec string) WorkflowSpec { +func GenerateWorkflowSpec(id, owner, name, spec string) WorkflowSpec { template := ` type = "workflow" schemaVersion = 1 name = "test-spec" workflowId = "%s" workflowOwner = "%s" +workflowName = "%s" workflow = """ %s """ ` - toml := fmt.Sprintf(template, id, owner, spec) + toml := fmt.Sprintf(template, id, owner, name, spec) return WorkflowSpec{toml: toml} } diff --git a/core/web/jobs_controller_test.go b/core/web/jobs_controller_test.go index 8aaae0d5ba3..359f9ba8b1c 100644 --- a/core/web/jobs_controller_test.go +++ b/core/web/jobs_controller_test.go @@ -394,6 +394,7 @@ func TestJobController_Create_HappyPath(t *testing.T) { tomlTemplate: func(_ string) string { id := "15c631d295ef5e32deb99a10ee6804bc4af1385568f9b3363f6552ac6dbb2cef" owner := "00000000000000000000000000000000000000aa" + name := "my-test-workflow" workflow := ` triggers: - id: "mercury-trigger" @@ -441,14 +442,14 @@ targets: params: ["$(report)"] abi: "receive(report bytes)" ` - return testspecs.GenerateWorkflowSpec(id, owner, workflow).Toml() + return testspecs.GenerateWorkflowSpec(id, owner, name, workflow).Toml() }, assertion: func(t *testing.T, nameAndExternalJobID string, r *http.Response) { require.Equal(t, http.StatusOK, r.StatusCode) resp := cltest.ParseResponseBody(t, r) resource := presenters.JobResource{} err := web.ParseJSONAPIResponse(resp, &resource) - require.NoError(t, err) + require.NoError(t, err, "failed to parse response body: %s", resp) jb, err := jorm.FindJob(testutils.Context(t), mustInt32FromString(t, resource.ID)) require.NoError(t, err) @@ -457,6 +458,7 @@ targets: assert.Equal(t, jb.WorkflowSpec.Workflow, resource.WorkflowSpec.Workflow) assert.Equal(t, jb.WorkflowSpec.WorkflowID, resource.WorkflowSpec.WorkflowID) assert.Equal(t, jb.WorkflowSpec.WorkflowOwner, resource.WorkflowSpec.WorkflowOwner) + assert.Equal(t, jb.WorkflowSpec.WorkflowName, resource.WorkflowSpec.WorkflowName) }, }, } diff --git a/core/web/presenters/job.go b/core/web/presenters/job.go index 12b958a346d..ff59bc9bd11 100644 --- a/core/web/presenters/job.go +++ b/core/web/presenters/job.go @@ -433,6 +433,7 @@ type WorkflowSpec struct { Workflow string `json:"workflow"` WorkflowID string `json:"workflowId"` WorkflowOwner string `json:"workflowOwner"` + WorkflowName string `json:"workflowName"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } @@ -442,6 +443,7 @@ func NewWorkflowSpec(spec *job.WorkflowSpec) *WorkflowSpec { Workflow: spec.Workflow, WorkflowID: spec.WorkflowID, WorkflowOwner: spec.WorkflowOwner, + WorkflowName: spec.WorkflowName, CreatedAt: spec.CreatedAt, UpdatedAt: spec.UpdatedAt, } diff --git a/core/web/presenters/job_test.go b/core/web/presenters/job_test.go index 7d3c31465db..ba485d27789 100644 --- a/core/web/presenters/job_test.go +++ b/core/web/presenters/job_test.go @@ -861,6 +861,7 @@ func TestJob(t *testing.T) { WorkflowID: "", Workflow: ``, WorkflowOwner: "", + WorkflowName: "", }, PipelineSpec: &pipeline.Spec{ ID: 1, @@ -896,6 +897,7 @@ func TestJob(t *testing.T) { "workflow": "", "workflowId": "", "workflowOwner": "", + "workflowName": "", "createdAt":"0001-01-01T00:00:00Z", "updatedAt":"0001-01-01T00:00:00Z" }, From 38a8f8d914abf60310e441cc25358ebfd15dc66c Mon Sep 17 00:00:00 2001 From: Radek Scheibinger Date: Thu, 23 May 2024 16:00:37 +0200 Subject: [PATCH 10/16] Bump crib chart and update config (#13238) * Bump crib chart and update config The latest version of crib chainlink chart uses map instead of array for defining nodes * Bump chart and update config --- crib/devspace.yaml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/crib/devspace.yaml b/crib/devspace.yaml index f22e710f943..229c0829d02 100644 --- a/crib/devspace.yaml +++ b/crib/devspace.yaml @@ -100,7 +100,7 @@ deployments: releaseName: "app" chart: name: ${CHAINLINK_CLUSTER_HELM_CHART_URI} - version: 0.6.0 + version: "1.1.0" # for simplicity, we define all the values here # they can be defined the same way in values.yml # devspace merges these "values" with the "values.yaml" before deploy @@ -158,7 +158,7 @@ deployments: # extraEnvVars: # "CL_MEDIAN_CMD": "chainlink-feeds" nodes: - - name: node-1 + node1: image: ${runtime.images.app} # default resources are 300m/1Gi # first node need more resources to build faster inside container @@ -209,13 +209,13 @@ deployments: # CollectorTarget = 'app-opentelemetry-collector:4317' # TLSCertPath = '' # Mode = 'unencrypted' - - name: node-2 + node2: image: ${runtime.images.app} - - name: node-3 + node3: image: ${runtime.images.app} - - name: node-4 + node4: image: ${runtime.images.app} - - name: node-5 + node5: image: ${runtime.images.app} # each CL node have a dedicated PostgreSQL 11.15 @@ -307,7 +307,7 @@ deployments: - path: / backend: service: - name: app-node-1 + name: app-node1 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-node2.${DEVSPACE_INGRESS_BASE_DOMAIN} @@ -316,7 +316,7 @@ deployments: - path: / backend: service: - name: app-node-2 + name: app-node2 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-node3.${DEVSPACE_INGRESS_BASE_DOMAIN} @@ -325,7 +325,7 @@ deployments: - path: / backend: service: - name: app-node-3 + name: app-node3 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-node4.${DEVSPACE_INGRESS_BASE_DOMAIN} @@ -334,7 +334,7 @@ deployments: - path: / backend: service: - name: app-node-4 + name: app-node4 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-node5.${DEVSPACE_INGRESS_BASE_DOMAIN} @@ -343,7 +343,7 @@ deployments: - path: / backend: service: - name: app-node-5 + name: app-node5 port: number: 6688 - host: ${DEVSPACE_NAMESPACE}-geth-1337-http.${DEVSPACE_INGRESS_BASE_DOMAIN} From d90cd654ec396bb43c26e897bcbaa190226ceb81 Mon Sep 17 00:00:00 2001 From: Gabriel Paradiso Date: Thu, 23 May 2024 16:24:40 +0200 Subject: [PATCH 11/16] [FUN-1332] Allowlist optimisation (#12588) * feat: update allowlist in batches giving priority to latest allowed addresses * fix: adjust iteration and add tests on updateAllowedSendersInBatches * fix: make a deep copy of the map to avoid race conditions * feat: extra step to fetch latest added addresses while batching * fix: check allowlist size is bigger than the batchsize * chore: remove leftover and add modify tests to be closer to a real scenario * chore: simplify lastBatchIdxStart * chore: remove newlines to pass sonarqube check --- .../handlers/functions/allowlist/allowlist.go | 129 ++++++++--- .../allowlist/allowlist_internal_test.go | 216 ++++++++++++++++++ 2 files changed, 311 insertions(+), 34 deletions(-) create mode 100644 core/services/gateway/handlers/functions/allowlist/allowlist_internal_test.go diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist.go b/core/services/gateway/handlers/functions/allowlist/allowlist.go index f0fe5c8c829..2a27f51471a 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist.go @@ -210,33 +210,23 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b return errors.Wrap(err, "unexpected error during functions_allow_list.NewTermsOfServiceAllowList") } - var allowedSenderList []common.Address - typeAndVersion, err := tosContract.TypeAndVersion(&bind.CallOpts{ - Pending: false, - BlockNumber: blockNum, - Context: ctx, - }) - if err != nil { - return errors.Wrap(err, "failed to fetch the tos contract type and version") - } - - currentVersion, err := ExtractContractVersion(typeAndVersion) + currentVersion, err := fetchTosCurrentVersion(ctx, tosContract, blockNum) if err != nil { - return fmt.Errorf("failed to extract version: %w", err) + return fmt.Errorf("failed to fetch tos current version: %w", err) } if semver.Compare(tosContractMinBatchProcessingVersion, currentVersion) <= 0 { - err = a.syncBlockedSenders(ctx, tosContract, blockNum) + err = a.updateAllowedSendersInBatches(ctx, tosContract, blockNum) if err != nil { - return errors.Wrap(err, "failed to sync the stored allowed and blocked senders") + return errors.Wrap(err, "failed to get allowed senders in rage") } - allowedSenderList, err = a.getAllowedSendersBatched(ctx, tosContract, blockNum) + err := a.syncBlockedSenders(ctx, tosContract, blockNum) if err != nil { - return errors.Wrap(err, "failed to get allowed senders in rage") + return errors.Wrap(err, "failed to sync the stored allowed and blocked senders") } } else { - allowedSenderList, err = tosContract.GetAllAllowedSenders(&bind.CallOpts{ + allowedSenderList, err := tosContract.GetAllAllowedSenders(&bind.CallOpts{ Pending: false, BlockNumber: blockNum, Context: ctx, @@ -254,50 +244,108 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b if err != nil { a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) } + + a.update(allowedSenderList) } - a.update(allowedSenderList) return nil } -func (a *onchainAllowlist) getAllowedSendersBatched(ctx context.Context, tosContract *functions_allow_list.TermsOfServiceAllowList, blockNum *big.Int) ([]common.Address, error) { - allowedSenderList := make([]common.Address, 0) - count, err := tosContract.GetAllowedSendersCount(&bind.CallOpts{ +// updateAllowedSendersInBatches will update the node's inmemory state and the orm layer representing the allowlist. +// it will get the current node's in memory allowlist and start fetching and adding from the tos contract in batches. +// the iteration order will give priority to new allowed senders, if new addresses are added while iterating over the batches +// an extra step will be executed to keep this up to date. +func (a *onchainAllowlist) updateAllowedSendersInBatches(ctx context.Context, tosContract functions_allow_list.TermsOfServiceAllowListInterface, blockNum *big.Int) error { + // currentAllowedSenderList will be the starting point from which we will be adding the new allowed senders + currentAllowedSenderList := make(map[common.Address]struct{}, 0) + if cal := a.allowlist.Load(); cal != nil { + for k := range *cal { + currentAllowedSenderList[k] = struct{}{} + } + } + + currentAllowedSenderCount, err := tosContract.GetAllowedSendersCount(&bind.CallOpts{ Pending: false, BlockNumber: blockNum, Context: ctx, }) if err != nil { - return nil, errors.Wrap(err, "unexpected error during functions_allow_list.GetAllowedSendersCount") + return errors.Wrap(err, "unexpected error during functions_allow_list.GetAllowedSendersCount") } throttleTicker := time.NewTicker(time.Duration(a.config.FetchingDelayInRangeSec) * time.Second) - for idxStart := uint64(0); idxStart < count; idxStart += uint64(a.config.OnchainAllowlistBatchSize) { - <-throttleTicker.C - idxEnd := idxStart + uint64(a.config.OnchainAllowlistBatchSize) - if idxEnd >= count { - idxEnd = count - 1 + for i := int64(currentAllowedSenderCount); i > 0; i -= int64(a.config.OnchainAllowlistBatchSize) { + <-throttleTicker.C + var idxStart uint64 + if uint64(i) > uint64(a.config.OnchainAllowlistBatchSize) { + idxStart = uint64(i) - uint64(a.config.OnchainAllowlistBatchSize) } - allowedSendersBatch, err := tosContract.GetAllowedSendersInRange(&bind.CallOpts{ + idxEnd := uint64(i) - 1 + + // before continuing we evaluate if the size of the list changed, if that happens we trigger an extra step + // getting the latest added addresses from the list + updatedAllowedSenderCount, err := tosContract.GetAllowedSendersCount(&bind.CallOpts{ Pending: false, BlockNumber: blockNum, Context: ctx, - }, idxStart, idxEnd) + }) if err != nil { - return nil, errors.Wrap(err, "error calling GetAllowedSendersInRange") + return errors.Wrap(err, "unexpected error while fetching the updated functions_allow_list.GetAllowedSendersCount") + } + + if updatedAllowedSenderCount > currentAllowedSenderCount { + lastBatchIdxStart := currentAllowedSenderCount + lastBatchIdxEnd := updatedAllowedSenderCount - 1 + currentAllowedSenderCount = updatedAllowedSenderCount + + err = a.updateAllowedSendersBatch(ctx, tosContract, blockNum, lastBatchIdxStart, lastBatchIdxEnd, currentAllowedSenderList) + if err != nil { + return err + } } - allowedSenderList = append(allowedSenderList, allowedSendersBatch...) - err = a.orm.CreateAllowedSenders(ctx, allowedSendersBatch) + err = a.updateAllowedSendersBatch(ctx, tosContract, blockNum, idxStart, idxEnd, currentAllowedSenderList) if err != nil { - a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) + return err } } throttleTicker.Stop() - return allowedSenderList, nil + return nil +} + +func (a *onchainAllowlist) updateAllowedSendersBatch( + ctx context.Context, + tosContract functions_allow_list.TermsOfServiceAllowListInterface, + blockNum *big.Int, + idxStart uint64, + idxEnd uint64, + currentAllowedSenderList map[common.Address]struct{}, +) error { + allowedSendersBatch, err := tosContract.GetAllowedSendersInRange(&bind.CallOpts{ + Pending: false, + BlockNumber: blockNum, + Context: ctx, + }, idxStart, idxEnd) + if err != nil { + return errors.Wrap(err, "error calling GetAllowedSendersInRange") + } + + // add the fetched batch to the currentAllowedSenderList and replace the existing allowlist + for _, addr := range allowedSendersBatch { + currentAllowedSenderList[addr] = struct{}{} + } + a.allowlist.Store(¤tAllowedSenderList) + a.lggr.Infow("allowlist updated in batches successfully", "len", len(currentAllowedSenderList)) + + // persist each batch to the underalying orm layer + err = a.orm.CreateAllowedSenders(ctx, allowedSendersBatch) + if err != nil { + a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) + } + return nil } // syncBlockedSenders fetches the list of blocked addresses from the contract in batches @@ -370,6 +418,19 @@ func (a *onchainAllowlist) loadStoredAllowedSenderList(ctx context.Context) { a.update(allowedList) } +func fetchTosCurrentVersion(ctx context.Context, tosContract *functions_allow_list.TermsOfServiceAllowList, blockNum *big.Int) (string, error) { + typeAndVersion, err := tosContract.TypeAndVersion(&bind.CallOpts{ + Pending: false, + BlockNumber: blockNum, + Context: ctx, + }) + if err != nil { + return "", errors.Wrap(err, "failed to fetch the tos contract type and version") + } + + return ExtractContractVersion(typeAndVersion) +} + func ExtractContractVersion(str string) (string, error) { pattern := `v(\d+).(\d+).(\d+)` re := regexp.MustCompile(pattern) diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist_internal_test.go b/core/services/gateway/handlers/functions/allowlist/allowlist_internal_test.go new file mode 100644 index 00000000000..966db032636 --- /dev/null +++ b/core/services/gateway/handlers/functions/allowlist/allowlist_internal_test.go @@ -0,0 +1,216 @@ +package allowlist + +import ( + "context" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/functions/generated/functions_allow_list" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/logger" + amocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/allowlist/mocks" +) + +func TestUpdateAllowedSendersInBatches(t *testing.T) { + t.Run("OK-simple_update_in_batches", func(t *testing.T) { + ctx := context.Background() + config := OnchainAllowlistConfig{ + ContractAddress: testutils.NewAddress(), + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + StoredAllowlistBatchSize: 2, + OnchainAllowlistBatchSize: 10, + FetchingDelayInRangeSec: 1, + } + + // allowlistSize defines how big the mocked allowlist will be + allowlistSize := 53 + // allowlist represents the actual allowlist the tos contract will return + allowlist := make([]common.Address, 0, allowlistSize) + // expectedAllowlist will be used to compare the actual status with what we actually want + expectedAllowlist := make(map[common.Address]struct{}, 0) + + // we load both the expectedAllowlist and the allowlist the contract will return with some new addresses + for i := 0; i < allowlistSize; i++ { + addr := testutils.NewAddress() + allowlist = append(allowlist, addr) + expectedAllowlist[addr] = struct{}{} + } + + tosContract := NewTosContractMock(allowlist) + + // with the orm mock we can validate the actual order in which the allowlist is fetched giving priority to newest addresses + orm := amocks.NewORM(t) + firstCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[43:53]).Times(1).Return(nil) + secondCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[33:43]).Times(1).Return(nil).NotBefore(firstCall) + thirdCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[23:33]).Times(1).Return(nil).NotBefore(secondCall) + forthCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[13:23]).Times(1).Return(nil).NotBefore(thirdCall) + fifthCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[3:13]).Times(1).Return(nil).NotBefore(forthCall) + orm.On("CreateAllowedSenders", context.Background(), allowlist[0:3]).Times(1).Return(nil).NotBefore(fifthCall) + + onchainAllowlist := &onchainAllowlist{ + config: config, + orm: orm, + blockConfirmations: big.NewInt(int64(config.BlockConfirmations)), + lggr: logger.TestLogger(t).Named("OnchainAllowlist"), + stopCh: make(services.StopChan), + } + + // we set the onchain allowlist to an empty state before updating it in batches + emptyMap := make(map[common.Address]struct{}) + onchainAllowlist.allowlist.Store(&emptyMap) + + err := onchainAllowlist.updateAllowedSendersInBatches(ctx, tosContract, big.NewInt(0)) + require.NoError(t, err) + + currentAllowlist := onchainAllowlist.allowlist.Load() + require.Equal(t, &expectedAllowlist, currentAllowlist) + }) + + t.Run("OK-new_address_added_while_updating_in_batches", func(t *testing.T) { + ctx := context.Background() + config := OnchainAllowlistConfig{ + ContractAddress: testutils.NewAddress(), + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + StoredAllowlistBatchSize: 2, + OnchainAllowlistBatchSize: 10, + FetchingDelayInRangeSec: 1, + } + + // allowlistSize defines how big the initial mocked allowlist will be + allowlistSize := 50 + // allowlist represents the actual allowlist the tos contract will return + allowlist := make([]common.Address, 0) + // expectedAllowlist will be used to compare the actual status with what we actually want + expectedAllowlist := make(map[common.Address]struct{}, 0) + + // we load both the expectedAllowlist and the allowlist the contract will return with some new addresses + for i := 0; i < allowlistSize; i++ { + addr := testutils.NewAddress() + allowlist = append(allowlist, addr) + expectedAllowlist[addr] = struct{}{} + } + + tosContract := NewTosContractMock(allowlist) + + // with the orm mock we can validate the actual order in which the allowlist is fetched giving priority to newest addresses + orm := amocks.NewORM(t) + firstCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[40:50]).Times(1).Run(func(args mock.Arguments) { + // after the first call we update the tosContract by adding a new address + addr := testutils.NewAddress() + allowlist = append(allowlist, addr) + expectedAllowlist[addr] = struct{}{} + *tosContract = *NewTosContractMock(allowlist) + }).Return(nil) + + // this is the extra step that will fetch the new address we want to validate + extraStepCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[50:51]).Times(1).Return(nil).NotBefore(firstCall) + + secondCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[30:40]).Times(1).Return(nil).NotBefore(extraStepCall) + thirdCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[20:30]).Times(1).Return(nil).NotBefore(secondCall) + forthCall := orm.On("CreateAllowedSenders", context.Background(), allowlist[10:20]).Times(1).Return(nil).NotBefore(thirdCall) + orm.On("CreateAllowedSenders", context.Background(), allowlist[0:10]).Times(1).Return(nil).NotBefore(forthCall) + + onchainAllowlist := &onchainAllowlist{ + config: config, + orm: orm, + blockConfirmations: big.NewInt(int64(config.BlockConfirmations)), + lggr: logger.TestLogger(t).Named("OnchainAllowlist"), + stopCh: make(services.StopChan), + } + + // we set the onchain allowlist to an empty state before updating it in batches + emptyMap := make(map[common.Address]struct{}) + onchainAllowlist.allowlist.Store(&emptyMap) + + err := onchainAllowlist.updateAllowedSendersInBatches(ctx, tosContract, big.NewInt(0)) + require.NoError(t, err) + + currentAllowlist := onchainAllowlist.allowlist.Load() + require.Equal(t, &expectedAllowlist, currentAllowlist) + }) + + t.Run("OK-allowlist_size_smaller_than_batchsize", func(t *testing.T) { + ctx := context.Background() + config := OnchainAllowlistConfig{ + ContractAddress: testutils.NewAddress(), + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + StoredAllowlistBatchSize: 2, + OnchainAllowlistBatchSize: 100, + FetchingDelayInRangeSec: 1, + } + + // allowlistSize defines how big the mocked allowlist will be + allowlistSize := 50 + // allowlist represents the actual allowlist the tos contract will return + allowlist := make([]common.Address, 0, allowlistSize) + // expectedAllowlist will be used to compare the actual status with what we actually want + expectedAllowlist := make(map[common.Address]struct{}, 0) + + // we load both the expectedAllowlist and the allowlist the contract will return with some new addresses + for i := 0; i < allowlistSize; i++ { + addr := testutils.NewAddress() + allowlist = append(allowlist, addr) + expectedAllowlist[addr] = struct{}{} + } + + tosContract := NewTosContractMock(allowlist) + + // with the orm mock we can validate the actual order in which the allowlist is fetched giving priority to newest addresses + orm := amocks.NewORM(t) + orm.On("CreateAllowedSenders", context.Background(), allowlist[0:50]).Times(1).Return(nil) + + onchainAllowlist := &onchainAllowlist{ + config: config, + orm: orm, + blockConfirmations: big.NewInt(int64(config.BlockConfirmations)), + lggr: logger.TestLogger(t).Named("OnchainAllowlist"), + stopCh: make(services.StopChan), + } + + // we set the onchain allowlist to an empty state before updating it in batches + emptyMap := make(map[common.Address]struct{}) + onchainAllowlist.allowlist.Store(&emptyMap) + + err := onchainAllowlist.updateAllowedSendersInBatches(ctx, tosContract, big.NewInt(0)) + require.NoError(t, err) + + currentAllowlist := onchainAllowlist.allowlist.Load() + require.Equal(t, &expectedAllowlist, currentAllowlist) + }) +} + +type tosContractMock struct { + functions_allow_list.TermsOfServiceAllowListInterface + + onchainAllowlist []common.Address +} + +func NewTosContractMock(onchainAllowlist []common.Address) *tosContractMock { + return &tosContractMock{ + onchainAllowlist: onchainAllowlist, + } +} + +func (t *tosContractMock) GetAllowedSendersCount(opts *bind.CallOpts) (uint64, error) { + return uint64(len(t.onchainAllowlist)), nil +} + +func (t *tosContractMock) GetAllowedSendersInRange(opts *bind.CallOpts, allowedSenderIdxStart uint64, allowedSenderIdxEnd uint64) ([]common.Address, error) { + // we replicate the onchain behaviour of including start and end indexes + return t.onchainAllowlist[allowedSenderIdxStart : allowedSenderIdxEnd+1], nil +} From 203a95ed829bc37c3c89468850fabfecc9e7f3fd Mon Sep 17 00:00:00 2001 From: Bartek Tofel Date: Thu, 23 May 2024 17:41:52 +0200 Subject: [PATCH 12/16] [TT-1198] Fix missing logs (#13300) * dont fail with Failf(), use Errorf() instead; break loop when first concerning log is found * print test summary in automation nightly tests * remove unneeded file * do not add log stream clean up only when cleanup is set to none * fix VRFv2 smoke test * Fix log artifacts in github * TT-1198:fix TestVRFv2BatchFulfillmentEnabledDisabled test --------- Co-authored-by: lukaszcl <120112546+lukaszcl@users.noreply.github.com> Co-authored-by: Ilja Pavlovs --- .github/workflows/automation-nightly-tests.yml | 6 +++++- .../docker/test_env/test_env_builder.go | 14 +++++++++----- integration-tests/smoke/vrfv2_test.go | 18 ++++++++++-------- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/.github/workflows/automation-nightly-tests.yml b/.github/workflows/automation-nightly-tests.yml index f25700f3155..ae0f3526e99 100644 --- a/.github/workflows/automation-nightly-tests.yml +++ b/.github/workflows/automation-nightly-tests.yml @@ -129,6 +129,7 @@ jobs: cl_image_tag: 'latest' aws_registries: ${{ secrets.QA_AWS_ACCOUNT_NUMBER }} artifacts_location: ./integration-tests/${{ matrix.tests.suite }}/logs + artifacts_name: testcontainers-logs-${{ matrix.tests.name }} publish_check_name: Automation Results ${{ matrix.tests.name }} token: ${{ secrets.GITHUB_TOKEN }} go_mod_path: ./integration-tests/go.mod @@ -139,7 +140,7 @@ jobs: uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 if: failure() with: - name: test-log-${{ matrix.tests.name }} + name: gotest-logs-${{ matrix.tests.name }} path: /tmp/gotest.log retention-days: 7 continue-on-error: true @@ -155,6 +156,9 @@ jobs: this-job-name: Automation ${{ matrix.tests.name }} Test test-results-file: '{"testType":"go","filePath":"/tmp/gotest.log"}' continue-on-error: true + - name: Print failed test summary + if: always() + uses: smartcontractkit/chainlink-github-actions/chainlink-testing-framework/show-test-summary@b49a9d04744b0237908831730f8553f26d73a94b # v2.3.17 test-notify: name: Start Slack Thread diff --git a/integration-tests/docker/test_env/test_env_builder.go b/integration-tests/docker/test_env/test_env_builder.go index 852918cc7d4..f5a5e558572 100644 --- a/integration-tests/docker/test_env/test_env_builder.go +++ b/integration-tests/docker/test_env/test_env_builder.go @@ -279,7 +279,7 @@ func (b *CLTestEnvBuilder) Build() (*CLClusterTestEnv, error) { } // this clean up has to be added as the FIRST one, because cleanup functions are executed in reverse order (LIFO) - if b.t != nil && b.cleanUpType == CleanUpTypeStandard { + if b.t != nil && b.cleanUpType != CleanUpTypeNone { b.t.Cleanup(func() { b.l.Info().Msg("Shutting down LogStream") logPath, err := osutil.GetAbsoluteFolderPath("logs") @@ -306,21 +306,24 @@ func (b *CLTestEnvBuilder) Build() (*CLClusterTestEnv, error) { // we cannot do parallel processing here, because ProcessContainerLogs() locks a mutex that controls whether // new logs can be added to the log stream, so parallel processing would get stuck on waiting for it to be unlocked + LogScanningLoop: for i := 0; i < b.clNodesCount; i++ { // ignore count return, because we are only interested in the error _, err := logProcessor.ProcessContainerLogs(b.te.ClCluster.Nodes[i].ContainerName, processFn) if err != nil && !strings.Contains(err.Error(), testreporters.MultipleLogsAtLogLevelErr) && !strings.Contains(err.Error(), testreporters.OneLogAtLogLevelErr) { - b.l.Error().Err(err).Msg("Error processing logs") - return + b.l.Error().Err(err).Msg("Error processing CL node logs") + continue } else if err != nil && (strings.Contains(err.Error(), testreporters.MultipleLogsAtLogLevelErr) || strings.Contains(err.Error(), testreporters.OneLogAtLogLevelErr)) { flushLogStream = true - b.t.Fatalf("Found a concerning log in Chainklink Node logs: %v", err) + b.t.Errorf("Found a concerning log in Chainklink Node logs: %v", err) + break LogScanningLoop } } b.l.Info().Msg("Finished scanning Chainlink Node logs for concerning errors") } if flushLogStream { + b.l.Info().Msg("Flushing LogStream logs") // we can't do much if this fails, so we just log the error in LogStream if err := b.te.LogStream.FlushAndShutdown(); err != nil { b.l.Error().Err(err).Msg("Error flushing and shutting down LogStream") @@ -328,9 +331,10 @@ func (b *CLTestEnvBuilder) Build() (*CLClusterTestEnv, error) { b.te.LogStream.PrintLogTargetsLocations() b.te.LogStream.SaveLogLocationInTestSummary() } + b.l.Info().Msg("Finished shutting down LogStream") }) } else { - b.l.Warn().Msg("LogStream won't be cleaned up, because test instance is not set or cleanup type is not standard") + b.l.Warn().Msg("LogStream won't be cleaned up, because either test instance is not set or cleanup type is set to none") } } diff --git a/integration-tests/smoke/vrfv2_test.go b/integration-tests/smoke/vrfv2_test.go index c3e516b093a..18a017110c7 100644 --- a/integration-tests/smoke/vrfv2_test.go +++ b/integration-tests/smoke/vrfv2_test.go @@ -861,10 +861,11 @@ func TestVRFV2WithBHS(t *testing.T) { CleanupFn: cleanupFn, } newEnvConfig := vrfcommon.NewEnvConfig{ - NodesToCreate: []vrfcommon.VRFNodeType{vrfcommon.VRF, vrfcommon.BHS}, - NumberOfTxKeysToCreate: 0, - UseVRFOwner: false, - UseTestCoordinator: false, + NodesToCreate: []vrfcommon.VRFNodeType{vrfcommon.VRF, vrfcommon.BHS}, + NumberOfTxKeysToCreate: 0, + UseVRFOwner: false, + UseTestCoordinator: false, + ChainlinkNodeLogScannerSettings: test_env.DefaultChainlinkNodeLogScannerSettings, } testEnv, vrfContracts, vrfKey, nodeTypeToNodeMap, err = vrfv2.SetupVRFV2Universe(testcontext.Get(t), t, vrfEnvConfig, newEnvConfig, l) require.NoError(t, err, "Error setting up VRFV2 universe") @@ -1266,10 +1267,11 @@ func TestVRFv2BatchFulfillmentEnabledDisabled(t *testing.T) { CleanupFn: cleanupFn, } newEnvConfig := vrfcommon.NewEnvConfig{ - NodesToCreate: []vrfcommon.VRFNodeType{vrfcommon.VRF}, - NumberOfTxKeysToCreate: 0, - UseVRFOwner: false, - UseTestCoordinator: false, + NodesToCreate: []vrfcommon.VRFNodeType{vrfcommon.VRF}, + NumberOfTxKeysToCreate: 0, + UseVRFOwner: false, + UseTestCoordinator: false, + ChainlinkNodeLogScannerSettings: test_env.DefaultChainlinkNodeLogScannerSettings, } env, vrfContracts, vrfKey, nodeTypeToNodeMap, err = vrfv2.SetupVRFV2Universe(testcontext.Get(t), t, vrfEnvConfig, newEnvConfig, l) require.NoError(t, err, "Error setting up VRFv2 universe") From 2380c4114c9f98236463e921f0fdce748d55da33 Mon Sep 17 00:00:00 2001 From: Ilja Pavlovs Date: Thu, 23 May 2024 19:36:25 +0300 Subject: [PATCH 13/16] VRF-1112: remove unnecessary check from TestVRFv2PlusReplayAfterTimeout test to fix the test (#13306) --- integration-tests/smoke/vrfv2plus_test.go | 44 ++++------------------- 1 file changed, 7 insertions(+), 37 deletions(-) diff --git a/integration-tests/smoke/vrfv2plus_test.go b/integration-tests/smoke/vrfv2plus_test.go index a4b542004d2..473510d2d0c 100644 --- a/integration-tests/smoke/vrfv2plus_test.go +++ b/integration-tests/smoke/vrfv2plus_test.go @@ -1736,42 +1736,12 @@ func TestVRFv2PlusReplayAfterTimeout(t *testing.T) { ) require.NoError(t, err, "error requesting randomness and waiting for requested event") - // 3. create new request in a subscription with balance and wait for fulfilment - fundingLinkAmt := big.NewFloat(*configCopy.VRFv2Plus.General.SubscriptionRefundingAmountLink) - fundingNativeAmt := big.NewFloat(*configCopy.VRFv2Plus.General.SubscriptionRefundingAmountNative) - l.Info(). - Str("Coordinator", vrfContracts.CoordinatorV2Plus.Address()). - Int("Number of Subs to create", 1). - Msg("Creating and funding subscriptions, adding consumers") - fundedSubIDs, err := vrfv2plus.CreateFundSubsAndAddConsumers( - testcontext.Get(t), - env, - chainID, - fundingLinkAmt, - fundingNativeAmt, - vrfContracts.LinkToken, - vrfContracts.CoordinatorV2Plus, - []contracts.VRFv2PlusLoadTestConsumer{consumers[1]}, - 1, - ) - require.NoError(t, err, "error creating funded sub in replay test") - _, randomWordsFulfilledEvent, err := vrfv2plus.RequestRandomnessAndWaitForFulfillment( - consumers[1], - vrfContracts.CoordinatorV2Plus, - vrfKey, - fundedSubIDs[0], - isNativeBilling, - configCopy.VRFv2Plus.General, - l, - 0, - ) - require.NoError(t, err, "error requesting randomness and waiting for fulfilment") - require.True(t, randomWordsFulfilledEvent.Success, "RandomWordsFulfilled Event's `Success` field should be true") - - // 4. wait for the request timeout (1s more) duration + // 3. wait for the request timeout (1s more) duration time.Sleep(timeout + 1*time.Second) - // 5. fund sub so that node can fulfill request + fundingLinkAmt := big.NewFloat(*configCopy.VRFv2Plus.General.SubscriptionRefundingAmountLink) + fundingNativeAmt := big.NewFloat(*configCopy.VRFv2Plus.General.SubscriptionRefundingAmountNative) + // 4. fund sub so that node can fulfill request err = vrfv2plus.FundSubscriptions( fundingLinkAmt, fundingNativeAmt, @@ -1781,12 +1751,12 @@ func TestVRFv2PlusReplayAfterTimeout(t *testing.T) { ) require.NoError(t, err, "error funding subs after request timeout") - // 6. no fulfilment should happen since timeout+1 seconds passed in the job + // 5. no fulfilment should happen since timeout+1 seconds passed in the job pendingReqExists, err := vrfContracts.CoordinatorV2Plus.PendingRequestsExist(testcontext.Get(t), subID) require.NoError(t, err, "error fetching PendingRequestsExist from coordinator") require.True(t, pendingReqExists, "pendingRequest must exist since subID was underfunded till request timeout") - // 7. remove job and add new job with requestTimeout = 1 hour + // 6. remove job and add new job with requestTimeout = 1 hour vrfNode, exists := nodeTypeToNodeMap[vrfcommon.VRF] require.True(t, exists, "VRF Node does not exist") resp, err := vrfNode.CLNode.API.DeleteJob(vrfNode.Job.Data.ID) @@ -1821,7 +1791,7 @@ func TestVRFv2PlusReplayAfterTimeout(t *testing.T) { vrfNode.Job = job }() - // 8. Check if initial req in underfunded sub is fulfilled now, since it has been topped up and timeout increased + // 7. Check if initial req in underfunded sub is fulfilled now, since it has been topped up and timeout increased l.Info().Str("reqID", initialReqRandomWordsRequestedEvent.RequestId.String()). Str("subID", subID.String()). Msg("Waiting for initalReqRandomWordsFulfilledEvent") From c15e9e59c2035407d61095bca56838baac7bec35 Mon Sep 17 00:00:00 2001 From: Ilja Pavlovs Date: Thu, 23 May 2024 19:37:33 +0300 Subject: [PATCH 14/16] VRF-1106: Add "vrf_job_simulation_block" to default.toml (#13296) --- integration-tests/smoke/vrfv2plus_test.go | 2 ++ integration-tests/testconfig/vrfv2/vrfv2.toml | 1 + .../testconfig/vrfv2plus/vrfv2plus.toml | 35 +++++++++++-------- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/integration-tests/smoke/vrfv2plus_test.go b/integration-tests/smoke/vrfv2plus_test.go index 473510d2d0c..d1593373204 100644 --- a/integration-tests/smoke/vrfv2plus_test.go +++ b/integration-tests/smoke/vrfv2plus_test.go @@ -970,6 +970,7 @@ func TestVRFv2PlusMigration(t *testing.T) { BatchFulfillmentGasMultiplier: *configCopy.VRFv2Plus.General.VRFJobBatchFulfillmentGasMultiplier, PollPeriod: configCopy.VRFv2Plus.General.VRFJobPollPeriod.Duration, RequestTimeout: configCopy.VRFv2Plus.General.VRFJobRequestTimeout.Duration, + SimulationBlock: configCopy.VRFv2Plus.General.VRFJobSimulationBlock, } _, err = vrfv2plus.CreateVRFV2PlusJob( @@ -1141,6 +1142,7 @@ func TestVRFv2PlusMigration(t *testing.T) { BatchFulfillmentGasMultiplier: *configCopy.VRFv2Plus.General.VRFJobBatchFulfillmentGasMultiplier, PollPeriod: configCopy.VRFv2Plus.General.VRFJobPollPeriod.Duration, RequestTimeout: configCopy.VRFv2Plus.General.VRFJobRequestTimeout.Duration, + SimulationBlock: configCopy.VRFv2Plus.General.VRFJobSimulationBlock, } _, err = vrfv2plus.CreateVRFV2PlusJob( diff --git a/integration-tests/testconfig/vrfv2/vrfv2.toml b/integration-tests/testconfig/vrfv2/vrfv2.toml index 3b447a082bf..4ff48a3181a 100644 --- a/integration-tests/testconfig/vrfv2/vrfv2.toml +++ b/integration-tests/testconfig/vrfv2/vrfv2.toml @@ -57,6 +57,7 @@ subscription_refunding_amount_link = 5.0 cl_node_max_gas_price_gwei = 10 link_native_feed_response = 1000000000000000000 +#todo - need to have separate minimum_confirmations config for Coordinator, CL Node and Consumer request minimum_confirmations = 3 number_of_words = 3 diff --git a/integration-tests/testconfig/vrfv2plus/vrfv2plus.toml b/integration-tests/testconfig/vrfv2plus/vrfv2plus.toml index 859945bad9a..717a62e997f 100644 --- a/integration-tests/testconfig/vrfv2plus/vrfv2plus.toml +++ b/integration-tests/testconfig/vrfv2plus/vrfv2plus.toml @@ -63,35 +63,40 @@ chainlink_node_funding = 0.5 [VRFv2Plus.General] cancel_subs_after_test_run = true use_existing_env = false -subscription_funding_amount_link = 5.0 -subscription_funding_amount_native=1 - -subscription_refunding_amount_link = 5.0 -subscription_refunding_amount_native = 1 - -cl_node_max_gas_price_gwei = 10 -link_native_feed_response = 1000000000000000000 +#todo - need to have separate minimum_confirmations config for Coordinator, CL Node and Consumer request minimum_confirmations = 3 +# Can be "LINK", "NATIVE" or "LINK_AND_NATIVE" subscription_billing_type = "LINK_AND_NATIVE" +#CL Node config +cl_node_max_gas_price_gwei = 10 +number_of_sending_keys_to_create = 0 + +# Randomness Request Config +number_of_sub_to_create = 1 number_of_words = 3 callback_gas_limit = 1000000 -max_gas_limit_coordinator_config = 2500000 -fallback_wei_per_unit_link = "60000000000000000" -staleness_seconds = 86400 -gas_after_payment_calculation = 33825 -number_of_sub_to_create = 1 -number_of_sending_keys_to_create = 0 +subscription_funding_amount_link = 5.0 +subscription_funding_amount_native=1 +subscription_refunding_amount_link = 5.0 +subscription_refunding_amount_native = 1 randomness_request_count_per_request = 1 randomness_request_count_per_request_deviation = 0 random_words_fulfilled_event_timeout = "2m" wait_for_256_blocks_timeout = "280s" +# Coordinator config +link_native_feed_response = 1000000000000000000 +max_gas_limit_coordinator_config = 2500000 +fallback_wei_per_unit_link = "60000000000000000" +staleness_seconds = 86400 +gas_after_payment_calculation = 33825 fulfillment_flat_fee_native_ppm=0 fulfillment_flat_fee_link_discount_ppm=0 native_premium_percentage=24 link_premium_percentage=20 +# Wrapper config wrapped_gas_overhead = 50000 coordinator_gas_overhead_native = 52000 coordinator_gas_overhead_link = 74000 @@ -109,6 +114,8 @@ vrf_job_batch_fulfillment_enabled = true vrf_job_batch_fulfillment_gas_multiplier = 1.15 vrf_job_poll_period = "1s" vrf_job_request_timeout = "24h" +# should be "latest" if minimum_confirmations>0, "pending" if minimum_confirmations=0 +vrf_job_simulation_block="latest" # BHS Job config bhs_job_wait_blocks = 30 From e2bedae3594b273e58239343926144d1c160d689 Mon Sep 17 00:00:00 2001 From: Aaron Lu <50029043+aalu1418@users.noreply.github.com> Date: Thu, 23 May 2024 13:00:55 -0600 Subject: [PATCH 15/16] bump solana + cleanup types (#13253) * bump solana + cleanup types * bump solana with multierr delimiter * fix type * bump solana + fix parameters * bump solana * add more unwrapping * bump solana * update e2e test workflow * try without download * add gauntlet build * bump solana --- .github/workflows/integration-tests.yml | 10 +++++- core/cmd/shell_test.go | 32 +++++++++---------- core/cmd/solana_chains_commands_test.go | 4 +-- core/cmd/solana_node_commands_test.go | 7 ++-- core/cmd/solana_transaction_commands_test.go | 5 ++- core/config/docs/docs_test.go | 4 +-- core/scripts/go.mod | 2 +- core/scripts/go.sum | 4 +-- core/services/chainlink/config.go | 6 ++-- core/services/chainlink/config_general.go | 4 +-- core/services/chainlink/config_test.go | 7 ++-- .../chainlink/mocks/general_config.go | 10 +++--- .../relayer_chain_interoperators_test.go | 7 ++-- core/services/chainlink/relayer_factory.go | 10 +++--- core/services/chainlink/types.go | 4 +-- core/utils/config/validate.go | 4 ++- core/web/solana_chains_controller_test.go | 9 +++--- go.mod | 2 +- go.sum | 4 +-- integration-tests/go.mod | 2 +- integration-tests/go.sum | 4 +-- integration-tests/load/go.mod | 2 +- integration-tests/load/go.sum | 4 +-- 23 files changed, 75 insertions(+), 72 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ad04a9d808c..9e944ead9bf 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1214,7 +1214,7 @@ jobs: QA_AWS_REGION: ${{ secrets.QA_AWS_REGION }} QA_AWS_ROLE_TO_ASSUME: ${{ secrets.QA_AWS_ROLE_TO_ASSUME }} QA_KUBECONFIG: ${{ secrets.QA_KUBECONFIG }} - - name: Pull Artfacts + - name: Pull Artifacts if: needs.changes.outputs.src == 'true' || github.event_name == 'workflow_dispatch' run: | IMAGE_NAME=${{ secrets.QA_AWS_ACCOUNT_NUMBER }}.dkr.ecr.${{ secrets.QA_AWS_REGION }}.amazonaws.com/chainlink-solana-tests:${{ needs.get_solana_sha.outputs.sha }} @@ -1232,12 +1232,20 @@ jobs: docker rm "$CONTAINER_ID" - name: Install Solana CLI # required for ensuring the local test validator is configured correctly run: ./scripts/install-solana-ci.sh + - name: Install gauntlet + run: | + yarn --cwd ./gauntlet install --frozen-lockfile + yarn --cwd ./gauntlet build + yarn --cwd ./gauntlet gauntlet - name: Generate config overrides run: | # https://github.com/smartcontractkit/chainlink-testing-framework/blob/main/config/README.md cat << EOF > config.toml [ChainlinkImage] image="${{ env.CHAINLINK_IMAGE }}" version="${{ inputs.evm-ref || github.sha }}" + [Common] + user="${{ github.actor }}" + internal_docker_repo = "${{ secrets.QA_AWS_ACCOUNT_NUMBER }}.dkr.ecr.${{ secrets.QA_AWS_REGION }}.amazonaws.com" EOF # shellcheck disable=SC2002 BASE64_CONFIG_OVERRIDE=$(cat config.toml | base64 -w 0) diff --git a/core/cmd/shell_test.go b/core/cmd/shell_test.go index 1e3b93851f3..6ecdc4a34de 100644 --- a/core/cmd/shell_test.go +++ b/core/cmd/shell_test.go @@ -18,9 +18,7 @@ import ( "github.com/urfave/cli" commoncfg "github.com/smartcontractkit/chainlink-common/pkg/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" - "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" "github.com/smartcontractkit/chainlink/v2/core/cmd" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" @@ -359,20 +357,20 @@ func TestSetupSolanaRelayer(t *testing.T) { // config 3 chains but only enable 2 => should only be 2 relayer nEnabledChains := 2 tConfig := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { - c.Solana = solana.TOMLConfigs{ - &solana.TOMLConfig{ + c.Solana = solcfg.TOMLConfigs{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("solana-id-1"), Enabled: ptr(true), Chain: solcfg.Chain{}, Nodes: []*solcfg.Node{}, }, - &solana.TOMLConfig{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("solana-id-2"), Enabled: ptr(true), Chain: solcfg.Chain{}, Nodes: []*solcfg.Node{}, }, - &solana.TOMLConfig{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("disabled-solana-id-1"), Enabled: ptr(false), Chain: solcfg.Chain{}, @@ -382,8 +380,8 @@ func TestSetupSolanaRelayer(t *testing.T) { }) t2Config := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { - c.Solana = solana.TOMLConfigs{ - &solana.TOMLConfig{ + c.Solana = solcfg.TOMLConfigs{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("solana-id-1"), Enabled: ptr(true), Chain: solcfg.Chain{}, @@ -420,14 +418,14 @@ func TestSetupSolanaRelayer(t *testing.T) { // test that duplicate enabled chains is an error when duplicateConfig := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { - c.Solana = solana.TOMLConfigs{ - &solana.TOMLConfig{ + c.Solana = solcfg.TOMLConfigs{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("dupe"), Enabled: ptr(true), Chain: solcfg.Chain{}, Nodes: []*solcfg.Node{}, }, - &solana.TOMLConfig{ + &solcfg.TOMLConfig{ ChainID: ptr[string]("dupe"), Enabled: ptr(true), Chain: solcfg.Chain{}, @@ -478,21 +476,21 @@ func TestSetupStarkNetRelayer(t *testing.T) { ChainID: ptr[string]("starknet-id-1"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, &stkcfg.TOMLConfig{ ChainID: ptr[string]("starknet-id-2"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, &stkcfg.TOMLConfig{ ChainID: ptr[string]("disabled-starknet-id-1"), Enabled: ptr(false), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, } @@ -504,7 +502,7 @@ func TestSetupStarkNetRelayer(t *testing.T) { ChainID: ptr[string]("starknet-id-3"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, } @@ -542,14 +540,14 @@ func TestSetupStarkNetRelayer(t *testing.T) { ChainID: ptr[string]("dupe"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, &stkcfg.TOMLConfig{ ChainID: ptr[string]("dupe"), Enabled: ptr(true), Chain: stkcfg.Chain{}, - Nodes: []*config.Node{}, + Nodes: []*stkcfg.Node{}, FeederURL: commoncfg.MustParseURL("https://feeder.url"), }, } diff --git a/core/cmd/solana_chains_commands_test.go b/core/cmd/solana_chains_commands_test.go index 88bc8049247..e374ba11c65 100644 --- a/core/cmd/solana_chains_commands_test.go +++ b/core/cmd/solana_chains_commands_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink/v2/core/cmd" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/solanatest" @@ -16,7 +16,7 @@ func TestShell_IndexSolanaChains(t *testing.T) { t.Parallel() id := solanatest.RandomChainID() - cfg := solana.TOMLConfig{ + cfg := solcfg.TOMLConfig{ ChainID: &id, Enabled: ptr(true), } diff --git a/core/cmd/solana_node_commands_test.go b/core/cmd/solana_node_commands_test.go index ebe9502d1fa..adc699de79b 100644 --- a/core/cmd/solana_node_commands_test.go +++ b/core/cmd/solana_node_commands_test.go @@ -12,14 +12,13 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/config" solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" "github.com/smartcontractkit/chainlink/v2/core/cmd" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/solanatest" "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" ) -func solanaStartNewApplication(t *testing.T, cfgs ...*solana.TOMLConfig) *cltest.TestApplication { +func solanaStartNewApplication(t *testing.T, cfgs ...*solcfg.TOMLConfig) *cltest.TestApplication { for i := range cfgs { cfgs[i].SetDefaults() } @@ -41,9 +40,9 @@ func TestShell_IndexSolanaNodes(t *testing.T) { Name: ptr("second"), URL: config.MustParseURL("https://solana2.example"), } - chain := solana.TOMLConfig{ + chain := solcfg.TOMLConfig{ ChainID: &id, - Nodes: solana.SolanaNodes{&node1, &node2}, + Nodes: solcfg.Nodes{&node1, &node2}, } app := solanaStartNewApplication(t, &chain) client, r := app.NewShellAndRenderer() diff --git a/core/cmd/solana_transaction_commands_test.go b/core/cmd/solana_transaction_commands_test.go index c26bd89ab94..79a5513f190 100644 --- a/core/cmd/solana_transaction_commands_test.go +++ b/core/cmd/solana_transaction_commands_test.go @@ -16,7 +16,6 @@ import ( "github.com/urfave/cli" "github.com/smartcontractkit/chainlink-common/pkg/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" solanaClient "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" @@ -32,9 +31,9 @@ func TestShell_SolanaSendSol(t *testing.T) { Name: ptr(t.Name()), URL: config.MustParseURL(url), } - cfg := solana.TOMLConfig{ + cfg := solcfg.TOMLConfig{ ChainID: &chainID, - Nodes: solana.SolanaNodes{&node}, + Nodes: solcfg.Nodes{&node}, Enabled: ptr(true), } app := solanaStartNewApplication(t, &cfg) diff --git a/core/config/docs/docs_test.go b/core/config/docs/docs_test.go index 8c6429fd0df..aceb77dc33a 100644 --- a/core/config/docs/docs_test.go +++ b/core/config/docs/docs_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" "github.com/smartcontractkit/chainlink-common/pkg/config" @@ -97,7 +97,7 @@ func TestDoc(t *testing.T) { }) t.Run("Solana", func(t *testing.T) { - var fallbackDefaults solana.TOMLConfig + var fallbackDefaults solcfg.TOMLConfig fallbackDefaults.SetDefaults() assertTOML(t, fallbackDefaults.Chain, defaults.Solana[0].Chain) diff --git a/core/scripts/go.mod b/core/scripts/go.mod index 18840a04cad..cea4ca21242 100644 --- a/core/scripts/go.mod +++ b/core/scripts/go.mod @@ -259,7 +259,7 @@ require ( github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240508101745-af1ed7bc8a69 // indirect github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 // indirect github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab // indirect - github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 // indirect + github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc // indirect github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 // indirect github.com/smartcontractkit/tdh2/go/ocr2/decryptionplugin v0.0.0-20230906073235-9e478e5e19f1 // indirect github.com/smartcontractkit/tdh2/go/tdh2 v0.0.0-20230906073235-9e478e5e19f1 // indirect diff --git a/core/scripts/go.sum b/core/scripts/go.sum index 48cff213690..cde855e2d2f 100644 --- a/core/scripts/go.sum +++ b/core/scripts/go.sum @@ -1193,8 +1193,8 @@ github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab/go.mod h1:RPUY7r8GxgzXxS1ijtU1P/fpJomOXztXgUbEziNmbCA= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 h1:f3W82k9V/XA6ZP/VQVJcGMVR6CrL3pQrPJSwyQWVFys= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83/go.mod h1:RdAtOeBUWq2zByw2kEbwPlXaPIb7YlaDOmnn+nVUBJI= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc h1:ZqgatXFWsJR/hkvm2mKAta6ivXZqTw7542Iz9ucoOq0= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc/go.mod h1:sR0dMjjpvvEpX3qH8DPRANauPkbO9jgUUGYK95xjLRU= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 h1:ssh/w3oXWu+C6bE88GuFRC1+0Bx/4ihsbc80XMLrl2k= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69/go.mod h1:VsfjhvWgjxqWja4q+FlXEtX5lu8BSxn10xRo6gi948g= github.com/smartcontractkit/chainlink-vrf v0.0.0-20240222010609-cd67d123c772 h1:LQmRsrzzaYYN3wEU1l5tWiccznhvbyGnu2N+wHSXZAo= diff --git a/core/services/chainlink/config.go b/core/services/chainlink/config.go index b77a54f39a8..7e6fa413c67 100644 --- a/core/services/chainlink/config.go +++ b/core/services/chainlink/config.go @@ -9,7 +9,7 @@ import ( gotoml "github.com/pelletier/go-toml/v2" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" commoncfg "github.com/smartcontractkit/chainlink/v2/common/config" @@ -39,7 +39,7 @@ type Config struct { Cosmos coscfg.TOMLConfigs `toml:",omitempty"` - Solana solana.TOMLConfigs `toml:",omitempty"` + Solana solcfg.TOMLConfigs `toml:",omitempty"` Starknet stkcfg.TOMLConfigs `toml:",omitempty"` } @@ -122,7 +122,7 @@ func (c *Config) setDefaults() { for i := range c.Solana { if c.Solana[i] == nil { - c.Solana[i] = new(solana.TOMLConfig) + c.Solana[i] = new(solcfg.TOMLConfig) } c.Solana[i].Chain.SetDefaults() } diff --git a/core/services/chainlink/config_general.go b/core/services/chainlink/config_general.go index cae01c01cb7..ce34cc47e47 100644 --- a/core/services/chainlink/config_general.go +++ b/core/services/chainlink/config_general.go @@ -14,7 +14,7 @@ import ( "go.uber.org/zap/zapcore" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" starknet "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" @@ -201,7 +201,7 @@ func (g *generalConfig) CosmosConfigs() coscfg.TOMLConfigs { return g.c.Cosmos } -func (g *generalConfig) SolanaConfigs() solana.TOMLConfigs { +func (g *generalConfig) SolanaConfigs() solcfg.TOMLConfigs { return g.c.Solana } diff --git a/core/services/chainlink/config_test.go b/core/services/chainlink/config_test.go index 11d286fbcd5..624a575d65c 100644 --- a/core/services/chainlink/config_test.go +++ b/core/services/chainlink/config_test.go @@ -22,7 +22,6 @@ import ( commoncfg "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/utils/hex" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" commonconfig "github.com/smartcontractkit/chainlink/v2/common/config" @@ -160,7 +159,7 @@ var ( {Name: ptr("secondary"), TendermintURL: commoncfg.MustParseURL("http://bombay.cosmos.com")}, }}, }, - Solana: []*solana.TOMLConfig{ + Solana: []*solcfg.TOMLConfig{ { ChainID: ptr("mainnet"), Chain: solcfg.Chain{ @@ -632,7 +631,7 @@ func TestConfig_Marshal(t *testing.T) { }, }}, } - full.Solana = []*solana.TOMLConfig{ + full.Solana = []*solcfg.TOMLConfig{ { ChainID: ptr("mainnet"), Enabled: ptr(false), @@ -1546,7 +1545,7 @@ func TestConfig_setDefaults(t *testing.T) { var c Config c.EVM = evmcfg.EVMConfigs{{ChainID: ubig.NewI(99999133712345)}} c.Cosmos = coscfg.TOMLConfigs{{ChainID: ptr("unknown cosmos chain")}} - c.Solana = solana.TOMLConfigs{{ChainID: ptr("unknown solana chain")}} + c.Solana = solcfg.TOMLConfigs{{ChainID: ptr("unknown solana chain")}} c.Starknet = stkcfg.TOMLConfigs{{ChainID: ptr("unknown starknet chain")}} c.setDefaults() if s, err := c.TOMLString(); assert.NoError(t, err) { diff --git a/core/services/chainlink/mocks/general_config.go b/core/services/chainlink/mocks/general_config.go index a86753a59e3..c7e224f4f2a 100644 --- a/core/services/chainlink/mocks/general_config.go +++ b/core/services/chainlink/mocks/general_config.go @@ -10,7 +10,7 @@ import ( mock "github.com/stretchr/testify/mock" - solana "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solanaconfig "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" time "time" @@ -616,19 +616,19 @@ func (_m *GeneralConfig) ShutdownGracePeriod() time.Duration { } // SolanaConfigs provides a mock function with given fields: -func (_m *GeneralConfig) SolanaConfigs() solana.TOMLConfigs { +func (_m *GeneralConfig) SolanaConfigs() solanaconfig.TOMLConfigs { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for SolanaConfigs") } - var r0 solana.TOMLConfigs - if rf, ok := ret.Get(0).(func() solana.TOMLConfigs); ok { + var r0 solanaconfig.TOMLConfigs + if rf, ok := ret.Get(0).(func() solanaconfig.TOMLConfigs); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(solana.TOMLConfigs) + r0 = ret.Get(0).(solanaconfig.TOMLConfigs) } } diff --git a/core/services/chainlink/relayer_chain_interoperators_test.go b/core/services/chainlink/relayer_chain_interoperators_test.go index c6183cc1a34..c2baa1edcde 100644 --- a/core/services/chainlink/relayer_chain_interoperators_test.go +++ b/core/services/chainlink/relayer_chain_interoperators_test.go @@ -18,7 +18,6 @@ import ( solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" @@ -75,8 +74,8 @@ func TestCoreRelayerChainInteroperators(t *testing.T) { Nodes: evmcfg.EVMNodes{&node2_1}, }) - c.Solana = solana.TOMLConfigs{ - &solana.TOMLConfig{ + c.Solana = solcfg.TOMLConfigs{ + &solcfg.TOMLConfig{ ChainID: &solanaChainID1, Enabled: ptr(true), Chain: solcfg.Chain{}, @@ -85,7 +84,7 @@ func TestCoreRelayerChainInteroperators(t *testing.T) { URL: ((*commonconfig.URL)(commonconfig.MustParseURL("http://localhost:8547").URL())), }}, }, - &solana.TOMLConfig{ + &solcfg.TOMLConfig{ ChainID: &solanaChainID2, Enabled: ptr(true), Chain: solcfg.Chain{}, diff --git a/core/services/chainlink/relayer_factory.go b/core/services/chainlink/relayer_factory.go index 2aaeb253c0a..bcdb08b8026 100644 --- a/core/services/chainlink/relayer_factory.go +++ b/core/services/chainlink/relayer_factory.go @@ -14,7 +14,7 @@ import ( "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana" - pkgsolana "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" pkgstarknet "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink" starkchain "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/chain" "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" @@ -92,10 +92,10 @@ func (r *RelayerFactory) NewEVM(ctx context.Context, config EVMFactoryConfig) (m type SolanaFactoryConfig struct { Keystore keystore.Solana - solana.TOMLConfigs + solcfg.TOMLConfigs } -func (r *RelayerFactory) NewSolana(ks keystore.Solana, chainCfgs solana.TOMLConfigs) (map[types.RelayID]loop.Relayer, error) { +func (r *RelayerFactory) NewSolana(ks keystore.Solana, chainCfgs solcfg.TOMLConfigs) (map[types.RelayID]loop.Relayer, error) { solanaRelayers := make(map[types.RelayID]loop.Relayer) var ( solLggr = r.Logger.Named("Solana") @@ -123,7 +123,7 @@ func (r *RelayerFactory) NewSolana(ks keystore.Solana, chainCfgs solana.TOMLConf if cmdName := env.SolanaPlugin.Cmd.Get(); cmdName != "" { // setup the solana relayer to be a LOOP cfgTOML, err := toml.Marshal(struct { - Solana solana.TOMLConfig + Solana solcfg.TOMLConfig }{Solana: *chainCfg}) if err != nil { @@ -154,7 +154,7 @@ func (r *RelayerFactory) NewSolana(ks keystore.Solana, chainCfgs solana.TOMLConf if err != nil { return nil, err } - solanaRelayers[relayID] = relay.NewServerAdapter(pkgsolana.NewRelayer(lggr, chain), chain) + solanaRelayers[relayID] = relay.NewServerAdapter(solana.NewRelayer(lggr, chain), chain) } } return solanaRelayers, nil diff --git a/core/services/chainlink/types.go b/core/services/chainlink/types.go index 72cad694167..4c7550142a2 100644 --- a/core/services/chainlink/types.go +++ b/core/services/chainlink/types.go @@ -2,7 +2,7 @@ package chainlink import ( coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" + solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" @@ -15,7 +15,7 @@ type GeneralConfig interface { config.AppConfig toml.HasEVMConfigs CosmosConfigs() coscfg.TOMLConfigs - SolanaConfigs() solana.TOMLConfigs + SolanaConfigs() solcfg.TOMLConfigs StarknetConfigs() stkcfg.TOMLConfigs // ConfigTOML returns both the user provided and effective configuration as TOML. ConfigTOML() (user, effective string) diff --git a/core/utils/config/validate.go b/core/utils/config/validate.go index 5fbae24ad53..f8508f27bf7 100644 --- a/core/utils/config/validate.go +++ b/core/utils/config/validate.go @@ -33,7 +33,9 @@ func validate(v reflect.Value, checkInterface bool) (err error) { if checkInterface { i := v.Interface() if vc, ok := i.(Validated); ok { - err = multierr.Append(err, vc.ValidateConfig()) + for _, e := range utils.UnwrapError(vc.ValidateConfig()) { + err = multierr.Append(err, e) + } } else if v.CanAddr() { i = v.Addr().Interface() if vc, ok := i.(Validated); ok { diff --git a/core/web/solana_chains_controller_test.go b/core/web/solana_chains_controller_test.go index 1377cb65aba..a2ac904b783 100644 --- a/core/web/solana_chains_controller_test.go +++ b/core/web/solana_chains_controller_test.go @@ -15,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" @@ -80,7 +79,7 @@ Nodes = [] t.Run(tc.name, func(t *testing.T) { t.Parallel() - controller := setupSolanaChainsControllerTestV2(t, &solana.TOMLConfig{ + controller := setupSolanaChainsControllerTestV2(t, &config.TOMLConfig{ ChainID: ptr(validId), Chain: config.Chain{ SkipPreflight: ptr(false), @@ -111,13 +110,13 @@ Nodes = [] func Test_SolanaChainsController_Index(t *testing.T) { t.Parallel() - chainA := &solana.TOMLConfig{ + chainA := &config.TOMLConfig{ ChainID: ptr(fmt.Sprintf("ChainlinktestA-%d", rand.Int31n(999999))), Chain: config.Chain{ TxTimeout: commoncfg.MustNewDuration(time.Hour), }, } - chainB := &solana.TOMLConfig{ + chainB := &config.TOMLConfig{ ChainID: ptr(fmt.Sprintf("ChainlinktestB-%d", rand.Int31n(999999))), Chain: config.Chain{ SkipPreflight: ptr(false), @@ -175,7 +174,7 @@ type TestSolanaChainsController struct { client cltest.HTTPClientCleaner } -func setupSolanaChainsControllerTestV2(t *testing.T, cfgs ...*solana.TOMLConfig) *TestSolanaChainsController { +func setupSolanaChainsControllerTestV2(t *testing.T, cfgs ...*config.TOMLConfig) *TestSolanaChainsController { for i := range cfgs { cfgs[i].SetDefaults() } diff --git a/go.mod b/go.mod index 9b72abb4db0..9ce8d98b91d 100644 --- a/go.mod +++ b/go.mod @@ -76,7 +76,7 @@ require ( github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240508101745-af1ed7bc8a69 github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab - github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 + github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 github.com/smartcontractkit/chainlink-vrf v0.0.0-20231120191722-fef03814f868 github.com/smartcontractkit/libocr v0.0.0-20240419185742-fd3cab206b2c diff --git a/go.sum b/go.sum index e452df784af..a41f19e2cb1 100644 --- a/go.sum +++ b/go.sum @@ -1179,8 +1179,8 @@ github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab/go.mod h1:RPUY7r8GxgzXxS1ijtU1P/fpJomOXztXgUbEziNmbCA= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 h1:f3W82k9V/XA6ZP/VQVJcGMVR6CrL3pQrPJSwyQWVFys= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83/go.mod h1:RdAtOeBUWq2zByw2kEbwPlXaPIb7YlaDOmnn+nVUBJI= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc h1:ZqgatXFWsJR/hkvm2mKAta6ivXZqTw7542Iz9ucoOq0= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc/go.mod h1:sR0dMjjpvvEpX3qH8DPRANauPkbO9jgUUGYK95xjLRU= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 h1:ssh/w3oXWu+C6bE88GuFRC1+0Bx/4ihsbc80XMLrl2k= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69/go.mod h1:VsfjhvWgjxqWja4q+FlXEtX5lu8BSxn10xRo6gi948g= github.com/smartcontractkit/chainlink-vrf v0.0.0-20231120191722-fef03814f868 h1:FFdvEzlYwcuVHkdZ8YnZR/XomeMGbz5E2F2HZI3I3w8= diff --git a/integration-tests/go.mod b/integration-tests/go.mod index 144b1f62643..9442426eeb3 100644 --- a/integration-tests/go.mod +++ b/integration-tests/go.mod @@ -376,7 +376,7 @@ require ( github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240508101745-af1ed7bc8a69 // indirect github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 // indirect github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab // indirect - github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 // indirect + github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc // indirect github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 // indirect github.com/smartcontractkit/chainlink-testing-framework/grafana v0.0.0-20240328204215-ac91f55f1449 // indirect github.com/smartcontractkit/tdh2/go/ocr2/decryptionplugin v0.0.0-20230906073235-9e478e5e19f1 // indirect diff --git a/integration-tests/go.sum b/integration-tests/go.sum index 98d30f02d35..477edf4ab53 100644 --- a/integration-tests/go.sum +++ b/integration-tests/go.sum @@ -1520,8 +1520,8 @@ github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab/go.mod h1:RPUY7r8GxgzXxS1ijtU1P/fpJomOXztXgUbEziNmbCA= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 h1:f3W82k9V/XA6ZP/VQVJcGMVR6CrL3pQrPJSwyQWVFys= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83/go.mod h1:RdAtOeBUWq2zByw2kEbwPlXaPIb7YlaDOmnn+nVUBJI= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc h1:ZqgatXFWsJR/hkvm2mKAta6ivXZqTw7542Iz9ucoOq0= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc/go.mod h1:sR0dMjjpvvEpX3qH8DPRANauPkbO9jgUUGYK95xjLRU= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 h1:ssh/w3oXWu+C6bE88GuFRC1+0Bx/4ihsbc80XMLrl2k= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69/go.mod h1:VsfjhvWgjxqWja4q+FlXEtX5lu8BSxn10xRo6gi948g= github.com/smartcontractkit/chainlink-testing-framework v1.28.15 h1:mga7N6jtXQ3UOCt43IdsEnCMBh9xjOWPaE9BiM6kr6Q= diff --git a/integration-tests/load/go.mod b/integration-tests/load/go.mod index 162e506196b..395199f599b 100644 --- a/integration-tests/load/go.mod +++ b/integration-tests/load/go.mod @@ -369,7 +369,7 @@ require ( github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240508101745-af1ed7bc8a69 // indirect github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 // indirect github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab // indirect - github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 // indirect + github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc // indirect github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 // indirect github.com/smartcontractkit/chainlink-testing-framework/grafana v0.0.0-20240328204215-ac91f55f1449 // indirect github.com/smartcontractkit/chainlink-vrf v0.0.0-20231120191722-fef03814f868 // indirect diff --git a/integration-tests/load/go.sum b/integration-tests/load/go.sum index 35c37e1ecf9..95624177f68 100644 --- a/integration-tests/load/go.sum +++ b/integration-tests/load/go.sum @@ -1510,8 +1510,8 @@ github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab/go.mod h1:RPUY7r8GxgzXxS1ijtU1P/fpJomOXztXgUbEziNmbCA= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83 h1:f3W82k9V/XA6ZP/VQVJcGMVR6CrL3pQrPJSwyQWVFys= -github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240510181707-46b1311a5a83/go.mod h1:RdAtOeBUWq2zByw2kEbwPlXaPIb7YlaDOmnn+nVUBJI= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc h1:ZqgatXFWsJR/hkvm2mKAta6ivXZqTw7542Iz9ucoOq0= +github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240523174813-45db170c1ccc/go.mod h1:sR0dMjjpvvEpX3qH8DPRANauPkbO9jgUUGYK95xjLRU= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69 h1:ssh/w3oXWu+C6bE88GuFRC1+0Bx/4ihsbc80XMLrl2k= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240508155030-1024f2b55c69/go.mod h1:VsfjhvWgjxqWja4q+FlXEtX5lu8BSxn10xRo6gi948g= github.com/smartcontractkit/chainlink-testing-framework v1.28.15 h1:mga7N6jtXQ3UOCt43IdsEnCMBh9xjOWPaE9BiM6kr6Q= From c7a6356f4903e919964ca91493f18e0ebf4eb08b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Friedemann=20F=C3=BCrst?= <59653747+friedemannf@users.noreply.github.com> Date: Thu, 23 May 2024 22:02:03 +0200 Subject: [PATCH 16/16] Decouple ChainType from config string [SHIP-2001] (#13272) * fix: Decouple ChainType from config string * fix: receiver name and failing test * test: enhance config test to test for xdai specifically * refactor: directly unmarshal into ChainType * fix: validation * test: fix TestDoc/EVM * test: add xdai to warnings.xtar --- .changeset/young-candles-brush.md | 5 + common/config/chaintype.go | 117 ++++++++++++++---- core/chains/evm/client/config_builder.go | 3 +- core/chains/evm/client/pool_test.go | 2 - core/chains/evm/config/chain_scoped.go | 2 +- core/chains/evm/config/config_test.go | 4 +- core/chains/evm/config/toml/config.go | 26 ++-- core/chains/evm/config/toml/defaults.go | 5 +- .../evm/gas/block_history_estimator_test.go | 5 - core/chains/evm/gas/chain_specific.go | 2 +- core/config/docs/docs_test.go | 3 +- core/services/chainlink/config.go | 5 +- core/services/chainlink/config_test.go | 6 +- core/services/ocr/contract_tracker.go | 2 +- core/services/ocrcommon/block_translator.go | 2 +- testdata/scripts/node/validate/warnings.txtar | 110 +++++++++++++++- 16 files changed, 231 insertions(+), 68 deletions(-) create mode 100644 .changeset/young-candles-brush.md diff --git a/.changeset/young-candles-brush.md b/.changeset/young-candles-brush.md new file mode 100644 index 00000000000..5d10eaabf80 --- /dev/null +++ b/.changeset/young-candles-brush.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +#bugfix allow ChainType to be set to xdai diff --git a/common/config/chaintype.go b/common/config/chaintype.go index 73c48960a13..3f3150950d6 100644 --- a/common/config/chaintype.go +++ b/common/config/chaintype.go @@ -5,10 +5,8 @@ import ( "strings" ) -// ChainType denotes the chain or network to work with type ChainType string -// nolint const ( ChainArbitrum ChainType = "arbitrum" ChainCelo ChainType = "celo" @@ -18,11 +16,103 @@ const ( ChainOptimismBedrock ChainType = "optimismBedrock" ChainScroll ChainType = "scroll" ChainWeMix ChainType = "wemix" - ChainXDai ChainType = "xdai" // Deprecated: use ChainGnosis instead ChainXLayer ChainType = "xlayer" ChainZkSync ChainType = "zksync" ) +// IsL2 returns true if this chain is a Layer 2 chain. Notably: +// - the block numbers used for log searching are different from calling block.number +// - gas bumping is not supported, since there is no tx mempool +func (c ChainType) IsL2() bool { + switch c { + case ChainArbitrum, ChainMetis: + return true + default: + return false + } +} + +func (c ChainType) IsValid() bool { + switch c { + case "", ChainArbitrum, ChainCelo, ChainGnosis, ChainKroma, ChainMetis, ChainOptimismBedrock, ChainScroll, ChainWeMix, ChainXLayer, ChainZkSync: + return true + } + return false +} + +func ChainTypeFromSlug(slug string) ChainType { + switch slug { + case "arbitrum": + return ChainArbitrum + case "celo": + return ChainCelo + case "gnosis", "xdai": + return ChainGnosis + case "kroma": + return ChainKroma + case "metis": + return ChainMetis + case "optimismBedrock": + return ChainOptimismBedrock + case "scroll": + return ChainScroll + case "wemix": + return ChainWeMix + case "xlayer": + return ChainXLayer + case "zksync": + return ChainZkSync + default: + return ChainType(slug) + } +} + +type ChainTypeConfig struct { + value ChainType + slug string +} + +func NewChainTypeConfig(slug string) *ChainTypeConfig { + return &ChainTypeConfig{ + value: ChainTypeFromSlug(slug), + slug: slug, + } +} + +func (c *ChainTypeConfig) MarshalText() ([]byte, error) { + if c == nil { + return nil, nil + } + return []byte(c.slug), nil +} + +func (c *ChainTypeConfig) UnmarshalText(b []byte) error { + c.slug = string(b) + c.value = ChainTypeFromSlug(c.slug) + return nil +} + +func (c *ChainTypeConfig) Slug() string { + if c == nil { + return "" + } + return c.slug +} + +func (c *ChainTypeConfig) ChainType() ChainType { + if c == nil { + return "" + } + return c.value +} + +func (c *ChainTypeConfig) String() string { + if c == nil { + return "" + } + return string(c.value) +} + var ErrInvalidChainType = fmt.Errorf("must be one of %s or omitted", strings.Join([]string{ string(ChainArbitrum), string(ChainCelo), @@ -35,24 +125,3 @@ var ErrInvalidChainType = fmt.Errorf("must be one of %s or omitted", strings.Joi string(ChainXLayer), string(ChainZkSync), }, ", ")) - -// IsValid returns true if the ChainType value is known or empty. -func (c ChainType) IsValid() bool { - switch c { - case "", ChainArbitrum, ChainCelo, ChainGnosis, ChainKroma, ChainMetis, ChainOptimismBedrock, ChainScroll, ChainWeMix, ChainXDai, ChainXLayer, ChainZkSync: - return true - } - return false -} - -// IsL2 returns true if this chain is a Layer 2 chain. Notably: -// - the block numbers used for log searching are different from calling block.number -// - gas bumping is not supported, since there is no tx mempool -func (c ChainType) IsL2() bool { - switch c { - case ChainArbitrum, ChainMetis: - return true - default: - return false - } -} diff --git a/core/chains/evm/client/config_builder.go b/core/chains/evm/client/config_builder.go index d78a981b881..9817879b579 100644 --- a/core/chains/evm/client/config_builder.go +++ b/core/chains/evm/client/config_builder.go @@ -8,6 +8,7 @@ import ( "go.uber.org/multierr" commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink/v2/common/config" commonclient "github.com/smartcontractkit/chainlink/v2/common/client" evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" @@ -55,7 +56,7 @@ func NewClientConfigs( chainConfig := &evmconfig.EVMConfig{ C: &toml.EVMConfig{ Chain: toml.Chain{ - ChainType: &chainType, + ChainType: config.NewChainTypeConfig(chainType), FinalityDepth: finalityDepth, FinalityTagEnabled: finalityTagEnabled, NoNewHeadsThreshold: commonconfig.MustNewDuration(noNewHeadsThreshold), diff --git a/core/chains/evm/client/pool_test.go b/core/chains/evm/client/pool_test.go index 5f614b7ed24..5a2c13130d3 100644 --- a/core/chains/evm/client/pool_test.go +++ b/core/chains/evm/client/pool_test.go @@ -169,7 +169,6 @@ func TestPool_Dial(t *testing.T) { if err == nil { t.Cleanup(func() { assert.NoError(t, p.Close()) }) } - assert.True(t, p.ChainType().IsValid()) assert.False(t, p.ChainType().IsL2()) if test.errStr != "" { require.Error(t, err) @@ -333,7 +332,6 @@ func TestUnit_Pool_BatchCallContextAll(t *testing.T) { p := evmclient.NewPool(logger.Test(t), defaultConfig.NodeSelectionMode(), defaultConfig.LeaseDuration(), time.Second*0, nodes, sendonlys, &cltest.FixtureChainID, "") - assert.True(t, p.ChainType().IsValid()) assert.False(t, p.ChainType().IsL2()) require.NoError(t, p.BatchCallContextAll(ctx, b)) } diff --git a/core/chains/evm/config/chain_scoped.go b/core/chains/evm/config/chain_scoped.go index 8f94fef09f4..17d4120ddf6 100644 --- a/core/chains/evm/config/chain_scoped.go +++ b/core/chains/evm/config/chain_scoped.go @@ -128,7 +128,7 @@ func (e *EVMConfig) ChainType() commonconfig.ChainType { if e.C.ChainType == nil { return "" } - return commonconfig.ChainType(*e.C.ChainType) + return e.C.ChainType.ChainType() } func (e *EVMConfig) ChainID() *big.Int { diff --git a/core/chains/evm/config/config_test.go b/core/chains/evm/config/config_test.go index 9553f59ad61..ddf9817958d 100644 --- a/core/chains/evm/config/config_test.go +++ b/core/chains/evm/config/config_test.go @@ -406,7 +406,7 @@ func Test_chainScopedConfig_Validate(t *testing.T) { t.Run("arbitrum-estimator", func(t *testing.T) { t.Run("custom", func(t *testing.T) { cfg := configWithChains(t, 0, &toml.Chain{ - ChainType: ptr(string(commonconfig.ChainArbitrum)), + ChainType: commonconfig.NewChainTypeConfig(string(commonconfig.ChainArbitrum)), GasEstimator: toml.GasEstimator{ Mode: ptr("BlockHistory"), }, @@ -437,7 +437,7 @@ func Test_chainScopedConfig_Validate(t *testing.T) { t.Run("optimism-estimator", func(t *testing.T) { t.Run("custom", func(t *testing.T) { cfg := configWithChains(t, 0, &toml.Chain{ - ChainType: ptr(string(commonconfig.ChainOptimismBedrock)), + ChainType: commonconfig.NewChainTypeConfig(string(commonconfig.ChainOptimismBedrock)), GasEstimator: toml.GasEstimator{ Mode: ptr("BlockHistory"), }, diff --git a/core/chains/evm/config/toml/config.go b/core/chains/evm/config/toml/config.go index 1b1baf41094..a326881bdde 100644 --- a/core/chains/evm/config/toml/config.go +++ b/core/chains/evm/config/toml/config.go @@ -294,18 +294,14 @@ func (c *EVMConfig) ValidateConfig() (err error) { } else if c.ChainID.String() == "" { err = multierr.Append(err, commonconfig.ErrEmpty{Name: "ChainID", Msg: "required for all chains"}) } else if must, ok := ChainTypeForID(c.ChainID); ok { // known chain id - if c.ChainType == nil && must != "" { - err = multierr.Append(err, commonconfig.ErrMissing{Name: "ChainType", - Msg: fmt.Sprintf("only %q can be used with this chain id", must)}) - } else if c.ChainType != nil && *c.ChainType != string(must) { - if *c.ChainType == "" { - err = multierr.Append(err, commonconfig.ErrEmpty{Name: "ChainType", - Msg: fmt.Sprintf("only %q can be used with this chain id", must)}) - } else if must == "" { - err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: *c.ChainType, + // Check if the parsed value matched the expected value + is := c.ChainType.ChainType() + if is != must { + if must == "" { + err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: c.ChainType.ChainType(), Msg: "must not be set with this chain id"}) } else { - err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: *c.ChainType, + err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: c.ChainType.ChainType(), Msg: fmt.Sprintf("only %q can be used with this chain id", must)}) } } @@ -345,7 +341,7 @@ type Chain struct { AutoCreateKey *bool BlockBackfillDepth *uint32 BlockBackfillSkip *bool - ChainType *string + ChainType *config.ChainTypeConfig FinalityDepth *uint32 FinalityTagEnabled *bool FlagsContractAddress *types.EIP55Address @@ -375,12 +371,8 @@ type Chain struct { } func (c *Chain) ValidateConfig() (err error) { - var chainType config.ChainType - if c.ChainType != nil { - chainType = config.ChainType(*c.ChainType) - } - if !chainType.IsValid() { - err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: *c.ChainType, + if !c.ChainType.ChainType().IsValid() { + err = multierr.Append(err, commonconfig.ErrInvalid{Name: "ChainType", Value: c.ChainType.ChainType(), Msg: config.ErrInvalidChainType.Error()}) } diff --git a/core/chains/evm/config/toml/defaults.go b/core/chains/evm/config/toml/defaults.go index 951246eeb22..622ac132e13 100644 --- a/core/chains/evm/config/toml/defaults.go +++ b/core/chains/evm/config/toml/defaults.go @@ -94,10 +94,7 @@ func Defaults(chainID *big.Big, with ...*Chain) Chain { func ChainTypeForID(chainID *big.Big) (config.ChainType, bool) { s := chainID.String() if d, ok := defaults[s]; ok { - if d.ChainType == nil { - return "", true - } - return config.ChainType(*d.ChainType), true + return d.ChainType.ChainType(), true } return "", false } diff --git a/core/chains/evm/gas/block_history_estimator_test.go b/core/chains/evm/gas/block_history_estimator_test.go index 730bfcab7e1..b38cd069c69 100644 --- a/core/chains/evm/gas/block_history_estimator_test.go +++ b/core/chains/evm/gas/block_history_estimator_test.go @@ -994,11 +994,6 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { bhe.Recalculate(testutils.Head(0)) require.Equal(t, assets.NewWeiI(80), gas.GetGasPrice(bhe)) - // Same for xDai (deprecated) - cfg.ChainTypeF = string(config.ChainXDai) - bhe.Recalculate(testutils.Head(0)) - require.Equal(t, assets.NewWeiI(80), gas.GetGasPrice(bhe)) - // And for X Layer cfg.ChainTypeF = string(config.ChainXLayer) bhe.Recalculate(testutils.Head(0)) diff --git a/core/chains/evm/gas/chain_specific.go b/core/chains/evm/gas/chain_specific.go index 694411f164b..f9985a6fafc 100644 --- a/core/chains/evm/gas/chain_specific.go +++ b/core/chains/evm/gas/chain_specific.go @@ -9,7 +9,7 @@ import ( // chainSpecificIsUsable allows for additional logic specific to a particular // Config that determines whether a transaction should be used for gas estimation func chainSpecificIsUsable(tx evmtypes.Transaction, baseFee *assets.Wei, chainType config.ChainType, minGasPriceWei *assets.Wei) bool { - if chainType == config.ChainGnosis || chainType == config.ChainXDai || chainType == config.ChainXLayer { + if chainType == config.ChainGnosis || chainType == config.ChainXLayer { // GasPrice 0 on most chains is great since it indicates cheap/free transactions. // However, Gnosis and XLayer reserve a special type of "bridge" transaction with 0 gas // price that is always processed at top priority. Ordinary transactions diff --git a/core/config/docs/docs_test.go b/core/config/docs/docs_test.go index aceb77dc33a..1f76eedcc67 100644 --- a/core/config/docs/docs_test.go +++ b/core/config/docs/docs_test.go @@ -15,6 +15,7 @@ import ( stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" "github.com/smartcontractkit/chainlink-common/pkg/config" + commonconfig "github.com/smartcontractkit/chainlink/v2/common/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" evmcfg "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -46,7 +47,7 @@ func TestDoc(t *testing.T) { fallbackDefaults := evmcfg.Defaults(nil) docDefaults := defaults.EVM[0].Chain - require.Equal(t, "", *docDefaults.ChainType) + require.Equal(t, commonconfig.ChainType(""), docDefaults.ChainType.ChainType()) docDefaults.ChainType = nil // clean up KeySpecific as a special case diff --git a/core/services/chainlink/config.go b/core/services/chainlink/config.go index 7e6fa413c67..d0d25a5e461 100644 --- a/core/services/chainlink/config.go +++ b/core/services/chainlink/config.go @@ -12,7 +12,6 @@ import ( solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" - commoncfg "github.com/smartcontractkit/chainlink/v2/common/config" evmcfg "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" "github.com/smartcontractkit/chainlink/v2/core/config/docs" "github.com/smartcontractkit/chainlink/v2/core/config/env" @@ -79,10 +78,10 @@ func (c *Config) valueWarnings() (err error) { func (c *Config) deprecationWarnings() (err error) { // ChainType xdai is deprecated and has been renamed to gnosis for _, evm := range c.EVM { - if evm.ChainType != nil && *evm.ChainType == string(commoncfg.ChainXDai) { + if evm.ChainType != nil && evm.ChainType.Slug() == "xdai" { err = multierr.Append(err, config.ErrInvalid{ Name: "EVM.ChainType", - Value: *evm.ChainType, + Value: evm.ChainType.Slug(), Msg: "deprecated and will be removed in v2.13.0, use 'gnosis' instead", }) } diff --git a/core/services/chainlink/config_test.go b/core/services/chainlink/config_test.go index 624a575d65c..8119021b565 100644 --- a/core/services/chainlink/config_test.go +++ b/core/services/chainlink/config_test.go @@ -25,8 +25,8 @@ import ( solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" stkcfg "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/config" commonconfig "github.com/smartcontractkit/chainlink/v2/common/config" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" evmcfg "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -494,7 +494,7 @@ func TestConfig_Marshal(t *testing.T) { }, BlockBackfillDepth: ptr[uint32](100), BlockBackfillSkip: ptr(true), - ChainType: ptr("Optimism"), + ChainType: commonconfig.NewChainTypeConfig("Optimism"), FinalityDepth: ptr[uint32](42), FinalityTagEnabled: ptr[bool](false), FlagsContractAddress: mustAddress("0xae4E781a6218A8031764928E88d457937A954fC3"), @@ -1625,7 +1625,7 @@ func TestConfig_warnings(t *testing.T) { { name: "Value warning - ChainType=xdai is deprecated", config: Config{ - EVM: evmcfg.EVMConfigs{{Chain: evmcfg.Chain{ChainType: ptr(string(commonconfig.ChainXDai))}}}, + EVM: evmcfg.EVMConfigs{{Chain: evmcfg.Chain{ChainType: commonconfig.NewChainTypeConfig("xdai")}}}, }, expectedErrors: []string{"EVM.ChainType: invalid value (xdai): deprecated and will be removed in v2.13.0, use 'gnosis' instead"}, }, diff --git a/core/services/ocr/contract_tracker.go b/core/services/ocr/contract_tracker.go index 1d9076b8322..94ad1237e90 100644 --- a/core/services/ocr/contract_tracker.go +++ b/core/services/ocr/contract_tracker.go @@ -400,7 +400,7 @@ func (t *OCRContractTracker) LatestBlockHeight(ctx context.Context) (blockheight // care about the block height; we have no way of getting the L1 block // height anyway return 0, nil - case "", config.ChainArbitrum, config.ChainCelo, config.ChainGnosis, config.ChainKroma, config.ChainOptimismBedrock, config.ChainScroll, config.ChainWeMix, config.ChainXDai, config.ChainXLayer, config.ChainZkSync: + case "", config.ChainArbitrum, config.ChainCelo, config.ChainGnosis, config.ChainKroma, config.ChainOptimismBedrock, config.ChainScroll, config.ChainWeMix, config.ChainXLayer, config.ChainZkSync: // continue } latestBlockHeight := t.getLatestBlockHeight() diff --git a/core/services/ocrcommon/block_translator.go b/core/services/ocrcommon/block_translator.go index 6ef64499fa9..06fd9941992 100644 --- a/core/services/ocrcommon/block_translator.go +++ b/core/services/ocrcommon/block_translator.go @@ -21,7 +21,7 @@ func NewBlockTranslator(cfg Config, client evmclient.Client, lggr logger.Logger) switch cfg.ChainType() { case config.ChainArbitrum: return NewArbitrumBlockTranslator(client, lggr) - case "", config.ChainCelo, config.ChainGnosis, config.ChainKroma, config.ChainMetis, config.ChainOptimismBedrock, config.ChainScroll, config.ChainWeMix, config.ChainXDai, config.ChainXLayer, config.ChainZkSync: + case "", config.ChainCelo, config.ChainGnosis, config.ChainKroma, config.ChainMetis, config.ChainOptimismBedrock, config.ChainScroll, config.ChainWeMix, config.ChainXLayer, config.ChainZkSync: fallthrough default: return &l1BlockTranslator{} diff --git a/testdata/scripts/node/validate/warnings.txtar b/testdata/scripts/node/validate/warnings.txtar index cf121e959e1..54de3227a9e 100644 --- a/testdata/scripts/node/validate/warnings.txtar +++ b/testdata/scripts/node/validate/warnings.txtar @@ -9,6 +9,15 @@ CollectorTarget = 'otel-collector:4317' TLSCertPath = 'something' Mode = 'unencrypted' +[[EVM]] +ChainID = '10200' +ChainType = 'xdai' + +[[EVM.Nodes]] +Name = 'fake' +WSURL = 'wss://foo.bar/ws' +HTTPURL = 'https://foo.bar' + -- secrets.toml -- [Database] URL = 'postgresql://user:pass1234567890abcd@localhost:5432/dbname?sslmode=disable' @@ -32,6 +41,15 @@ CollectorTarget = 'otel-collector:4317' Mode = 'unencrypted' TLSCertPath = 'something' +[[EVM]] +ChainID = '10200' +ChainType = 'xdai' + +[[EVM.Nodes]] +Name = 'fake' +WSURL = 'wss://foo.bar/ws' +HTTPURL = 'https://foo.bar' + # Effective Configuration, with defaults applied: InsecureFastScrypt = false RootDir = '~/.chainlink' @@ -284,6 +302,94 @@ DeltaDial = '15s' DeltaReconcile = '1m0s' ListenAddresses = [] +[[EVM]] +ChainID = '10200' +AutoCreateKey = true +BlockBackfillDepth = 10 +BlockBackfillSkip = false +ChainType = 'xdai' +FinalityDepth = 100 +FinalityTagEnabled = false +LogBackfillBatchSize = 1000 +LogPollInterval = '5s' +LogKeepBlocksDepth = 100000 +LogPrunePageSize = 0 +BackupLogPollerBlockDelay = 100 +MinIncomingConfirmations = 3 +MinContractPayment = '0.00001 link' +NonceAutoSync = true +NoNewHeadsThreshold = '3m0s' +RPCDefaultBatchSize = 250 +RPCBlockQueryDelay = 1 + +[EVM.Transactions] +ForwardersEnabled = false +MaxInFlight = 16 +MaxQueued = 250 +ReaperInterval = '1h0m0s' +ReaperThreshold = '168h0m0s' +ResendAfterThreshold = '1m0s' + +[EVM.BalanceMonitor] +Enabled = true + +[EVM.GasEstimator] +Mode = 'BlockHistory' +PriceDefault = '20 gwei' +PriceMax = '500 gwei' +PriceMin = '1 gwei' +LimitDefault = 500000 +LimitMax = 500000 +LimitMultiplier = '1' +LimitTransfer = 21000 +BumpMin = '5 gwei' +BumpPercent = 20 +BumpThreshold = 3 +EIP1559DynamicFees = true +FeeCapDefault = '100 gwei' +TipCapDefault = '1 wei' +TipCapMin = '1 wei' + +[EVM.GasEstimator.BlockHistory] +BatchSize = 25 +BlockHistorySize = 8 +CheckInclusionBlocks = 12 +CheckInclusionPercentile = 90 +TransactionPercentile = 60 + +[EVM.HeadTracker] +HistoryDepth = 100 +MaxBufferSize = 3 +SamplingInterval = '1s' + +[EVM.NodePool] +PollFailureThreshold = 5 +PollInterval = '10s' +SelectionMode = 'HighestHead' +SyncThreshold = 5 +LeaseDuration = '0s' +NodeIsSyncingEnabled = false +FinalizedBlockPollInterval = '5s' + +[EVM.OCR] +ContractConfirmations = 4 +ContractTransmitterTransmitTimeout = '10s' +DatabaseTimeout = '10s' +DeltaCOverride = '168h0m0s' +DeltaCJitterOverride = '1h0m0s' +ObservationGracePeriod = '1s' + +[EVM.OCR2] +[EVM.OCR2.Automation] +GasLimit = 5400000 + +[[EVM.Nodes]] +Name = 'fake' +WSURL = 'wss://foo.bar/ws' +HTTPURL = 'https://foo.bar' + # Configuration warning: -Tracing.TLSCertPath: invalid value (something): must be empty when Tracing.Mode is 'unencrypted' -Valid configuration. +2 errors: + - EVM.ChainType: invalid value (xdai): deprecated and will be removed in v2.13.0, use 'gnosis' instead + - Tracing.TLSCertPath: invalid value (something): must be empty when Tracing.Mode is 'unencrypted' +Valid configuration. \ No newline at end of file