Skip to content

Commit

Permalink
[Functions] Heartbeat request support in Gateway handlers
Browse files Browse the repository at this point in the history
1. Functions Handler
  - add a new method "heartbeat"
  - add a configurable list of allowed heartbeat senders
  - collect results from first F+1 nodes and send back in raw form
2. Connector Handler
  - asynchronously forward requests to Listener and cache results
  - run a loop to collect OCR reports from Offchain Transmitter
3. Listener
  - remove SubscriptionOwner field (must be equal to sender) and add Timestamp
  • Loading branch information
bolekk committed Nov 20, 2023
1 parent c9312c6 commit 2ac4bfc
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 92 deletions.
181 changes: 160 additions & 21 deletions core/services/functions/connector_handler.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package functions

import (
"bytes"
"context"
"crypto/ecdsa"
"encoding/json"
"fmt"
"sync"
"time"

"go.uber.org/multierr"

ethCommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"

"github.com/smartcontractkit/chainlink-common/pkg/assets"
"github.com/smartcontractkit/chainlink-common/pkg/services"
Expand All @@ -25,35 +29,60 @@ import (
type functionsConnectorHandler struct {
services.StateMachine

connector connector.GatewayConnector
signerKey *ecdsa.PrivateKey
nodeAddress string
storage s4.Storage
allowlist functions.OnchainAllowlist
rateLimiter *hc.RateLimiter
subscriptions functions.OnchainSubscriptions
minimumBalance assets.Link
lggr logger.Logger
connector connector.GatewayConnector
signerKey *ecdsa.PrivateKey
nodeAddress string
storage s4.Storage
allowlist functions.OnchainAllowlist
rateLimiter *hc.RateLimiter
subscriptions functions.OnchainSubscriptions
minimumBalance assets.Link
listener FunctionsListener
offchainTransmitter OffchainTransmitter
heartbeatRequests map[RequestID]*HeartbeatResponse
orderedRequests []RequestID
mu sync.Mutex
chStop chan struct{}
shutdownWaitGroup sync.WaitGroup
serviceContext context.Context
serviceCancel context.CancelFunc
lggr logger.Logger
}

const (
HeartbeatRequestTimeoutSec = 240
HeartbeatCacheSize = 1000
)

var (
_ connector.Signer = &functionsConnectorHandler{}
_ connector.GatewayConnectorHandler = &functionsConnectorHandler{}
)

func NewFunctionsConnectorHandler(nodeAddress string, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, minimumBalance assets.Link, lggr logger.Logger) (*functionsConnectorHandler, error) {
if signerKey == nil || storage == nil || allowlist == nil || rateLimiter == nil || subscriptions == nil {
return nil, fmt.Errorf("signerKey, storage, allowlist, rateLimiter and subscriptions must be non-nil")
// internal request ID is a hash of (sender, requestID)
func InternalId(sender []byte, requestId []byte) RequestID {
idPair := sender
idPair = append(idPair, requestId...)
return RequestID(crypto.Keccak256Hash(idPair).Bytes())
}

func NewFunctionsConnectorHandler(nodeAddress string, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, listener FunctionsListener, offchainTransmitter OffchainTransmitter, minimumBalance assets.Link, lggr logger.Logger) (*functionsConnectorHandler, error) {
if signerKey == nil || storage == nil || allowlist == nil || rateLimiter == nil || subscriptions == nil || listener == nil || offchainTransmitter == nil {
return nil, fmt.Errorf("all dependencies must be non-nil")
}
return &functionsConnectorHandler{
nodeAddress: nodeAddress,
signerKey: signerKey,
storage: storage,
allowlist: allowlist,
rateLimiter: rateLimiter,
subscriptions: subscriptions,
minimumBalance: minimumBalance,
lggr: lggr.Named("FunctionsConnectorHandler"),
nodeAddress: nodeAddress,
signerKey: signerKey,
storage: storage,
allowlist: allowlist,
rateLimiter: rateLimiter,
subscriptions: subscriptions,
minimumBalance: minimumBalance,
listener: listener,
offchainTransmitter: offchainTransmitter,
heartbeatRequests: make(map[RequestID]*HeartbeatResponse),
chStop: make(chan struct{}),
lggr: lggr.Named("FunctionsConnectorHandler"),
}, nil
}

Expand Down Expand Up @@ -92,24 +121,35 @@ func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, ga
return
}
h.handleSecretsSet(ctx, gatewayId, body, fromAddr)
case functions.MethodHeartbeat:
h.handleHeartbeat(ctx, gatewayId, body, fromAddr)
default:
h.lggr.Errorw("unsupported method", "id", gatewayId, "method", body.Method)
}
}

func (h *functionsConnectorHandler) Start(ctx context.Context) error {
return h.StartOnce("FunctionsConnectorHandler", func() error {
h.serviceContext, h.serviceCancel = context.WithCancel(context.Background())
if err := h.allowlist.Start(ctx); err != nil {
return err
}
return h.subscriptions.Start(ctx)
if err := h.subscriptions.Start(ctx); err != nil {
return err
}
h.shutdownWaitGroup.Add(1)
go h.reportLoop()
return nil
})
}

func (h *functionsConnectorHandler) Close() error {
return h.StopOnce("FunctionsConnectorHandler", func() (err error) {
h.serviceCancel()
close(h.chStop)
err = multierr.Combine(err, h.allowlist.Close())
err = multierr.Combine(err, h.subscriptions.Close())
h.shutdownWaitGroup.Wait()
return
})
}
Expand Down Expand Up @@ -160,6 +200,105 @@ func (h *functionsConnectorHandler) handleSecretsSet(ctx context.Context, gatewa
h.sendResponseAndLog(ctx, gatewayId, body, response)
}

func (h *functionsConnectorHandler) handleHeartbeat(ctx context.Context, gatewayId string, requestBody *api.MessageBody, fromAddr ethCommon.Address) {
var request *OffchainRequest
err := json.Unmarshal(requestBody.Payload, &request)
if err != nil {
h.sendResponseAndLog(ctx, gatewayId, requestBody, HeartbeatResponse{
Status: RequestStateInternalError,
InternalError: fmt.Sprintf("failed to unmarshal request: %v", err),
})
return
}
if !bytes.Equal(request.RequestInitiator, fromAddr.Bytes()) {
h.sendResponseAndLog(ctx, gatewayId, requestBody, HeartbeatResponse{
Status: RequestStateInternalError,
InternalError: "requestInitiator doesn't match sender",
})
return
}

internalId := InternalId(fromAddr.Bytes(), request.RequestId)
request.RequestId = internalId[:]
h.lggr.Infow("handling offchain heartbeat", "messageId", requestBody.MessageId, "id", internalId, "sender", requestBody.Sender)
h.mu.Lock()
response, ok := h.heartbeatRequests[internalId]
if !ok { // new request
response = &HeartbeatResponse{
Status: RequestStatePending,
ReceivedTs: uint64(time.Now().Unix()),
}
h.cacheNewRequestLocked(internalId, response)
h.shutdownWaitGroup.Add(1)
go h.handleOffchainRequest(request)
}
responseToSend := *response
h.mu.Unlock()
requestBody.Receiver = requestBody.Sender
h.sendResponseAndLog(ctx, gatewayId, requestBody, responseToSend)
}

func (h *functionsConnectorHandler) handleOffchainRequest(request *OffchainRequest) {
defer h.shutdownWaitGroup.Done()
ctx, cancel := context.WithTimeout(h.serviceContext, time.Duration(HeartbeatRequestTimeoutSec)*time.Second)
defer cancel()
err := h.listener.HandleOffchainRequest(ctx, request)
if err != nil {
h.lggr.Infow("internal error while processing", "id", request.RequestId)
h.mu.Lock()
defer h.mu.Unlock()
state, ok := h.heartbeatRequests[RequestID(request.RequestId)]
if !ok {
h.lggr.Errorw("request unexpectedly disappeared from local cache", "id", request.RequestId)
return
}
state.CompletedTs = uint64(time.Now().Unix())
state.Status = RequestStateInternalError
state.InternalError = err.Error()
} else {
// no error - results will be sent to OCR aggregation and returned via reportLoop()
h.lggr.Infow("request processed successfully, waiting for aggregation ...", "id", request.RequestId)
}
}

// Listen to OCR reports passed from the plugin and process them against a local cache of requests.
func (h *functionsConnectorHandler) reportLoop() {
defer h.shutdownWaitGroup.Done()
for {
select {
case report := <-h.offchainTransmitter.ReportChannel():
h.lggr.Infow("received report", "requestId", report.RequestId, "resultLen", len(report.Result), "errorLen", len(report.Error))
if len(report.RequestId) != RequestIDLength {
h.lggr.Errorw("report has invalid requestId", "requestId", report.RequestId)
continue
}
h.mu.Lock()
cachedResponse, ok := h.heartbeatRequests[RequestID(report.RequestId)]
if !ok {
h.lggr.Infow("received report for unknown request, caching it", "id", report.RequestId)
cachedResponse = &HeartbeatResponse{}
h.cacheNewRequestLocked(RequestID(report.RequestId), cachedResponse)
}
cachedResponse.CompletedTs = uint64(time.Now().Unix())
cachedResponse.Status = RequestStateComplete
cachedResponse.Response = report
h.mu.Unlock()
case <-h.chStop:
h.lggr.Info("exiting reportLoop")
return
}
}
}

func (h *functionsConnectorHandler) cacheNewRequestLocked(requestId RequestID, response *HeartbeatResponse) {
// remove oldest requests
for len(h.orderedRequests) >= HeartbeatCacheSize {
delete(h.heartbeatRequests, h.orderedRequests[0])
h.orderedRequests = h.orderedRequests[1:]
}
h.heartbeatRequests[requestId] = response
}

func (h *functionsConnectorHandler) sendResponseAndLog(ctx context.Context, gatewayId string, requestBody *api.MessageBody, payload any) {
err := h.sendResponse(ctx, gatewayId, requestBody, payload)
if err != nil {
Expand Down
108 changes: 107 additions & 1 deletion core/services/functions/connector_handler_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package functions_test

import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"math/big"
"testing"

geth_common "github.com/ethereum/go-ethereum/common"

"github.com/smartcontractkit/chainlink-common/pkg/assets"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/functions"
sfmocks "github.com/smartcontractkit/chainlink/v2/core/services/functions/mocks"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/api"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/common"
gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks"
Expand All @@ -24,6 +28,30 @@ import (
"github.com/stretchr/testify/require"
)

func newOffchainRequest(t *testing.T, sender []byte) (*api.Message, functions.RequestID) {
requestId := make([]byte, 32)
_, err := rand.Read(requestId)
require.NoError(t, err)
request := &functions.OffchainRequest{
RequestId: requestId,
RequestInitiator: sender,
SubscriptionId: 1,
}

internalId := functions.InternalId(request.RequestInitiator, request.RequestId)
req, err := json.Marshal(request)
require.NoError(t, err)
msg := &api.Message{
Body: api.MessageBody{
DonId: "fun4",
MessageId: "1",
Method: "heartbeat",
Payload: req,
},
}
return msg, internalId
}

func TestFunctionsConnectorHandler(t *testing.T) {
t.Parallel()

Expand All @@ -34,12 +62,16 @@ func TestFunctionsConnectorHandler(t *testing.T) {
allowlist := gfmocks.NewOnchainAllowlist(t)
rateLimiter, err := hc.NewRateLimiter(hc.RateLimiterConfig{GlobalRPS: 100.0, GlobalBurst: 100, PerSenderRPS: 100.0, PerSenderBurst: 100})
subscriptions := gfmocks.NewOnchainSubscriptions(t)
reportCh := make(chan *functions.OffchainResponse)
offchainTransmitter := sfmocks.NewOffchainTransmitter(t)
offchainTransmitter.On("ReportChannel", mock.Anything).Return(reportCh)
listener := sfmocks.NewFunctionsListener(t)
require.NoError(t, err)
allowlist.On("Start", mock.Anything).Return(nil)
allowlist.On("Close", mock.Anything).Return(nil)
subscriptions.On("Start", mock.Anything).Return(nil)
subscriptions.On("Close", mock.Anything).Return(nil)
handler, err := functions.NewFunctionsConnectorHandler(addr.Hex(), privateKey, storage, allowlist, rateLimiter, subscriptions, *assets.NewLinkFromJuels(100), logger)
handler, err := functions.NewFunctionsConnectorHandler(addr.Hex(), privateKey, storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(100), logger)
require.NoError(t, err)

handler.SetConnector(connector)
Expand Down Expand Up @@ -219,4 +251,78 @@ func TestFunctionsConnectorHandler(t *testing.T) {
handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg)
})
})

t.Run("heartbeat success", func(t *testing.T) {
ctx := testutils.Context(t)
msg, internalId := newOffchainRequest(t, addr.Bytes())
require.NoError(t, msg.Sign(privateKey))

// first call to trigger the request
var response functions.HeartbeatResponse
allowlist.On("Allow", addr).Return(true).Once()
listener.On("HandleOffchainRequest", mock.Anything, mock.Anything).Return(nil).Once()
connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) {
respMsg, ok := args[2].(*api.Message)
require.True(t, ok)
require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response))
require.Equal(t, functions.RequestStatePending, response.Status)
}).Return(nil).Once()
handler.HandleGatewayMessage(ctx, "gw1", msg)

// async response computation
reportCh <- &functions.OffchainResponse{
RequestId: internalId[:],
Result: []byte("ok!"),
}
reportCh <- &functions.OffchainResponse{} // sending second item to make sure the first one got processed

// second call to collect the response
allowlist.On("Allow", addr).Return(true).Once()
connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) {
respMsg, ok := args[2].(*api.Message)
require.True(t, ok)
require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response))
require.Equal(t, functions.RequestStateComplete, response.Status)
}).Return(nil).Once()
handler.HandleGatewayMessage(ctx, "gw1", msg)
})

t.Run("heartbeat internal error", func(t *testing.T) {
ctx := testutils.Context(t)
msg, _ := newOffchainRequest(t, addr.Bytes())
require.NoError(t, msg.Sign(privateKey))

// first call to trigger the request
var response functions.HeartbeatResponse
allowlist.On("Allow", addr).Return(true).Once()
listener.On("HandleOffchainRequest", mock.Anything, mock.Anything).Return(errors.New("boom")).Once()
connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Return(nil).Once()
handler.HandleGatewayMessage(ctx, "gw1", msg)

// second call to collect the response
allowlist.On("Allow", addr).Return(true).Once()
connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) {
respMsg, ok := args[2].(*api.Message)
require.True(t, ok)
require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response))
require.Equal(t, functions.RequestStateInternalError, response.Status)
}).Return(nil).Once()
handler.HandleGatewayMessage(ctx, "gw1", msg)
})

t.Run("heartbeat sender address doesn't match", func(t *testing.T) {
ctx := testutils.Context(t)
msg, _ := newOffchainRequest(t, geth_common.BytesToAddress([]byte("0x1234")).Bytes())
require.NoError(t, msg.Sign(privateKey))

var response functions.HeartbeatResponse
allowlist.On("Allow", addr).Return(true).Once()
connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) {
respMsg, ok := args[2].(*api.Message)
require.True(t, ok)
require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response))
require.Equal(t, functions.RequestStateInternalError, response.Status)
}).Return(nil).Once()
handler.HandleGatewayMessage(ctx, "gw1", msg)
})
}
Loading

0 comments on commit 2ac4bfc

Please sign in to comment.