diff --git a/core/services/functions/connector_handler.go b/core/services/functions/connector_handler.go index 1594dc6eb56..18df644c876 100644 --- a/core/services/functions/connector_handler.go +++ b/core/services/functions/connector_handler.go @@ -6,6 +6,7 @@ import ( "crypto/ecdsa" "encoding/json" "fmt" + "strings" "sync" "time" @@ -25,34 +26,34 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" "github.com/smartcontractkit/chainlink/v2/core/services/s4" ) type functionsConnectorHandler struct { services.StateMachine - connector connector.GatewayConnector - signerKey *ecdsa.PrivateKey - nodeAddress string - storage s4.Storage - allowlist functions.OnchainAllowlist - rateLimiter *hc.RateLimiter - subscriptions functions.OnchainSubscriptions - minimumBalance assets.Link - listener FunctionsListener - offchainTransmitter OffchainTransmitter - heartbeatRequests map[RequestID]*HeartbeatResponse - orderedRequests []RequestID - mu sync.Mutex - chStop services.StopChan - shutdownWaitGroup sync.WaitGroup - lggr logger.Logger + connector connector.GatewayConnector + signerKey *ecdsa.PrivateKey + nodeAddress string + storage s4.Storage + allowlist functions.OnchainAllowlist + rateLimiter *hc.RateLimiter + subscriptions functions.OnchainSubscriptions + minimumBalance assets.Link + listener FunctionsListener + offchainTransmitter OffchainTransmitter + allowedHeartbeatInitiators map[string]struct{} + heartbeatRequests map[RequestID]*HeartbeatResponse + requestTimeoutSec uint32 + orderedRequests []RequestID + mu sync.Mutex + chStop services.StopChan + shutdownWaitGroup sync.WaitGroup + lggr logger.Logger } -const ( - HeartbeatRequestTimeoutSec = 240 - HeartbeatCacheSize = 1000 -) +const HeartbeatCacheSize = 1000 var ( _ connector.Signer = &functionsConnectorHandler{} @@ -71,23 +72,29 @@ func InternalId(sender []byte, requestId []byte) RequestID { return RequestID(crypto.Keccak256Hash(append(sender, requestId...)).Bytes()) } -func NewFunctionsConnectorHandler(nodeAddress string, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, listener FunctionsListener, offchainTransmitter OffchainTransmitter, minimumBalance assets.Link, lggr logger.Logger) (*functionsConnectorHandler, error) { +func NewFunctionsConnectorHandler(pluginConfig *config.PluginConfig, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, listener FunctionsListener, offchainTransmitter OffchainTransmitter, lggr logger.Logger) (*functionsConnectorHandler, error) { if signerKey == nil || storage == nil || allowlist == nil || rateLimiter == nil || subscriptions == nil || listener == nil || offchainTransmitter == nil { return nil, fmt.Errorf("all dependencies must be non-nil") } + allowedHeartbeatInitiators := make(map[string]struct{}) + for _, initiator := range pluginConfig.AllowedHeartbeatInitiators { + allowedHeartbeatInitiators[strings.ToLower(initiator)] = struct{}{} + } return &functionsConnectorHandler{ - nodeAddress: nodeAddress, - signerKey: signerKey, - storage: storage, - allowlist: allowlist, - rateLimiter: rateLimiter, - subscriptions: subscriptions, - minimumBalance: minimumBalance, - listener: listener, - offchainTransmitter: offchainTransmitter, - heartbeatRequests: make(map[RequestID]*HeartbeatResponse), - chStop: make(services.StopChan), - lggr: lggr.Named("FunctionsConnectorHandler"), + nodeAddress: pluginConfig.GatewayConnectorConfig.NodeAddress, + signerKey: signerKey, + storage: storage, + allowlist: allowlist, + rateLimiter: rateLimiter, + subscriptions: subscriptions, + minimumBalance: pluginConfig.MinimumSubscriptionBalance, + listener: listener, + offchainTransmitter: offchainTransmitter, + allowedHeartbeatInitiators: allowedHeartbeatInitiators, + heartbeatRequests: make(map[RequestID]*HeartbeatResponse), + requestTimeoutSec: pluginConfig.RequestTimeoutSec, + chStop: make(services.StopChan), + lggr: lggr.Named("FunctionsConnectorHandler"), }, nil } @@ -211,6 +218,10 @@ func (h *functionsConnectorHandler) handleHeartbeat(ctx context.Context, gateway h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse(fmt.Sprintf("failed to unmarshal request: %v", err))) return } + if _, ok := h.allowedHeartbeatInitiators[requestBody.Sender]; !ok { + h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("sender not allowed to send heartbeat requests")) + return + } if !bytes.Equal(request.RequestInitiator, fromAddr.Bytes()) { h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("RequestInitiator doesn't match sender")) return @@ -219,6 +230,10 @@ func (h *functionsConnectorHandler) handleHeartbeat(ctx context.Context, gateway h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("SubscriptionOwner doesn't match sender")) return } + if request.Timestamp < uint64(time.Now().Unix())-uint64(h.requestTimeoutSec) { + h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("Request is too old")) + return + } internalId := InternalId(fromAddr.Bytes(), request.RequestId) request.RequestId = internalId[:] @@ -250,7 +265,7 @@ func internalErrorResponse(internalError string) HeartbeatResponse { func (h *functionsConnectorHandler) handleOffchainRequest(request *OffchainRequest) { defer h.shutdownWaitGroup.Done() stopCtx, _ := h.chStop.NewCtx() - ctx, cancel := context.WithTimeout(stopCtx, time.Duration(HeartbeatRequestTimeoutSec)*time.Second) + ctx, cancel := context.WithTimeout(stopCtx, time.Duration(h.requestTimeoutSec)*time.Second) defer cancel() err := h.listener.HandleOffchainRequest(ctx, request) if err != nil { diff --git a/core/services/functions/connector_handler_test.go b/core/services/functions/connector_handler_test.go index 409f9cdcc56..9a5a9042693 100644 --- a/core/services/functions/connector_handler_test.go +++ b/core/services/functions/connector_handler_test.go @@ -10,6 +10,7 @@ import ( "time" geth_common "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/onsi/gomega" "github.com/smartcontractkit/chainlink-common/pkg/assets" @@ -19,9 +20,11 @@ import ( sfmocks "github.com/smartcontractkit/chainlink/v2/core/services/functions/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" + gwconnector "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" gfmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/mocks" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" "github.com/smartcontractkit/chainlink/v2/core/services/s4" s4mocks "github.com/smartcontractkit/chainlink/v2/core/services/s4/mocks" @@ -30,7 +33,7 @@ import ( "github.com/stretchr/testify/require" ) -func newOffchainRequest(t *testing.T, sender []byte) (*api.Message, functions.RequestID) { +func newOffchainRequest(t *testing.T, sender []byte, ageSec uint64) (*api.Message, functions.RequestID) { requestId := make([]byte, 32) _, err := rand.Read(requestId) require.NoError(t, err) @@ -39,6 +42,7 @@ func newOffchainRequest(t *testing.T, sender []byte) (*api.Message, functions.Re RequestInitiator: sender, SubscriptionId: 1, SubscriptionOwner: sender, + Timestamp: uint64(time.Now().Unix()) - ageSec, } internalId := functions.InternalId(request.RequestInitiator, request.RequestId) @@ -74,7 +78,15 @@ func TestFunctionsConnectorHandler(t *testing.T) { allowlist.On("Close", mock.Anything).Return(nil) subscriptions.On("Start", mock.Anything).Return(nil) subscriptions.On("Close", mock.Anything).Return(nil) - handler, err := functions.NewFunctionsConnectorHandler(addr.Hex(), privateKey, storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(100), logger) + config := &config.PluginConfig{ + GatewayConnectorConfig: &gwconnector.ConnectorConfig{ + NodeAddress: addr.Hex(), + }, + MinimumSubscriptionBalance: *assets.NewLinkFromJuels(100), + RequestTimeoutSec: 1_000, + AllowedHeartbeatInitiators: []string{crypto.PubkeyToAddress(privateKey.PublicKey).Hex()}, + } + handler, err := functions.NewFunctionsConnectorHandler(config, privateKey, storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger) require.NoError(t, err) handler.SetConnector(connector) @@ -257,7 +269,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { t.Run("heartbeat success", func(t *testing.T) { ctx := testutils.Context(t) - msg, internalId := newOffchainRequest(t, addr.Bytes()) + msg, internalId := newOffchainRequest(t, addr.Bytes(), 0) require.NoError(t, msg.Sign(privateKey)) // first call to trigger the request @@ -292,7 +304,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { t.Run("heartbeat internal error", func(t *testing.T) { ctx := testutils.Context(t) - msg, _ := newOffchainRequest(t, addr.Bytes()) + msg, _ := newOffchainRequest(t, addr.Bytes(), 0) require.NoError(t, msg.Sign(privateKey)) // first call to trigger the request @@ -319,7 +331,23 @@ func TestFunctionsConnectorHandler(t *testing.T) { t.Run("heartbeat sender address doesn't match", func(t *testing.T) { ctx := testutils.Context(t) - msg, _ := newOffchainRequest(t, geth_common.BytesToAddress([]byte("0x1234")).Bytes()) + msg, _ := newOffchainRequest(t, geth_common.BytesToAddress([]byte("0x1234")).Bytes(), 0) + require.NoError(t, msg.Sign(privateKey)) + + var response functions.HeartbeatResponse + allowlist.On("Allow", addr).Return(true).Once() + connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) { + respMsg, ok := args[2].(*api.Message) + require.True(t, ok) + require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response)) + require.Equal(t, functions.RequestStateInternalError, response.Status) + }).Return(nil).Once() + handler.HandleGatewayMessage(ctx, "gw1", msg) + }) + + t.Run("heartbeat request too old", func(t *testing.T) { + ctx := testutils.Context(t) + msg, _ := newOffchainRequest(t, addr.Bytes(), 10_000) require.NoError(t, msg.Sign(privateKey)) var response functions.HeartbeatResponse diff --git a/core/services/gateway/api/message.go b/core/services/gateway/api/message.go index c01d3bb9f2e..5e6c8e49247 100644 --- a/core/services/gateway/api/message.go +++ b/core/services/gateway/api/message.go @@ -20,6 +20,7 @@ const ( MessageMethodMaxLen = 64 MessageDonIdMaxLen = 64 MessageReceiverLen = 2 + 2*20 + NullChar = "\x00" ) /* @@ -56,12 +57,21 @@ func (m *Message) Validate() error { if len(m.Body.MessageId) == 0 || len(m.Body.MessageId) > MessageIdMaxLen { return errors.New("invalid message ID length") } + if strings.HasSuffix(m.Body.MessageId, NullChar) { + return errors.New("message ID ending with null bytes") + } if len(m.Body.Method) == 0 || len(m.Body.Method) > MessageMethodMaxLen { return errors.New("invalid method name length") } + if strings.HasSuffix(m.Body.Method, NullChar) { + return errors.New("method name ending with null bytes") + } if len(m.Body.DonId) == 0 || len(m.Body.DonId) > MessageDonIdMaxLen { return errors.New("invalid DON ID length") } + if strings.HasSuffix(m.Body.DonId, NullChar) { + return errors.New("DON ID ending with null bytes") + } if len(m.Body.Receiver) != 0 && len(m.Body.Receiver) != MessageReceiverLen { return errors.New("invalid Receiver length") } diff --git a/core/services/gateway/api/message_test.go b/core/services/gateway/api/message_test.go index a0835ea24bb..1f292db26b9 100644 --- a/core/services/gateway/api/message_test.go +++ b/core/services/gateway/api/message_test.go @@ -31,22 +31,38 @@ func TestMessage_Validate(t *testing.T) { // missing message ID msg.Body.MessageId = "" require.Error(t, msg.Validate()) + // message ID ending with null bytes + msg.Body.MessageId = "myid\x00\x00" + require.Error(t, msg.Validate()) msg.Body.MessageId = "abcd" + require.NoError(t, msg.Validate()) // missing DON ID msg.Body.DonId = "" require.Error(t, msg.Validate()) + // DON ID ending with null bytes + msg.Body.DonId = "mydon\x00\x00" + require.Error(t, msg.Validate()) msg.Body.DonId = "donA" + require.NoError(t, msg.Validate()) - // method too long + // method name too long msg.Body.Method = string(bytes.Repeat([]byte("a"), api.MessageMethodMaxLen+1)) require.Error(t, msg.Validate()) + // empty method name + msg.Body.Method = "" + require.Error(t, msg.Validate()) + // method name ending with null bytes + msg.Body.Method = "method\x00" + require.Error(t, msg.Validate()) msg.Body.Method = "request" + require.NoError(t, msg.Validate()) // incorrect receiver msg.Body.Receiver = "blah" require.Error(t, msg.Validate()) msg.Body.Receiver = "0x0000000000000000000000000000000000000000" + require.NoError(t, msg.Validate()) // invalid signature msg.Signature = "0x00" diff --git a/core/services/gateway/connectionmanager.go b/core/services/gateway/connectionmanager.go index 9f88b51e7b5..e5f7fb13afb 100644 --- a/core/services/gateway/connectionmanager.go +++ b/core/services/gateway/connectionmanager.go @@ -287,6 +287,10 @@ func (m *donConnectionManager) readLoop(nodeAddress string, nodeState *nodeState m.lggr.Errorw("message validation error when reading from node", "nodeAddress", nodeAddress, "err", err) break } + if msg.Body.Sender != nodeAddress { + m.lggr.Errorw("message sender mismatch when reading from node", "nodeAddress", nodeAddress, "sender", msg.Body.Sender) + break + } err = m.handler.HandleNodeMessage(ctx, msg, nodeAddress) if err != nil { m.lggr.Error("error when calling HandleNodeMessage ", err) diff --git a/core/services/gateway/integration_tests/gateway_integration_test.go b/core/services/gateway/integration_tests/gateway_integration_test.go index 9e4900efeee..a2064b7a591 100644 --- a/core/services/gateway/integration_tests/gateway_integration_test.go +++ b/core/services/gateway/integration_tests/gateway_integration_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ecdsa" "fmt" + "io" "net/http" "strings" "sync/atomic" @@ -36,18 +37,18 @@ Path = "/node" Port = 0 HandshakeTimeoutMillis = 2_000 MaxRequestBytes = 20_000 -ReadTimeoutMillis = 100 -RequestTimeoutMillis = 100 -WriteTimeoutMillis = 100 +ReadTimeoutMillis = 1000 +RequestTimeoutMillis = 1000 +WriteTimeoutMillis = 1000 [UserServerConfig] Path = "/user" Port = 0 ContentTypeHeader = "application/jsonrpc" MaxRequestBytes = 20_000 -ReadTimeoutMillis = 100 -RequestTimeoutMillis = 100 -WriteTimeoutMillis = 100 +ReadTimeoutMillis = 1000 +RequestTimeoutMillis = 1000 +WriteTimeoutMillis = 1000 [[Dons]] DonId = "test_don" @@ -72,6 +73,13 @@ Id = "test_gateway" URL = "%s" ` +const ( + messageId1 = "123" + messageId2 = "456" + + nodeResponsePayload = `{"response":"correct response"}` +) + func parseGatewayConfig(t *testing.T, tomlConfig string) *config.GatewayConfig { var cfg config.GatewayConfig err := toml.Unmarshal([]byte(tomlConfig), &cfg) @@ -94,6 +102,21 @@ type client struct { func (c *client) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) { c.done.Store(true) + // send back user's message without re-signing - should be ignored by the Gateway + _ = c.connector.SendToGateway(ctx, gatewayId, msg) + // send back a correct response + responseMsg := &api.Message{Body: api.MessageBody{ + MessageId: msg.Body.MessageId, + Method: "test", + DonId: "test_don", + Receiver: msg.Body.Sender, + Payload: []byte(nodeResponsePayload), + }} + err := responseMsg.Sign(c.privateKey) + if err != nil { + panic(err) + } + _ = c.connector.SendToGateway(ctx, gatewayId, responseMsg) } func (c *client) Sign(data ...[]byte) ([]byte, error) { @@ -111,7 +134,9 @@ func (*client) Close() error { func TestIntegration_Gateway_NoFullNodes_BasicConnectionAndMessage(t *testing.T) { t.Parallel() - nodeKeys := common.NewTestNodes(t, 1)[0] + testWallets := common.NewTestNodes(t, 2) + nodeKeys := testWallets[0] + userKeys := testWallets[1] // Verify that addresses in config are case-insensitive nodeKeys.Address = strings.ToUpper(nodeKeys.Address) @@ -132,17 +157,39 @@ func TestIntegration_Gateway_NoFullNodes_BasicConnectionAndMessage(t *testing.T) client.connector = connector servicetest.Run(t, connector) - // Send requests until one of them reaches Connector + // Send requests until one of them reaches Connector (i.e. the node) gomega.NewGomegaWithT(t).Eventually(func() bool { - msg := &api.Message{Body: api.MessageBody{MessageId: "123", Method: "test", DonId: "test_don"}} - require.NoError(t, msg.Sign(nodeKeys.PrivateKey)) - codec := api.JsonRPCCodec{} - rawMsg, err := codec.EncodeRequest(msg) - require.NoError(t, err) - req, err := http.NewRequestWithContext(testutils.Context(t), "POST", userUrl, bytes.NewBuffer(rawMsg)) - require.NoError(t, err) + req := newHttpRequestObject(t, messageId1, userUrl, userKeys.PrivateKey) httpClient := &http.Client{} _, _ = httpClient.Do(req) // could initially return error if Gateway is not fully initialized yet return client.done.Load() }, testutils.WaitTimeout(t), testutils.TestInterval).Should(gomega.Equal(true)) + + // Send another request and validate that response has correct content and sender + req := newHttpRequestObject(t, messageId2, userUrl, userKeys.PrivateKey) + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + rawResp, err := io.ReadAll(resp.Body) + require.NoError(t, err) + codec := api.JsonRPCCodec{} + respMsg, err := codec.DecodeResponse(rawResp) + require.NoError(t, err) + require.NoError(t, respMsg.Validate()) + require.Equal(t, strings.ToLower(nodeKeys.Address), respMsg.Body.Sender) + require.Equal(t, messageId2, respMsg.Body.MessageId) + require.Equal(t, nodeResponsePayload, string(respMsg.Body.Payload)) +} + +func newHttpRequestObject(t *testing.T, messageId string, userUrl string, signerKey *ecdsa.PrivateKey) *http.Request { + msg := &api.Message{Body: api.MessageBody{MessageId: messageId, Method: "test", DonId: "test_don"}} + require.NoError(t, msg.Sign(signerKey)) + codec := api.JsonRPCCodec{} + rawMsg, err := codec.EncodeRequest(msg) + require.NoError(t, err) + req, err := http.NewRequestWithContext(testutils.Context(t), "POST", userUrl, bytes.NewBuffer(rawMsg)) + require.NoError(t, err) + return req } diff --git a/core/services/ocr2/plugins/functions/config/config.go b/core/services/ocr2/plugins/functions/config/config.go index 13e02042506..38af7b8587f 100644 --- a/core/services/ocr2/plugins/functions/config/config.go +++ b/core/services/ocr2/plugins/functions/config/config.go @@ -41,6 +41,7 @@ type PluginConfig struct { MaxRequestSizesList []uint32 `json:"maxRequestSizesList"` MaxSecretsSizesList []uint32 `json:"maxSecretsSizesList"` MinimumSubscriptionBalance assets.Link `json:"minimumSubscriptionBalance"` + AllowedHeartbeatInitiators []string `json:"allowedHeartbeatInitiators"` GatewayConnectorConfig *connector.ConnectorConfig `json:"gatewayConnectorConfig"` OnchainAllowlist *functions.OnchainAllowlistConfig `json:"onchainAllowlist"` OnchainSubscriptions *functions.OnchainSubscriptionsConfig `json:"onchainSubscriptions"` diff --git a/core/services/ocr2/plugins/functions/plugin.go b/core/services/ocr2/plugins/functions/plugin.go index c6cfa946aba..fd72b6fd38e 100644 --- a/core/services/ocr2/plugins/functions/plugin.go +++ b/core/services/ocr2/plugins/functions/plugin.go @@ -13,7 +13,6 @@ import ( "github.com/smartcontractkit/libocr/commontypes" libocr2 "github.com/smartcontractkit/libocr/offchainreporting2plus" - "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/core/bridges" @@ -152,7 +151,7 @@ func NewFunctionsServices(functionsOracleArgs, thresholdOracleArgs, s4OracleArgs return nil, errors.Wrap(err, "failed to create a OnchainSubscriptions") } connectorLogger := conf.Logger.Named("GatewayConnector").With("jobName", conf.Job.PipelineSpec.JobName) - connector, err2 := NewConnector(pluginConfig.GatewayConnectorConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, pluginConfig.MinimumSubscriptionBalance, connectorLogger) + connector, err2 := NewConnector(&pluginConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, connectorLogger) if err2 != nil { return nil, errors.Wrap(err, "failed to create a GatewayConnector") } @@ -179,24 +178,26 @@ func NewFunctionsServices(functionsOracleArgs, thresholdOracleArgs, s4OracleArgs return allServices, nil } -func NewConnector(gwcCfg *connector.ConnectorConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwFunctions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwFunctions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, minimumBalance assets.Link, lggr logger.Logger) (connector.GatewayConnector, error) { +func NewConnector(pluginConfig *config.PluginConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwFunctions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwFunctions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, lggr logger.Logger) (connector.GatewayConnector, error) { enabledKeys, err := ethKeystore.EnabledKeysForChain(chainID) if err != nil { return nil, err } - configuredNodeAddress := common.HexToAddress(gwcCfg.NodeAddress) + configuredNodeAddress := common.HexToAddress(pluginConfig.GatewayConnectorConfig.NodeAddress) idx := slices.IndexFunc(enabledKeys, func(key ethkey.KeyV2) bool { return key.Address == configuredNodeAddress }) if idx == -1 { return nil, errors.New("key for configured node address not found") } signerKey := enabledKeys[idx].ToEcdsaPrivKey() - nodeAddress := enabledKeys[idx].ID() + if enabledKeys[idx].ID() != pluginConfig.GatewayConnectorConfig.NodeAddress { + return nil, errors.New("node address mismatch") + } - handler, err := functions.NewFunctionsConnectorHandler(nodeAddress, signerKey, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, minimumBalance, lggr) + handler, err := functions.NewFunctionsConnectorHandler(pluginConfig, signerKey, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, lggr) if err != nil { return nil, err } - connector, err := connector.NewGatewayConnector(gwcCfg, handler, handler, utils.NewRealClock(), lggr) + connector, err := connector.NewGatewayConnector(pluginConfig.GatewayConnectorConfig, handler, handler, utils.NewRealClock(), lggr) if err != nil { return nil, err } diff --git a/core/services/ocr2/plugins/functions/plugin_test.go b/core/services/ocr2/plugins/functions/plugin_test.go index d77fabcc437..6d3f57b086c 100644 --- a/core/services/ocr2/plugins/functions/plugin_test.go +++ b/core/services/ocr2/plugins/functions/plugin_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink/v2/core/logger" sfmocks "github.com/smartcontractkit/chainlink/v2/core/services/functions/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" @@ -17,6 +16,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" ksmocks "github.com/smartcontractkit/chainlink/v2/core/services/keystore/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" s4mocks "github.com/smartcontractkit/chainlink/v2/core/services/s4/mocks" ) @@ -39,7 +39,10 @@ func TestNewConnector_Success(t *testing.T) { listener := sfmocks.NewFunctionsListener(t) offchainTransmitter := sfmocks.NewOffchainTransmitter(t) ethKeystore.On("EnabledKeysForChain", mock.Anything).Return([]ethkey.KeyV2{keyV2}, nil) - _, err = functions.NewConnector(gwcCfg, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(0), logger.TestLogger(t)) + config := &config.PluginConfig{ + GatewayConnectorConfig: gwcCfg, + } + _, err = functions.NewConnector(config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t)) require.NoError(t, err) } @@ -64,6 +67,9 @@ func TestNewConnector_NoKeyForConfiguredAddress(t *testing.T) { listener := sfmocks.NewFunctionsListener(t) offchainTransmitter := sfmocks.NewOffchainTransmitter(t) ethKeystore.On("EnabledKeysForChain", mock.Anything).Return([]ethkey.KeyV2{{Address: common.HexToAddress(addresses[1])}}, nil) - _, err = functions.NewConnector(gwcCfg, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(0), logger.TestLogger(t)) + config := &config.PluginConfig{ + GatewayConnectorConfig: gwcCfg, + } + _, err = functions.NewConnector(config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t)) require.Error(t, err) }