From ead3beb91de7ad2922cb63f9b31132c85264a64d Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Mon, 27 Nov 2023 13:37:43 -0800 Subject: [PATCH] [Functions] Heartbeat request support in Gateway handlers (#11345) 1. Functions Handler - add a new method "heartbeat" - add a configurable list of allowed heartbeat senders - collect results from first F+1 nodes and send back in raw form 2. Connector Handler - asynchronously forward requests to Listener and cache results - run a loop to collect OCR reports from Offchain Transmitter 3. Listener - add Timestampi field and validate it --- core/services/functions/connector_handler.go | 182 ++++++++++++++++-- .../functions/connector_handler_test.go | 109 ++++++++++- core/services/functions/listener.go | 3 + core/services/functions/listener_test.go | 8 + core/services/functions/request.go | 17 +- .../gateway/handlers/functions/api.go | 1 + .../handlers/functions/handler.functions.go | 104 +++++++--- .../functions/handler.functions_test.go | 54 +++++- .../services/ocr2/plugins/functions/plugin.go | 6 +- .../ocr2/plugins/functions/plugin_test.go | 9 +- 10 files changed, 432 insertions(+), 61 deletions(-) diff --git a/core/services/functions/connector_handler.go b/core/services/functions/connector_handler.go index 76608b8ada3..5496bbdefc1 100644 --- a/core/services/functions/connector_handler.go +++ b/core/services/functions/connector_handler.go @@ -1,14 +1,18 @@ package functions import ( + "bytes" "context" "crypto/ecdsa" "encoding/json" "fmt" + "sync" + "time" "go.uber.org/multierr" ethCommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -25,35 +29,56 @@ import ( type functionsConnectorHandler struct { services.StateMachine - connector connector.GatewayConnector - signerKey *ecdsa.PrivateKey - nodeAddress string - storage s4.Storage - allowlist functions.OnchainAllowlist - rateLimiter *hc.RateLimiter - subscriptions functions.OnchainSubscriptions - minimumBalance assets.Link - lggr logger.Logger + connector connector.GatewayConnector + signerKey *ecdsa.PrivateKey + nodeAddress string + storage s4.Storage + allowlist functions.OnchainAllowlist + rateLimiter *hc.RateLimiter + subscriptions functions.OnchainSubscriptions + minimumBalance assets.Link + listener FunctionsListener + offchainTransmitter OffchainTransmitter + heartbeatRequests map[RequestID]*HeartbeatResponse + orderedRequests []RequestID + mu sync.Mutex + chStop services.StopChan + shutdownWaitGroup sync.WaitGroup + lggr logger.Logger } +const ( + HeartbeatRequestTimeoutSec = 240 + HeartbeatCacheSize = 1000 +) + var ( _ connector.Signer = &functionsConnectorHandler{} _ connector.GatewayConnectorHandler = &functionsConnectorHandler{} ) -func NewFunctionsConnectorHandler(nodeAddress string, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, minimumBalance assets.Link, lggr logger.Logger) (*functionsConnectorHandler, error) { - if signerKey == nil || storage == nil || allowlist == nil || rateLimiter == nil || subscriptions == nil { - return nil, fmt.Errorf("signerKey, storage, allowlist, rateLimiter and subscriptions must be non-nil") +// internal request ID is a hash of (sender, requestID) +func InternalId(sender []byte, requestId []byte) RequestID { + return RequestID(crypto.Keccak256Hash(append(sender, requestId...)).Bytes()) +} + +func NewFunctionsConnectorHandler(nodeAddress string, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, listener FunctionsListener, offchainTransmitter OffchainTransmitter, minimumBalance assets.Link, lggr logger.Logger) (*functionsConnectorHandler, error) { + if signerKey == nil || storage == nil || allowlist == nil || rateLimiter == nil || subscriptions == nil || listener == nil || offchainTransmitter == nil { + return nil, fmt.Errorf("all dependencies must be non-nil") } return &functionsConnectorHandler{ - nodeAddress: nodeAddress, - signerKey: signerKey, - storage: storage, - allowlist: allowlist, - rateLimiter: rateLimiter, - subscriptions: subscriptions, - minimumBalance: minimumBalance, - lggr: lggr.Named("FunctionsConnectorHandler"), + nodeAddress: nodeAddress, + signerKey: signerKey, + storage: storage, + allowlist: allowlist, + rateLimiter: rateLimiter, + subscriptions: subscriptions, + minimumBalance: minimumBalance, + listener: listener, + offchainTransmitter: offchainTransmitter, + heartbeatRequests: make(map[RequestID]*HeartbeatResponse), + chStop: make(services.StopChan), + lggr: lggr.Named("FunctionsConnectorHandler"), }, nil } @@ -92,6 +117,8 @@ func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, ga return } h.handleSecretsSet(ctx, gatewayId, body, fromAddr) + case functions.MethodHeartbeat: + h.handleHeartbeat(ctx, gatewayId, body, fromAddr) default: h.lggr.Errorw("unsupported method", "id", gatewayId, "method", body.Method) } @@ -102,14 +129,21 @@ func (h *functionsConnectorHandler) Start(ctx context.Context) error { if err := h.allowlist.Start(ctx); err != nil { return err } - return h.subscriptions.Start(ctx) + if err := h.subscriptions.Start(ctx); err != nil { + return err + } + h.shutdownWaitGroup.Add(1) + go h.reportLoop() + return nil }) } func (h *functionsConnectorHandler) Close() error { return h.StopOnce("FunctionsConnectorHandler", func() (err error) { + close(h.chStop) err = multierr.Combine(err, h.allowlist.Close()) err = multierr.Combine(err, h.subscriptions.Close()) + h.shutdownWaitGroup.Wait() return }) } @@ -160,6 +194,112 @@ func (h *functionsConnectorHandler) handleSecretsSet(ctx context.Context, gatewa h.sendResponseAndLog(ctx, gatewayId, body, response) } +func (h *functionsConnectorHandler) handleHeartbeat(ctx context.Context, gatewayId string, requestBody *api.MessageBody, fromAddr ethCommon.Address) { + var request *OffchainRequest + err := json.Unmarshal(requestBody.Payload, &request) + if err != nil { + h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse(fmt.Sprintf("failed to unmarshal request: %v", err))) + return + } + if !bytes.Equal(request.RequestInitiator, fromAddr.Bytes()) { + h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("RequestInitiator doesn't match sender")) + return + } + if !bytes.Equal(request.SubscriptionOwner, fromAddr.Bytes()) { + h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("SubscriptionOwner doesn't match sender")) + return + } + + internalId := InternalId(fromAddr.Bytes(), request.RequestId) + request.RequestId = internalId[:] + h.lggr.Infow("handling offchain heartbeat", "messageId", requestBody.MessageId, "internalId", internalId, "sender", requestBody.Sender) + h.mu.Lock() + response, ok := h.heartbeatRequests[internalId] + if !ok { // new request + response = &HeartbeatResponse{ + Status: RequestStatePending, + ReceivedTs: uint64(time.Now().Unix()), + } + h.cacheNewRequestLocked(internalId, response) + h.shutdownWaitGroup.Add(1) + go h.handleOffchainRequest(request) + } + responseToSend := *response + h.mu.Unlock() + requestBody.Receiver = requestBody.Sender + h.sendResponseAndLog(ctx, gatewayId, requestBody, responseToSend) +} + +func internalErrorResponse(internalError string) HeartbeatResponse { + return HeartbeatResponse{ + Status: RequestStateInternalError, + InternalError: internalError, + } +} + +func (h *functionsConnectorHandler) handleOffchainRequest(request *OffchainRequest) { + defer h.shutdownWaitGroup.Done() + stopCtx, _ := h.chStop.NewCtx() + ctx, cancel := context.WithTimeout(stopCtx, time.Duration(HeartbeatRequestTimeoutSec)*time.Second) + defer cancel() + err := h.listener.HandleOffchainRequest(ctx, request) + if err != nil { + h.lggr.Errorw("internal error while processing", "id", request.RequestId, "error", err) + h.mu.Lock() + defer h.mu.Unlock() + state, ok := h.heartbeatRequests[RequestID(request.RequestId)] + if !ok { + h.lggr.Errorw("request unexpectedly disappeared from local cache", "id", request.RequestId) + return + } + state.CompletedTs = uint64(time.Now().Unix()) + state.Status = RequestStateInternalError + state.InternalError = err.Error() + } else { + // no error - results will be sent to OCR aggregation and returned via reportLoop() + h.lggr.Infow("request processed successfully, waiting for aggregation ...", "id", request.RequestId) + } +} + +// Listen to OCR reports passed from the plugin and process them against a local cache of requests. +func (h *functionsConnectorHandler) reportLoop() { + defer h.shutdownWaitGroup.Done() + for { + select { + case report := <-h.offchainTransmitter.ReportChannel(): + h.lggr.Infow("received report", "requestId", report.RequestId, "resultLen", len(report.Result), "errorLen", len(report.Error)) + if len(report.RequestId) != RequestIDLength { + h.lggr.Errorw("report has invalid requestId", "requestId", report.RequestId) + continue + } + h.mu.Lock() + cachedResponse, ok := h.heartbeatRequests[RequestID(report.RequestId)] + if !ok { + h.lggr.Infow("received report for unknown request, caching it", "id", report.RequestId) + cachedResponse = &HeartbeatResponse{} + h.cacheNewRequestLocked(RequestID(report.RequestId), cachedResponse) + } + cachedResponse.CompletedTs = uint64(time.Now().Unix()) + cachedResponse.Status = RequestStateComplete + cachedResponse.Response = report + h.mu.Unlock() + case <-h.chStop: + h.lggr.Info("exiting reportLoop") + return + } + } +} + +func (h *functionsConnectorHandler) cacheNewRequestLocked(requestId RequestID, response *HeartbeatResponse) { + // remove oldest requests + for len(h.orderedRequests) >= HeartbeatCacheSize { + delete(h.heartbeatRequests, h.orderedRequests[0]) + h.orderedRequests = h.orderedRequests[1:] + } + h.heartbeatRequests[requestId] = response + h.orderedRequests = append(h.orderedRequests, requestId) +} + func (h *functionsConnectorHandler) sendResponseAndLog(ctx context.Context, gatewayId string, requestBody *api.MessageBody, payload any) { err := h.sendResponse(ctx, gatewayId, requestBody, payload) if err != nil { diff --git a/core/services/functions/connector_handler_test.go b/core/services/functions/connector_handler_test.go index 82c3dab3afc..fe1a1baa6fc 100644 --- a/core/services/functions/connector_handler_test.go +++ b/core/services/functions/connector_handler_test.go @@ -1,16 +1,20 @@ package functions_test import ( + "crypto/rand" "encoding/base64" "encoding/json" "errors" "math/big" "testing" + geth_common "github.com/ethereum/go-ethereum/common" + "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/functions" + sfmocks "github.com/smartcontractkit/chainlink/v2/core/services/functions/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" @@ -24,6 +28,31 @@ import ( "github.com/stretchr/testify/require" ) +func newOffchainRequest(t *testing.T, sender []byte) (*api.Message, functions.RequestID) { + requestId := make([]byte, 32) + _, err := rand.Read(requestId) + require.NoError(t, err) + request := &functions.OffchainRequest{ + RequestId: requestId, + RequestInitiator: sender, + SubscriptionId: 1, + SubscriptionOwner: sender, + } + + internalId := functions.InternalId(request.RequestInitiator, request.RequestId) + req, err := json.Marshal(request) + require.NoError(t, err) + msg := &api.Message{ + Body: api.MessageBody{ + DonId: "fun4", + MessageId: "1", + Method: "heartbeat", + Payload: req, + }, + } + return msg, internalId +} + func TestFunctionsConnectorHandler(t *testing.T) { t.Parallel() @@ -34,12 +63,16 @@ func TestFunctionsConnectorHandler(t *testing.T) { allowlist := gfmocks.NewOnchainAllowlist(t) rateLimiter, err := hc.NewRateLimiter(hc.RateLimiterConfig{GlobalRPS: 100.0, GlobalBurst: 100, PerSenderRPS: 100.0, PerSenderBurst: 100}) subscriptions := gfmocks.NewOnchainSubscriptions(t) + reportCh := make(chan *functions.OffchainResponse) + offchainTransmitter := sfmocks.NewOffchainTransmitter(t) + offchainTransmitter.On("ReportChannel", mock.Anything).Return(reportCh) + listener := sfmocks.NewFunctionsListener(t) require.NoError(t, err) allowlist.On("Start", mock.Anything).Return(nil) allowlist.On("Close", mock.Anything).Return(nil) subscriptions.On("Start", mock.Anything).Return(nil) subscriptions.On("Close", mock.Anything).Return(nil) - handler, err := functions.NewFunctionsConnectorHandler(addr.Hex(), privateKey, storage, allowlist, rateLimiter, subscriptions, *assets.NewLinkFromJuels(100), logger) + handler, err := functions.NewFunctionsConnectorHandler(addr.Hex(), privateKey, storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(100), logger) require.NoError(t, err) handler.SetConnector(connector) @@ -219,4 +252,78 @@ func TestFunctionsConnectorHandler(t *testing.T) { handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg) }) }) + + t.Run("heartbeat success", func(t *testing.T) { + ctx := testutils.Context(t) + msg, internalId := newOffchainRequest(t, addr.Bytes()) + require.NoError(t, msg.Sign(privateKey)) + + // first call to trigger the request + var response functions.HeartbeatResponse + allowlist.On("Allow", addr).Return(true).Once() + listener.On("HandleOffchainRequest", mock.Anything, mock.Anything).Return(nil).Once() + connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) { + respMsg, ok := args[2].(*api.Message) + require.True(t, ok) + require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response)) + require.Equal(t, functions.RequestStatePending, response.Status) + }).Return(nil).Once() + handler.HandleGatewayMessage(ctx, "gw1", msg) + + // async response computation + reportCh <- &functions.OffchainResponse{ + RequestId: internalId[:], + Result: []byte("ok!"), + } + reportCh <- &functions.OffchainResponse{} // sending second item to make sure the first one got processed + + // second call to collect the response + allowlist.On("Allow", addr).Return(true).Once() + connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) { + respMsg, ok := args[2].(*api.Message) + require.True(t, ok) + require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response)) + require.Equal(t, functions.RequestStateComplete, response.Status) + }).Return(nil).Once() + handler.HandleGatewayMessage(ctx, "gw1", msg) + }) + + t.Run("heartbeat internal error", func(t *testing.T) { + ctx := testutils.Context(t) + msg, _ := newOffchainRequest(t, addr.Bytes()) + require.NoError(t, msg.Sign(privateKey)) + + // first call to trigger the request + var response functions.HeartbeatResponse + allowlist.On("Allow", addr).Return(true).Once() + listener.On("HandleOffchainRequest", mock.Anything, mock.Anything).Return(errors.New("boom")).Once() + connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Return(nil).Once() + handler.HandleGatewayMessage(ctx, "gw1", msg) + + // second call to collect the response + allowlist.On("Allow", addr).Return(true).Once() + connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) { + respMsg, ok := args[2].(*api.Message) + require.True(t, ok) + require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response)) + require.Equal(t, functions.RequestStateInternalError, response.Status) + }).Return(nil).Once() + handler.HandleGatewayMessage(ctx, "gw1", msg) + }) + + t.Run("heartbeat sender address doesn't match", func(t *testing.T) { + ctx := testutils.Context(t) + msg, _ := newOffchainRequest(t, geth_common.BytesToAddress([]byte("0x1234")).Bytes()) + require.NoError(t, msg.Sign(privateKey)) + + var response functions.HeartbeatResponse + allowlist.On("Allow", addr).Return(true).Once() + connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) { + respMsg, ok := args[2].(*api.Message) + require.True(t, ok) + require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response)) + require.Equal(t, functions.RequestStateInternalError, response.Status) + }).Return(nil).Once() + handler.HandleGatewayMessage(ctx, "gw1", msg) + }) } diff --git a/core/services/functions/listener.go b/core/services/functions/listener.go index 3a308431807..65c364adb7c 100644 --- a/core/services/functions/listener.go +++ b/core/services/functions/listener.go @@ -300,6 +300,9 @@ func (l *functionsListener) HandleOffchainRequest(ctx context.Context, request * if len(request.SubscriptionOwner) != common.AddressLength || len(request.RequestInitiator) != common.AddressLength { return fmt.Errorf("HandleOffchainRequest: SubscriptionOwner and RequestInitiator must be set to valid addresses") } + if request.Timestamp < uint64(time.Now().Unix()-int64(l.pluginConfig.RequestTimeoutSec)) { + return fmt.Errorf("HandleOffchainRequest: request timestamp is too old") + } var requestId RequestID copy(requestId[:], request.RequestId[:32]) diff --git a/core/services/functions/listener_test.go b/core/services/functions/listener_test.go index ecad9e4cceb..0fcc9c65599 100644 --- a/core/services/functions/listener_test.go +++ b/core/services/functions/listener_test.go @@ -7,6 +7,7 @@ import ( "math/big" "sync" "testing" + "time" "github.com/ethereum/go-ethereum/common" "github.com/fxamacker/cbor/v2" @@ -195,6 +196,7 @@ func TestFunctionsListener_HandleOffchainRequest_Success(t *testing.T) { RequestInitiator: SubscriptionOwner.Bytes(), SubscriptionId: uint64(SubscriptionID), SubscriptionOwner: SubscriptionOwner.Bytes(), + Timestamp: uint64(time.Now().Unix()), Data: functions_service.RequestData{}, } require.NoError(t, uni.service.HandleOffchainRequest(testutils.Context(t), request)) @@ -210,6 +212,7 @@ func TestFunctionsListener_HandleOffchainRequest_Invalid(t *testing.T) { RequestInitiator: []byte("invalid_address"), SubscriptionId: uint64(SubscriptionID), SubscriptionOwner: SubscriptionOwner.Bytes(), + Timestamp: uint64(time.Now().Unix()), Data: functions_service.RequestData{}, } require.Error(t, uni.service.HandleOffchainRequest(testutils.Context(t), request)) @@ -217,6 +220,10 @@ func TestFunctionsListener_HandleOffchainRequest_Invalid(t *testing.T) { request.RequestInitiator = SubscriptionOwner.Bytes() request.SubscriptionOwner = []byte("invalid_address") require.Error(t, uni.service.HandleOffchainRequest(testutils.Context(t), request)) + + request.SubscriptionOwner = SubscriptionOwner.Bytes() + request.Timestamp = 1 + require.Error(t, uni.service.HandleOffchainRequest(testutils.Context(t), request)) } func TestFunctionsListener_HandleOffchainRequest_InternalError(t *testing.T) { @@ -233,6 +240,7 @@ func TestFunctionsListener_HandleOffchainRequest_InternalError(t *testing.T) { RequestInitiator: SubscriptionOwner.Bytes(), SubscriptionId: uint64(SubscriptionID), SubscriptionOwner: SubscriptionOwner.Bytes(), + Timestamp: uint64(time.Now().Unix()), Data: functions_service.RequestData{}, } require.Error(t, uni.service.HandleOffchainRequest(testutils.Context(t), request)) diff --git a/core/services/functions/request.go b/core/services/functions/request.go index 14c0b0d0e5a..eaa92fc8088 100644 --- a/core/services/functions/request.go +++ b/core/services/functions/request.go @@ -5,6 +5,10 @@ const ( LocationRemote = 1 LocationDONHosted = 2 LanguageJavaScript = 0 + + RequestStatePending = 1 + RequestStateComplete = 2 + RequestStateInternalError = 3 ) type RequestFlags [32]byte @@ -14,6 +18,7 @@ type OffchainRequest struct { RequestInitiator []byte `json:"requestInitiator"` SubscriptionId uint64 `json:"subscriptionId"` SubscriptionOwner []byte `json:"subscriptionOwner"` + Timestamp uint64 `json:"timestamp"` Data RequestData `json:"data"` } @@ -30,8 +35,16 @@ type RequestData struct { // NOTE: to be extended with raw report and signatures when needed type OffchainResponse struct { RequestId []byte `json:"requestId"` - Result []byte `json:"result"` - Error []byte `json:"error"` + Result []byte `json:"result,omitempty"` + Error []byte `json:"error,omitempty"` +} + +type HeartbeatResponse struct { + Status int `json:"status"` + InternalError string `json:"internalError,omitempty"` + ReceivedTs uint64 `json:"receivedTs"` + CompletedTs uint64 `json:"completedTs"` + Response *OffchainResponse `json:"response,omitempty"` } type DONHostedSecrets struct { diff --git a/core/services/gateway/handlers/functions/api.go b/core/services/gateway/handlers/functions/api.go index 202fa99e414..36db1943931 100644 --- a/core/services/gateway/handlers/functions/api.go +++ b/core/services/gateway/handlers/functions/api.go @@ -5,6 +5,7 @@ import "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" const ( MethodSecretsSet = "secrets_set" MethodSecretsList = "secrets_list" + MethodHeartbeat = "heartbeat" ) type SecretsSetRequest struct { diff --git a/core/services/gateway/handlers/functions/handler.functions.go b/core/services/gateway/handlers/functions/handler.functions.go index 3269caa2d6a..b52c866a131 100644 --- a/core/services/gateway/handlers/functions/handler.functions.go +++ b/core/services/gateway/handlers/functions/handler.functions.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "math/big" + "strings" "time" "github.com/ethereum/go-ethereum/common" @@ -62,26 +63,28 @@ type FunctionsHandlerConfig struct { OnchainSubscriptions *OnchainSubscriptionsConfig `json:"onchainSubscriptions"` MinimumSubscriptionBalance *assets.Link `json:"minimumSubscriptionBalance"` // Not specifying RateLimiter config disables rate limiting - UserRateLimiter *hc.RateLimiterConfig `json:"userRateLimiter"` - NodeRateLimiter *hc.RateLimiterConfig `json:"nodeRateLimiter"` - MaxPendingRequests uint32 `json:"maxPendingRequests"` - RequestTimeoutMillis int64 `json:"requestTimeoutMillis"` + UserRateLimiter *hc.RateLimiterConfig `json:"userRateLimiter"` + NodeRateLimiter *hc.RateLimiterConfig `json:"nodeRateLimiter"` + MaxPendingRequests uint32 `json:"maxPendingRequests"` + RequestTimeoutMillis int64 `json:"requestTimeoutMillis"` + AllowedHeartbeatInitiators []string `json:"allowedHeartbeatInitiators"` } type functionsHandler struct { services.StateMachine - handlerConfig FunctionsHandlerConfig - donConfig *config.DONConfig - don handlers.DON - pendingRequests hc.RequestCache[PendingRequest] - allowlist OnchainAllowlist - subscriptions OnchainSubscriptions - minimumBalance *assets.Link - userRateLimiter *hc.RateLimiter - nodeRateLimiter *hc.RateLimiter - chStop services.StopChan - lggr logger.Logger + handlerConfig FunctionsHandlerConfig + donConfig *config.DONConfig + don handlers.DON + pendingRequests hc.RequestCache[PendingRequest] + allowlist OnchainAllowlist + subscriptions OnchainSubscriptions + minimumBalance *assets.Link + userRateLimiter *hc.RateLimiter + nodeRateLimiter *hc.RateLimiter + allowedHeartbeatInitiators map[string]struct{} + chStop services.StopChan + lggr logger.Logger } type PendingRequest struct { @@ -135,8 +138,12 @@ func NewFunctionsHandlerFromConfig(handlerConfig json.RawMessage, donConfig *con return nil, err2 } } + allowedHeartbeatInitiators := make(map[string]struct{}) + for _, initiator := range cfg.AllowedHeartbeatInitiators { + allowedHeartbeatInitiators[strings.ToLower(initiator)] = struct{}{} + } pendingRequestsCache := hc.NewRequestCache[PendingRequest](time.Millisecond*time.Duration(cfg.RequestTimeoutMillis), cfg.MaxPendingRequests) - return NewFunctionsHandler(cfg, donConfig, don, pendingRequestsCache, allowlist, subscriptions, cfg.MinimumSubscriptionBalance, userRateLimiter, nodeRateLimiter, lggr), nil + return NewFunctionsHandler(cfg, donConfig, don, pendingRequestsCache, allowlist, subscriptions, cfg.MinimumSubscriptionBalance, userRateLimiter, nodeRateLimiter, allowedHeartbeatInitiators, lggr), nil } func NewFunctionsHandler( @@ -149,19 +156,21 @@ func NewFunctionsHandler( minimumBalance *assets.Link, userRateLimiter *hc.RateLimiter, nodeRateLimiter *hc.RateLimiter, + allowedHeartbeatInitiators map[string]struct{}, lggr logger.Logger) handlers.Handler { return &functionsHandler{ - handlerConfig: cfg, - donConfig: donConfig, - don: don, - pendingRequests: pendingRequestsCache, - allowlist: allowlist, - subscriptions: subscriptions, - minimumBalance: minimumBalance, - userRateLimiter: userRateLimiter, - nodeRateLimiter: nodeRateLimiter, - chStop: make(services.StopChan), - lggr: lggr, + handlerConfig: cfg, + donConfig: donConfig, + don: don, + pendingRequests: pendingRequestsCache, + allowlist: allowlist, + subscriptions: subscriptions, + minimumBalance: minimumBalance, + userRateLimiter: userRateLimiter, + nodeRateLimiter: nodeRateLimiter, + allowedHeartbeatInitiators: allowedHeartbeatInitiators, + chStop: make(services.StopChan), + lggr: lggr, } } @@ -193,6 +202,13 @@ func (h *functionsHandler) HandleUserMessage(ctx context.Context, msg *api.Messa switch msg.Body.Method { case MethodSecretsSet, MethodSecretsList: return h.handleRequest(ctx, msg, callbackCh) + case MethodHeartbeat: + if _, ok := h.allowedHeartbeatInitiators[msg.Body.Sender]; !ok { + h.lggr.Debugw("received heartbeat request from a non-allowed sender", "sender", msg.Body.Sender) + promHandlerError.WithLabelValues(h.donConfig.DonId, ErrNotAllowlisted.Error()).Inc() + return ErrUnsupportedMethod + } + return h.handleRequest(ctx, msg, callbackCh) default: h.lggr.Debugw("unsupported method", "method", msg.Body.Method) promHandlerError.WithLabelValues(h.donConfig.DonId, ErrUnsupportedMethod.Error()).Inc() @@ -227,6 +243,8 @@ func (h *functionsHandler) HandleNodeMessage(ctx context.Context, msg *api.Messa switch msg.Body.Method { case MethodSecretsSet, MethodSecretsList: return h.pendingRequests.ProcessResponse(msg, h.processSecretsResponse) + case MethodHeartbeat: + return h.pendingRequests.ProcessResponse(msg, h.processHeartbeatResponse) default: h.lggr.Debugw("unsupported method", "method", msg.Body.Method) return ErrUnsupportedMethod @@ -295,6 +313,38 @@ func newSecretsResponse(request *api.Message, success bool, responses []*api.Mes return &handlers.UserCallbackPayload{Msg: &userResponse, ErrCode: api.NoError, ErrMsg: ""}, nil } +// Conforms to ResponseProcessor[*PendingRequest] +func (h *functionsHandler) processHeartbeatResponse(response *api.Message, responseData *PendingRequest) (*handlers.UserCallbackPayload, *PendingRequest, error) { + if _, exists := responseData.responses[response.Body.Sender]; exists { + return nil, nil, errors.New("duplicate response") + } + if response.Body.Method != responseData.request.Body.Method { + return nil, responseData, errors.New("invalid method") + } + responseData.responses[response.Body.Sender] = response + + // user response is ready with F+1 node responses + if len(responseData.responses) >= h.donConfig.F+1 { + var responseList []*api.Message + for _, response := range responseData.responses { + responseList = append(responseList, response) + } + userResponse := *responseData.request + userResponse.Body.Receiver = responseData.request.Body.Sender + // success = true only means that we got F+1 responses + // it's up to the heartbeat sender to validate computation results + payload := CombinedResponse{ResponseBase: ResponseBase{Success: true}, NodeResponses: responseList} + payloadJson, err := json.Marshal(payload) + if err != nil { + return &handlers.UserCallbackPayload{Msg: &userResponse, ErrCode: api.NodeReponseEncodingError, ErrMsg: ""}, nil, nil + } + userResponse.Body.Payload = payloadJson + return &handlers.UserCallbackPayload{Msg: &userResponse, ErrCode: api.NoError, ErrMsg: ""}, nil, nil + } + // not ready to be processed yet + return nil, responseData, nil +} + func (h *functionsHandler) Start(ctx context.Context) error { return h.StartOnce("FunctionsHandler", func() error { h.lggr.Info("starting FunctionsHandler") diff --git a/core/services/gateway/handlers/functions/handler.functions_test.go b/core/services/gateway/handlers/functions/handler.functions_test.go index 402823df173..f36b64709a2 100644 --- a/core/services/gateway/handlers/functions/handler.functions_test.go +++ b/core/services/gateway/handlers/functions/handler.functions_test.go @@ -25,7 +25,7 @@ import ( handlers_mocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/mocks" ) -func newFunctionsHandlerForATestDON(t *testing.T, nodes []gc.TestNode, requestTimeout time.Duration) (handlers.Handler, *handlers_mocks.DON, *functions_mocks.OnchainAllowlist, *functions_mocks.OnchainSubscriptions) { +func newFunctionsHandlerForATestDON(t *testing.T, nodes []gc.TestNode, requestTimeout time.Duration, heartbeatSender string) (handlers.Handler, *handlers_mocks.DON, *functions_mocks.OnchainAllowlist, *functions_mocks.OnchainSubscriptions) { cfg := functions.FunctionsHandlerConfig{} donConfig := &config.DONConfig{ Members: []config.NodeConfig{}, @@ -48,7 +48,8 @@ func newFunctionsHandlerForATestDON(t *testing.T, nodes []gc.TestNode, requestTi nodeRateLimiter, err := hc.NewRateLimiter(hc.RateLimiterConfig{GlobalRPS: 100.0, GlobalBurst: 100, PerSenderRPS: 100.0, PerSenderBurst: 100}) require.NoError(t, err) pendingRequestsCache := hc.NewRequestCache[functions.PendingRequest](requestTimeout, 1000) - handler := functions.NewFunctionsHandler(cfg, donConfig, don, pendingRequestsCache, allowlist, subscriptions, minBalance, userRateLimiter, nodeRateLimiter, logger.TestLogger(t)) + allowedHeartbeatInititors := map[string]struct{}{heartbeatSender: {}} + handler := functions.NewFunctionsHandler(cfg, donConfig, don, pendingRequestsCache, allowlist, subscriptions, minBalance, userRateLimiter, nodeRateLimiter, allowedHeartbeatInititors, logger.TestLogger(t)) return handler, don, allowlist, subscriptions } @@ -117,7 +118,7 @@ func TestFunctionsHandler_HandleUserMessage_SecretsSet(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { nodes, user := gc.NewTestNodes(t, 4), gc.NewTestNodes(t, 1)[0] - handler, don, allowlist, subscriptions := newFunctionsHandlerForATestDON(t, nodes, time.Hour*24) + handler, don, allowlist, subscriptions := newFunctionsHandlerForATestDON(t, nodes, time.Hour*24, user.Address) userRequestMsg := newSignedMessage(t, "1234", "secrets_set", "don_id", user.PrivateKey) callbachCh := make(chan handlers.UserCallbackPayload) @@ -144,11 +145,54 @@ func TestFunctionsHandler_HandleUserMessage_SecretsSet(t *testing.T) { } } +func TestFunctionsHandler_HandleUserMessage_Heartbeat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + nodeResults []bool + expectedGatewayResult bool + expectedNodeMessageCount int + }{ + {"three successful", []bool{true, true, true, false}, true, 2}, + {"two successful", []bool{false, true, false, true}, true, 2}, + {"one successful", []bool{false, true, false, false}, true, 2}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + nodes, user := gc.NewTestNodes(t, 4), gc.NewTestNodes(t, 1)[0] + handler, don, allowlist, _ := newFunctionsHandlerForATestDON(t, nodes, time.Hour*24, user.Address) + userRequestMsg := newSignedMessage(t, "1234", "heartbeat", "don_id", user.PrivateKey) + + callbachCh := make(chan handlers.UserCallbackPayload) + done := make(chan struct{}) + go func() { + defer close(done) + // wait on a response from Gateway to the user + response := <-callbachCh + require.Equal(t, api.NoError, response.ErrCode) + require.Equal(t, userRequestMsg.Body.MessageId, response.Msg.Body.MessageId) + var payload functions.CombinedResponse + require.NoError(t, json.Unmarshal(response.Msg.Body.Payload, &payload)) + require.Equal(t, test.expectedGatewayResult, payload.Success) + require.Equal(t, test.expectedNodeMessageCount, len(payload.NodeResponses)) + }() + + allowlist.On("Allow", common.HexToAddress(user.Address)).Return(true, nil) + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) + require.NoError(t, handler.HandleUserMessage(testutils.Context(t), &userRequestMsg, callbachCh)) + sendNodeReponses(t, handler, userRequestMsg, nodes, test.nodeResults) + <-done + }) + } +} + func TestFunctionsHandler_HandleUserMessage_InvalidMethod(t *testing.T) { t.Parallel() nodes, user := gc.NewTestNodes(t, 4), gc.NewTestNodes(t, 1)[0] - handler, _, allowlist, _ := newFunctionsHandlerForATestDON(t, nodes, time.Hour*24) + handler, _, allowlist, _ := newFunctionsHandlerForATestDON(t, nodes, time.Hour*24, user.Address) userRequestMsg := newSignedMessage(t, "1234", "secrets_reveal_all_please", "don_id", user.PrivateKey) allowlist.On("Allow", common.HexToAddress(user.Address)).Return(true, nil) @@ -160,7 +204,7 @@ func TestFunctionsHandler_HandleUserMessage_Timeout(t *testing.T) { t.Parallel() nodes, user := gc.NewTestNodes(t, 4), gc.NewTestNodes(t, 1)[0] - handler, don, allowlist, subscriptions := newFunctionsHandlerForATestDON(t, nodes, time.Millisecond*10) + handler, don, allowlist, subscriptions := newFunctionsHandlerForATestDON(t, nodes, time.Millisecond*10, user.Address) userRequestMsg := newSignedMessage(t, "1234", "secrets_set", "don_id", user.PrivateKey) callbachCh := make(chan handlers.UserCallbackPayload) diff --git a/core/services/ocr2/plugins/functions/plugin.go b/core/services/ocr2/plugins/functions/plugin.go index 475cf0a2af7..82280f527cd 100644 --- a/core/services/ocr2/plugins/functions/plugin.go +++ b/core/services/ocr2/plugins/functions/plugin.go @@ -146,7 +146,7 @@ func NewFunctionsServices(functionsOracleArgs, thresholdOracleArgs, s4OracleArgs return nil, errors.Wrap(err, "failed to create a OnchainSubscriptions") } connectorLogger := conf.Logger.Named("GatewayConnector").With("jobName", conf.Job.PipelineSpec.JobName) - connector, err2 := NewConnector(pluginConfig.GatewayConnectorConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, pluginConfig.MinimumSubscriptionBalance, connectorLogger) + connector, err2 := NewConnector(pluginConfig.GatewayConnectorConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, pluginConfig.MinimumSubscriptionBalance, connectorLogger) if err2 != nil { return nil, errors.Wrap(err, "failed to create a GatewayConnector") } @@ -173,7 +173,7 @@ func NewFunctionsServices(functionsOracleArgs, thresholdOracleArgs, s4OracleArgs return allServices, nil } -func NewConnector(gwcCfg *connector.ConnectorConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwFunctions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwFunctions.OnchainSubscriptions, minimumBalance assets.Link, lggr logger.Logger) (connector.GatewayConnector, error) { +func NewConnector(gwcCfg *connector.ConnectorConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwFunctions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwFunctions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, minimumBalance assets.Link, lggr logger.Logger) (connector.GatewayConnector, error) { enabledKeys, err := ethKeystore.EnabledKeysForChain(chainID) if err != nil { return nil, err @@ -186,7 +186,7 @@ func NewConnector(gwcCfg *connector.ConnectorConfig, ethKeystore keystore.Eth, c signerKey := enabledKeys[idx].ToEcdsaPrivKey() nodeAddress := enabledKeys[idx].ID() - handler, err := functions.NewFunctionsConnectorHandler(nodeAddress, signerKey, s4Storage, allowlist, rateLimiter, subscriptions, minimumBalance, lggr) + handler, err := functions.NewFunctionsConnectorHandler(nodeAddress, signerKey, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, minimumBalance, lggr) if err != nil { return nil, err } diff --git a/core/services/ocr2/plugins/functions/plugin_test.go b/core/services/ocr2/plugins/functions/plugin_test.go index 453d4b67aa8..d77fabcc437 100644 --- a/core/services/ocr2/plugins/functions/plugin_test.go +++ b/core/services/ocr2/plugins/functions/plugin_test.go @@ -10,6 +10,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink/v2/core/logger" + sfmocks "github.com/smartcontractkit/chainlink/v2/core/services/functions/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" gfmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/mocks" @@ -35,8 +36,10 @@ func TestNewConnector_Success(t *testing.T) { subscriptions := gfmocks.NewOnchainSubscriptions(t) rateLimiter, err := hc.NewRateLimiter(hc.RateLimiterConfig{GlobalRPS: 100.0, GlobalBurst: 100, PerSenderRPS: 100.0, PerSenderBurst: 100}) require.NoError(t, err) + listener := sfmocks.NewFunctionsListener(t) + offchainTransmitter := sfmocks.NewOffchainTransmitter(t) ethKeystore.On("EnabledKeysForChain", mock.Anything).Return([]ethkey.KeyV2{keyV2}, nil) - _, err = functions.NewConnector(gwcCfg, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, *assets.NewLinkFromJuels(0), logger.TestLogger(t)) + _, err = functions.NewConnector(gwcCfg, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(0), logger.TestLogger(t)) require.NoError(t, err) } @@ -58,7 +61,9 @@ func TestNewConnector_NoKeyForConfiguredAddress(t *testing.T) { subscriptions := gfmocks.NewOnchainSubscriptions(t) rateLimiter, err := hc.NewRateLimiter(hc.RateLimiterConfig{GlobalRPS: 100.0, GlobalBurst: 100, PerSenderRPS: 100.0, PerSenderBurst: 100}) require.NoError(t, err) + listener := sfmocks.NewFunctionsListener(t) + offchainTransmitter := sfmocks.NewOffchainTransmitter(t) ethKeystore.On("EnabledKeysForChain", mock.Anything).Return([]ethkey.KeyV2{{Address: common.HexToAddress(addresses[1])}}, nil) - _, err = functions.NewConnector(gwcCfg, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, *assets.NewLinkFromJuels(0), logger.TestLogger(t)) + _, err = functions.NewConnector(gwcCfg, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(0), logger.TestLogger(t)) require.Error(t, err) }