From 388e7794dffc9f228e04aec264ff5bcac7908520 Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Mon, 22 Jan 2024 19:06:34 -0800 Subject: [PATCH] [Functions] Add extra validations for offchain heartbeats (#11783) 1. Add AllowedHeartbeatInitiators list to node's config and validate senders of incoming requests against it (same logic as in Gateway). 2. Validate Sender value in nodes' reponses to make sure it matches the expected node. Extend an integration test to cover this change. 3. Validate age of incoming requests against RequestTimeoutSec from job config to avoid processing ones that already timed out. 4. Disallow null-byte suffixes in message fields to avoid any potential confusion with default padding. --- core/services/functions/connector_handler.go | 83 +++++++++++-------- .../functions/connector_handler_test.go | 38 +++++++-- core/services/gateway/api/message.go | 10 +++ core/services/gateway/api/message_test.go | 18 +++- core/services/gateway/connectionmanager.go | 4 + .../gateway_integration_test.go | 77 +++++++++++++---- .../ocr2/plugins/functions/config/config.go | 1 + .../services/ocr2/plugins/functions/plugin.go | 15 ++-- .../ocr2/plugins/functions/plugin_test.go | 12 ++- 9 files changed, 193 insertions(+), 65 deletions(-) diff --git a/core/services/functions/connector_handler.go b/core/services/functions/connector_handler.go index 1594dc6eb56..18df644c876 100644 --- a/core/services/functions/connector_handler.go +++ b/core/services/functions/connector_handler.go @@ -6,6 +6,7 @@ import ( "crypto/ecdsa" "encoding/json" "fmt" + "strings" "sync" "time" @@ -25,34 +26,34 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" "github.com/smartcontractkit/chainlink/v2/core/services/s4" ) type functionsConnectorHandler struct { services.StateMachine - connector connector.GatewayConnector - signerKey *ecdsa.PrivateKey - nodeAddress string - storage s4.Storage - allowlist functions.OnchainAllowlist - rateLimiter *hc.RateLimiter - subscriptions functions.OnchainSubscriptions - minimumBalance assets.Link - listener FunctionsListener - offchainTransmitter OffchainTransmitter - heartbeatRequests map[RequestID]*HeartbeatResponse - orderedRequests []RequestID - mu sync.Mutex - chStop services.StopChan - shutdownWaitGroup sync.WaitGroup - lggr logger.Logger + connector connector.GatewayConnector + signerKey *ecdsa.PrivateKey + nodeAddress string + storage s4.Storage + allowlist functions.OnchainAllowlist + rateLimiter *hc.RateLimiter + subscriptions functions.OnchainSubscriptions + minimumBalance assets.Link + listener FunctionsListener + offchainTransmitter OffchainTransmitter + allowedHeartbeatInitiators map[string]struct{} + heartbeatRequests map[RequestID]*HeartbeatResponse + requestTimeoutSec uint32 + orderedRequests []RequestID + mu sync.Mutex + chStop services.StopChan + shutdownWaitGroup sync.WaitGroup + lggr logger.Logger } -const ( - HeartbeatRequestTimeoutSec = 240 - HeartbeatCacheSize = 1000 -) +const HeartbeatCacheSize = 1000 var ( _ connector.Signer = &functionsConnectorHandler{} @@ -71,23 +72,29 @@ func InternalId(sender []byte, requestId []byte) RequestID { return RequestID(crypto.Keccak256Hash(append(sender, requestId...)).Bytes()) } -func NewFunctionsConnectorHandler(nodeAddress string, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, listener FunctionsListener, offchainTransmitter OffchainTransmitter, minimumBalance assets.Link, lggr logger.Logger) (*functionsConnectorHandler, error) { +func NewFunctionsConnectorHandler(pluginConfig *config.PluginConfig, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, listener FunctionsListener, offchainTransmitter OffchainTransmitter, lggr logger.Logger) (*functionsConnectorHandler, error) { if signerKey == nil || storage == nil || allowlist == nil || rateLimiter == nil || subscriptions == nil || listener == nil || offchainTransmitter == nil { return nil, fmt.Errorf("all dependencies must be non-nil") } + allowedHeartbeatInitiators := make(map[string]struct{}) + for _, initiator := range pluginConfig.AllowedHeartbeatInitiators { + allowedHeartbeatInitiators[strings.ToLower(initiator)] = struct{}{} + } return &functionsConnectorHandler{ - nodeAddress: nodeAddress, - signerKey: signerKey, - storage: storage, - allowlist: allowlist, - rateLimiter: rateLimiter, - subscriptions: subscriptions, - minimumBalance: minimumBalance, - listener: listener, - offchainTransmitter: offchainTransmitter, - heartbeatRequests: make(map[RequestID]*HeartbeatResponse), - chStop: make(services.StopChan), - lggr: lggr.Named("FunctionsConnectorHandler"), + nodeAddress: pluginConfig.GatewayConnectorConfig.NodeAddress, + signerKey: signerKey, + storage: storage, + allowlist: allowlist, + rateLimiter: rateLimiter, + subscriptions: subscriptions, + minimumBalance: pluginConfig.MinimumSubscriptionBalance, + listener: listener, + offchainTransmitter: offchainTransmitter, + allowedHeartbeatInitiators: allowedHeartbeatInitiators, + heartbeatRequests: make(map[RequestID]*HeartbeatResponse), + requestTimeoutSec: pluginConfig.RequestTimeoutSec, + chStop: make(services.StopChan), + lggr: lggr.Named("FunctionsConnectorHandler"), }, nil } @@ -211,6 +218,10 @@ func (h *functionsConnectorHandler) handleHeartbeat(ctx context.Context, gateway h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse(fmt.Sprintf("failed to unmarshal request: %v", err))) return } + if _, ok := h.allowedHeartbeatInitiators[requestBody.Sender]; !ok { + h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("sender not allowed to send heartbeat requests")) + return + } if !bytes.Equal(request.RequestInitiator, fromAddr.Bytes()) { h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("RequestInitiator doesn't match sender")) return @@ -219,6 +230,10 @@ func (h *functionsConnectorHandler) handleHeartbeat(ctx context.Context, gateway h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("SubscriptionOwner doesn't match sender")) return } + if request.Timestamp < uint64(time.Now().Unix())-uint64(h.requestTimeoutSec) { + h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("Request is too old")) + return + } internalId := InternalId(fromAddr.Bytes(), request.RequestId) request.RequestId = internalId[:] @@ -250,7 +265,7 @@ func internalErrorResponse(internalError string) HeartbeatResponse { func (h *functionsConnectorHandler) handleOffchainRequest(request *OffchainRequest) { defer h.shutdownWaitGroup.Done() stopCtx, _ := h.chStop.NewCtx() - ctx, cancel := context.WithTimeout(stopCtx, time.Duration(HeartbeatRequestTimeoutSec)*time.Second) + ctx, cancel := context.WithTimeout(stopCtx, time.Duration(h.requestTimeoutSec)*time.Second) defer cancel() err := h.listener.HandleOffchainRequest(ctx, request) if err != nil { diff --git a/core/services/functions/connector_handler_test.go b/core/services/functions/connector_handler_test.go index 409f9cdcc56..9a5a9042693 100644 --- a/core/services/functions/connector_handler_test.go +++ b/core/services/functions/connector_handler_test.go @@ -10,6 +10,7 @@ import ( "time" geth_common "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/onsi/gomega" "github.com/smartcontractkit/chainlink-common/pkg/assets" @@ -19,9 +20,11 @@ import ( sfmocks "github.com/smartcontractkit/chainlink/v2/core/services/functions/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" + gwconnector "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" gfmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/mocks" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" "github.com/smartcontractkit/chainlink/v2/core/services/s4" s4mocks "github.com/smartcontractkit/chainlink/v2/core/services/s4/mocks" @@ -30,7 +33,7 @@ import ( "github.com/stretchr/testify/require" ) -func newOffchainRequest(t *testing.T, sender []byte) (*api.Message, functions.RequestID) { +func newOffchainRequest(t *testing.T, sender []byte, ageSec uint64) (*api.Message, functions.RequestID) { requestId := make([]byte, 32) _, err := rand.Read(requestId) require.NoError(t, err) @@ -39,6 +42,7 @@ func newOffchainRequest(t *testing.T, sender []byte) (*api.Message, functions.Re RequestInitiator: sender, SubscriptionId: 1, SubscriptionOwner: sender, + Timestamp: uint64(time.Now().Unix()) - ageSec, } internalId := functions.InternalId(request.RequestInitiator, request.RequestId) @@ -74,7 +78,15 @@ func TestFunctionsConnectorHandler(t *testing.T) { allowlist.On("Close", mock.Anything).Return(nil) subscriptions.On("Start", mock.Anything).Return(nil) subscriptions.On("Close", mock.Anything).Return(nil) - handler, err := functions.NewFunctionsConnectorHandler(addr.Hex(), privateKey, storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(100), logger) + config := &config.PluginConfig{ + GatewayConnectorConfig: &gwconnector.ConnectorConfig{ + NodeAddress: addr.Hex(), + }, + MinimumSubscriptionBalance: *assets.NewLinkFromJuels(100), + RequestTimeoutSec: 1_000, + AllowedHeartbeatInitiators: []string{crypto.PubkeyToAddress(privateKey.PublicKey).Hex()}, + } + handler, err := functions.NewFunctionsConnectorHandler(config, privateKey, storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger) require.NoError(t, err) handler.SetConnector(connector) @@ -257,7 +269,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { t.Run("heartbeat success", func(t *testing.T) { ctx := testutils.Context(t) - msg, internalId := newOffchainRequest(t, addr.Bytes()) + msg, internalId := newOffchainRequest(t, addr.Bytes(), 0) require.NoError(t, msg.Sign(privateKey)) // first call to trigger the request @@ -292,7 +304,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { t.Run("heartbeat internal error", func(t *testing.T) { ctx := testutils.Context(t) - msg, _ := newOffchainRequest(t, addr.Bytes()) + msg, _ := newOffchainRequest(t, addr.Bytes(), 0) require.NoError(t, msg.Sign(privateKey)) // first call to trigger the request @@ -319,7 +331,23 @@ func TestFunctionsConnectorHandler(t *testing.T) { t.Run("heartbeat sender address doesn't match", func(t *testing.T) { ctx := testutils.Context(t) - msg, _ := newOffchainRequest(t, geth_common.BytesToAddress([]byte("0x1234")).Bytes()) + msg, _ := newOffchainRequest(t, geth_common.BytesToAddress([]byte("0x1234")).Bytes(), 0) + require.NoError(t, msg.Sign(privateKey)) + + var response functions.HeartbeatResponse + allowlist.On("Allow", addr).Return(true).Once() + connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) { + respMsg, ok := args[2].(*api.Message) + require.True(t, ok) + require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response)) + require.Equal(t, functions.RequestStateInternalError, response.Status) + }).Return(nil).Once() + handler.HandleGatewayMessage(ctx, "gw1", msg) + }) + + t.Run("heartbeat request too old", func(t *testing.T) { + ctx := testutils.Context(t) + msg, _ := newOffchainRequest(t, addr.Bytes(), 10_000) require.NoError(t, msg.Sign(privateKey)) var response functions.HeartbeatResponse diff --git a/core/services/gateway/api/message.go b/core/services/gateway/api/message.go index c01d3bb9f2e..5e6c8e49247 100644 --- a/core/services/gateway/api/message.go +++ b/core/services/gateway/api/message.go @@ -20,6 +20,7 @@ const ( MessageMethodMaxLen = 64 MessageDonIdMaxLen = 64 MessageReceiverLen = 2 + 2*20 + NullChar = "\x00" ) /* @@ -56,12 +57,21 @@ func (m *Message) Validate() error { if len(m.Body.MessageId) == 0 || len(m.Body.MessageId) > MessageIdMaxLen { return errors.New("invalid message ID length") } + if strings.HasSuffix(m.Body.MessageId, NullChar) { + return errors.New("message ID ending with null bytes") + } if len(m.Body.Method) == 0 || len(m.Body.Method) > MessageMethodMaxLen { return errors.New("invalid method name length") } + if strings.HasSuffix(m.Body.Method, NullChar) { + return errors.New("method name ending with null bytes") + } if len(m.Body.DonId) == 0 || len(m.Body.DonId) > MessageDonIdMaxLen { return errors.New("invalid DON ID length") } + if strings.HasSuffix(m.Body.DonId, NullChar) { + return errors.New("DON ID ending with null bytes") + } if len(m.Body.Receiver) != 0 && len(m.Body.Receiver) != MessageReceiverLen { return errors.New("invalid Receiver length") } diff --git a/core/services/gateway/api/message_test.go b/core/services/gateway/api/message_test.go index a0835ea24bb..1f292db26b9 100644 --- a/core/services/gateway/api/message_test.go +++ b/core/services/gateway/api/message_test.go @@ -31,22 +31,38 @@ func TestMessage_Validate(t *testing.T) { // missing message ID msg.Body.MessageId = "" require.Error(t, msg.Validate()) + // message ID ending with null bytes + msg.Body.MessageId = "myid\x00\x00" + require.Error(t, msg.Validate()) msg.Body.MessageId = "abcd" + require.NoError(t, msg.Validate()) // missing DON ID msg.Body.DonId = "" require.Error(t, msg.Validate()) + // DON ID ending with null bytes + msg.Body.DonId = "mydon\x00\x00" + require.Error(t, msg.Validate()) msg.Body.DonId = "donA" + require.NoError(t, msg.Validate()) - // method too long + // method name too long msg.Body.Method = string(bytes.Repeat([]byte("a"), api.MessageMethodMaxLen+1)) require.Error(t, msg.Validate()) + // empty method name + msg.Body.Method = "" + require.Error(t, msg.Validate()) + // method name ending with null bytes + msg.Body.Method = "method\x00" + require.Error(t, msg.Validate()) msg.Body.Method = "request" + require.NoError(t, msg.Validate()) // incorrect receiver msg.Body.Receiver = "blah" require.Error(t, msg.Validate()) msg.Body.Receiver = "0x0000000000000000000000000000000000000000" + require.NoError(t, msg.Validate()) // invalid signature msg.Signature = "0x00" diff --git a/core/services/gateway/connectionmanager.go b/core/services/gateway/connectionmanager.go index 9f88b51e7b5..e5f7fb13afb 100644 --- a/core/services/gateway/connectionmanager.go +++ b/core/services/gateway/connectionmanager.go @@ -287,6 +287,10 @@ func (m *donConnectionManager) readLoop(nodeAddress string, nodeState *nodeState m.lggr.Errorw("message validation error when reading from node", "nodeAddress", nodeAddress, "err", err) break } + if msg.Body.Sender != nodeAddress { + m.lggr.Errorw("message sender mismatch when reading from node", "nodeAddress", nodeAddress, "sender", msg.Body.Sender) + break + } err = m.handler.HandleNodeMessage(ctx, msg, nodeAddress) if err != nil { m.lggr.Error("error when calling HandleNodeMessage ", err) diff --git a/core/services/gateway/integration_tests/gateway_integration_test.go b/core/services/gateway/integration_tests/gateway_integration_test.go index 9e4900efeee..a2064b7a591 100644 --- a/core/services/gateway/integration_tests/gateway_integration_test.go +++ b/core/services/gateway/integration_tests/gateway_integration_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ecdsa" "fmt" + "io" "net/http" "strings" "sync/atomic" @@ -36,18 +37,18 @@ Path = "/node" Port = 0 HandshakeTimeoutMillis = 2_000 MaxRequestBytes = 20_000 -ReadTimeoutMillis = 100 -RequestTimeoutMillis = 100 -WriteTimeoutMillis = 100 +ReadTimeoutMillis = 1000 +RequestTimeoutMillis = 1000 +WriteTimeoutMillis = 1000 [UserServerConfig] Path = "/user" Port = 0 ContentTypeHeader = "application/jsonrpc" MaxRequestBytes = 20_000 -ReadTimeoutMillis = 100 -RequestTimeoutMillis = 100 -WriteTimeoutMillis = 100 +ReadTimeoutMillis = 1000 +RequestTimeoutMillis = 1000 +WriteTimeoutMillis = 1000 [[Dons]] DonId = "test_don" @@ -72,6 +73,13 @@ Id = "test_gateway" URL = "%s" ` +const ( + messageId1 = "123" + messageId2 = "456" + + nodeResponsePayload = `{"response":"correct response"}` +) + func parseGatewayConfig(t *testing.T, tomlConfig string) *config.GatewayConfig { var cfg config.GatewayConfig err := toml.Unmarshal([]byte(tomlConfig), &cfg) @@ -94,6 +102,21 @@ type client struct { func (c *client) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) { c.done.Store(true) + // send back user's message without re-signing - should be ignored by the Gateway + _ = c.connector.SendToGateway(ctx, gatewayId, msg) + // send back a correct response + responseMsg := &api.Message{Body: api.MessageBody{ + MessageId: msg.Body.MessageId, + Method: "test", + DonId: "test_don", + Receiver: msg.Body.Sender, + Payload: []byte(nodeResponsePayload), + }} + err := responseMsg.Sign(c.privateKey) + if err != nil { + panic(err) + } + _ = c.connector.SendToGateway(ctx, gatewayId, responseMsg) } func (c *client) Sign(data ...[]byte) ([]byte, error) { @@ -111,7 +134,9 @@ func (*client) Close() error { func TestIntegration_Gateway_NoFullNodes_BasicConnectionAndMessage(t *testing.T) { t.Parallel() - nodeKeys := common.NewTestNodes(t, 1)[0] + testWallets := common.NewTestNodes(t, 2) + nodeKeys := testWallets[0] + userKeys := testWallets[1] // Verify that addresses in config are case-insensitive nodeKeys.Address = strings.ToUpper(nodeKeys.Address) @@ -132,17 +157,39 @@ func TestIntegration_Gateway_NoFullNodes_BasicConnectionAndMessage(t *testing.T) client.connector = connector servicetest.Run(t, connector) - // Send requests until one of them reaches Connector + // Send requests until one of them reaches Connector (i.e. the node) gomega.NewGomegaWithT(t).Eventually(func() bool { - msg := &api.Message{Body: api.MessageBody{MessageId: "123", Method: "test", DonId: "test_don"}} - require.NoError(t, msg.Sign(nodeKeys.PrivateKey)) - codec := api.JsonRPCCodec{} - rawMsg, err := codec.EncodeRequest(msg) - require.NoError(t, err) - req, err := http.NewRequestWithContext(testutils.Context(t), "POST", userUrl, bytes.NewBuffer(rawMsg)) - require.NoError(t, err) + req := newHttpRequestObject(t, messageId1, userUrl, userKeys.PrivateKey) httpClient := &http.Client{} _, _ = httpClient.Do(req) // could initially return error if Gateway is not fully initialized yet return client.done.Load() }, testutils.WaitTimeout(t), testutils.TestInterval).Should(gomega.Equal(true)) + + // Send another request and validate that response has correct content and sender + req := newHttpRequestObject(t, messageId2, userUrl, userKeys.PrivateKey) + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + rawResp, err := io.ReadAll(resp.Body) + require.NoError(t, err) + codec := api.JsonRPCCodec{} + respMsg, err := codec.DecodeResponse(rawResp) + require.NoError(t, err) + require.NoError(t, respMsg.Validate()) + require.Equal(t, strings.ToLower(nodeKeys.Address), respMsg.Body.Sender) + require.Equal(t, messageId2, respMsg.Body.MessageId) + require.Equal(t, nodeResponsePayload, string(respMsg.Body.Payload)) +} + +func newHttpRequestObject(t *testing.T, messageId string, userUrl string, signerKey *ecdsa.PrivateKey) *http.Request { + msg := &api.Message{Body: api.MessageBody{MessageId: messageId, Method: "test", DonId: "test_don"}} + require.NoError(t, msg.Sign(signerKey)) + codec := api.JsonRPCCodec{} + rawMsg, err := codec.EncodeRequest(msg) + require.NoError(t, err) + req, err := http.NewRequestWithContext(testutils.Context(t), "POST", userUrl, bytes.NewBuffer(rawMsg)) + require.NoError(t, err) + return req } diff --git a/core/services/ocr2/plugins/functions/config/config.go b/core/services/ocr2/plugins/functions/config/config.go index 13e02042506..38af7b8587f 100644 --- a/core/services/ocr2/plugins/functions/config/config.go +++ b/core/services/ocr2/plugins/functions/config/config.go @@ -41,6 +41,7 @@ type PluginConfig struct { MaxRequestSizesList []uint32 `json:"maxRequestSizesList"` MaxSecretsSizesList []uint32 `json:"maxSecretsSizesList"` MinimumSubscriptionBalance assets.Link `json:"minimumSubscriptionBalance"` + AllowedHeartbeatInitiators []string `json:"allowedHeartbeatInitiators"` GatewayConnectorConfig *connector.ConnectorConfig `json:"gatewayConnectorConfig"` OnchainAllowlist *functions.OnchainAllowlistConfig `json:"onchainAllowlist"` OnchainSubscriptions *functions.OnchainSubscriptionsConfig `json:"onchainSubscriptions"` diff --git a/core/services/ocr2/plugins/functions/plugin.go b/core/services/ocr2/plugins/functions/plugin.go index c6cfa946aba..fd72b6fd38e 100644 --- a/core/services/ocr2/plugins/functions/plugin.go +++ b/core/services/ocr2/plugins/functions/plugin.go @@ -13,7 +13,6 @@ import ( "github.com/smartcontractkit/libocr/commontypes" libocr2 "github.com/smartcontractkit/libocr/offchainreporting2plus" - "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/core/bridges" @@ -152,7 +151,7 @@ func NewFunctionsServices(functionsOracleArgs, thresholdOracleArgs, s4OracleArgs return nil, errors.Wrap(err, "failed to create a OnchainSubscriptions") } connectorLogger := conf.Logger.Named("GatewayConnector").With("jobName", conf.Job.PipelineSpec.JobName) - connector, err2 := NewConnector(pluginConfig.GatewayConnectorConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, pluginConfig.MinimumSubscriptionBalance, connectorLogger) + connector, err2 := NewConnector(&pluginConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, connectorLogger) if err2 != nil { return nil, errors.Wrap(err, "failed to create a GatewayConnector") } @@ -179,24 +178,26 @@ func NewFunctionsServices(functionsOracleArgs, thresholdOracleArgs, s4OracleArgs return allServices, nil } -func NewConnector(gwcCfg *connector.ConnectorConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwFunctions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwFunctions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, minimumBalance assets.Link, lggr logger.Logger) (connector.GatewayConnector, error) { +func NewConnector(pluginConfig *config.PluginConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwFunctions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwFunctions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, lggr logger.Logger) (connector.GatewayConnector, error) { enabledKeys, err := ethKeystore.EnabledKeysForChain(chainID) if err != nil { return nil, err } - configuredNodeAddress := common.HexToAddress(gwcCfg.NodeAddress) + configuredNodeAddress := common.HexToAddress(pluginConfig.GatewayConnectorConfig.NodeAddress) idx := slices.IndexFunc(enabledKeys, func(key ethkey.KeyV2) bool { return key.Address == configuredNodeAddress }) if idx == -1 { return nil, errors.New("key for configured node address not found") } signerKey := enabledKeys[idx].ToEcdsaPrivKey() - nodeAddress := enabledKeys[idx].ID() + if enabledKeys[idx].ID() != pluginConfig.GatewayConnectorConfig.NodeAddress { + return nil, errors.New("node address mismatch") + } - handler, err := functions.NewFunctionsConnectorHandler(nodeAddress, signerKey, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, minimumBalance, lggr) + handler, err := functions.NewFunctionsConnectorHandler(pluginConfig, signerKey, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, lggr) if err != nil { return nil, err } - connector, err := connector.NewGatewayConnector(gwcCfg, handler, handler, utils.NewRealClock(), lggr) + connector, err := connector.NewGatewayConnector(pluginConfig.GatewayConnectorConfig, handler, handler, utils.NewRealClock(), lggr) if err != nil { return nil, err } diff --git a/core/services/ocr2/plugins/functions/plugin_test.go b/core/services/ocr2/plugins/functions/plugin_test.go index d77fabcc437..6d3f57b086c 100644 --- a/core/services/ocr2/plugins/functions/plugin_test.go +++ b/core/services/ocr2/plugins/functions/plugin_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink/v2/core/logger" sfmocks "github.com/smartcontractkit/chainlink/v2/core/services/functions/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" @@ -17,6 +16,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" ksmocks "github.com/smartcontractkit/chainlink/v2/core/services/keystore/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" s4mocks "github.com/smartcontractkit/chainlink/v2/core/services/s4/mocks" ) @@ -39,7 +39,10 @@ func TestNewConnector_Success(t *testing.T) { listener := sfmocks.NewFunctionsListener(t) offchainTransmitter := sfmocks.NewOffchainTransmitter(t) ethKeystore.On("EnabledKeysForChain", mock.Anything).Return([]ethkey.KeyV2{keyV2}, nil) - _, err = functions.NewConnector(gwcCfg, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(0), logger.TestLogger(t)) + config := &config.PluginConfig{ + GatewayConnectorConfig: gwcCfg, + } + _, err = functions.NewConnector(config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t)) require.NoError(t, err) } @@ -64,6 +67,9 @@ func TestNewConnector_NoKeyForConfiguredAddress(t *testing.T) { listener := sfmocks.NewFunctionsListener(t) offchainTransmitter := sfmocks.NewOffchainTransmitter(t) ethKeystore.On("EnabledKeysForChain", mock.Anything).Return([]ethkey.KeyV2{{Address: common.HexToAddress(addresses[1])}}, nil) - _, err = functions.NewConnector(gwcCfg, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(0), logger.TestLogger(t)) + config := &config.PluginConfig{ + GatewayConnectorConfig: gwcCfg, + } + _, err = functions.NewConnector(config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t)) require.Error(t, err) }