Skip to content

Commit

Permalink
[Functions] Add extra validations for offchain heartbeats (#11783)
Browse files Browse the repository at this point in the history
1. Add AllowedHeartbeatInitiators list to node's config and validate senders of incoming requests against it (same logic as in Gateway).
2. Validate Sender value in nodes' reponses to make sure it matches the expected node. Extend an integration test to cover this change.
3. Validate age of incoming requests against RequestTimeoutSec from job config to avoid processing ones that already timed out.
4. Disallow null-byte suffixes in message fields to avoid any potential confusion with default padding.
  • Loading branch information
bolekk authored Jan 23, 2024
1 parent c45ff89 commit 388e779
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 65 deletions.
83 changes: 49 additions & 34 deletions core/services/functions/connector_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/ecdsa"
"encoding/json"
"fmt"
"strings"
"sync"
"time"

Expand All @@ -25,34 +26,34 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector"
hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions"
"github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config"
"github.com/smartcontractkit/chainlink/v2/core/services/s4"
)

type functionsConnectorHandler struct {
services.StateMachine

connector connector.GatewayConnector
signerKey *ecdsa.PrivateKey
nodeAddress string
storage s4.Storage
allowlist functions.OnchainAllowlist
rateLimiter *hc.RateLimiter
subscriptions functions.OnchainSubscriptions
minimumBalance assets.Link
listener FunctionsListener
offchainTransmitter OffchainTransmitter
heartbeatRequests map[RequestID]*HeartbeatResponse
orderedRequests []RequestID
mu sync.Mutex
chStop services.StopChan
shutdownWaitGroup sync.WaitGroup
lggr logger.Logger
connector connector.GatewayConnector
signerKey *ecdsa.PrivateKey
nodeAddress string
storage s4.Storage
allowlist functions.OnchainAllowlist
rateLimiter *hc.RateLimiter
subscriptions functions.OnchainSubscriptions
minimumBalance assets.Link
listener FunctionsListener
offchainTransmitter OffchainTransmitter
allowedHeartbeatInitiators map[string]struct{}
heartbeatRequests map[RequestID]*HeartbeatResponse
requestTimeoutSec uint32
orderedRequests []RequestID
mu sync.Mutex
chStop services.StopChan
shutdownWaitGroup sync.WaitGroup
lggr logger.Logger
}

const (
HeartbeatRequestTimeoutSec = 240
HeartbeatCacheSize = 1000
)
const HeartbeatCacheSize = 1000

var (
_ connector.Signer = &functionsConnectorHandler{}
Expand All @@ -71,23 +72,29 @@ func InternalId(sender []byte, requestId []byte) RequestID {
return RequestID(crypto.Keccak256Hash(append(sender, requestId...)).Bytes())
}

func NewFunctionsConnectorHandler(nodeAddress string, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, listener FunctionsListener, offchainTransmitter OffchainTransmitter, minimumBalance assets.Link, lggr logger.Logger) (*functionsConnectorHandler, error) {
func NewFunctionsConnectorHandler(pluginConfig *config.PluginConfig, signerKey *ecdsa.PrivateKey, storage s4.Storage, allowlist functions.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions functions.OnchainSubscriptions, listener FunctionsListener, offchainTransmitter OffchainTransmitter, lggr logger.Logger) (*functionsConnectorHandler, error) {
if signerKey == nil || storage == nil || allowlist == nil || rateLimiter == nil || subscriptions == nil || listener == nil || offchainTransmitter == nil {
return nil, fmt.Errorf("all dependencies must be non-nil")
}
allowedHeartbeatInitiators := make(map[string]struct{})
for _, initiator := range pluginConfig.AllowedHeartbeatInitiators {
allowedHeartbeatInitiators[strings.ToLower(initiator)] = struct{}{}
}
return &functionsConnectorHandler{
nodeAddress: nodeAddress,
signerKey: signerKey,
storage: storage,
allowlist: allowlist,
rateLimiter: rateLimiter,
subscriptions: subscriptions,
minimumBalance: minimumBalance,
listener: listener,
offchainTransmitter: offchainTransmitter,
heartbeatRequests: make(map[RequestID]*HeartbeatResponse),
chStop: make(services.StopChan),
lggr: lggr.Named("FunctionsConnectorHandler"),
nodeAddress: pluginConfig.GatewayConnectorConfig.NodeAddress,
signerKey: signerKey,
storage: storage,
allowlist: allowlist,
rateLimiter: rateLimiter,
subscriptions: subscriptions,
minimumBalance: pluginConfig.MinimumSubscriptionBalance,
listener: listener,
offchainTransmitter: offchainTransmitter,
allowedHeartbeatInitiators: allowedHeartbeatInitiators,
heartbeatRequests: make(map[RequestID]*HeartbeatResponse),
requestTimeoutSec: pluginConfig.RequestTimeoutSec,
chStop: make(services.StopChan),
lggr: lggr.Named("FunctionsConnectorHandler"),
}, nil
}

Expand Down Expand Up @@ -211,6 +218,10 @@ func (h *functionsConnectorHandler) handleHeartbeat(ctx context.Context, gateway
h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse(fmt.Sprintf("failed to unmarshal request: %v", err)))
return
}
if _, ok := h.allowedHeartbeatInitiators[requestBody.Sender]; !ok {
h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("sender not allowed to send heartbeat requests"))
return
}
if !bytes.Equal(request.RequestInitiator, fromAddr.Bytes()) {
h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("RequestInitiator doesn't match sender"))
return
Expand All @@ -219,6 +230,10 @@ func (h *functionsConnectorHandler) handleHeartbeat(ctx context.Context, gateway
h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("SubscriptionOwner doesn't match sender"))
return
}
if request.Timestamp < uint64(time.Now().Unix())-uint64(h.requestTimeoutSec) {
h.sendResponseAndLog(ctx, gatewayId, requestBody, internalErrorResponse("Request is too old"))
return
}

internalId := InternalId(fromAddr.Bytes(), request.RequestId)
request.RequestId = internalId[:]
Expand Down Expand Up @@ -250,7 +265,7 @@ func internalErrorResponse(internalError string) HeartbeatResponse {
func (h *functionsConnectorHandler) handleOffchainRequest(request *OffchainRequest) {
defer h.shutdownWaitGroup.Done()
stopCtx, _ := h.chStop.NewCtx()
ctx, cancel := context.WithTimeout(stopCtx, time.Duration(HeartbeatRequestTimeoutSec)*time.Second)
ctx, cancel := context.WithTimeout(stopCtx, time.Duration(h.requestTimeoutSec)*time.Second)
defer cancel()
err := h.listener.HandleOffchainRequest(ctx, request)
if err != nil {
Expand Down
38 changes: 33 additions & 5 deletions core/services/functions/connector_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

geth_common "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/onsi/gomega"

"github.com/smartcontractkit/chainlink-common/pkg/assets"
Expand All @@ -19,9 +20,11 @@ import (
sfmocks "github.com/smartcontractkit/chainlink/v2/core/services/functions/mocks"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/api"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/common"
gwconnector "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector"
gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks"
hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common"
gfmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/mocks"
"github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config"
"github.com/smartcontractkit/chainlink/v2/core/services/s4"
s4mocks "github.com/smartcontractkit/chainlink/v2/core/services/s4/mocks"

Expand All @@ -30,7 +33,7 @@ import (
"github.com/stretchr/testify/require"
)

func newOffchainRequest(t *testing.T, sender []byte) (*api.Message, functions.RequestID) {
func newOffchainRequest(t *testing.T, sender []byte, ageSec uint64) (*api.Message, functions.RequestID) {
requestId := make([]byte, 32)
_, err := rand.Read(requestId)
require.NoError(t, err)
Expand All @@ -39,6 +42,7 @@ func newOffchainRequest(t *testing.T, sender []byte) (*api.Message, functions.Re
RequestInitiator: sender,
SubscriptionId: 1,
SubscriptionOwner: sender,
Timestamp: uint64(time.Now().Unix()) - ageSec,
}

internalId := functions.InternalId(request.RequestInitiator, request.RequestId)
Expand Down Expand Up @@ -74,7 +78,15 @@ func TestFunctionsConnectorHandler(t *testing.T) {
allowlist.On("Close", mock.Anything).Return(nil)
subscriptions.On("Start", mock.Anything).Return(nil)
subscriptions.On("Close", mock.Anything).Return(nil)
handler, err := functions.NewFunctionsConnectorHandler(addr.Hex(), privateKey, storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, *assets.NewLinkFromJuels(100), logger)
config := &config.PluginConfig{
GatewayConnectorConfig: &gwconnector.ConnectorConfig{
NodeAddress: addr.Hex(),
},
MinimumSubscriptionBalance: *assets.NewLinkFromJuels(100),
RequestTimeoutSec: 1_000,
AllowedHeartbeatInitiators: []string{crypto.PubkeyToAddress(privateKey.PublicKey).Hex()},
}
handler, err := functions.NewFunctionsConnectorHandler(config, privateKey, storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger)
require.NoError(t, err)

handler.SetConnector(connector)
Expand Down Expand Up @@ -257,7 +269,7 @@ func TestFunctionsConnectorHandler(t *testing.T) {

t.Run("heartbeat success", func(t *testing.T) {
ctx := testutils.Context(t)
msg, internalId := newOffchainRequest(t, addr.Bytes())
msg, internalId := newOffchainRequest(t, addr.Bytes(), 0)
require.NoError(t, msg.Sign(privateKey))

// first call to trigger the request
Expand Down Expand Up @@ -292,7 +304,7 @@ func TestFunctionsConnectorHandler(t *testing.T) {

t.Run("heartbeat internal error", func(t *testing.T) {
ctx := testutils.Context(t)
msg, _ := newOffchainRequest(t, addr.Bytes())
msg, _ := newOffchainRequest(t, addr.Bytes(), 0)
require.NoError(t, msg.Sign(privateKey))

// first call to trigger the request
Expand All @@ -319,7 +331,23 @@ func TestFunctionsConnectorHandler(t *testing.T) {

t.Run("heartbeat sender address doesn't match", func(t *testing.T) {
ctx := testutils.Context(t)
msg, _ := newOffchainRequest(t, geth_common.BytesToAddress([]byte("0x1234")).Bytes())
msg, _ := newOffchainRequest(t, geth_common.BytesToAddress([]byte("0x1234")).Bytes(), 0)
require.NoError(t, msg.Sign(privateKey))

var response functions.HeartbeatResponse
allowlist.On("Allow", addr).Return(true).Once()
connector.On("SendToGateway", mock.Anything, "gw1", mock.Anything).Run(func(args mock.Arguments) {
respMsg, ok := args[2].(*api.Message)
require.True(t, ok)
require.NoError(t, json.Unmarshal(respMsg.Body.Payload, &response))
require.Equal(t, functions.RequestStateInternalError, response.Status)
}).Return(nil).Once()
handler.HandleGatewayMessage(ctx, "gw1", msg)
})

t.Run("heartbeat request too old", func(t *testing.T) {
ctx := testutils.Context(t)
msg, _ := newOffchainRequest(t, addr.Bytes(), 10_000)
require.NoError(t, msg.Sign(privateKey))

var response functions.HeartbeatResponse
Expand Down
10 changes: 10 additions & 0 deletions core/services/gateway/api/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
MessageMethodMaxLen = 64
MessageDonIdMaxLen = 64
MessageReceiverLen = 2 + 2*20
NullChar = "\x00"
)

/*
Expand Down Expand Up @@ -56,12 +57,21 @@ func (m *Message) Validate() error {
if len(m.Body.MessageId) == 0 || len(m.Body.MessageId) > MessageIdMaxLen {
return errors.New("invalid message ID length")
}
if strings.HasSuffix(m.Body.MessageId, NullChar) {
return errors.New("message ID ending with null bytes")
}
if len(m.Body.Method) == 0 || len(m.Body.Method) > MessageMethodMaxLen {
return errors.New("invalid method name length")
}
if strings.HasSuffix(m.Body.Method, NullChar) {
return errors.New("method name ending with null bytes")
}
if len(m.Body.DonId) == 0 || len(m.Body.DonId) > MessageDonIdMaxLen {
return errors.New("invalid DON ID length")
}
if strings.HasSuffix(m.Body.DonId, NullChar) {
return errors.New("DON ID ending with null bytes")
}
if len(m.Body.Receiver) != 0 && len(m.Body.Receiver) != MessageReceiverLen {
return errors.New("invalid Receiver length")
}
Expand Down
18 changes: 17 additions & 1 deletion core/services/gateway/api/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,38 @@ func TestMessage_Validate(t *testing.T) {
// missing message ID
msg.Body.MessageId = ""
require.Error(t, msg.Validate())
// message ID ending with null bytes
msg.Body.MessageId = "myid\x00\x00"
require.Error(t, msg.Validate())
msg.Body.MessageId = "abcd"
require.NoError(t, msg.Validate())

// missing DON ID
msg.Body.DonId = ""
require.Error(t, msg.Validate())
// DON ID ending with null bytes
msg.Body.DonId = "mydon\x00\x00"
require.Error(t, msg.Validate())
msg.Body.DonId = "donA"
require.NoError(t, msg.Validate())

// method too long
// method name too long
msg.Body.Method = string(bytes.Repeat([]byte("a"), api.MessageMethodMaxLen+1))
require.Error(t, msg.Validate())
// empty method name
msg.Body.Method = ""
require.Error(t, msg.Validate())
// method name ending with null bytes
msg.Body.Method = "method\x00"
require.Error(t, msg.Validate())
msg.Body.Method = "request"
require.NoError(t, msg.Validate())

// incorrect receiver
msg.Body.Receiver = "blah"
require.Error(t, msg.Validate())
msg.Body.Receiver = "0x0000000000000000000000000000000000000000"
require.NoError(t, msg.Validate())

// invalid signature
msg.Signature = "0x00"
Expand Down
4 changes: 4 additions & 0 deletions core/services/gateway/connectionmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ func (m *donConnectionManager) readLoop(nodeAddress string, nodeState *nodeState
m.lggr.Errorw("message validation error when reading from node", "nodeAddress", nodeAddress, "err", err)
break
}
if msg.Body.Sender != nodeAddress {
m.lggr.Errorw("message sender mismatch when reading from node", "nodeAddress", nodeAddress, "sender", msg.Body.Sender)
break
}
err = m.handler.HandleNodeMessage(ctx, msg, nodeAddress)
if err != nil {
m.lggr.Error("error when calling HandleNodeMessage ", err)
Expand Down
Loading

0 comments on commit 388e779

Please sign in to comment.