From 08c9f89bcabbde28d40349e010f629ea840bfd9a Mon Sep 17 00:00:00 2001 From: Dimitris Grigoriou Date: Wed, 1 Nov 2023 22:41:45 +0200 Subject: [PATCH] Introduce generalized multi node client (#10907) * Introduce generalized multi node client * Export EVM RPC client * Unexport clientAPI interface * Add BlockDifficulty to mocks * Unexport node state * Rename error parsing * Nit fixes * Rename error classification functions * Update NodeSelection names * Deprecate StartStopOnce --- common/chains/client/models.go | 22 + common/client/multi_node.go | 669 +++++++++++ common/client/node.go | 282 +++++ common/client/node_fsm.go | 266 +++++ common/client/node_lifecycle.go | 431 +++++++ common/client/node_selector_highest_head.go | 41 + common/client/node_selector_priority_level.go | 129 ++ common/client/node_selector_round_robin.go | 50 + .../client/node_selector_total_difficulty.go | 54 + common/client/send_only_node.go | 183 +++ common/client/send_only_node_lifecycle.go | 66 ++ common/client/types.go | 133 +++ common/headtracker/types/mocks/head.go | 17 + common/types/head.go | 6 + common/types/mocks/head.go | 17 + common/types/receipt.go | 14 + core/chains/evm/chain.go | 18 + core/chains/evm/client/chain_client.go | 274 +++++ core/chains/evm/client/client.go | 2 +- core/chains/evm/client/client_test.go | 442 ++++--- core/chains/evm/client/errors.go | 14 +- core/chains/evm/client/helpers_test.go | 65 + core/chains/evm/client/rpc_client.go | 1046 +++++++++++++++++ core/chains/evm/txmgr/client.go | 2 +- core/chains/evm/types/models.go | 4 + 25 files changed, 4110 insertions(+), 137 deletions(-) create mode 100644 common/client/multi_node.go create mode 100644 common/client/node.go create mode 100644 common/client/node_fsm.go create mode 100644 common/client/node_lifecycle.go create mode 100644 common/client/node_selector_highest_head.go create mode 100644 common/client/node_selector_priority_level.go create mode 100644 common/client/node_selector_round_robin.go create mode 100644 common/client/node_selector_total_difficulty.go create mode 100644 common/client/send_only_node.go create mode 100644 common/client/send_only_node_lifecycle.go create mode 100644 common/client/types.go create mode 100644 common/types/receipt.go create mode 100644 core/chains/evm/client/chain_client.go create mode 100644 core/chains/evm/client/rpc_client.go diff --git a/common/chains/client/models.go b/common/chains/client/models.go index ebe7bb7576d..bd974f901fc 100644 --- a/common/chains/client/models.go +++ b/common/chains/client/models.go @@ -1,5 +1,9 @@ package client +import ( + "fmt" +) + type SendTxReturnCode int // SendTxReturnCode is a generalized client error that dictates what should be the next action, depending on the RPC error response. @@ -15,3 +19,21 @@ const ( ExceedsMaxFee // Attempt's fee was higher than the node's limit and got rejected. FeeOutOfValidRange // This error is returned when we use a fee price suggested from an RPC, but the network rejects the attempt due to an invalid range(mostly used by L2 chains). Retry by requesting a new suggested fee price. ) + +type NodeTier int + +const ( + Primary = NodeTier(iota) + Secondary +) + +func (n NodeTier) String() string { + switch n { + case Primary: + return "primary" + case Secondary: + return "secondary" + default: + return fmt.Sprintf("NodeTier(%d)", n) + } +} diff --git a/common/client/multi_node.go b/common/client/multi_node.go new file mode 100644 index 00000000000..0da3b89076b --- /dev/null +++ b/common/client/multi_node.go @@ -0,0 +1,669 @@ +package client + +import ( + "context" + "fmt" + "math/big" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/smartcontractkit/chainlink-relay/pkg/services" + + "github.com/smartcontractkit/chainlink/v2/common/chains/client" + feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/core/assets" + "github.com/smartcontractkit/chainlink/v2/core/config" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +var ( + // PromMultiNodeRPCNodeStates reports current RPC node state + PromMultiNodeRPCNodeStates = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "multi_node_states", + Help: "The number of RPC nodes currently in the given state for the given chain", + }, []string{"network", "chainId", "state"}) + ErroringNodeError = fmt.Errorf("no live nodes available") +) + +const ( + NodeSelectionModeHighestHead = "HighestHead" + NodeSelectionModeRoundRobin = "RoundRobin" + NodeSelectionModeTotalDifficulty = "TotalDifficulty" + NodeSelectionModePriorityLevel = "PriorityLevel" +) + +type NodeSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +] interface { + // Select returns a Node, or nil if none can be selected. + // Implementation must be thread-safe. + Select() Node[CHAIN_ID, HEAD, RPC] + // Name returns the strategy name, e.g. "HighestHead" or "RoundRobin" + Name() string +} + +// MultiNode is a generalized multi node client interface that includes methods to interact with different chains. +// It also handles multiple node RPC connections simultaneously. +type MultiNode[ + CHAIN_ID types.ID, + SEQ types.Sequence, + ADDR types.Hashable, + BLOCK_HASH types.Hashable, + TX any, + TX_HASH types.Hashable, + EVENT any, + EVENT_OPS any, + TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], + FEE feetypes.Fee, + HEAD types.Head[BLOCK_HASH], + RPC_CLIENT RPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD], +] interface { + clientAPI[ + CHAIN_ID, + SEQ, + ADDR, + BLOCK_HASH, + TX, + TX_HASH, + EVENT, + EVENT_OPS, + TX_RECEIPT, + FEE, + HEAD, + ] + Close() error + NodeStates() map[string]string + SelectNodeRPC() (RPC_CLIENT, error) + + BatchCallContextAll(ctx context.Context, b []any) error + ConfiguredChainID() CHAIN_ID + IsL2() bool +} + +type multiNode[ + CHAIN_ID types.ID, + SEQ types.Sequence, + ADDR types.Hashable, + BLOCK_HASH types.Hashable, + TX any, + TX_HASH types.Hashable, + EVENT any, + EVENT_OPS any, + TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], + FEE feetypes.Fee, + HEAD types.Head[BLOCK_HASH], + RPC_CLIENT RPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD], +] struct { + services.StateMachine + nodes []Node[CHAIN_ID, HEAD, RPC_CLIENT] + sendonlys []SendOnlyNode[CHAIN_ID, RPC_CLIENT] + chainID CHAIN_ID + chainType config.ChainType + logger logger.Logger + selectionMode string + noNewHeadsThreshold time.Duration + nodeSelector NodeSelector[CHAIN_ID, HEAD, RPC_CLIENT] + leaseDuration time.Duration + leaseTicker *time.Ticker + chainFamily string + + activeMu sync.RWMutex + activeNode Node[CHAIN_ID, HEAD, RPC_CLIENT] + + chStop utils.StopChan + wg sync.WaitGroup + + sendOnlyErrorParser func(err error) client.SendTxReturnCode +} + +func NewMultiNode[ + CHAIN_ID types.ID, + SEQ types.Sequence, + ADDR types.Hashable, + BLOCK_HASH types.Hashable, + TX any, + TX_HASH types.Hashable, + EVENT any, + EVENT_OPS any, + TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], + FEE feetypes.Fee, + HEAD types.Head[BLOCK_HASH], + RPC_CLIENT RPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD], +]( + logger logger.Logger, + selectionMode string, + leaseDuration time.Duration, + noNewHeadsThreshold time.Duration, + nodes []Node[CHAIN_ID, HEAD, RPC_CLIENT], + sendonlys []SendOnlyNode[CHAIN_ID, RPC_CLIENT], + chainID CHAIN_ID, + chainType config.ChainType, + chainFamily string, + sendOnlyErrorParser func(err error) client.SendTxReturnCode, +) MultiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT] { + nodeSelector := func() NodeSelector[CHAIN_ID, HEAD, RPC_CLIENT] { + switch selectionMode { + case NodeSelectionModeHighestHead: + return NewHighestHeadNodeSelector[CHAIN_ID, HEAD, RPC_CLIENT](nodes) + case NodeSelectionModeRoundRobin: + return NewRoundRobinSelector[CHAIN_ID, HEAD, RPC_CLIENT](nodes) + case NodeSelectionModeTotalDifficulty: + return NewTotalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC_CLIENT](nodes) + case NodeSelectionModePriorityLevel: + return NewPriorityLevelNodeSelector[CHAIN_ID, HEAD, RPC_CLIENT](nodes) + default: + panic(fmt.Sprintf("unsupported NodeSelectionMode: %s", selectionMode)) + } + }() + + lggr := logger.Named("MultiNode").With("chainID", chainID.String()) + + c := &multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]{ + nodes: nodes, + sendonlys: sendonlys, + chainID: chainID, + chainType: chainType, + logger: lggr, + selectionMode: selectionMode, + noNewHeadsThreshold: noNewHeadsThreshold, + nodeSelector: nodeSelector, + chStop: make(chan struct{}), + leaseDuration: leaseDuration, + chainFamily: chainFamily, + sendOnlyErrorParser: sendOnlyErrorParser, + } + + c.logger.Debugf("The MultiNode is configured to use NodeSelectionMode: %s", selectionMode) + + return c +} + +// Dial starts every node in the pool +// +// Nodes handle their own redialing and runloops, so this function does not +// return any error if the nodes aren't available +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) Dial(ctx context.Context) error { + return c.StartOnce("MultiNode", func() (merr error) { + if len(c.nodes) == 0 { + return errors.Errorf("no available nodes for chain %s", c.chainID.String()) + } + var ms services.MultiStart + for _, n := range c.nodes { + if n.ConfiguredChainID().String() != c.chainID.String() { + return ms.CloseBecause(errors.Errorf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", n.String(), n.ConfiguredChainID().String(), c.chainID.String())) + } + rawNode, ok := n.(*node[CHAIN_ID, HEAD, RPC_CLIENT]) + if ok { + // This is a bit hacky but it allows the node to be aware of + // pool state and prevent certain state transitions that might + // otherwise leave no nodes available. It is better to have one + // node in a degraded state than no nodes at all. + rawNode.nLiveNodes = c.nLiveNodes + } + // node will handle its own redialing and automatic recovery + if err := ms.Start(ctx, n); err != nil { + return err + } + } + for _, s := range c.sendonlys { + if s.ConfiguredChainID().String() != c.chainID.String() { + return ms.CloseBecause(errors.Errorf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", s.String(), s.ConfiguredChainID().String(), c.chainID.String())) + } + if err := ms.Start(ctx, s); err != nil { + return err + } + } + c.wg.Add(1) + go c.runLoop() + + if c.leaseDuration.Seconds() > 0 && c.selectionMode != NodeSelectionModeRoundRobin { + c.logger.Infof("The MultiNode will switch to best node every %s", c.leaseDuration.String()) + c.wg.Add(1) + go c.checkLeaseLoop() + } else { + c.logger.Info("Best node switching is disabled") + } + + return nil + }) +} + +// Close tears down the MultiNode and closes all nodes +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) Close() error { + return c.StopOnce("MultiNode", func() error { + close(c.chStop) + c.wg.Wait() + + return services.CloseAll(services.MultiCloser(c.nodes), services.MultiCloser(c.sendonlys)) + }) +} + +// SelectNodeRPC returns an RPC of an active node. If there are no active nodes it returns an error. +// Call this method from your chain-specific client implementation to access any chain-specific rpc calls. +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) SelectNodeRPC() (rpc RPC_CLIENT, err error) { + n, err := c.selectNode() + if err != nil { + return rpc, err + } + return n.RPC(), nil + +} + +// selectNode returns the active Node, if it is still nodeStateAlive, otherwise it selects a new one from the NodeSelector. +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) selectNode() (node Node[CHAIN_ID, HEAD, RPC_CLIENT], err error) { + c.activeMu.RLock() + node = c.activeNode + c.activeMu.RUnlock() + if node != nil && node.State() == nodeStateAlive { + return // still alive + } + + // select a new one + c.activeMu.Lock() + defer c.activeMu.Unlock() + node = c.activeNode + if node != nil && node.State() == nodeStateAlive { + return // another goroutine beat us here + } + + c.activeNode = c.nodeSelector.Select() + + if c.activeNode == nil { + c.logger.Criticalw("No live RPC nodes available", "NodeSelectionMode", c.nodeSelector.Name()) + errmsg := fmt.Errorf("no live nodes available for chain %s", c.chainID.String()) + c.SvcErrBuffer.Append(errmsg) + err = ErroringNodeError + } + + return c.activeNode, err +} + +// nLiveNodes returns the number of currently alive nodes, as well as the highest block number and greatest total difficulty. +// totalDifficulty will be 0 if all nodes return nil. +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) nLiveNodes() (nLiveNodes int, blockNumber int64, totalDifficulty *utils.Big) { + totalDifficulty = utils.NewBigI(0) + for _, n := range c.nodes { + if s, num, td := n.StateAndLatest(); s == nodeStateAlive { + nLiveNodes++ + if num > blockNumber { + blockNumber = num + } + if td != nil && td.Cmp(totalDifficulty) > 0 { + totalDifficulty = td + } + } + } + return +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) checkLease() { + bestNode := c.nodeSelector.Select() + for _, n := range c.nodes { + // Terminate client subscriptions. Services are responsible for reconnecting, which will be routed to the new + // best node. Only terminate connections with more than 1 subscription to account for the aliveLoop subscription + if n.State() == nodeStateAlive && n != bestNode && n.SubscribersCount() > 1 { + c.logger.Infof("Switching to best node from %q to %q", n.String(), bestNode.String()) + n.UnsubscribeAllExceptAliveLoop() + } + } + + c.activeMu.Lock() + if bestNode != c.activeNode { + c.activeNode = bestNode + } + c.activeMu.Unlock() +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) checkLeaseLoop() { + defer c.wg.Done() + c.leaseTicker = time.NewTicker(c.leaseDuration) + defer c.leaseTicker.Stop() + + for { + select { + case <-c.leaseTicker.C: + c.checkLease() + case <-c.chStop: + return + } + } +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) runLoop() { + defer c.wg.Done() + + c.report() + + // Prometheus' default interval is 15s, set this to under 7.5s to avoid + // aliasing (see: https://en.wikipedia.org/wiki/Nyquist_frequency) + reportInterval := 6500 * time.Millisecond + monitor := time.NewTicker(utils.WithJitter(reportInterval)) + defer monitor.Stop() + + for { + select { + case <-monitor.C: + c.report() + case <-c.chStop: + return + } + } +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) report() { + type nodeWithState struct { + Node string + State string + } + + var total, dead int + counts := make(map[nodeState]int) + nodeStates := make([]nodeWithState, len(c.nodes)) + for i, n := range c.nodes { + state := n.State() + nodeStates[i] = nodeWithState{n.String(), state.String()} + total++ + if state != nodeStateAlive { + dead++ + } + counts[state]++ + } + for _, state := range allNodeStates { + count := counts[state] + PromMultiNodeRPCNodeStates.WithLabelValues(c.chainFamily, c.chainID.String(), state.String()).Set(float64(count)) + } + + live := total - dead + c.logger.Tracew(fmt.Sprintf("MultiNode state: %d/%d nodes are alive", live, total), "nodeStates", nodeStates) + if total == dead { + rerr := fmt.Errorf("no primary nodes available: 0/%d nodes are alive", total) + c.logger.Criticalw(rerr.Error(), "nodeStates", nodeStates) + c.SvcErrBuffer.Append(rerr) + } else if dead > 0 { + c.logger.Errorw(fmt.Sprintf("At least one primary node is dead: %d/%d nodes are alive", live, total), "nodeStates", nodeStates) + } +} + +// ClientAPI methods +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) BalanceAt(ctx context.Context, account ADDR, blockNumber *big.Int) (*big.Int, error) { + n, err := c.selectNode() + if err != nil { + return nil, err + } + return n.RPC().BalanceAt(ctx, account, blockNumber) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) BatchCallContext(ctx context.Context, b []any) error { + n, err := c.selectNode() + if err != nil { + return err + } + return n.RPC().BatchCallContext(ctx, b) +} + +// BatchCallContextAll calls BatchCallContext for every single node including +// sendonlys. +// CAUTION: This should only be used for mass re-transmitting transactions, it +// might have unexpected effects to use it for anything else. +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) BatchCallContextAll(ctx context.Context, b []any) error { + var wg sync.WaitGroup + defer wg.Wait() + + main, selectionErr := c.selectNode() + var all []SendOnlyNode[CHAIN_ID, RPC_CLIENT] + for _, n := range c.nodes { + all = append(all, n) + } + all = append(all, c.sendonlys...) + for _, n := range all { + if n == main { + // main node is used at the end for the return value + continue + } + // Parallel call made to all other nodes with ignored return value + wg.Add(1) + go func(n SendOnlyNode[CHAIN_ID, RPC_CLIENT]) { + defer wg.Done() + err := n.RPC().BatchCallContext(ctx, b) + if err != nil { + c.logger.Debugw("Secondary node BatchCallContext failed", "err", err) + } else { + c.logger.Trace("Secondary node BatchCallContext success") + } + }(n) + } + + if selectionErr != nil { + return selectionErr + } + return main.RPC().BatchCallContext(ctx, b) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) BlockByHash(ctx context.Context, hash BLOCK_HASH) (h HEAD, err error) { + n, err := c.selectNode() + if err != nil { + return h, err + } + return n.RPC().BlockByHash(ctx, hash) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) BlockByNumber(ctx context.Context, number *big.Int) (h HEAD, err error) { + n, err := c.selectNode() + if err != nil { + return h, err + } + return n.RPC().BlockByNumber(ctx, number) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { + n, err := c.selectNode() + if err != nil { + return err + } + return n.RPC().CallContext(ctx, result, method, args...) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) CallContract( + ctx context.Context, + attempt interface{}, + blockNumber *big.Int, +) (rpcErr []byte, extractErr error) { + n, err := c.selectNode() + if err != nil { + return rpcErr, err + } + return n.RPC().CallContract(ctx, attempt, blockNumber) +} + +// ChainID makes a direct RPC call. In most cases it should be better to use the configured chain id instead by +// calling ConfiguredChainID. +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) ChainID(ctx context.Context) (id CHAIN_ID, err error) { + n, err := c.selectNode() + if err != nil { + return id, err + } + return n.RPC().ChainID(ctx) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) ChainType() config.ChainType { + return c.chainType +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) CodeAt(ctx context.Context, account ADDR, blockNumber *big.Int) (code []byte, err error) { + n, err := c.selectNode() + if err != nil { + return code, err + } + return n.RPC().CodeAt(ctx, account, blockNumber) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) ConfiguredChainID() CHAIN_ID { + return c.chainID +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) EstimateGas(ctx context.Context, call any) (gas uint64, err error) { + n, err := c.selectNode() + if err != nil { + return gas, err + } + return n.RPC().EstimateGas(ctx, call) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) FilterEvents(ctx context.Context, query EVENT_OPS) (e []EVENT, err error) { + n, err := c.selectNode() + if err != nil { + return e, err + } + return n.RPC().FilterEvents(ctx, query) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) IsL2() bool { + return c.ChainType().IsL2() +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) LatestBlockHeight(ctx context.Context) (h *big.Int, err error) { + n, err := c.selectNode() + if err != nil { + return h, err + } + return n.RPC().LatestBlockHeight(ctx) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) LINKBalance(ctx context.Context, accountAddress ADDR, linkAddress ADDR) (b *assets.Link, err error) { + n, err := c.selectNode() + if err != nil { + return b, err + } + return n.RPC().LINKBalance(ctx, accountAddress, linkAddress) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) NodeStates() (states map[string]string) { + states = make(map[string]string) + for _, n := range c.nodes { + states[n.Name()] = n.State().String() + } + for _, s := range c.sendonlys { + states[s.Name()] = s.State().String() + } + return +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) PendingSequenceAt(ctx context.Context, addr ADDR) (s SEQ, err error) { + n, err := c.selectNode() + if err != nil { + return s, err + } + return n.RPC().PendingSequenceAt(ctx, addr) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) SendEmptyTransaction( + ctx context.Context, + newTxAttempt func(seq SEQ, feeLimit uint32, fee FEE, fromAddress ADDR) (attempt any, err error), + seq SEQ, + gasLimit uint32, + fee FEE, + fromAddress ADDR, +) (txhash string, err error) { + n, err := c.selectNode() + if err != nil { + return txhash, err + } + return n.RPC().SendEmptyTransaction(ctx, newTxAttempt, seq, gasLimit, fee, fromAddress) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) SendTransaction(ctx context.Context, tx TX) error { + main, nodeError := c.selectNode() + var all []SendOnlyNode[CHAIN_ID, RPC_CLIENT] + for _, n := range c.nodes { + all = append(all, n) + } + all = append(all, c.sendonlys...) + for _, n := range all { + if n == main { + // main node is used at the end for the return value + continue + } + // Parallel send to all other nodes with ignored return value + // Async - we do not want to block the main thread with secondary nodes + // in case they are unreliable/slow. + // It is purely a "best effort" send. + // Resource is not unbounded because the default context has a timeout. + ok := c.IfNotStopped(func() { + // Must wrap inside IfNotStopped to avoid waitgroup racing with Close + c.wg.Add(1) + go func(n SendOnlyNode[CHAIN_ID, RPC_CLIENT]) { + defer c.wg.Done() + + txErr := n.RPC().SendTransaction(ctx, tx) + c.logger.Debugw("Sendonly node sent transaction", "name", n.String(), "tx", tx, "err", txErr) + sendOnlyError := c.sendOnlyErrorParser(txErr) + if sendOnlyError != client.Successful { + c.logger.Warnw("RPC returned error", "name", n.String(), "tx", tx, "err", txErr) + } + }(n) + }) + if !ok { + c.logger.Debug("Cannot send transaction on sendonly node; MultiNode is stopped", "node", n.String()) + } + } + if nodeError != nil { + return nodeError + } + return main.RPC().SendTransaction(ctx, tx) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) SequenceAt(ctx context.Context, account ADDR, blockNumber *big.Int) (s SEQ, err error) { + n, err := c.selectNode() + if err != nil { + return s, err + } + return n.RPC().SequenceAt(ctx, account, blockNumber) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) SimulateTransaction(ctx context.Context, tx TX) error { + n, err := c.selectNode() + if err != nil { + return err + } + return n.RPC().SimulateTransaction(ctx, tx) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) Subscribe(ctx context.Context, channel chan<- HEAD, args ...interface{}) (s types.Subscription, err error) { + n, err := c.selectNode() + if err != nil { + return s, err + } + return n.RPC().Subscribe(ctx, channel, args...) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) TokenBalance(ctx context.Context, account ADDR, tokenAddr ADDR) (b *big.Int, err error) { + n, err := c.selectNode() + if err != nil { + return b, err + } + return n.RPC().TokenBalance(ctx, account, tokenAddr) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) TransactionByHash(ctx context.Context, txHash TX_HASH) (tx TX, err error) { + n, err := c.selectNode() + if err != nil { + return tx, err + } + return n.RPC().TransactionByHash(ctx, txHash) +} + +func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]) TransactionReceipt(ctx context.Context, txHash TX_HASH) (txr TX_RECEIPT, err error) { + n, err := c.selectNode() + if err != nil { + return txr, err + } + return n.RPC().TransactionReceipt(ctx, txHash) +} diff --git a/common/client/node.go b/common/client/node.go new file mode 100644 index 00000000000..71b34452f02 --- /dev/null +++ b/common/client/node.go @@ -0,0 +1,282 @@ +package client + +import ( + "context" + "fmt" + "net/url" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/smartcontractkit/chainlink-relay/pkg/services" + + "github.com/smartcontractkit/chainlink/v2/common/chains/client" + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +const QueryTimeout = 10 * time.Second + +var errInvalidChainID = errors.New("invalid chain id") + +var ( + promPoolRPCNodeVerifies = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_verifies", + Help: "The total number of chain ID verifications for the given RPC node", + }, []string{"network", "chainID", "nodeName"}) + promPoolRPCNodeVerifiesFailed = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_verifies_failed", + Help: "The total number of failed chain ID verifications for the given RPC node", + }, []string{"network", "chainID", "nodeName"}) + promPoolRPCNodeVerifiesSuccess = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_verifies_success", + Help: "The total number of successful chain ID verifications for the given RPC node", + }, []string{"network", "chainID", "nodeName"}) +) + +type NodeConfig interface { + PollFailureThreshold() uint32 + PollInterval() time.Duration + SelectionMode() string + SyncThreshold() uint32 +} + +type Node[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +] interface { + // State returns nodeState + State() nodeState + // StateAndLatest returns nodeState with the latest received block number & total difficulty. + StateAndLatest() (nodeState, int64, *utils.Big) + // Name is a unique identifier for this node. + Name() string + String() string + RPC() RPC + SubscribersCount() int32 + UnsubscribeAllExceptAliveLoop() + ConfiguredChainID() CHAIN_ID + Order() int32 + Start(context.Context) error + Close() error +} + +type node[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +] struct { + services.StateMachine + lfcLog logger.Logger + name string + id int32 + chainID CHAIN_ID + nodePoolCfg NodeConfig + noNewHeadsThreshold time.Duration + order int32 + chainFamily string + + ws url.URL + http *url.URL + + rpc RPC + + stateMu sync.RWMutex // protects state* fields + state nodeState + // Each node is tracking the last received head number and total difficulty + stateLatestBlockNumber int64 + stateLatestTotalDifficulty *utils.Big + + // nodeCtx is the node lifetime's context + nodeCtx context.Context + // cancelNodeCtx cancels nodeCtx when stopping the node + cancelNodeCtx context.CancelFunc + // wg waits for subsidiary goroutines + wg sync.WaitGroup + + // nLiveNodes is a passed in function that allows this node to: + // 1. see how many live nodes there are in total, so we can prevent the last alive node in a pool from being + // moved to out-of-sync state. It is better to have one out-of-sync node than no nodes at all. + // 2. compare against the highest head (by number or difficulty) to ensure we don't fall behind too far. + nLiveNodes func() (count int, blockNumber int64, totalDifficulty *utils.Big) +} + +func NewNode[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +]( + nodeCfg NodeConfig, + noNewHeadsThreshold time.Duration, + lggr logger.Logger, + wsuri url.URL, + httpuri *url.URL, + name string, + id int32, + chainID CHAIN_ID, + nodeOrder int32, + rpc RPC, + chainFamily string, +) Node[CHAIN_ID, HEAD, RPC] { + n := new(node[CHAIN_ID, HEAD, RPC]) + n.name = name + n.id = id + n.chainID = chainID + n.nodePoolCfg = nodeCfg + n.noNewHeadsThreshold = noNewHeadsThreshold + n.ws = wsuri + n.order = nodeOrder + if httpuri != nil { + n.http = httpuri + } + n.nodeCtx, n.cancelNodeCtx = context.WithCancel(context.Background()) + lggr = lggr.Named("Node").With( + "nodeTier", client.Primary.String(), + "nodeName", name, + "node", n.String(), + "chainID", chainID, + "nodeOrder", n.order, + ) + n.lfcLog = lggr.Named("Lifecycle") + n.stateLatestBlockNumber = -1 + n.rpc = rpc + n.chainFamily = chainFamily + return n +} + +func (n *node[CHAIN_ID, HEAD, RPC]) String() string { + s := fmt.Sprintf("(%s)%s:%s", client.Primary.String(), n.name, n.ws.String()) + if n.http != nil { + s = s + fmt.Sprintf(":%s", n.http.String()) + } + return s +} + +func (n *node[CHAIN_ID, HEAD, RPC]) ConfiguredChainID() (chainID CHAIN_ID) { + return n.chainID +} + +func (n *node[CHAIN_ID, HEAD, RPC]) Name() string { + return n.name +} + +func (n *node[CHAIN_ID, HEAD, RPC]) RPC() RPC { + return n.rpc +} + +func (n *node[CHAIN_ID, HEAD, RPC]) SubscribersCount() int32 { + return n.rpc.SubscribersCount() +} + +func (n *node[CHAIN_ID, HEAD, RPC]) UnsubscribeAllExceptAliveLoop() { + n.rpc.UnsubscribeAllExceptAliveLoop() +} + +func (n *node[CHAIN_ID, HEAD, RPC]) Close() error { + return n.StopOnce(n.name, func() error { + defer func() { + n.wg.Wait() + n.rpc.Close() + }() + + n.stateMu.Lock() + defer n.stateMu.Unlock() + + n.cancelNodeCtx() + n.state = nodeStateClosed + return nil + }) +} + +// Start dials and verifies the node +// Should only be called once in a node's lifecycle +// Return value is necessary to conform to interface but this will never +// actually return an error. +func (n *node[CHAIN_ID, HEAD, RPC]) Start(startCtx context.Context) error { + return n.StartOnce(n.name, func() error { + n.start(startCtx) + return nil + }) +} + +// start initially dials the node and verifies chain ID +// This spins off lifecycle goroutines. +// Not thread-safe. +// Node lifecycle is synchronous: only one goroutine should be running at a +// time. +func (n *node[CHAIN_ID, HEAD, RPC]) start(startCtx context.Context) { + if n.state != nodeStateUndialed { + panic(fmt.Sprintf("cannot dial node with state %v", n.state)) + } + + if err := n.rpc.Dial(startCtx); err != nil { + n.lfcLog.Errorw("Dial failed: Node is unreachable", "err", err) + n.declareUnreachable() + return + } + n.setState(nodeStateDialed) + + if err := n.verify(startCtx); errors.Is(err, errInvalidChainID) { + n.lfcLog.Errorw("Verify failed: Node has the wrong chain ID", "err", err) + n.declareInvalidChainID() + return + } else if err != nil { + n.lfcLog.Errorw(fmt.Sprintf("Verify failed: %v", err), "err", err) + n.declareUnreachable() + return + } + + n.declareAlive() +} + +// verify checks that all connections to eth nodes match the given chain ID +// Not thread-safe +// Pure verify: does not mutate node "state" field. +func (n *node[CHAIN_ID, HEAD, RPC]) verify(callerCtx context.Context) (err error) { + promPoolRPCNodeVerifies.WithLabelValues(n.chainFamily, n.chainID.String(), n.name).Inc() + promFailed := func() { + promPoolRPCNodeVerifiesFailed.WithLabelValues(n.chainFamily, n.chainID.String(), n.name).Inc() + } + + st := n.State() + switch st { + case nodeStateDialed, nodeStateOutOfSync, nodeStateInvalidChainID: + default: + panic(fmt.Sprintf("cannot verify node in state %v", st)) + } + + var chainID CHAIN_ID + if chainID, err = n.rpc.ChainID(callerCtx); err != nil { + promFailed() + return errors.Wrapf(err, "failed to verify chain ID for node %s", n.name) + } else if chainID.String() != n.chainID.String() { + promFailed() + return errors.Wrapf( + errInvalidChainID, + "rpc ChainID doesn't match local chain ID: RPC ID=%s, local ID=%s, node name=%s", + chainID.String(), + n.chainID.String(), + n.name, + ) + } + + promPoolRPCNodeVerifiesSuccess.WithLabelValues(n.chainFamily, n.chainID.String(), n.name).Inc() + + return nil +} + +// disconnectAll disconnects all clients connected to the node +// WARNING: NOT THREAD-SAFE +// This must be called from within the n.stateMu lock +func (n *node[CHAIN_ID, HEAD, RPC]) disconnectAll() { + n.rpc.DisconnectAll() +} + +func (n *node[CHAIN_ID, HEAD, RPC]) Order() int32 { + return n.order +} diff --git a/common/client/node_fsm.go b/common/client/node_fsm.go new file mode 100644 index 00000000000..d4fc19140e9 --- /dev/null +++ b/common/client/node_fsm.go @@ -0,0 +1,266 @@ +package client + +import ( + "fmt" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +var ( + promPoolRPCNodeTransitionsToAlive = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_num_transitions_to_alive", + Help: transitionString(nodeStateAlive), + }, []string{"chainID", "nodeName"}) + promPoolRPCNodeTransitionsToInSync = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_num_transitions_to_in_sync", + Help: fmt.Sprintf("%s to %s", transitionString(nodeStateOutOfSync), nodeStateAlive), + }, []string{"chainID", "nodeName"}) + promPoolRPCNodeTransitionsToOutOfSync = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_num_transitions_to_out_of_sync", + Help: transitionString(nodeStateOutOfSync), + }, []string{"chainID", "nodeName"}) + promPoolRPCNodeTransitionsToUnreachable = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_num_transitions_to_unreachable", + Help: transitionString(nodeStateUnreachable), + }, []string{"chainID", "nodeName"}) + promPoolRPCNodeTransitionsToInvalidChainID = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_num_transitions_to_invalid_chain_id", + Help: transitionString(nodeStateInvalidChainID), + }, []string{"chainID", "nodeName"}) + promPoolRPCNodeTransitionsToUnusable = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_num_transitions_to_unusable", + Help: transitionString(nodeStateUnusable), + }, []string{"chainID", "nodeName"}) +) + +// nodeState represents the current state of the node +// Node is a FSM (finite state machine) +type nodeState int + +func (n nodeState) String() string { + switch n { + case nodeStateUndialed: + return "Undialed" + case nodeStateDialed: + return "Dialed" + case nodeStateInvalidChainID: + return "InvalidChainID" + case nodeStateAlive: + return "Alive" + case nodeStateUnreachable: + return "Unreachable" + case nodeStateUnusable: + return "Unusable" + case nodeStateOutOfSync: + return "OutOfSync" + case nodeStateClosed: + return "Closed" + default: + return fmt.Sprintf("nodeState(%d)", n) + } +} + +// GoString prints a prettier state +func (n nodeState) GoString() string { + return fmt.Sprintf("nodeState%s(%d)", n.String(), n) +} + +const ( + // nodeStateUndialed is the first state of a virgin node + nodeStateUndialed = nodeState(iota) + // nodeStateDialed is after a node has successfully dialed but before it has verified the correct chain ID + nodeStateDialed + // nodeStateInvalidChainID is after chain ID verification failed + nodeStateInvalidChainID + // nodeStateAlive is a healthy node after chain ID verification succeeded + nodeStateAlive + // nodeStateUnreachable is a node that cannot be dialed or has disconnected + nodeStateUnreachable + // nodeStateOutOfSync is a node that is accepting connections but exceeded + // the failure threshold without sending any new heads. It will be + // disconnected, then put into a revive loop and re-awakened after redial + // if a new head arrives + nodeStateOutOfSync + // nodeStateUnusable is a sendonly node that has an invalid URL that can never be reached + nodeStateUnusable + // nodeStateClosed is after the connection has been closed and the node is at the end of its lifecycle + nodeStateClosed + // nodeStateLen tracks the number of states + nodeStateLen +) + +// allNodeStates represents all possible states a node can be in +var allNodeStates []nodeState + +func init() { + for s := nodeState(0); s < nodeStateLen; s++ { + allNodeStates = append(allNodeStates, s) + } +} + +// FSM methods + +// State allows reading the current state of the node. +func (n *node[CHAIN_ID, HEAD, RPC]) State() nodeState { + n.stateMu.RLock() + defer n.stateMu.RUnlock() + return n.state +} + +func (n *node[CHAIN_ID, HEAD, RPC]) StateAndLatest() (nodeState, int64, *utils.Big) { + n.stateMu.RLock() + defer n.stateMu.RUnlock() + return n.state, n.stateLatestBlockNumber, n.stateLatestTotalDifficulty +} + +// setState is only used by internal state management methods. +// This is low-level; care should be taken by the caller to ensure the new state is a valid transition. +// State changes should always be synchronous: only one goroutine at a time should change state. +// n.stateMu should not be locked for long periods of time because external clients expect a timely response from n.State() +func (n *node[CHAIN_ID, HEAD, RPC]) setState(s nodeState) { + n.stateMu.Lock() + defer n.stateMu.Unlock() + n.state = s +} + +// declareXXX methods change the state and pass conrol off the new state +// management goroutine + +func (n *node[CHAIN_ID, HEAD, RPC]) declareAlive() { + n.transitionToAlive(func() { + n.lfcLog.Infow("RPC Node is online", "nodeState", n.state) + n.wg.Add(1) + go n.aliveLoop() + }) +} + +func (n *node[CHAIN_ID, HEAD, RPC]) transitionToAlive(fn func()) { + promPoolRPCNodeTransitionsToAlive.WithLabelValues(n.chainID.String(), n.name).Inc() + n.stateMu.Lock() + defer n.stateMu.Unlock() + if n.state == nodeStateClosed { + return + } + switch n.state { + case nodeStateDialed, nodeStateInvalidChainID: + n.state = nodeStateAlive + default: + panic(transitionFail(n.state, nodeStateAlive)) + } + fn() +} + +// declareInSync puts a node back into Alive state, allowing it to be used by +// pool consumers again +func (n *node[CHAIN_ID, HEAD, RPC]) declareInSync() { + n.transitionToInSync(func() { + n.lfcLog.Infow("RPC Node is back in sync", "nodeState", n.state) + n.wg.Add(1) + go n.aliveLoop() + }) +} + +func (n *node[CHAIN_ID, HEAD, RPC]) transitionToInSync(fn func()) { + promPoolRPCNodeTransitionsToAlive.WithLabelValues(n.chainID.String(), n.name).Inc() + promPoolRPCNodeTransitionsToInSync.WithLabelValues(n.chainID.String(), n.name).Inc() + n.stateMu.Lock() + defer n.stateMu.Unlock() + if n.state == nodeStateClosed { + return + } + switch n.state { + case nodeStateOutOfSync: + n.state = nodeStateAlive + default: + panic(transitionFail(n.state, nodeStateAlive)) + } + fn() +} + +// declareOutOfSync puts a node into OutOfSync state, disconnecting all current +// clients and making it unavailable for use until back in-sync. +func (n *node[CHAIN_ID, HEAD, RPC]) declareOutOfSync(isOutOfSync func(num int64, td *utils.Big) bool) { + n.transitionToOutOfSync(func() { + n.lfcLog.Errorw("RPC Node is out of sync", "nodeState", n.state) + n.wg.Add(1) + go n.outOfSyncLoop(isOutOfSync) + }) +} + +func (n *node[CHAIN_ID, HEAD, RPC]) transitionToOutOfSync(fn func()) { + promPoolRPCNodeTransitionsToOutOfSync.WithLabelValues(n.chainID.String(), n.name).Inc() + n.stateMu.Lock() + defer n.stateMu.Unlock() + if n.state == nodeStateClosed { + return + } + switch n.state { + case nodeStateAlive: + n.disconnectAll() + n.state = nodeStateOutOfSync + default: + panic(transitionFail(n.state, nodeStateOutOfSync)) + } + fn() +} + +func (n *node[CHAIN_ID, HEAD, RPC]) declareUnreachable() { + n.transitionToUnreachable(func() { + n.lfcLog.Errorw("RPC Node is unreachable", "nodeState", n.state) + n.wg.Add(1) + go n.unreachableLoop() + }) +} + +func (n *node[CHAIN_ID, HEAD, RPC]) transitionToUnreachable(fn func()) { + promPoolRPCNodeTransitionsToUnreachable.WithLabelValues(n.chainID.String(), n.name).Inc() + n.stateMu.Lock() + defer n.stateMu.Unlock() + if n.state == nodeStateClosed { + return + } + switch n.state { + case nodeStateUndialed, nodeStateDialed, nodeStateAlive, nodeStateOutOfSync, nodeStateInvalidChainID: + n.disconnectAll() + n.state = nodeStateUnreachable + default: + panic(transitionFail(n.state, nodeStateUnreachable)) + } + fn() +} + +func (n *node[CHAIN_ID, HEAD, RPC]) declareInvalidChainID() { + n.transitionToInvalidChainID(func() { + n.lfcLog.Errorw("RPC Node has the wrong chain ID", "nodeState", n.state) + n.wg.Add(1) + go n.invalidChainIDLoop() + }) +} + +func (n *node[CHAIN_ID, HEAD, RPC]) transitionToInvalidChainID(fn func()) { + promPoolRPCNodeTransitionsToInvalidChainID.WithLabelValues(n.chainID.String(), n.name).Inc() + n.stateMu.Lock() + defer n.stateMu.Unlock() + if n.state == nodeStateClosed { + return + } + switch n.state { + case nodeStateDialed, nodeStateOutOfSync: + n.disconnectAll() + n.state = nodeStateInvalidChainID + default: + panic(transitionFail(n.state, nodeStateInvalidChainID)) + } + fn() +} + +func transitionString(state nodeState) string { + return fmt.Sprintf("Total number of times node has transitioned to %s", state) +} + +func transitionFail(from nodeState, to nodeState) string { + return fmt.Sprintf("cannot transition from %#v to %#v", from, to) +} diff --git a/common/client/node_lifecycle.go b/common/client/node_lifecycle.go new file mode 100644 index 00000000000..149c5f01a6d --- /dev/null +++ b/common/client/node_lifecycle.go @@ -0,0 +1,431 @@ +package client + +import ( + "context" + "fmt" + "math" + "time" + + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +var ( + promPoolRPCNodeHighestSeenBlock = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "pool_rpc_node_highest_seen_block", + Help: "The highest seen block for the given RPC node", + }, []string{"chainID", "nodeName"}) + promPoolRPCNodeNumSeenBlocks = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_num_seen_blocks", + Help: "The total number of new blocks seen by the given RPC node", + }, []string{"chainID", "nodeName"}) + promPoolRPCNodePolls = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_polls_total", + Help: "The total number of poll checks for the given RPC node", + }, []string{"chainID", "nodeName"}) + promPoolRPCNodePollsFailed = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_polls_failed", + Help: "The total number of failed poll checks for the given RPC node", + }, []string{"chainID", "nodeName"}) + promPoolRPCNodePollsSuccess = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "pool_rpc_node_polls_success", + Help: "The total number of successful poll checks for the given RPC node", + }, []string{"chainID", "nodeName"}) +) + +// zombieNodeCheckInterval controls how often to re-check to see if we need to +// state change in case we have to force a state transition due to no available +// nodes. +// NOTE: This only applies to out-of-sync nodes if they are the last available node +func zombieNodeCheckInterval(noNewHeadsThreshold time.Duration) time.Duration { + interval := noNewHeadsThreshold + if interval <= 0 || interval > QueryTimeout { + interval = QueryTimeout + } + return utils.WithJitter(interval) +} + +func (n *node[CHAIN_ID, HEAD, RPC]) setLatestReceived(blockNumber int64, totalDifficulty *utils.Big) { + n.stateMu.Lock() + defer n.stateMu.Unlock() + n.stateLatestBlockNumber = blockNumber + n.stateLatestTotalDifficulty = totalDifficulty +} + +const ( + msgCannotDisable = "but cannot disable this connection because there are no other RPC endpoints, or all other RPC endpoints are dead." + msgDegradedState = "Chainlink is now operating in a degraded state and urgent action is required to resolve the issue" +) + +// Node is a FSM +// Each state has a loop that goes with it, which monitors the node and moves it into another state as necessary. +// Only one loop must run at a time. +// Each loop passes control onto the next loop as it exits, except when the node is Closed which terminates the loop permanently. + +// This handles node lifecycle for the ALIVE state +// Should only be run ONCE per node, after a successful Dial +func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { + defer n.wg.Done() + + { + // sanity check + state := n.State() + switch state { + case nodeStateAlive: + case nodeStateClosed: + return + default: + panic(fmt.Sprintf("aliveLoop can only run for node in Alive state, got: %s", state)) + } + } + + noNewHeadsTimeoutThreshold := n.noNewHeadsThreshold + pollFailureThreshold := n.nodePoolCfg.PollFailureThreshold() + pollInterval := n.nodePoolCfg.PollInterval() + + lggr := n.lfcLog.Named("Alive").With("noNewHeadsTimeoutThreshold", noNewHeadsTimeoutThreshold, "pollInterval", pollInterval, "pollFailureThreshold", pollFailureThreshold) + lggr.Tracew("Alive loop starting", "nodeState", n.State()) + + headsC := make(chan HEAD) + sub, err := n.rpc.Subscribe(n.nodeCtx, headsC, "newHeads") + if err != nil { + lggr.Errorw("Initial subscribe for heads failed", "nodeState", n.State()) + n.declareUnreachable() + return + } + n.rpc.SetAliveLoopSub(sub) + defer sub.Unsubscribe() + + var outOfSyncT *time.Ticker + var outOfSyncTC <-chan time.Time + if noNewHeadsTimeoutThreshold > 0 { + lggr.Debugw("Head liveness checking enabled", "nodeState", n.State()) + outOfSyncT = time.NewTicker(noNewHeadsTimeoutThreshold) + defer outOfSyncT.Stop() + outOfSyncTC = outOfSyncT.C + } else { + lggr.Debug("Head liveness checking disabled") + } + + var pollCh <-chan time.Time + if pollInterval > 0 { + lggr.Debug("Polling enabled") + pollT := time.NewTicker(pollInterval) + defer pollT.Stop() + pollCh = pollT.C + if pollFailureThreshold > 0 { + // polling can be enabled with no threshold to enable polling but + // the node will not be marked offline regardless of the number of + // poll failures + lggr.Debug("Polling liveness checking enabled") + } + } else { + lggr.Debug("Polling disabled") + } + + _, highestReceivedBlockNumber, _ := n.StateAndLatest() + var pollFailures uint32 + + for { + select { + case <-n.nodeCtx.Done(): + return + case <-pollCh: + var version string + promPoolRPCNodePolls.WithLabelValues(n.chainID.String(), n.name).Inc() + lggr.Tracew("Polling for version", "nodeState", n.State(), "pollFailures", pollFailures) + ctx, cancel := context.WithTimeout(n.nodeCtx, pollInterval) + version, err := n.RPC().ClientVersion(ctx) + cancel() + if err != nil { + // prevent overflow + if pollFailures < math.MaxUint32 { + promPoolRPCNodePollsFailed.WithLabelValues(n.chainID.String(), n.name).Inc() + pollFailures++ + } + lggr.Warnw(fmt.Sprintf("Poll failure, RPC endpoint %s failed to respond properly", n.String()), "err", err, "pollFailures", pollFailures, "nodeState", n.State()) + } else { + lggr.Debugw("Version poll successful", "nodeState", n.State(), "clientVersion", version) + promPoolRPCNodePollsSuccess.WithLabelValues(n.chainID.String(), n.name).Inc() + pollFailures = 0 + } + if pollFailureThreshold > 0 && pollFailures >= pollFailureThreshold { + lggr.Errorw(fmt.Sprintf("RPC endpoint failed to respond to %d consecutive polls", pollFailures), "pollFailures", pollFailures, "nodeState", n.State()) + if n.nLiveNodes != nil { + if l, _, _ := n.nLiveNodes(); l < 2 { + lggr.Criticalf("RPC endpoint failed to respond to polls; %s %s", msgCannotDisable, msgDegradedState) + continue + } + } + n.declareUnreachable() + return + } + _, num, td := n.StateAndLatest() + if outOfSync, liveNodes := n.syncStatus(num, td); outOfSync { + // note: there must be another live node for us to be out of sync + lggr.Errorw("RPC endpoint has fallen behind", "blockNumber", num, "totalDifficulty", td, "nodeState", n.State()) + if liveNodes < 2 { + lggr.Criticalf("RPC endpoint has fallen behind; %s %s", msgCannotDisable, msgDegradedState) + continue + } + n.declareOutOfSync(n.isOutOfSync) + return + } + case bh, open := <-headsC: + if !open { + lggr.Errorw("Subscription channel unexpectedly closed", "nodeState", n.State()) + n.declareUnreachable() + return + } + promPoolRPCNodeNumSeenBlocks.WithLabelValues(n.chainID.String(), n.name).Inc() + lggr.Tracew("Got head", "head", bh) + if bh.BlockNumber() > highestReceivedBlockNumber { + promPoolRPCNodeHighestSeenBlock.WithLabelValues(n.chainID.String(), n.name).Set(float64(bh.BlockNumber())) + lggr.Tracew("Got higher block number, resetting timer", "latestReceivedBlockNumber", highestReceivedBlockNumber, "blockNumber", bh.BlockNumber(), "nodeState", n.State()) + highestReceivedBlockNumber = bh.BlockNumber() + } else { + lggr.Tracew("Ignoring previously seen block number", "latestReceivedBlockNumber", highestReceivedBlockNumber, "blockNumber", bh.BlockNumber(), "nodeState", n.State()) + } + if outOfSyncT != nil { + outOfSyncT.Reset(noNewHeadsTimeoutThreshold) + } + n.setLatestReceived(bh.BlockNumber(), bh.BlockDifficulty()) + case err := <-sub.Err(): + lggr.Errorw("Subscription was terminated", "err", err, "nodeState", n.State()) + n.declareUnreachable() + return + case <-outOfSyncTC: + // We haven't received a head on the channel for at least the + // threshold amount of time, mark it broken + lggr.Errorw(fmt.Sprintf("RPC endpoint detected out of sync; no new heads received for %s (last head received was %v)", noNewHeadsTimeoutThreshold, highestReceivedBlockNumber), "nodeState", n.State(), "latestReceivedBlockNumber", highestReceivedBlockNumber, "noNewHeadsTimeoutThreshold", noNewHeadsTimeoutThreshold) + if n.nLiveNodes != nil { + if l, _, _ := n.nLiveNodes(); l < 2 { + lggr.Criticalf("RPC endpoint detected out of sync; %s %s", msgCannotDisable, msgDegradedState) + // We don't necessarily want to wait the full timeout to check again, we should + // check regularly and log noisily in this state + outOfSyncT.Reset(zombieNodeCheckInterval(n.noNewHeadsThreshold)) + continue + } + } + n.declareOutOfSync(func(num int64, td *utils.Big) bool { return num < highestReceivedBlockNumber }) + return + } + } +} + +func (n *node[CHAIN_ID, HEAD, RPC]) isOutOfSync(num int64, td *utils.Big) (outOfSync bool) { + outOfSync, _ = n.syncStatus(num, td) + return +} + +// syncStatus returns outOfSync true if num or td is more than SyncThresold behind the best node. +// Always returns outOfSync false for SyncThreshold 0. +// liveNodes is only included when outOfSync is true. +func (n *node[CHAIN_ID, HEAD, RPC]) syncStatus(num int64, td *utils.Big) (outOfSync bool, liveNodes int) { + if n.nLiveNodes == nil { + return // skip for tests + } + threshold := n.nodePoolCfg.SyncThreshold() + if threshold == 0 { + return // disabled + } + // Check against best node + ln, highest, greatest := n.nLiveNodes() + mode := n.nodePoolCfg.SelectionMode() + switch mode { + case NodeSelectionModeHighestHead, NodeSelectionModeRoundRobin, NodeSelectionModePriorityLevel: + return num < highest-int64(threshold), ln + case NodeSelectionModeTotalDifficulty: + bigThreshold := utils.NewBigI(int64(threshold)) + return td.Cmp(greatest.Sub(bigThreshold)) < 0, ln + default: + panic("unrecognized NodeSelectionMode: " + mode) + } +} + +const ( + msgReceivedBlock = "Received block for RPC node, waiting until back in-sync to mark as live again" + msgInSync = "RPC node back in sync" +) + +// outOfSyncLoop takes an OutOfSync node and waits until isOutOfSync returns false to go back to live status +func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td *utils.Big) bool) { + defer n.wg.Done() + + { + // sanity check + state := n.State() + switch state { + case nodeStateOutOfSync: + case nodeStateClosed: + return + default: + panic(fmt.Sprintf("outOfSyncLoop can only run for node in OutOfSync state, got: %s", state)) + } + } + + outOfSyncAt := time.Now() + + lggr := n.lfcLog.Named("OutOfSync") + lggr.Debugw("Trying to revive out-of-sync RPC node", "nodeState", n.State()) + + // Need to redial since out-of-sync nodes are automatically disconnected + if err := n.rpc.Dial(n.nodeCtx); err != nil { + lggr.Errorw("Failed to dial out-of-sync RPC node", "nodeState", n.State()) + n.declareUnreachable() + return + } + + // Manually re-verify since out-of-sync nodes are automatically disconnected + if err := n.verify(n.nodeCtx); err != nil { + lggr.Errorw(fmt.Sprintf("Failed to verify out-of-sync RPC node: %v", err), "err", err) + n.declareInvalidChainID() + return + } + + lggr.Tracew("Successfully subscribed to heads feed on out-of-sync RPC node", "nodeState", n.State()) + + ch := make(chan HEAD) + sub, err := n.rpc.Subscribe(n.nodeCtx, ch, "newHeads") + if err != nil { + lggr.Errorw("Failed to subscribe heads on out-of-sync RPC node", "nodeState", n.State(), "err", err) + n.declareUnreachable() + return + } + defer sub.Unsubscribe() + + for { + select { + case <-n.nodeCtx.Done(): + return + case head, open := <-ch: + if !open { + lggr.Error("Subscription channel unexpectedly closed", "nodeState", n.State()) + n.declareUnreachable() + return + } + n.setLatestReceived(head.BlockNumber(), head.BlockDifficulty()) + if !isOutOfSync(head.BlockNumber(), head.BlockDifficulty()) { + // back in-sync! flip back into alive loop + lggr.Infow(fmt.Sprintf("%s: %s. Node was out-of-sync for %s", msgInSync, n.String(), time.Since(outOfSyncAt)), "blockNumber", head.BlockNumber(), "totalDifficulty", "nodeState", n.State()) + n.declareInSync() + return + } + lggr.Debugw(msgReceivedBlock, "blockNumber", head.BlockNumber(), "totalDifficulty", "nodeState", n.State()) + case <-time.After(zombieNodeCheckInterval(n.noNewHeadsThreshold)): + if n.nLiveNodes != nil { + if l, _, _ := n.nLiveNodes(); l < 1 { + lggr.Critical("RPC endpoint is still out of sync, but there are no other available nodes. This RPC node will be forcibly moved back into the live pool in a degraded state") + n.declareInSync() + return + } + } + case err := <-sub.Err(): + lggr.Errorw("Subscription was terminated", "nodeState", n.State(), "err", err) + n.declareUnreachable() + return + } + } +} + +func (n *node[CHAIN_ID, HEAD, RPC]) unreachableLoop() { + defer n.wg.Done() + + { + // sanity check + state := n.State() + switch state { + case nodeStateUnreachable: + case nodeStateClosed: + return + default: + panic(fmt.Sprintf("unreachableLoop can only run for node in Unreachable state, got: %s", state)) + } + } + + unreachableAt := time.Now() + + lggr := n.lfcLog.Named("Unreachable") + lggr.Debugw("Trying to revive unreachable RPC node", "nodeState", n.State()) + + dialRetryBackoff := utils.NewRedialBackoff() + + for { + select { + case <-n.nodeCtx.Done(): + return + case <-time.After(dialRetryBackoff.Duration()): + lggr.Tracew("Trying to re-dial RPC node", "nodeState", n.State()) + + err := n.rpc.Dial(n.nodeCtx) + if err != nil { + lggr.Errorw(fmt.Sprintf("Failed to redial RPC node; still unreachable: %v", err), "err", err, "nodeState", n.State()) + continue + } + + n.setState(nodeStateDialed) + + err = n.verify(n.nodeCtx) + + if errors.Is(err, errInvalidChainID) { + lggr.Errorw("Failed to redial RPC node; remote endpoint returned the wrong chain ID", "err", err) + n.declareInvalidChainID() + return + } else if err != nil { + lggr.Errorw(fmt.Sprintf("Failed to redial RPC node; verify failed: %v", err), "err", err) + n.declareUnreachable() + return + } + + lggr.Infow(fmt.Sprintf("Successfully redialled and verified RPC node %s. Node was offline for %s", n.String(), time.Since(unreachableAt)), "nodeState", n.State()) + n.declareAlive() + return + } + } +} + +func (n *node[CHAIN_ID, HEAD, RPC]) invalidChainIDLoop() { + defer n.wg.Done() + + { + // sanity check + state := n.State() + switch state { + case nodeStateInvalidChainID: + case nodeStateClosed: + return + default: + panic(fmt.Sprintf("invalidChainIDLoop can only run for node in InvalidChainID state, got: %s", state)) + } + } + + invalidAt := time.Now() + + lggr := n.lfcLog.Named("InvalidChainID") + lggr.Debugw(fmt.Sprintf("Periodically re-checking RPC node %s with invalid chain ID", n.String()), "nodeState", n.State()) + + chainIDRecheckBackoff := utils.NewRedialBackoff() + + for { + select { + case <-n.nodeCtx.Done(): + return + case <-time.After(chainIDRecheckBackoff.Duration()): + err := n.verify(n.nodeCtx) + if errors.Is(err, errInvalidChainID) { + lggr.Errorw("Failed to verify RPC node; remote endpoint returned the wrong chain ID", "err", err) + continue + } else if err != nil { + lggr.Errorw(fmt.Sprintf("Unexpected error while verifying RPC node chain ID; %v", err), "err", err) + n.declareUnreachable() + return + } + lggr.Infow(fmt.Sprintf("Successfully verified RPC node. Node was offline for %s", time.Since(invalidAt)), "nodeState", n.State()) + n.declareAlive() + return + } + } +} diff --git a/common/client/node_selector_highest_head.go b/common/client/node_selector_highest_head.go new file mode 100644 index 00000000000..99a130004a9 --- /dev/null +++ b/common/client/node_selector_highest_head.go @@ -0,0 +1,41 @@ +package client + +import ( + "math" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +type highestHeadNodeSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +] []Node[CHAIN_ID, HEAD, RPC] + +func NewHighestHeadNodeSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +](nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { + return highestHeadNodeSelector[CHAIN_ID, HEAD, RPC](nodes) +} + +func (s highestHeadNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { + var highestHeadNumber int64 = math.MinInt64 + var highestHeadNodes []Node[CHAIN_ID, HEAD, RPC] + for _, n := range s { + state, currentHeadNumber, _ := n.StateAndLatest() + if state == nodeStateAlive && currentHeadNumber >= highestHeadNumber { + if highestHeadNumber < currentHeadNumber { + highestHeadNumber = currentHeadNumber + highestHeadNodes = nil + } + highestHeadNodes = append(highestHeadNodes, n) + } + } + return firstOrHighestPriority(highestHeadNodes) +} + +func (s highestHeadNodeSelector[CHAIN_ID, HEAD, RPC]) Name() string { + return NodeSelectionModeHighestHead +} diff --git a/common/client/node_selector_priority_level.go b/common/client/node_selector_priority_level.go new file mode 100644 index 00000000000..45cc62de077 --- /dev/null +++ b/common/client/node_selector_priority_level.go @@ -0,0 +1,129 @@ +package client + +import ( + "math" + "sort" + "sync/atomic" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +type priorityLevelNodeSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +] struct { + nodes []Node[CHAIN_ID, HEAD, RPC] + roundRobinCount []atomic.Uint32 +} + +type nodeWithPriority[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +] struct { + node Node[CHAIN_ID, HEAD, RPC] + priority int32 +} + +func NewPriorityLevelNodeSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +](nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { + return &priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]{ + nodes: nodes, + roundRobinCount: make([]atomic.Uint32, nrOfPriorityTiers(nodes)), + } +} + +func (s priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { + nodes := s.getHighestPriorityAliveTier() + + if len(nodes) == 0 { + return nil + } + priorityLevel := nodes[len(nodes)-1].priority + + // NOTE: Inc returns the number after addition, so we must -1 to get the "current" counter + count := s.roundRobinCount[priorityLevel].Add(1) - 1 + idx := int(count % uint32(len(nodes))) + + return nodes[idx].node +} + +func (s priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]) Name() string { + return NodeSelectionModePriorityLevel +} + +// getHighestPriorityAliveTier filters nodes that are not in state nodeStateAlive and +// returns only the highest tier of alive nodes +func (s priorityLevelNodeSelector[CHAIN_ID, HEAD, RPC]) getHighestPriorityAliveTier() []nodeWithPriority[CHAIN_ID, HEAD, RPC] { + var nodes []nodeWithPriority[CHAIN_ID, HEAD, RPC] + for _, n := range s.nodes { + if n.State() == nodeStateAlive { + nodes = append(nodes, nodeWithPriority[CHAIN_ID, HEAD, RPC]{n, n.Order()}) + } + } + + if len(nodes) == 0 { + return nil + } + + return removeLowerTiers(nodes) +} + +// removeLowerTiers take a slice of nodeWithPriority[CHAIN_ID, BLOCK_HASH, HEAD, RPC] and keeps only the highest tier +func removeLowerTiers[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +](nodes []nodeWithPriority[CHAIN_ID, HEAD, RPC]) []nodeWithPriority[CHAIN_ID, HEAD, RPC] { + sort.SliceStable(nodes, func(i, j int) bool { + return nodes[i].priority > nodes[j].priority + }) + + var nodes2 []nodeWithPriority[CHAIN_ID, HEAD, RPC] + currentPriority := nodes[len(nodes)-1].priority + + for _, n := range nodes { + if n.priority == currentPriority { + nodes2 = append(nodes2, n) + } + } + + return nodes2 +} + +// nrOfPriorityTiers calculates the total number of priority tiers +func nrOfPriorityTiers[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +](nodes []Node[CHAIN_ID, HEAD, RPC]) int32 { + highestPriority := int32(0) + for _, n := range nodes { + priority := n.Order() + if highestPriority < priority { + highestPriority = priority + } + } + return highestPriority + 1 +} + +// firstOrHighestPriority takes a list of nodes and returns the first one with the highest priority +func firstOrHighestPriority[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +](nodes []Node[CHAIN_ID, HEAD, RPC]) Node[CHAIN_ID, HEAD, RPC] { + hp := int32(math.MaxInt32) + var node Node[CHAIN_ID, HEAD, RPC] + for _, n := range nodes { + if n.Order() < hp { + hp = n.Order() + node = n + } + } + return node +} diff --git a/common/client/node_selector_round_robin.go b/common/client/node_selector_round_robin.go new file mode 100644 index 00000000000..5cdad7f52ee --- /dev/null +++ b/common/client/node_selector_round_robin.go @@ -0,0 +1,50 @@ +package client + +import ( + "sync/atomic" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +type roundRobinSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +] struct { + nodes []Node[CHAIN_ID, HEAD, RPC] + roundRobinCount atomic.Uint32 +} + +func NewRoundRobinSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +](nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { + return &roundRobinSelector[CHAIN_ID, HEAD, RPC]{ + nodes: nodes, + } +} + +func (s *roundRobinSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { + var liveNodes []Node[CHAIN_ID, HEAD, RPC] + for _, n := range s.nodes { + if n.State() == nodeStateAlive { + liveNodes = append(liveNodes, n) + } + } + + nNodes := len(liveNodes) + if nNodes == 0 { + return nil + } + + // NOTE: Inc returns the number after addition, so we must -1 to get the "current" counter + count := s.roundRobinCount.Add(1) - 1 + idx := int(count % uint32(nNodes)) + + return liveNodes[idx] +} + +func (s *roundRobinSelector[CHAIN_ID, HEAD, RPC]) Name() string { + return NodeSelectionModeRoundRobin +} diff --git a/common/client/node_selector_total_difficulty.go b/common/client/node_selector_total_difficulty.go new file mode 100644 index 00000000000..9b29642d033 --- /dev/null +++ b/common/client/node_selector_total_difficulty.go @@ -0,0 +1,54 @@ +package client + +import ( + "github.com/smartcontractkit/chainlink/v2/core/utils" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +type totalDifficultyNodeSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +] []Node[CHAIN_ID, HEAD, RPC] + +func NewTotalDifficultyNodeSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +](nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { + return totalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC](nodes) +} + +func (s totalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { + // NodeNoNewHeadsThreshold may not be enabled, in this case all nodes have td == nil + var highestTD *utils.Big + var nodes []Node[CHAIN_ID, HEAD, RPC] + var aliveNodes []Node[CHAIN_ID, HEAD, RPC] + + for _, n := range s { + state, _, currentTD := n.StateAndLatest() + if state != nodeStateAlive { + continue + } + + aliveNodes = append(aliveNodes, n) + if currentTD != nil && (highestTD == nil || currentTD.Cmp(highestTD) >= 0) { + if highestTD == nil || currentTD.Cmp(highestTD) > 0 { + highestTD = currentTD + nodes = nil + } + nodes = append(nodes, n) + } + } + + //If all nodes have td == nil pick one from the nodes that are alive + if len(nodes) == 0 { + return firstOrHighestPriority(aliveNodes) + } + return firstOrHighestPriority(nodes) +} + +func (s totalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC]) Name() string { + return NodeSelectionModeTotalDifficulty +} diff --git a/common/client/send_only_node.go b/common/client/send_only_node.go new file mode 100644 index 00000000000..3b382b2dcb0 --- /dev/null +++ b/common/client/send_only_node.go @@ -0,0 +1,183 @@ +package client + +import ( + "context" + "fmt" + "net/url" + "sync" + + "github.com/smartcontractkit/chainlink-relay/pkg/services" + + "github.com/smartcontractkit/chainlink/v2/common/chains/client" + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +type sendOnlyClient[ + CHAIN_ID types.ID, +] interface { + Close() + ChainID(context.Context) (CHAIN_ID, error) + DialHTTP() error +} + +// SendOnlyNode represents one node used as a sendonly +type SendOnlyNode[ + CHAIN_ID types.ID, + RPC sendOnlyClient[CHAIN_ID], +] interface { + // Start may attempt to connect to the node, but should only return error for misconfiguration - never for temporary errors. + Start(context.Context) error + Close() error + + ConfiguredChainID() CHAIN_ID + RPC() RPC + + String() string + // State returns nodeState + State() nodeState + // Name is a unique identifier for this node. + Name() string +} + +// It only supports sending transactions +// It must use an http(s) url +type sendOnlyNode[ + CHAIN_ID types.ID, + RPC sendOnlyClient[CHAIN_ID], +] struct { + services.StateMachine + + stateMu sync.RWMutex // protects state* fields + state nodeState + + rpc RPC + uri url.URL + log logger.Logger + name string + chainID CHAIN_ID + chStop utils.StopChan + wg sync.WaitGroup +} + +// NewSendOnlyNode returns a new sendonly node +func NewSendOnlyNode[ + CHAIN_ID types.ID, + RPC sendOnlyClient[CHAIN_ID], +]( + lggr logger.Logger, + httpuri url.URL, + name string, + chainID CHAIN_ID, + rpc RPC, +) SendOnlyNode[CHAIN_ID, RPC] { + s := new(sendOnlyNode[CHAIN_ID, RPC]) + s.name = name + s.log = lggr.Named("SendOnlyNode").Named(name).With( + "nodeTier", "sendonly", + ) + s.rpc = rpc + s.uri = httpuri + s.chainID = chainID + s.chStop = make(chan struct{}) + return s +} + +func (s *sendOnlyNode[CHAIN_ID, RPC]) Start(ctx context.Context) error { + return s.StartOnce(s.name, func() error { + s.start(ctx) + return nil + }) +} + +// Start setups up and verifies the sendonly node +// Should only be called once in a node's lifecycle +func (s *sendOnlyNode[CHAIN_ID, RPC]) start(startCtx context.Context) { + if s.State() != nodeStateUndialed { + panic(fmt.Sprintf("cannot dial node with state %v", s.state)) + } + + err := s.rpc.DialHTTP() + if err != nil { + promPoolRPCNodeTransitionsToUnusable.WithLabelValues(s.chainID.String(), s.name).Inc() + s.log.Errorw("Dial failed: SendOnly Node is unusable", "err", err) + s.setState(nodeStateUnusable) + return + } + s.setState(nodeStateDialed) + + if s.chainID.String() == "0" { + // Skip verification if chainID is zero + s.log.Warn("sendonly rpc ChainID verification skipped") + } else { + chainID, err := s.rpc.ChainID(startCtx) + if err != nil || chainID.String() != s.chainID.String() { + promPoolRPCNodeTransitionsToUnreachable.WithLabelValues(s.chainID.String(), s.name).Inc() + if err != nil { + promPoolRPCNodeTransitionsToUnreachable.WithLabelValues(s.chainID.String(), s.name).Inc() + s.log.Errorw(fmt.Sprintf("Verify failed: %v", err), "err", err) + s.setState(nodeStateUnreachable) + } else { + promPoolRPCNodeTransitionsToInvalidChainID.WithLabelValues(s.chainID.String(), s.name).Inc() + s.log.Errorf( + "sendonly rpc ChainID doesn't match local chain ID: RPC ID=%s, local ID=%s, node name=%s", + chainID.String(), + s.chainID.String(), + s.name, + ) + s.setState(nodeStateInvalidChainID) + } + // Since it has failed, spin up the verifyLoop that will keep + // retrying until success + s.wg.Add(1) + go s.verifyLoop() + return + } + } + + promPoolRPCNodeTransitionsToAlive.WithLabelValues(s.chainID.String(), s.name).Inc() + s.setState(nodeStateAlive) + s.log.Infow("Sendonly RPC Node is online", "nodeState", s.state) +} + +func (s *sendOnlyNode[CHAIN_ID, RPC]) Close() error { + return s.StopOnce(s.name, func() error { + s.rpc.Close() + s.wg.Wait() + s.setState(nodeStateClosed) + return nil + }) +} + +func (s *sendOnlyNode[CHAIN_ID, RPC]) ConfiguredChainID() CHAIN_ID { + return s.chainID +} + +func (s *sendOnlyNode[CHAIN_ID, RPC]) RPC() RPC { + return s.rpc +} + +func (s *sendOnlyNode[CHAIN_ID, RPC]) String() string { + return fmt.Sprintf("(%s)%s:%s", client.Secondary.String(), s.name, s.uri.Redacted()) +} + +func (s *sendOnlyNode[CHAIN_ID, RPC]) setState(state nodeState) (changed bool) { + s.stateMu.Lock() + defer s.stateMu.Unlock() + if s.state == state { + return false + } + s.state = state + return true +} + +func (s *sendOnlyNode[CHAIN_ID, RPC]) State() nodeState { + s.stateMu.RLock() + defer s.stateMu.RUnlock() + return s.state +} + +func (s *sendOnlyNode[CHAIN_ID, RPC]) Name() string { + return s.name +} diff --git a/common/client/send_only_node_lifecycle.go b/common/client/send_only_node_lifecycle.go new file mode 100644 index 00000000000..0f663eab30e --- /dev/null +++ b/common/client/send_only_node_lifecycle.go @@ -0,0 +1,66 @@ +package client + +import ( + "context" + "fmt" + "time" + + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +// verifyLoop may only be triggered once, on Start, if initial chain ID check +// fails. +// +// It will continue checking until success and then exit permanently. +func (s *sendOnlyNode[CHAIN_ID, RPC]) verifyLoop() { + defer s.wg.Done() + + backoff := utils.NewRedialBackoff() + for { + select { + case <-s.chStop: + return + case <-time.After(backoff.Duration()): + } + chainID, err := s.rpc.ChainID(context.Background()) + if err != nil { + ok := s.IfStarted(func() { + if changed := s.setState(nodeStateUnreachable); changed { + promPoolRPCNodeTransitionsToUnreachable.WithLabelValues(s.chainID.String(), s.name).Inc() + } + }) + if !ok { + return + } + s.log.Errorw(fmt.Sprintf("Verify failed: %v", err), "err", err) + continue + } else if chainID.String() != s.chainID.String() { + ok := s.IfStarted(func() { + if changed := s.setState(nodeStateInvalidChainID); changed { + promPoolRPCNodeTransitionsToInvalidChainID.WithLabelValues(s.chainID.String(), s.name).Inc() + } + }) + if !ok { + return + } + s.log.Errorf( + "sendonly rpc ChainID doesn't match local chain ID: RPC ID=%s, local ID=%s, node name=%s", + chainID.String(), + s.chainID.String(), + s.name, + ) + + continue + } + ok := s.IfStarted(func() { + if changed := s.setState(nodeStateAlive); changed { + promPoolRPCNodeTransitionsToAlive.WithLabelValues(s.chainID.String(), s.name).Inc() + } + }) + if !ok { + return + } + s.log.Infow("Sendonly RPC Node is online", "nodeState", s.state) + return + } +} diff --git a/common/client/types.go b/common/client/types.go new file mode 100644 index 00000000000..f3a6029a9e8 --- /dev/null +++ b/common/client/types.go @@ -0,0 +1,133 @@ +package client + +import ( + "context" + "math/big" + + feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/core/assets" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +// RPC includes all the necessary methods for a multi-node client to interact directly with any RPC endpoint. +type RPC[ + CHAIN_ID types.ID, + SEQ types.Sequence, + ADDR types.Hashable, + BLOCK_HASH types.Hashable, + TX any, + TX_HASH types.Hashable, + EVENT any, + EVENT_OPS any, + TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], + FEE feetypes.Fee, + HEAD types.Head[BLOCK_HASH], + +] interface { + NodeClient[ + CHAIN_ID, + HEAD, + ] + clientAPI[ + CHAIN_ID, + SEQ, + ADDR, + BLOCK_HASH, + TX, + TX_HASH, + EVENT, + EVENT_OPS, + TX_RECEIPT, + FEE, + HEAD, + ] +} + +// Head is the interface required by the NodeClient +type Head interface { + BlockNumber() int64 + BlockDifficulty() *utils.Big +} + +// NodeClient includes all the necessary RPC methods required by a node. +type NodeClient[ + CHAIN_ID types.ID, + HEAD Head, +] interface { + connection[CHAIN_ID, HEAD] + + DialHTTP() error + DisconnectAll() + Close() + ClientVersion(context.Context) (string, error) + SubscribersCount() int32 + SetAliveLoopSub(types.Subscription) + UnsubscribeAllExceptAliveLoop() +} + +// clientAPI includes all the direct RPC methods required by the generalized common client to implement its own. +type clientAPI[ + CHAIN_ID types.ID, + SEQ types.Sequence, + ADDR types.Hashable, + BLOCK_HASH types.Hashable, + TX any, + TX_HASH types.Hashable, + EVENT any, + EVENT_OPS any, // event filter query options + TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], + FEE feetypes.Fee, + HEAD types.Head[BLOCK_HASH], +] interface { + connection[CHAIN_ID, HEAD] + + // Account + BalanceAt(ctx context.Context, accountAddress ADDR, blockNumber *big.Int) (*big.Int, error) + TokenBalance(ctx context.Context, accountAddress ADDR, tokenAddress ADDR) (*big.Int, error) + SequenceAt(ctx context.Context, accountAddress ADDR, blockNumber *big.Int) (SEQ, error) + LINKBalance(ctx context.Context, accountAddress ADDR, linkAddress ADDR) (*assets.Link, error) + PendingSequenceAt(ctx context.Context, addr ADDR) (SEQ, error) + EstimateGas(ctx context.Context, call any) (gas uint64, err error) + + // Transactions + SendTransaction(ctx context.Context, tx TX) error + SimulateTransaction(ctx context.Context, tx TX) error + TransactionByHash(ctx context.Context, txHash TX_HASH) (TX, error) + TransactionReceipt(ctx context.Context, txHash TX_HASH) (TX_RECEIPT, error) + SendEmptyTransaction( + ctx context.Context, + newTxAttempt func(seq SEQ, feeLimit uint32, fee FEE, fromAddress ADDR) (attempt any, err error), + seq SEQ, + gasLimit uint32, + fee FEE, + fromAddress ADDR, + ) (txhash string, err error) + + // Blocks + BlockByNumber(ctx context.Context, number *big.Int) (HEAD, error) + BlockByHash(ctx context.Context, hash BLOCK_HASH) (HEAD, error) + LatestBlockHeight(context.Context) (*big.Int, error) + + // Events + FilterEvents(ctx context.Context, query EVENT_OPS) ([]EVENT, error) + + // Misc + BatchCallContext(ctx context.Context, b []any) error + CallContract( + ctx context.Context, + msg interface{}, + blockNumber *big.Int, + ) (rpcErr []byte, extractErr error) + CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error + CodeAt(ctx context.Context, account ADDR, blockNumber *big.Int) ([]byte, error) +} + +type connection[ + CHAIN_ID types.ID, + HEAD Head, +] interface { + ChainID(ctx context.Context) (CHAIN_ID, error) + Dial(ctx context.Context) error + Subscribe(ctx context.Context, channel chan<- HEAD, args ...interface{}) (types.Subscription, error) +} diff --git a/common/headtracker/types/mocks/head.go b/common/headtracker/types/mocks/head.go index edda18d57e8..a56590b6ef3 100644 --- a/common/headtracker/types/mocks/head.go +++ b/common/headtracker/types/mocks/head.go @@ -4,6 +4,7 @@ package mocks import ( types "github.com/smartcontractkit/chainlink/v2/common/types" + utils "github.com/smartcontractkit/chainlink/v2/core/utils" mock "github.com/stretchr/testify/mock" ) @@ -12,6 +13,22 @@ type Head[BLOCK_HASH types.Hashable, CHAIN_ID types.ID] struct { mock.Mock } +// BlockDifficulty provides a mock function with given fields: +func (_m *Head[BLOCK_HASH, CHAIN_ID]) BlockDifficulty() *utils.Big { + ret := _m.Called() + + var r0 *utils.Big + if rf, ok := ret.Get(0).(func() *utils.Big); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*utils.Big) + } + } + + return r0 +} + // BlockHash provides a mock function with given fields: func (_m *Head[BLOCK_HASH, CHAIN_ID]) BlockHash() BLOCK_HASH { ret := _m.Called() diff --git a/common/types/head.go b/common/types/head.go index 4d339b1cddb..bef9c30d9e9 100644 --- a/common/types/head.go +++ b/common/types/head.go @@ -1,5 +1,7 @@ package types +import "github.com/smartcontractkit/chainlink/v2/core/utils" + // Head provides access to a chain's head, as needed by the TxManager. // This is a generic interface which ALL chains will implement. // @@ -24,4 +26,8 @@ type Head[BLOCK_HASH Hashable] interface { // HashAtHeight returns the hash of the block at the given height, if it is in the chain. // If not in chain, returns the zero hash HashAtHeight(blockNum int64) BLOCK_HASH + + // Returns the total difficulty of the block. For chains who do not have a concept of block + // difficulty, return 0. + BlockDifficulty() *utils.Big } diff --git a/common/types/mocks/head.go b/common/types/mocks/head.go index 3cb303ef267..816a9234a3c 100644 --- a/common/types/mocks/head.go +++ b/common/types/mocks/head.go @@ -4,6 +4,7 @@ package mocks import ( types "github.com/smartcontractkit/chainlink/v2/common/types" + utils "github.com/smartcontractkit/chainlink/v2/core/utils" mock "github.com/stretchr/testify/mock" ) @@ -12,6 +13,22 @@ type Head[BLOCK_HASH types.Hashable] struct { mock.Mock } +// BlockDifficulty provides a mock function with given fields: +func (_m *Head[BLOCK_HASH]) BlockDifficulty() *utils.Big { + ret := _m.Called() + + var r0 *utils.Big + if rf, ok := ret.Get(0).(func() *utils.Big); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*utils.Big) + } + } + + return r0 +} + // BlockHash provides a mock function with given fields: func (_m *Head[BLOCK_HASH]) BlockHash() BLOCK_HASH { ret := _m.Called() diff --git a/common/types/receipt.go b/common/types/receipt.go new file mode 100644 index 00000000000..01d5a72def5 --- /dev/null +++ b/common/types/receipt.go @@ -0,0 +1,14 @@ +package types + +import "math/big" + +type Receipt[TX_HASH Hashable, BLOCK_HASH Hashable] interface { + GetStatus() uint64 + GetTxHash() TX_HASH + GetBlockNumber() *big.Int + IsZero() bool + IsUnmined() bool + GetFeeUsed() uint64 + GetTransactionIndex() uint + GetBlockHash() BLOCK_HASH +} diff --git a/core/chains/evm/chain.go b/core/chains/evm/chain.go index 6eed13271e3..936abc6216c 100644 --- a/core/chains/evm/chain.go +++ b/core/chains/evm/chain.go @@ -498,3 +498,21 @@ func newPrimary(cfg evmconfig.NodePool, noNewHeadsThreshold time.Duration, lggr return evmclient.NewNode(cfg, noNewHeadsThreshold, lggr, (url.URL)(*n.WSURL), (*url.URL)(n.HTTPURL), *n.Name, id, chainID, *n.Order), nil } + +// TODO-1663: replace newEthClientFromChain with the function below once client.go is deprecated. +//func newEthClientFromChain(cfg evmconfig.NodePool, noNewHeadsThreshold time.Duration, lggr logger.Logger, chainID *big.Int, chainType config.ChainType, nodes []*toml.Node) evmclient.Client { +// var empty url.URL +// var primaries []commonclient.Node[*big.Int, *evmtypes.Head, evmclient.RPCCLient] +// var sendonlys []commonclient.SendOnlyNode[*big.Int, evmclient.RPCCLient] +// for i, node := range nodes { +// if node.SendOnly != nil && *node.SendOnly { +// rpc := evmclient.NewRPCClient(lggr, empty, (*url.URL)(node.HTTPURL), fmt.Sprintf("eth-sendonly-rpc-%d", i), int32(i), chainID, commontypes.Primary) +// sendonly := commonclient.NewSendOnlyNode[*big.Int, evmclient.RPCCLient](lggr, (url.URL)(*node.HTTPURL), *node.Name, chainID, rpc) +// sendonlys = append(sendonlys, sendonly) +// } else { +// rpc := evmclient.NewRPCClient(lggr, (url.URL)(*node.WSURL), (*url.URL)(node.HTTPURL), fmt.Sprintf("eth-sendonly-rpc-%d", i), int32(i), chainID, commontypes.Primary) +// primaries = append(primaries, commonclient.NewNode[*big.Int, *evmtypes.Head, evmclient.RPCCLient](cfg, noNewHeadsThreshold, lggr, (url.URL)(*node.WSURL), (*url.URL)(node.HTTPURL), *node.Name, int32(i), chainID, *node.Order, rpc, "EVM")) +// } +// } +// return evmclient.NewChainClient(lggr, cfg.SelectionMode(), cfg.LeaseDuration(), noNewHeadsThreshold, primaries, sendonlys, chainID, chainType) +//} diff --git a/core/chains/evm/client/chain_client.go b/core/chains/evm/client/chain_client.go new file mode 100644 index 00000000000..bda028cbf33 --- /dev/null +++ b/core/chains/evm/client/chain_client.go @@ -0,0 +1,274 @@ +package client + +import ( + "context" + "math/big" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rpc" + + commontypes "github.com/smartcontractkit/chainlink/v2/common/chains/client" + commonclient "github.com/smartcontractkit/chainlink/v2/common/client" + "github.com/smartcontractkit/chainlink/v2/core/assets" + evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + "github.com/smartcontractkit/chainlink/v2/core/config" + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +var _ Client = (*chainClient)(nil) + +// TODO-1663: rename this to client, once the client.go file is deprecated. +type chainClient struct { + multiNode commonclient.MultiNode[ + *big.Int, + evmtypes.Nonce, + common.Address, + common.Hash, + *types.Transaction, + common.Hash, + types.Log, + ethereum.FilterQuery, + *evmtypes.Receipt, + *assets.Wei, + *evmtypes.Head, + RPCCLient, + ] + logger logger.Logger +} + +func NewChainClient( + logger logger.Logger, + selectionMode string, + leaseDuration time.Duration, + noNewHeadsThreshold time.Duration, + nodes []commonclient.Node[*big.Int, *evmtypes.Head, RPCCLient], + sendonlys []commonclient.SendOnlyNode[*big.Int, RPCCLient], + chainID *big.Int, + chainType config.ChainType, +) Client { + multiNode := commonclient.NewMultiNode[ + *big.Int, + evmtypes.Nonce, + common.Address, + common.Hash, + *types.Transaction, + common.Hash, + types.Log, + ethereum.FilterQuery, + *evmtypes.Receipt, + *assets.Wei, + *evmtypes.Head, + RPCCLient, + ]( + logger, + selectionMode, + leaseDuration, + noNewHeadsThreshold, + nodes, + sendonlys, + chainID, + chainType, + "EVM", + ClassifySendOnlyError, + ) + return &chainClient{ + multiNode: multiNode, + logger: logger, + } +} + +func (c *chainClient) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { + return c.multiNode.BalanceAt(ctx, account, blockNumber) +} + +func (c *chainClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error { + batch := make([]any, len(b)) + for i, arg := range b { + batch[i] = any(arg) + } + return c.multiNode.BatchCallContext(ctx, batch) +} + +func (c *chainClient) BatchCallContextAll(ctx context.Context, b []rpc.BatchElem) error { + batch := make([]any, len(b)) + for i, arg := range b { + batch[i] = any(arg) + } + return c.multiNode.BatchCallContextAll(ctx, batch) +} + +// TODO-1663: return custom Block type instead of geth's once client.go is deprecated. +func (c *chainClient) BlockByHash(ctx context.Context, hash common.Hash) (b *types.Block, err error) { + rpc, err := c.multiNode.SelectNodeRPC() + if err != nil { + return b, err + } + return rpc.BlockByHashGeth(ctx, hash) +} + +// TODO-1663: return custom Block type instead of geth's once client.go is deprecated. +func (c *chainClient) BlockByNumber(ctx context.Context, number *big.Int) (b *types.Block, err error) { + rpc, err := c.multiNode.SelectNodeRPC() + if err != nil { + return b, err + } + return rpc.BlockByNumberGeth(ctx, number) +} + +func (c *chainClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { + return c.multiNode.CallContext(ctx, result, method) +} + +func (c *chainClient) CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { + return c.multiNode.CallContract(ctx, msg, blockNumber) +} + +// TODO-1663: change this to actual ChainID() call once client.go is deprecated. +func (c *chainClient) ChainID() (*big.Int, error) { + //return c.multiNode.ChainID(ctx), nil + return c.multiNode.ConfiguredChainID(), nil +} + +func (c *chainClient) Close() { + c.multiNode.Close() +} + +func (c *chainClient) CodeAt(ctx context.Context, account common.Address, blockNumber *big.Int) ([]byte, error) { + return c.multiNode.CodeAt(ctx, account, blockNumber) +} + +func (c *chainClient) ConfiguredChainID() *big.Int { + return c.multiNode.ConfiguredChainID() +} + +func (c *chainClient) Dial(ctx context.Context) error { + return c.multiNode.Dial(ctx) +} + +func (c *chainClient) EstimateGas(ctx context.Context, call ethereum.CallMsg) (uint64, error) { + return c.multiNode.EstimateGas(ctx, call) +} +func (c *chainClient) FilterLogs(ctx context.Context, q ethereum.FilterQuery) ([]types.Log, error) { + return c.multiNode.FilterEvents(ctx, q) +} + +func (c *chainClient) HeaderByHash(ctx context.Context, h common.Hash) (head *types.Header, err error) { + rpc, err := c.multiNode.SelectNodeRPC() + if err != nil { + return head, err + } + return rpc.HeaderByHash(ctx, h) +} + +func (c *chainClient) HeaderByNumber(ctx context.Context, n *big.Int) (head *types.Header, err error) { + rpc, err := c.multiNode.SelectNodeRPC() + if err != nil { + return head, err + } + return rpc.HeaderByNumber(ctx, n) +} + +func (c *chainClient) HeadByHash(ctx context.Context, h common.Hash) (*evmtypes.Head, error) { + return c.multiNode.BlockByHash(ctx, h) +} + +func (c *chainClient) HeadByNumber(ctx context.Context, n *big.Int) (*evmtypes.Head, error) { + return c.multiNode.BlockByNumber(ctx, n) +} + +func (c *chainClient) IsL2() bool { + return c.multiNode.IsL2() +} + +func (c *chainClient) LINKBalance(ctx context.Context, address common.Address, linkAddress common.Address) (*assets.Link, error) { + return c.multiNode.LINKBalance(ctx, address, linkAddress) +} + +func (c *chainClient) LatestBlockHeight(ctx context.Context) (*big.Int, error) { + return c.multiNode.LatestBlockHeight(ctx) +} + +func (c *chainClient) NodeStates() map[string]string { + return c.multiNode.NodeStates() +} + +func (c *chainClient) PendingCodeAt(ctx context.Context, account common.Address) (b []byte, err error) { + rpc, err := c.multiNode.SelectNodeRPC() + if err != nil { + return b, err + } + return rpc.PendingCodeAt(ctx, account) +} + +// TODO-1663: change this to evmtypes.Nonce(int64) once client.go is deprecated. +func (c *chainClient) PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) { + n, err := c.multiNode.PendingSequenceAt(ctx, account) + return uint64(n), err +} + +func (c *chainClient) SendTransaction(ctx context.Context, tx *types.Transaction) error { + return c.multiNode.SendTransaction(ctx, tx) +} + +func (c *chainClient) SendTransactionReturnCode(ctx context.Context, tx *types.Transaction, fromAddress common.Address) (commontypes.SendTxReturnCode, error) { + err := c.SendTransaction(ctx, tx) + return ClassifySendError(err, c.logger, tx, fromAddress, c.IsL2()) +} + +func (c *chainClient) SequenceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (evmtypes.Nonce, error) { + return c.multiNode.SequenceAt(ctx, account, blockNumber) +} + +func (c *chainClient) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (s ethereum.Subscription, err error) { + rpc, err := c.multiNode.SelectNodeRPC() + if err != nil { + return s, err + } + return rpc.SubscribeFilterLogs(ctx, q, ch) +} + +func (c *chainClient) SubscribeNewHead(ctx context.Context, ch chan<- *evmtypes.Head) (ethereum.Subscription, error) { + csf := newChainIDSubForwarder(c.ConfiguredChainID(), ch) + err := csf.start(c.multiNode.Subscribe(ctx, csf.srcCh, "newHeads")) + if err != nil { + return nil, err + } + return csf, nil +} + +func (c *chainClient) SuggestGasPrice(ctx context.Context) (p *big.Int, err error) { + rpc, err := c.multiNode.SelectNodeRPC() + if err != nil { + return p, err + } + return rpc.SuggestGasPrice(ctx) +} + +func (c *chainClient) SuggestGasTipCap(ctx context.Context) (t *big.Int, err error) { + rpc, err := c.multiNode.SelectNodeRPC() + if err != nil { + return t, err + } + return rpc.SuggestGasTipCap(ctx) +} + +func (c *chainClient) TokenBalance(ctx context.Context, address common.Address, contractAddress common.Address) (*big.Int, error) { + return c.multiNode.TokenBalance(ctx, address, contractAddress) +} + +func (c *chainClient) TransactionByHash(ctx context.Context, txHash common.Hash) (*types.Transaction, error) { + return c.multiNode.TransactionByHash(ctx, txHash) +} + +// TODO-1663: return custom Receipt type instead of geth's once client.go is deprecated. +func (c *chainClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (r *types.Receipt, err error) { + rpc, err := c.multiNode.SelectNodeRPC() + if err != nil { + return r, err + } + //return rpc.TransactionReceipt(ctx, txHash) + return rpc.TransactionReceiptGeth(ctx, txHash) +} diff --git a/core/chains/evm/client/client.go b/core/chains/evm/client/client.go index 3a3b8b23a92..af03720ced9 100644 --- a/core/chains/evm/client/client.go +++ b/core/chains/evm/client/client.go @@ -213,7 +213,7 @@ func (client *client) HeaderByHash(ctx context.Context, h common.Hash) (*types.H func (client *client) SendTransactionReturnCode(ctx context.Context, tx *types.Transaction, fromAddress common.Address) (clienttypes.SendTxReturnCode, error) { err := client.SendTransaction(ctx, tx) - return NewSendErrorReturnCode(err, client.logger, tx, fromAddress, client.pool.ChainType().IsL2()) + return ClassifySendError(err, client.logger, tx, fromAddress, client.pool.ChainType().IsL2()) } // SendTransaction also uses the sendonly HTTP RPC URLs if set diff --git a/core/chains/evm/client/client_test.go b/core/chains/evm/client/client_test.go index 88bc37411c6..81a82d20fa7 100644 --- a/core/chains/evm/client/client_test.go +++ b/core/chains/evm/client/client_test.go @@ -22,7 +22,9 @@ import ( "github.com/stretchr/testify/require" "github.com/tidwall/gjson" - clienttypes "github.com/smartcontractkit/chainlink/v2/common/chains/client" + commonclient "github.com/smartcontractkit/chainlink/v2/common/client" + + commontypes "github.com/smartcontractkit/chainlink/v2/common/chains/client" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" @@ -43,6 +45,33 @@ func mustNewClientWithChainID(t *testing.T, wsURL string, chainID *big.Int, send return c } +func mustNewChainClient(t *testing.T, wsURL string, sendonlys ...url.URL) evmclient.Client { + return mustNewChainClientWithChainID(t, wsURL, testutils.FixtureChainID, sendonlys...) +} + +func mustNewChainClientWithChainID(t *testing.T, wsURL string, chainID *big.Int, sendonlys ...url.URL) evmclient.Client { + cfg := evmclient.TestNodePoolConfig{ + NodeSelectionMode: evmclient.NodeSelectionMode_RoundRobin, + } + c, err := evmclient.NewChainClientWithTestNode(t, cfg, time.Second*0, cfg.NodeLeaseDuration, wsURL, nil, sendonlys, 42, chainID) + require.NoError(t, err) + return c +} + +func mustNewClients(t *testing.T, wsURL string, sendonlys ...url.URL) []evmclient.Client { + var clients []evmclient.Client + clients = append(clients, mustNewClient(t, wsURL, sendonlys...)) + clients = append(clients, mustNewChainClient(t, wsURL, sendonlys...)) + return clients +} + +func mustNewClientsWithChainID(t *testing.T, wsURL string, chainID *big.Int, sendonlys ...url.URL) []evmclient.Client { + var clients []evmclient.Client + clients = append(clients, mustNewClientWithChainID(t, wsURL, chainID, sendonlys...)) + clients = append(clients, mustNewChainClientWithChainID(t, wsURL, chainID, sendonlys...)) + return clients +} + func TestEthClient_TransactionReceipt(t *testing.T) { t.Parallel() @@ -78,15 +107,17 @@ func TestEthClient_TransactionReceipt(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - hash := common.HexToHash(txHash) - receipt, err := ethClient.TransactionReceipt(testutils.Context(t), hash) - require.NoError(t, err) - assert.Equal(t, hash, receipt.TxHash) - assert.Equal(t, big.NewInt(11), receipt.BlockNumber) + hash := common.HexToHash(txHash) + receipt, err := ethClient.TransactionReceipt(testutils.Context(t), hash) + require.NoError(t, err) + assert.Equal(t, hash, receipt.TxHash) + assert.Equal(t, big.NewInt(11), receipt.BlockNumber) + } }) t.Run("no tx hash, returns ethereum.NotFound", func(t *testing.T) { @@ -108,13 +139,15 @@ func TestEthClient_TransactionReceipt(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - hash := common.HexToHash(txHash) - _, err = ethClient.TransactionReceipt(testutils.Context(t), hash) - require.Equal(t, ethereum.NotFound, errors.Cause(err)) + hash := common.HexToHash(txHash) + _, err = ethClient.TransactionReceipt(testutils.Context(t), hash) + require.Equal(t, ethereum.NotFound, errors.Cause(err)) + } }) } @@ -144,15 +177,17 @@ func TestEthClient_PendingNonceAt(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - result, err := ethClient.PendingNonceAt(testutils.Context(t), address) - require.NoError(t, err) + result, err := ethClient.PendingNonceAt(testutils.Context(t), address) + require.NoError(t, err) - var expected uint64 = 256 - require.Equal(t, result, expected) + var expected uint64 = 256 + require.Equal(t, result, expected) + } } func TestEthClient_BalanceAt(t *testing.T) { @@ -189,13 +224,15 @@ func TestEthClient_BalanceAt(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - result, err := ethClient.BalanceAt(testutils.Context(t), address, nil) - require.NoError(t, err) - assert.Equal(t, test.balance, result) + result, err := ethClient.BalanceAt(testutils.Context(t), address, nil) + require.NoError(t, err) + assert.Equal(t, test.balance, result) + } }) } } @@ -220,13 +257,15 @@ func TestEthClient_LatestBlockHeight(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - result, err := ethClient.LatestBlockHeight(testutils.Context(t)) - require.NoError(t, err) - require.Equal(t, big.NewInt(256), result) + result, err := ethClient.LatestBlockHeight(testutils.Context(t)) + require.NoError(t, err) + require.Equal(t, big.NewInt(256), result) + } } func TestEthClient_GetERC20Balance(t *testing.T) { @@ -277,13 +316,15 @@ func TestEthClient_GetERC20Balance(t *testing.T) { }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - result, err := ethClient.TokenBalance(ctx, userAddress, contractAddress) - require.NoError(t, err) - assert.Equal(t, test.balance, result) + result, err := ethClient.TokenBalance(ctx, userAddress, contractAddress) + require.NoError(t, err) + assert.Equal(t, test.balance, result) + } }) } } @@ -354,20 +395,22 @@ func TestEthClient_HeaderByNumber(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) - - ctx, cancel := context.WithTimeout(testutils.Context(t), 5*time.Second) - defer cancel() - result, err := ethClient.HeadByNumber(ctx, expectedBlockNum) - if test.error != nil { - require.Error(t, err, test.error) - } else { + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) require.NoError(t, err) - require.Equal(t, expectedBlockHash, result.Hash.Hex()) - require.Equal(t, test.expectedResponseBlock, result.Number) - require.Zero(t, cltest.FixtureChainID.Cmp(result.EVMChainID.ToInt())) + + ctx, cancel := context.WithTimeout(testutils.Context(t), 5*time.Second) + result, err := ethClient.HeadByNumber(ctx, expectedBlockNum) + if test.error != nil { + require.Error(t, err, test.error) + } else { + require.NoError(t, err) + require.Equal(t, expectedBlockHash, result.Hash.Hex()) + require.Equal(t, test.expectedResponseBlock, result.Number) + require.Zero(t, cltest.FixtureChainID.Cmp(result.EVMChainID.ToInt())) + } + cancel() } }) } @@ -395,12 +438,14 @@ func TestEthClient_SendTransaction_NoSecondaryURL(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - err = ethClient.SendTransaction(testutils.Context(t), tx) - assert.NoError(t, err) + err = ethClient.SendTransaction(testutils.Context(t), tx) + assert.NoError(t, err) + } } func TestEthClient_SendTransaction_WithSecondaryURLs(t *testing.T) { @@ -432,16 +477,19 @@ func TestEthClient_SendTransaction_WithSecondaryURLs(t *testing.T) { t.Cleanup(ts.Close) sendonlyURL := *cltest.MustParseURL(t, ts.URL) - ethClient := mustNewClient(t, wsURL, sendonlyURL, sendonlyURL) - err = ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) - err = ethClient.SendTransaction(testutils.Context(t), tx) - require.NoError(t, err) + clients := mustNewClients(t, wsURL, sendonlyURL, sendonlyURL) + for _, ethClient := range clients { + err = ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) + + err = ethClient.SendTransaction(testutils.Context(t), tx) + require.NoError(t, err) + } // Unfortunately it's a bit tricky to test this, since there is no // synchronization. We have to rely on timing instead. - require.Eventually(t, func() bool { return service.sentCount.Load() == int32(2) }, testutils.WaitTimeout(t), 500*time.Millisecond) + require.Eventually(t, func() bool { return service.sentCount.Load() == int32(len(clients)*2) }, testutils.WaitTimeout(t), 500*time.Millisecond) } func TestEthClient_SendTransactionReturnCode(t *testing.T) { @@ -467,13 +515,15 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) - assert.Error(t, err) - assert.Equal(t, errType, clienttypes.Fatal) + errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) + assert.Error(t, err) + assert.Equal(t, errType, commontypes.Fatal) + } }) t.Run("returns TransactionAlreadyKnown error type when error message is nonce too low", func(t *testing.T) { @@ -493,13 +543,15 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) - assert.Error(t, err) - assert.Equal(t, errType, clienttypes.TransactionAlreadyKnown) + errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) + assert.Error(t, err) + assert.Equal(t, errType, commontypes.TransactionAlreadyKnown) + } }) t.Run("returns Successful error type when there is no error message", func(t *testing.T) { @@ -518,13 +570,15 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) - assert.NoError(t, err) - assert.Equal(t, errType, clienttypes.Successful) + errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) + assert.NoError(t, err) + assert.Equal(t, errType, commontypes.Successful) + } }) t.Run("returns Underpriced error type when transaction is terminally underpriced", func(t *testing.T) { @@ -544,13 +598,15 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) - assert.Error(t, err) - assert.Equal(t, errType, clienttypes.Underpriced) + errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) + assert.Error(t, err) + assert.Equal(t, errType, commontypes.Underpriced) + } }) t.Run("returns Unsupported error type when error message is queue full", func(t *testing.T) { @@ -570,13 +626,15 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) - assert.Error(t, err) - assert.Equal(t, errType, clienttypes.Unsupported) + errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) + assert.Error(t, err) + assert.Equal(t, errType, commontypes.Unsupported) + } }) t.Run("returns Retryable error type when there is a transaction gap", func(t *testing.T) { @@ -596,13 +654,15 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) - assert.Error(t, err) - assert.Equal(t, errType, clienttypes.Retryable) + errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) + assert.Error(t, err) + assert.Equal(t, errType, commontypes.Retryable) + } }) t.Run("returns InsufficientFunds error type when the sender address doesn't have enough funds", func(t *testing.T) { @@ -622,13 +682,15 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) - assert.Error(t, err) - assert.Equal(t, errType, clienttypes.InsufficientFunds) + errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) + assert.Error(t, err) + assert.Equal(t, errType, commontypes.InsufficientFunds) + } }) t.Run("returns ExceedsFeeCap error type when gas price is too high for the node", func(t *testing.T) { @@ -648,13 +710,15 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) - assert.Error(t, err) - assert.Equal(t, errType, clienttypes.ExceedsMaxFee) + errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) + assert.Error(t, err) + assert.Equal(t, errType, commontypes.ExceedsMaxFee) + } }) t.Run("returns Unknown error type when the error can't be categorized", func(t *testing.T) { @@ -674,13 +738,15 @@ func TestEthClient_SendTransactionReturnCode(t *testing.T) { return }) - ethClient := mustNewClient(t, wsURL) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClients(t, wsURL) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) - assert.Error(t, err) - assert.Equal(t, errType, clienttypes.Unknown) + errType, err := ethClient.SendTransactionReturnCode(testutils.Context(t), tx, fromAddress) + assert.Error(t, err) + assert.Equal(t, errType, commontypes.Unknown) + } }) } @@ -718,24 +784,132 @@ func TestEthClient_SubscribeNewHead(t *testing.T) { return }) - ethClient := mustNewClientWithChainID(t, wsURL, chainId) - err := ethClient.Dial(testutils.Context(t)) - require.NoError(t, err) + clients := mustNewClientsWithChainID(t, wsURL, chainId) + for _, ethClient := range clients { + err := ethClient.Dial(testutils.Context(t)) + require.NoError(t, err) - headCh := make(chan *evmtypes.Head) - sub, err := ethClient.SubscribeNewHead(ctx, headCh) - require.NoError(t, err) - defer sub.Unsubscribe() - - select { - case err := <-sub.Err(): - t.Fatal(err) - case <-ctx.Done(): - t.Fatal(ctx.Err()) - case h := <-headCh: - require.NotNil(t, h.EVMChainID) - require.Zero(t, chainId.Cmp(h.EVMChainID.ToInt())) + headCh := make(chan *evmtypes.Head) + sub, err := ethClient.SubscribeNewHead(ctx, headCh) + require.NoError(t, err) + + select { + case err := <-sub.Err(): + t.Fatal(err) + case <-ctx.Done(): + t.Fatal(ctx.Err()) + case h := <-headCh: + require.NotNil(t, h.EVMChainID) + require.Zero(t, chainId.Cmp(h.EVMChainID.ToInt())) + } + sub.Unsubscribe() } } +func TestEthClient_ErroringClient(t *testing.T) { + t.Parallel() + ctx := testutils.Context(t) + + // Empty node means there are no active nodes to select from, causing client to always return error. + erroringClient := evmclient.NewChainClientWithEmptyNode(t, commonclient.NodeSelectionModeRoundRobin, time.Second*0, time.Second*0, testutils.FixtureChainID) + + _, err := erroringClient.BalanceAt(ctx, common.Address{}, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + err = erroringClient.BatchCallContext(ctx, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + err = erroringClient.BatchCallContextAll(ctx, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.BlockByHash(ctx, common.Hash{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.BlockByNumber(ctx, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + err = erroringClient.CallContext(ctx, nil, "") + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.CallContract(ctx, ethereum.CallMsg{}, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + // TODO-1663: test actual ChainID() call once client.go is deprecated. + id, err := erroringClient.ChainID() + require.Equal(t, id, testutils.FixtureChainID) + //require.Equal(t, err, commonclient.ErroringNodeError) + require.Equal(t, err, nil) + + _, err = erroringClient.CodeAt(ctx, common.Address{}, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + id = erroringClient.ConfiguredChainID() + require.Equal(t, id, testutils.FixtureChainID) + + err = erroringClient.Dial(ctx) + require.ErrorContains(t, err, "no available nodes for chain") + + _, err = erroringClient.EstimateGas(ctx, ethereum.CallMsg{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.FilterLogs(ctx, ethereum.FilterQuery{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.HeaderByHash(ctx, common.Hash{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.HeaderByNumber(ctx, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.HeadByHash(ctx, common.Hash{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.HeadByNumber(ctx, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.LINKBalance(ctx, common.Address{}, common.Address{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.LatestBlockHeight(ctx) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.PendingCodeAt(ctx, common.Address{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.PendingNonceAt(ctx, common.Address{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + err = erroringClient.SendTransaction(ctx, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + code, err := erroringClient.SendTransactionReturnCode(ctx, nil, common.Address{}) + require.Equal(t, code, commontypes.Unknown) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.SequenceAt(ctx, common.Address{}, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.SubscribeFilterLogs(ctx, ethereum.FilterQuery{}, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.SubscribeNewHead(ctx, nil) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.SuggestGasPrice(ctx) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.SuggestGasTipCap(ctx) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.TokenBalance(ctx, common.Address{}, common.Address{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.TransactionByHash(ctx, common.Hash{}) + require.Equal(t, err, commonclient.ErroringNodeError) + + _, err = erroringClient.TransactionReceipt(ctx, common.Hash{}) + require.Equal(t, err, commonclient.ErroringNodeError) + +} + const headResult = evmclient.HeadResult diff --git a/core/chains/evm/client/errors.go b/core/chains/evm/client/errors.go index 7b89e7b92d1..7197d77b3d9 100644 --- a/core/chains/evm/client/errors.go +++ b/core/chains/evm/client/errors.go @@ -397,7 +397,7 @@ func ExtractRPCError(baseErr error) (*JsonError, error) { return &jErr, nil } -func NewSendErrorReturnCode(err error, lggr logger.Logger, tx *types.Transaction, fromAddress common.Address, isL2 bool) (clienttypes.SendTxReturnCode, error) { +func ClassifySendError(err error, lggr logger.Logger, tx *types.Transaction, fromAddress common.Address, isL2 bool) (clienttypes.SendTxReturnCode, error) { sendError := NewSendError(err) if sendError == nil { return clienttypes.Successful, err @@ -465,3 +465,15 @@ func NewSendErrorReturnCode(err error, lggr logger.Logger, tx *types.Transaction } return clienttypes.Unknown, err } + +// ClassifySendOnlyError handles SendOnly nodes error codes. In that case, we don't assume there is another transaction that will be correctly +// priced. +func ClassifySendOnlyError(err error) clienttypes.SendTxReturnCode { + sendError := NewSendError(err) + if sendError == nil || sendError.IsNonceTooLowError() || sendError.IsTransactionAlreadyMined() || sendError.IsTransactionAlreadyInMempool() { + // Nonce too low or transaction known errors are expected since + // the primary SendTransaction may well have succeeded already + return clienttypes.Successful + } + return clienttypes.Fatal +} diff --git a/core/chains/evm/client/helpers_test.go b/core/chains/evm/client/helpers_test.go index 342a9143432..8552b2c0a06 100644 --- a/core/chains/evm/client/helpers_test.go +++ b/core/chains/evm/client/helpers_test.go @@ -9,7 +9,11 @@ import ( "github.com/pkg/errors" + clienttypes "github.com/smartcontractkit/chainlink/v2/common/chains/client" + commonclient "github.com/smartcontractkit/chainlink/v2/common/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" + evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + commonconfig "github.com/smartcontractkit/chainlink/v2/core/config" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -64,6 +68,67 @@ func Wrap(err error, s string) error { return wrap(err, s) } +func NewChainClientWithTestNode( + t *testing.T, + nodeCfg commonclient.NodeConfig, + noNewHeadsThreshold time.Duration, + leaseDuration time.Duration, + rpcUrl string, + rpcHTTPURL *url.URL, + sendonlyRPCURLs []url.URL, + id int32, + chainID *big.Int, +) (Client, error) { + parsed, err := url.ParseRequestURI(rpcUrl) + if err != nil { + return nil, err + } + + if parsed.Scheme != "ws" && parsed.Scheme != "wss" { + return nil, errors.Errorf("ethereum url scheme must be websocket: %s", parsed.String()) + } + + lggr := logger.TestLogger(t) + rpc := NewRPCClient(lggr, *parsed, rpcHTTPURL, "eth-primary-rpc-0", id, chainID, clienttypes.Primary) + + n := commonclient.NewNode[*big.Int, *evmtypes.Head, RPCCLient]( + nodeCfg, noNewHeadsThreshold, lggr, *parsed, rpcHTTPURL, "eth-primary-node-0", id, chainID, 1, rpc, "EVM") + primaries := []commonclient.Node[*big.Int, *evmtypes.Head, RPCCLient]{n} + + var sendonlys []commonclient.SendOnlyNode[*big.Int, RPCCLient] + for i, u := range sendonlyRPCURLs { + if u.Scheme != "http" && u.Scheme != "https" { + return nil, errors.Errorf("sendonly ethereum rpc url scheme must be http(s): %s", u.String()) + } + var empty url.URL + rpc := NewRPCClient(lggr, empty, &sendonlyRPCURLs[i], fmt.Sprintf("eth-sendonly-rpc-%d", i), id, chainID, clienttypes.Secondary) + s := commonclient.NewSendOnlyNode[*big.Int, RPCCLient]( + lggr, u, fmt.Sprintf("eth-sendonly-%d", i), chainID, rpc) + sendonlys = append(sendonlys, s) + } + + var chainType commonconfig.ChainType + c := NewChainClient(lggr, nodeCfg.SelectionMode(), leaseDuration, noNewHeadsThreshold, primaries, sendonlys, chainID, chainType) + t.Cleanup(c.Close) + return c, nil +} + +func NewChainClientWithEmptyNode( + t *testing.T, + selectionMode string, + leaseDuration time.Duration, + noNewHeadsThreshold time.Duration, + chainID *big.Int, +) Client { + + lggr := logger.TestLogger(t) + + var chainType commonconfig.ChainType + c := NewChainClient(lggr, selectionMode, leaseDuration, noNewHeadsThreshold, nil, nil, chainID, chainType) + t.Cleanup(c.Close) + return c +} + type TestableSendOnlyNode interface { SendOnlyNode SetEthClient(newBatchSender BatchSender, newSender TxSender) diff --git a/core/chains/evm/client/rpc_client.go b/core/chains/evm/client/rpc_client.go new file mode 100644 index 00000000000..b6ed84eee4e --- /dev/null +++ b/core/chains/evm/client/rpc_client.go @@ -0,0 +1,1046 @@ +package client + +import ( + "context" + "fmt" + "math/big" + "net/url" + "strconv" + "sync" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethclient" + "github.com/ethereum/go-ethereum/rpc" + "github.com/google/uuid" + "github.com/pkg/errors" + + clienttypes "github.com/smartcontractkit/chainlink/v2/common/chains/client" + commonclient "github.com/smartcontractkit/chainlink/v2/common/client" + commontypes "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/core/assets" + evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +// RPCCLient includes all the necessary generalized RPC methods along with any additional chain-specific methods. +type RPCCLient interface { + commonclient.RPC[ + *big.Int, + evmtypes.Nonce, + common.Address, + common.Hash, + *types.Transaction, + common.Hash, + types.Log, + ethereum.FilterQuery, + *evmtypes.Receipt, + *assets.Wei, + *evmtypes.Head, + ] + BlockByHashGeth(ctx context.Context, hash common.Hash) (b *types.Block, err error) + BlockByNumberGeth(ctx context.Context, number *big.Int) (b *types.Block, err error) + HeaderByHash(ctx context.Context, h common.Hash) (head *types.Header, err error) + HeaderByNumber(ctx context.Context, n *big.Int) (head *types.Header, err error) + PendingCodeAt(ctx context.Context, account common.Address) (b []byte, err error) + SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (s ethereum.Subscription, err error) + SuggestGasPrice(ctx context.Context) (p *big.Int, err error) + SuggestGasTipCap(ctx context.Context) (t *big.Int, err error) + TransactionReceiptGeth(ctx context.Context, txHash common.Hash) (r *types.Receipt, err error) +} + +type rpcClient struct { + rpcLog logger.Logger + name string + id int32 + chainID *big.Int + tier clienttypes.NodeTier + + ws rawclient + http *rawclient + + stateMu sync.RWMutex // protects state* fields + + // Need to track subscriptions because closing the RPC does not (always?) + // close the underlying subscription + subs []ethereum.Subscription + + // Need to track the aliveLoop subscription, so we do not cancel it when checking lease on the MultiNode + aliveLoopSub ethereum.Subscription + + // chStopInFlight can be closed to immediately cancel all in-flight requests on + // this rpcClient. Closing and replacing should be serialized through + // stateMu since it can happen on state transitions as well as rpcClient Close. + chStopInFlight chan struct{} +} + +// NewRPCCLient returns a new *rpcClient as commonclient.RPC +func NewRPCClient( + lggr logger.Logger, + wsuri url.URL, + httpuri *url.URL, + name string, + id int32, + chainID *big.Int, + tier clienttypes.NodeTier, +) RPCCLient { + r := new(rpcClient) + r.name = name + r.id = id + r.chainID = chainID + r.tier = tier + r.ws.uri = wsuri + if httpuri != nil { + r.http = &rawclient{uri: *httpuri} + } + r.chStopInFlight = make(chan struct{}) + lggr = lggr.Named("Client").With( + "clientTier", tier.String(), + "clientName", name, + "client", r.String(), + "evmChainID", chainID, + ) + r.rpcLog = lggr.Named("RPC") + + return r +} + +// Not thread-safe, pure dial. +func (r *rpcClient) Dial(callerCtx context.Context) error { + ctx, cancel := r.makeQueryCtx(callerCtx) + defer cancel() + + promEVMPoolRPCNodeDials.WithLabelValues(r.chainID.String(), r.name).Inc() + lggr := r.rpcLog.With("wsuri", r.ws.uri.Redacted()) + if r.http != nil { + lggr = lggr.With("httpuri", r.http.uri.Redacted()) + } + lggr.Debugw("RPC dial: evmclient.Client#dial") + + wsrpc, err := rpc.DialWebsocket(ctx, r.ws.uri.String(), "") + if err != nil { + promEVMPoolRPCNodeDialsFailed.WithLabelValues(r.chainID.String(), r.name).Inc() + return errors.Wrapf(err, "error while dialing websocket: %v", r.ws.uri.Redacted()) + } + + r.ws.rpc = wsrpc + r.ws.geth = ethclient.NewClient(wsrpc) + + if r.http != nil { + if err := r.DialHTTP(); err != nil { + return err + } + } + + promEVMPoolRPCNodeDialsSuccess.WithLabelValues(r.chainID.String(), r.name).Inc() + + return nil +} + +// Not thread-safe, pure dial. +// DialHTTP doesn't actually make any external HTTP calls +// It can only return error if the URL is malformed. +func (r *rpcClient) DialHTTP() error { + promEVMPoolRPCNodeDials.WithLabelValues(r.chainID.String(), r.name).Inc() + lggr := r.rpcLog.With("httpuri", r.ws.uri.Redacted()) + lggr.Debugw("RPC dial: evmclient.Client#dial") + + var httprpc *rpc.Client + httprpc, err := rpc.DialHTTP(r.http.uri.String()) + if err != nil { + promEVMPoolRPCNodeDialsFailed.WithLabelValues(r.chainID.String(), r.name).Inc() + return errors.Wrapf(err, "error while dialing HTTP: %v", r.http.uri.Redacted()) + } + + r.http.rpc = httprpc + r.http.geth = ethclient.NewClient(httprpc) + + promEVMPoolRPCNodeDialsSuccess.WithLabelValues(r.chainID.String(), r.name).Inc() + + return nil +} + +func (r *rpcClient) Close() { + defer func() { + if r.ws.rpc != nil { + r.ws.rpc.Close() + } + }() + + r.stateMu.Lock() + defer r.stateMu.Unlock() + r.cancelInflightRequests() +} + +// cancelInflightRequests closes and replaces the chStopInFlight +// WARNING: NOT THREAD-SAFE +// This must be called from within the r.stateMu lock +func (r *rpcClient) cancelInflightRequests() { + close(r.chStopInFlight) + r.chStopInFlight = make(chan struct{}) +} + +func (r *rpcClient) String() string { + s := fmt.Sprintf("(%s)%s:%s", r.tier.String(), r.name, r.ws.uri.Redacted()) + if r.http != nil { + s = s + fmt.Sprintf(":%s", r.http.uri.Redacted()) + } + return s +} + +func (r *rpcClient) logResult( + lggr logger.Logger, + err error, + callDuration time.Duration, + rpcDomain, + callName string, + results ...interface{}, +) { + lggr = lggr.With("duration", callDuration, "rpcDomain", rpcDomain, "callName", callName) + promEVMPoolRPCNodeCalls.WithLabelValues(r.chainID.String(), r.name).Inc() + if err == nil { + promEVMPoolRPCNodeCallsSuccess.WithLabelValues(r.chainID.String(), r.name).Inc() + lggr.Tracew( + fmt.Sprintf("evmclient.Client#%s RPC call success", callName), + results..., + ) + } else { + promEVMPoolRPCNodeCallsFailed.WithLabelValues(r.chainID.String(), r.name).Inc() + lggr.Debugw( + fmt.Sprintf("evmclient.Client#%s RPC call failure", callName), + append(results, "err", err)..., + ) + } + promEVMPoolRPCCallTiming. + WithLabelValues( + r.chainID.String(), // chain id + r.name, // rpcClient name + rpcDomain, // rpc domain + "false", // is send only + strconv.FormatBool(err == nil), // is successful + callName, // rpc call name + ). + Observe(float64(callDuration)) +} + +func (r *rpcClient) getRPCDomain() string { + if r.http != nil { + return r.http.uri.Host + } + return r.ws.uri.Host +} + +// registerSub adds the sub to the rpcClient list +func (r *rpcClient) registerSub(sub ethereum.Subscription) { + r.stateMu.Lock() + defer r.stateMu.Unlock() + r.subs = append(r.subs, sub) +} + +// disconnectAll disconnects all clients connected to the rpcClient +// WARNING: NOT THREAD-SAFE +// This must be called from within the r.stateMu lock +func (r *rpcClient) DisconnectAll() { + if r.ws.rpc != nil { + r.ws.rpc.Close() + } + r.cancelInflightRequests() + r.unsubscribeAll() +} + +// unsubscribeAll unsubscribes all subscriptions +// WARNING: NOT THREAD-SAFE +// This must be called from within the r.stateMu lock +func (r *rpcClient) unsubscribeAll() { + for _, sub := range r.subs { + sub.Unsubscribe() + } + r.subs = nil +} +func (r *rpcClient) SetAliveLoopSub(sub commontypes.Subscription) { + r.stateMu.Lock() + defer r.stateMu.Unlock() + + r.aliveLoopSub = sub +} + +// SubscribersCount returns the number of client subscribed to the node +func (r *rpcClient) SubscribersCount() int32 { + r.stateMu.RLock() + defer r.stateMu.RUnlock() + return int32(len(r.subs)) +} + +// UnsubscribeAllExceptAliveLoop disconnects all subscriptions to the node except the alive loop subscription +// while holding the n.stateMu lock +func (r *rpcClient) UnsubscribeAllExceptAliveLoop() { + r.stateMu.Lock() + defer r.stateMu.Unlock() + + for _, s := range r.subs { + if s != r.aliveLoopSub { + s.Unsubscribe() + } + } +} + +// RPC wrappers + +// CallContext implementation +func (r *rpcClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return err + } + defer cancel() + lggr := r.newRqLggr().With( + "method", method, + "args", args, + ) + + lggr.Debug("RPC call: evmclient.Client#CallContext") + start := time.Now() + if http != nil { + err = r.wrapHTTP(http.rpc.CallContext(ctx, result, method, args...)) + } else { + err = r.wrapWS(ws.rpc.CallContext(ctx, result, method, args...)) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "CallContext") + + return err +} + +func (r *rpcClient) BatchCallContext(ctx context.Context, b []any) error { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return err + } + batch := make([]rpc.BatchElem, len(b)) + for i, arg := range b { + batch[i] = arg.(rpc.BatchElem) + } + defer cancel() + lggr := r.newRqLggr().With("nBatchElems", len(b), "batchElems", b) + + lggr.Trace("RPC call: evmclient.Client#BatchCallContext") + start := time.Now() + if http != nil { + err = r.wrapHTTP(http.rpc.BatchCallContext(ctx, batch)) + } else { + err = r.wrapWS(ws.rpc.BatchCallContext(ctx, batch)) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "BatchCallContext") + + return err +} + +func (r *rpcClient) Subscribe(ctx context.Context, channel chan<- *evmtypes.Head, args ...interface{}) (commontypes.Subscription, error) { + ctx, cancel, ws, _, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("args", args) + + lggr.Debug("RPC call: evmclient.Client#EthSubscribe") + start := time.Now() + sub, err := ws.rpc.EthSubscribe(ctx, channel, args...) + if err == nil { + r.registerSub(sub) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "EthSubscribe") + + return sub, err +} + +// GethClient wrappers + +func (r *rpcClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (receipt *evmtypes.Receipt, err error) { + err = r.CallContext(ctx, &receipt, "eth_getTransactionReceipt", txHash, false) + if err != nil { + return nil, err + } + if receipt == nil { + err = ethereum.NotFound + return + } + return +} + +func (r *rpcClient) TransactionReceiptGeth(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("txHash", txHash) + + lggr.Debug("RPC call: evmclient.Client#TransactionReceipt") + + start := time.Now() + if http != nil { + receipt, err = http.geth.TransactionReceipt(ctx, txHash) + err = r.wrapHTTP(err) + } else { + receipt, err = ws.geth.TransactionReceipt(ctx, txHash) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "TransactionReceipt", + "receipt", receipt, + ) + + return +} +func (r *rpcClient) TransactionByHash(ctx context.Context, txHash common.Hash) (tx *types.Transaction, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("txHash", txHash) + + lggr.Debug("RPC call: evmclient.Client#TransactionByHash") + + start := time.Now() + if http != nil { + tx, _, err = http.geth.TransactionByHash(ctx, txHash) + err = r.wrapHTTP(err) + } else { + tx, _, err = ws.geth.TransactionByHash(ctx, txHash) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "TransactionByHash", + "receipt", tx, + ) + + return +} + +func (r *rpcClient) HeaderByNumber(ctx context.Context, number *big.Int) (header *types.Header, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("number", number) + + lggr.Debug("RPC call: evmclient.Client#HeaderByNumber") + start := time.Now() + if http != nil { + header, err = http.geth.HeaderByNumber(ctx, number) + err = r.wrapHTTP(err) + } else { + header, err = ws.geth.HeaderByNumber(ctx, number) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "HeaderByNumber", "header", header) + + return +} + +func (r *rpcClient) HeaderByHash(ctx context.Context, hash common.Hash) (header *types.Header, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("hash", hash) + + lggr.Debug("RPC call: evmclient.Client#HeaderByHash") + start := time.Now() + if http != nil { + header, err = http.geth.HeaderByHash(ctx, hash) + err = r.wrapHTTP(err) + } else { + header, err = ws.geth.HeaderByHash(ctx, hash) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "HeaderByHash", + "header", header, + ) + + return +} + +func (r *rpcClient) BlockByNumber(ctx context.Context, number *big.Int) (head *evmtypes.Head, err error) { + hex := ToBlockNumArg(number) + err = r.CallContext(ctx, &head, "eth_getBlockByNumber", hex, false) + if err != nil { + return nil, err + } + if head == nil { + err = ethereum.NotFound + return + } + head.EVMChainID = utils.NewBig(r.chainID) + return +} + +func (r *rpcClient) BlockByHash(ctx context.Context, hash common.Hash) (head *evmtypes.Head, err error) { + err = r.CallContext(ctx, &head, "eth_getBlockByHash", hash.Hex(), false) + if err != nil { + return nil, err + } + if head == nil { + err = ethereum.NotFound + return + } + head.EVMChainID = utils.NewBig(r.chainID) + return +} + +func (r *rpcClient) BlockByHashGeth(ctx context.Context, hash common.Hash) (block *types.Block, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("hash", hash) + + lggr.Debug("RPC call: evmclient.Client#BlockByHash") + start := time.Now() + if http != nil { + block, err = http.geth.BlockByHash(ctx, hash) + err = r.wrapHTTP(err) + } else { + block, err = ws.geth.BlockByHash(ctx, hash) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "BlockByHash", + "block", block, + ) + + return +} + +func (r *rpcClient) BlockByNumberGeth(ctx context.Context, number *big.Int) (block *types.Block, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("number", number) + + lggr.Debug("RPC call: evmclient.Client#BlockByNumber") + start := time.Now() + if http != nil { + block, err = http.geth.BlockByNumber(ctx, number) + err = r.wrapHTTP(err) + } else { + block, err = ws.geth.BlockByNumber(ctx, number) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "BlockByNumber", + "block", block, + ) + + return +} + +func (r *rpcClient) SendTransaction(ctx context.Context, tx *types.Transaction) error { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return err + } + defer cancel() + lggr := r.newRqLggr().With("tx", tx) + + lggr.Debug("RPC call: evmclient.Client#SendTransaction") + start := time.Now() + if http != nil { + err = r.wrapHTTP(http.geth.SendTransaction(ctx, tx)) + } else { + err = r.wrapWS(ws.geth.SendTransaction(ctx, tx)) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "SendTransaction") + + return err +} + +func (r *rpcClient) SimulateTransaction(ctx context.Context, tx *types.Transaction) error { + // Not Implemented + return errors.New("SimulateTransaction not implemented") +} + +func (r *rpcClient) SendEmptyTransaction( + ctx context.Context, + newTxAttempt func(nonce evmtypes.Nonce, feeLimit uint32, fee *assets.Wei, fromAddress common.Address) (attempt any, err error), + nonce evmtypes.Nonce, + gasLimit uint32, + fee *assets.Wei, + fromAddress common.Address, +) (txhash string, err error) { + // Not Implemented + return "", errors.New("SendEmptyTransaction not implemented") +} + +// PendingSequenceAt returns one higher than the highest nonce from both mempool and mined transactions +func (r *rpcClient) PendingSequenceAt(ctx context.Context, account common.Address) (nonce evmtypes.Nonce, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return 0, err + } + defer cancel() + lggr := r.newRqLggr().With("account", account) + + lggr.Debug("RPC call: evmclient.Client#PendingNonceAt") + start := time.Now() + var n uint64 + if http != nil { + n, err = http.geth.PendingNonceAt(ctx, account) + nonce = evmtypes.Nonce(int64(n)) + err = r.wrapHTTP(err) + } else { + n, err = ws.geth.PendingNonceAt(ctx, account) + nonce = evmtypes.Nonce(int64(n)) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "PendingNonceAt", + "nonce", nonce, + ) + + return +} + +// SequenceAt is a bit of a misnomer. You might expect it to return the highest +// mined nonce at the given block number, but it actually returns the total +// transaction count which is the highest mined nonce + 1 +func (r *rpcClient) SequenceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (nonce evmtypes.Nonce, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return 0, err + } + defer cancel() + lggr := r.newRqLggr().With("account", account, "blockNumber", blockNumber) + + lggr.Debug("RPC call: evmclient.Client#NonceAt") + start := time.Now() + var n uint64 + if http != nil { + n, err = http.geth.NonceAt(ctx, account, blockNumber) + nonce = evmtypes.Nonce(int64(n)) + err = r.wrapHTTP(err) + } else { + n, err = ws.geth.NonceAt(ctx, account, blockNumber) + nonce = evmtypes.Nonce(int64(n)) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "NonceAt", + "nonce", nonce, + ) + + return +} + +func (r *rpcClient) PendingCodeAt(ctx context.Context, account common.Address) (code []byte, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("account", account) + + lggr.Debug("RPC call: evmclient.Client#PendingCodeAt") + start := time.Now() + if http != nil { + code, err = http.geth.PendingCodeAt(ctx, account) + err = r.wrapHTTP(err) + } else { + code, err = ws.geth.PendingCodeAt(ctx, account) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "PendingCodeAt", + "code", code, + ) + + return +} + +func (r *rpcClient) CodeAt(ctx context.Context, account common.Address, blockNumber *big.Int) (code []byte, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("account", account, "blockNumber", blockNumber) + + lggr.Debug("RPC call: evmclient.Client#CodeAt") + start := time.Now() + if http != nil { + code, err = http.geth.CodeAt(ctx, account, blockNumber) + err = r.wrapHTTP(err) + } else { + code, err = ws.geth.CodeAt(ctx, account, blockNumber) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "CodeAt", + "code", code, + ) + + return +} + +func (r *rpcClient) EstimateGas(ctx context.Context, c interface{}) (gas uint64, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return 0, err + } + defer cancel() + call := c.(ethereum.CallMsg) + lggr := r.newRqLggr().With("call", call) + + lggr.Debug("RPC call: evmclient.Client#EstimateGas") + start := time.Now() + if http != nil { + gas, err = http.geth.EstimateGas(ctx, call) + err = r.wrapHTTP(err) + } else { + gas, err = ws.geth.EstimateGas(ctx, call) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "EstimateGas", + "gas", gas, + ) + + return +} + +func (r *rpcClient) SuggestGasPrice(ctx context.Context) (price *big.Int, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr() + + lggr.Debug("RPC call: evmclient.Client#SuggestGasPrice") + start := time.Now() + if http != nil { + price, err = http.geth.SuggestGasPrice(ctx) + err = r.wrapHTTP(err) + } else { + price, err = ws.geth.SuggestGasPrice(ctx) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "SuggestGasPrice", + "price", price, + ) + + return +} + +func (r *rpcClient) CallContract(ctx context.Context, msg interface{}, blockNumber *big.Int) (val []byte, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("callMsg", msg, "blockNumber", blockNumber) + message := msg.(ethereum.CallMsg) + + lggr.Debug("RPC call: evmclient.Client#CallContract") + start := time.Now() + if http != nil { + val, err = http.geth.CallContract(ctx, message, blockNumber) + err = r.wrapHTTP(err) + } else { + val, err = ws.geth.CallContract(ctx, message, blockNumber) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "CallContract", + "val", val, + ) + + return + +} + +func (r *rpcClient) LatestBlockHeight(ctx context.Context) (*big.Int, error) { + var height big.Int + h, err := r.BlockNumber(ctx) + return height.SetUint64(h), err +} + +func (r *rpcClient) BlockNumber(ctx context.Context) (height uint64, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return 0, err + } + defer cancel() + lggr := r.newRqLggr() + + lggr.Debug("RPC call: evmclient.Client#BlockNumber") + start := time.Now() + if http != nil { + height, err = http.geth.BlockNumber(ctx) + err = r.wrapHTTP(err) + } else { + height, err = ws.geth.BlockNumber(ctx) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "BlockNumber", + "height", height, + ) + + return +} + +func (r *rpcClient) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (balance *big.Int, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("account", account.Hex(), "blockNumber", blockNumber) + + lggr.Debug("RPC call: evmclient.Client#BalanceAt") + start := time.Now() + if http != nil { + balance, err = http.geth.BalanceAt(ctx, account, blockNumber) + err = r.wrapHTTP(err) + } else { + balance, err = ws.geth.BalanceAt(ctx, account, blockNumber) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "BalanceAt", + "balance", balance, + ) + + return +} + +// TokenBalance returns the balance of the given address for the token contract address. +func (r *rpcClient) TokenBalance(ctx context.Context, address common.Address, contractAddress common.Address) (*big.Int, error) { + result := "" + numLinkBigInt := new(big.Int) + functionSelector := evmtypes.HexToFunctionSelector(BALANCE_OF_ADDRESS_FUNCTION_SELECTOR) // balanceOf(address) + data := utils.ConcatBytes(functionSelector.Bytes(), common.LeftPadBytes(address.Bytes(), utils.EVMWordByteLen)) + args := CallArgs{ + To: contractAddress, + Data: data, + } + err := r.CallContext(ctx, &result, "eth_call", args, "latest") + if err != nil { + return numLinkBigInt, err + } + numLinkBigInt.SetString(result, 0) + return numLinkBigInt, nil +} + +// LINKBalance returns the balance of LINK at the given address +func (r *rpcClient) LINKBalance(ctx context.Context, address common.Address, linkAddress common.Address) (*assets.Link, error) { + balance, err := r.TokenBalance(ctx, address, linkAddress) + if err != nil { + return assets.NewLinkFromJuels(0), err + } + return (*assets.Link)(balance), nil +} + +func (r *rpcClient) FilterEvents(ctx context.Context, q ethereum.FilterQuery) ([]types.Log, error) { + return r.FilterLogs(ctx, q) +} + +func (r *rpcClient) FilterLogs(ctx context.Context, q ethereum.FilterQuery) (l []types.Log, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("q", q) + + lggr.Debug("RPC call: evmclient.Client#FilterLogs") + start := time.Now() + if http != nil { + l, err = http.geth.FilterLogs(ctx, q) + err = r.wrapHTTP(err) + } else { + l, err = ws.geth.FilterLogs(ctx, q) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "FilterLogs", + "log", l, + ) + + return +} + +func (r *rpcClient) ClientVersion(ctx context.Context) (version string, err error) { + err = r.CallContext(ctx, &version, "web3_clientVersion") + return +} + +func (r *rpcClient) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (sub ethereum.Subscription, err error) { + ctx, cancel, ws, _, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr().With("q", q) + + lggr.Debug("RPC call: evmclient.Client#SubscribeFilterLogs") + start := time.Now() + sub, err = ws.geth.SubscribeFilterLogs(ctx, q, ch) + if err == nil { + r.registerSub(sub) + } + err = r.wrapWS(err) + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "SubscribeFilterLogs") + + return +} + +func (r *rpcClient) SuggestGasTipCap(ctx context.Context) (tipCap *big.Int, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + if err != nil { + return nil, err + } + defer cancel() + lggr := r.newRqLggr() + + lggr.Debug("RPC call: evmclient.Client#SuggestGasTipCap") + start := time.Now() + if http != nil { + tipCap, err = http.geth.SuggestGasTipCap(ctx) + err = r.wrapHTTP(err) + } else { + tipCap, err = ws.geth.SuggestGasTipCap(ctx) + err = r.wrapWS(err) + } + duration := time.Since(start) + + r.logResult(lggr, err, duration, r.getRPCDomain(), "SuggestGasTipCap", + "tipCap", tipCap, + ) + + return +} + +// Returns the ChainID according to the geth client. This is useful for functions like verify() +// the common node. +func (r *rpcClient) ChainID(ctx context.Context) (chainID *big.Int, err error) { + ctx, cancel, ws, http, err := r.makeLiveQueryCtxAndSafeGetClients(ctx) + + defer cancel() + + if http != nil { + chainID, err = http.geth.ChainID(ctx) + err = r.wrapHTTP(err) + } else { + chainID, err = ws.geth.ChainID(ctx) + err = r.wrapWS(err) + } + return +} + +// newRqLggr generates a new logger with a unique request ID +func (r *rpcClient) newRqLggr() logger.Logger { + return r.rpcLog.With( + "requestID", uuid.New(), + ) +} + +func wrapCallError(err error, tp string) error { + if err == nil { + return nil + } + if errors.Cause(err).Error() == "context deadline exceeded" { + err = errors.Wrap(err, "remote node timed out") + } + return errors.Wrapf(err, "%s call failed", tp) +} + +func (r *rpcClient) wrapWS(err error) error { + err = wrapCallError(err, fmt.Sprintf("%s websocket (%s)", r.tier.String(), r.ws.uri.Redacted())) + return err +} + +func (r *rpcClient) wrapHTTP(err error) error { + err = wrapCallError(err, fmt.Sprintf("%s http (%s)", r.tier.String(), r.http.uri.Redacted())) + if err != nil { + r.rpcLog.Debugw("Call failed", "err", err) + } else { + r.rpcLog.Trace("Call succeeded") + } + return err +} + +// makeLiveQueryCtxAndSafeGetClients wraps makeQueryCtx +func (r *rpcClient) makeLiveQueryCtxAndSafeGetClients(parentCtx context.Context) (ctx context.Context, cancel context.CancelFunc, ws rawclient, http *rawclient, err error) { + // Need to wrap in mutex because state transition can cancel and replace the + // context + r.stateMu.RLock() + cancelCh := r.chStopInFlight + ws = r.ws + if r.http != nil { + cp := *r.http + http = &cp + } + r.stateMu.RUnlock() + ctx, cancel = makeQueryCtx(parentCtx, cancelCh) + return +} + +func (r *rpcClient) makeQueryCtx(ctx context.Context) (context.Context, context.CancelFunc) { + return makeQueryCtx(ctx, r.getChStopInflight()) +} + +// getChStopInflight provides a convenience helper that mutex wraps a +// read to the chStopInFlight +func (r *rpcClient) getChStopInflight() chan struct{} { + r.stateMu.RLock() + defer r.stateMu.RUnlock() + return r.chStopInFlight +} + +func (r *rpcClient) Name() string { + return r.name +} + +func Name(r *rpcClient) string { + return r.name +} diff --git a/core/chains/evm/txmgr/client.go b/core/chains/evm/txmgr/client.go index 150ee277577..e1b12577749 100644 --- a/core/chains/evm/txmgr/client.go +++ b/core/chains/evm/txmgr/client.go @@ -80,7 +80,7 @@ func (c *evmTxmClient) BatchSendTransactions( processingErr[i] = fmt.Errorf("failed to process tx (index %d): %w", i, signedErr) return } - codes[i], txErrs[i] = evmclient.NewSendErrorReturnCode(reqs[i].Error, lggr, tx, attempts[i].Tx.FromAddress, c.client.IsL2()) + codes[i], txErrs[i] = evmclient.ClassifySendError(reqs[i].Error, lggr, tx, attempts[i].Tx.FromAddress, c.client.IsL2()) }(index) } wg.Wait() diff --git a/core/chains/evm/types/models.go b/core/chains/evm/types/models.go index 6210226120f..a71c9e8716c 100644 --- a/core/chains/evm/types/models.go +++ b/core/chains/evm/types/models.go @@ -76,6 +76,10 @@ func (h *Head) GetParent() commontypes.Head[common.Hash] { return h.Parent } +func (h *Head) BlockDifficulty() *utils.Big { + return h.Difficulty +} + // EarliestInChain recurses through parents until it finds the earliest one func (h *Head) EarliestInChain() *Head { for h.Parent != nil {