Skip to content

Commit

Permalink
[Functions] Enforce uniform coordinator addresses in OCR batches (#10612
Browse files Browse the repository at this point in the history
)

1. Ensure that all requests in an OCR query point to the same destination coordinator.
2. Filter out mismatching coordinators in report phase (to protect from malicious leaders).
3. Run an extra sanity check inside the transmitter - it's too late to do anything about it so only log an error.
  • Loading branch information
bolekk authored Sep 13, 2023
1 parent e027262 commit 4a2fc41
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 7 deletions.
45 changes: 45 additions & 0 deletions core/services/ocr2/plugins/functions/reporting.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"fmt"

"github.com/ethereum/go-ethereum/common"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -117,6 +119,21 @@ func (f FunctionsReportingPluginFactory) NewReportingPlugin(rpConfig types.Repor
return &plugin, info, nil
}

// Check if requestCoordinator can be included together with reportCoordinator.
// Return new reportCoordinator (if previous was nil) and error.
func ShouldIncludeCoordinator(requestCoordinator *common.Address, reportCoordinator *common.Address) (*common.Address, error) {
if requestCoordinator == nil || *requestCoordinator == (common.Address{}) {
return reportCoordinator, errors.New("missing/zero request coordinator address")
}
if reportCoordinator == nil {
return requestCoordinator, nil
}
if *reportCoordinator != *requestCoordinator {
return reportCoordinator, errors.New("coordinator contract address mismatch")
}
return reportCoordinator, nil
}

// Query() complies with ReportingPlugin
func (r *functionsReporting) Query(ctx context.Context, ts types.ReportTimestamp) (types.Query, error) {
r.logger.Debug("FunctionsReporting Query start", commontypes.LogFields{
Expand All @@ -132,8 +149,21 @@ func (r *functionsReporting) Query(ctx context.Context, ts types.ReportTimestamp

queryProto := encoding.Query{}
var idStrs []string
var reportCoordinator *common.Address
for _, result := range results {
result := result
if r.contractVersion == 1 {
reportCoordinator, err = ShouldIncludeCoordinator(result.CoordinatorContractAddress, reportCoordinator)
if err != nil {
r.logger.Debug("FunctionsReporting Query: skipping request with mismatched coordinator contract address", commontypes.LogFields{
"requestID": formatRequestId(result.RequestID[:]),
"requestCoordinator": result.CoordinatorContractAddress,
"reportCoordinator": reportCoordinator,
"error": err,
})
continue
}
}
queryProto.RequestIDs = append(queryProto.RequestIDs, result.RequestID[:])
idStrs = append(idStrs, formatRequestId(result.RequestID[:]))
}
Expand Down Expand Up @@ -288,6 +318,7 @@ func (r *functionsReporting) Report(ctx context.Context, ts types.ReportTimestam
var allAggregated []*encoding.ProcessedRequest
var allIdStrs []string
var totalCallbackGas uint32
var reportCoordinator *common.Address
for _, reqId := range uniqueQueryIds {
observations := reqIdToObservationList[reqId]
if !CanAggregate(r.genericConfig.N, r.genericConfig.F, observations) {
Expand Down Expand Up @@ -330,6 +361,20 @@ func (r *functionsReporting) Report(ctx context.Context, ts types.ReportTimestam
"requestID": reqId,
"nObservations": len(observations),
})
if r.contractVersion == 1 {
var requestCoordinator common.Address
requestCoordinator.SetBytes(aggregated.CoordinatorContract)
reportCoordinator, err = ShouldIncludeCoordinator(&requestCoordinator, reportCoordinator)
if err != nil {
r.logger.Error("FunctionsReporting Report: skipping request with mismatched coordinator contract address", commontypes.LogFields{
"requestID": reqId,
"requestCoordinator": requestCoordinator,
"reportCoordinator": reportCoordinator,
"error": err,
})
continue
}
}
allAggregated = append(allAggregated, aggregated)
allIdStrs = append(allIdStrs, reqId)
}
Expand Down
91 changes: 86 additions & 5 deletions core/services/ocr2/plugins/functions/reporting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,26 @@ func TestFunctionsReporting_Query(t *testing.T) {
require.Equal(t, reqs[1].RequestID[:], queryProto.RequestIDs[1])
}

func TestFunctionsReporting_Query_HandleCoordinatorMismatch(t *testing.T) {
t.Parallel()
const batchSize = 10
plugin, orm, _ := preparePlugin(t, batchSize, 1, 1000000)
reqs := []functions_srv.Request{newRequest(), newRequest()}
reqs[0].CoordinatorContractAddress = &common.Address{1}
reqs[1].CoordinatorContractAddress = &common.Address{2}
orm.On("FindOldestEntriesByState", functions_srv.RESULT_READY, uint32(batchSize), mock.Anything).Return(reqs, nil)

q, err := plugin.Query(testutils.Context(t), types.ReportTimestamp{})
require.NoError(t, err)

queryProto := &encoding.Query{}
err = proto.Unmarshal(q, queryProto)
require.NoError(t, err)
require.Equal(t, 1, len(queryProto.RequestIDs))
require.Equal(t, reqs[0].RequestID[:], queryProto.RequestIDs[0])
// reqs[1] should be excluded from this query because it has a different coordinator address
}

func TestFunctionsReporting_Observation(t *testing.T) {
t.Parallel()
plugin, orm, _ := preparePlugin(t, 10, 0, 0)
Expand Down Expand Up @@ -239,10 +259,10 @@ func TestFunctionsReporting_Report_WithGasLimitAndMetadata(t *testing.T) {
reqId1, reqId2, reqId3 := newRequestID(), newRequestID(), newRequestID()
compResult := []byte("aaa")
gasLimit1, gasLimit2 := uint32(100_000), uint32(200_000)
coordinatorContract1, coordinatorContract2 := common.Address{1}, common.Address{2}
coordinatorContract := common.Address{1}
meta1, meta2 := []byte("meta1"), []byte("meta2")
procReq1 := newProcessedRequestWithMeta(reqId1, compResult, []byte{}, gasLimit1, coordinatorContract1[:], meta1)
procReq2 := newProcessedRequestWithMeta(reqId2, compResult, []byte{}, gasLimit2, coordinatorContract2[:], meta2)
procReq1 := newProcessedRequestWithMeta(reqId1, compResult, []byte{}, gasLimit1, coordinatorContract[:], meta1)
procReq2 := newProcessedRequestWithMeta(reqId2, compResult, []byte{}, gasLimit2, coordinatorContract[:], meta2)

query := newMarshalledQuery(t, reqId1, reqId2, reqId3, reqId1, reqId2) // duplicates should be ignored
obs := []types.AttributedObservation{
Expand All @@ -262,18 +282,48 @@ func TestFunctionsReporting_Report_WithGasLimitAndMetadata(t *testing.T) {
require.Equal(t, reqId1[:], decoded[0].RequestID)
require.Equal(t, compResult, decoded[0].Result)
require.Equal(t, []byte{}, decoded[0].Error)
require.Equal(t, coordinatorContract1[:], decoded[0].CoordinatorContract)
require.Equal(t, coordinatorContract[:], decoded[0].CoordinatorContract)
require.Equal(t, meta1, decoded[0].OnchainMetadata)
// CallbackGasLimit is not ABI-encoded

require.Equal(t, reqId2[:], decoded[1].RequestID)
require.Equal(t, compResult, decoded[1].Result)
require.Equal(t, []byte{}, decoded[1].Error)
require.Equal(t, coordinatorContract2[:], decoded[1].CoordinatorContract)
require.Equal(t, coordinatorContract[:], decoded[1].CoordinatorContract)
require.Equal(t, meta2, decoded[1].OnchainMetadata)
// CallbackGasLimit is not ABI-encoded
}

func TestFunctionsReporting_Report_HandleCoordinatorMismatch(t *testing.T) {
t.Parallel()
plugin, _, codec := preparePlugin(t, 10, 1, 300000)
reqId1, reqId2, reqId3 := newRequestID(), newRequestID(), newRequestID()
compResult, meta := []byte("aaa"), []byte("meta")
coordinatorContractA, coordinatorContractB := common.Address{1}, common.Address{2}
procReq1 := newProcessedRequestWithMeta(reqId1, compResult, []byte{}, 0, coordinatorContractA[:], meta)
procReq2 := newProcessedRequestWithMeta(reqId2, compResult, []byte{}, 0, coordinatorContractB[:], meta)
procReq3 := newProcessedRequestWithMeta(reqId3, compResult, []byte{}, 0, coordinatorContractA[:], meta)

query := newMarshalledQuery(t, reqId1, reqId2, reqId3, reqId1, reqId2) // duplicates should be ignored
obs := []types.AttributedObservation{
newObservation(t, 1, procReq2, procReq3, procReq1),
newObservation(t, 2, procReq1, procReq2, procReq3),
newObservation(t, 3, procReq3, procReq1, procReq2),
}

produced, reportBytes, err := plugin.Report(testutils.Context(t), types.ReportTimestamp{}, query, obs)
require.True(t, produced)
require.NoError(t, err)

decoded, err := codec.DecodeReport(reportBytes)
require.NoError(t, err)
require.Equal(t, 2, len(decoded))

require.Equal(t, reqId1[:], decoded[0].RequestID)
require.Equal(t, reqId3[:], decoded[1].RequestID)
// reqId2 should be excluded from this report because it has a different coordinator address
}

func TestFunctionsReporting_Report_CallbackGasLimitExceeded(t *testing.T) {
t.Parallel()
plugin, _, codec := preparePlugin(t, 10, 1, 200000)
Expand Down Expand Up @@ -438,3 +488,34 @@ func TestFunctionsReporting_ShouldTransmitAcceptedReport(t *testing.T) {
require.NoError(t, err)
require.True(t, should)
}

func TestFunctionsReporting_ShouldIncludeCoordinator(t *testing.T) {
t.Parallel()

zeroAddr, coord1, coord2 := &common.Address{}, &common.Address{1}, &common.Address{2}

// should never pass nil requestCoordinator
newCoord, err := functions.ShouldIncludeCoordinator(nil, nil)
require.Error(t, err)
require.Nil(t, newCoord)

// should never pass zero requestCoordinator
newCoord, err = functions.ShouldIncludeCoordinator(zeroAddr, nil)
require.Error(t, err)
require.Nil(t, newCoord)

// overwrite nil reportCoordinator
newCoord, err = functions.ShouldIncludeCoordinator(coord1, nil)
require.NoError(t, err)
require.Equal(t, coord1, newCoord)

// same address is fine
newCoord, err = functions.ShouldIncludeCoordinator(coord1, newCoord)
require.NoError(t, err)
require.Equal(t, coord1, newCoord)

// different address is not accepted
newCoord, err = functions.ShouldIncludeCoordinator(coord2, newCoord)
require.Error(t, err)
require.Equal(t, coord1, newCoord)
}
17 changes: 15 additions & 2 deletions core/services/relay/evm/functions/contract_transmitter.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package functions

import (
"bytes"
"context"
"database/sql"
"encoding/hex"
"fmt"
"math/big"
"sync/atomic"
Expand Down Expand Up @@ -137,9 +139,20 @@ func (oc *contractTransmitter) Transmit(ctx context.Context, reportCtx ocrtypes.
if len(requests[0].CoordinatorContract) != common.AddressLength {
return fmt.Errorf("FunctionsContractTransmitter: incorrect length of CoordinatorContract field: %d", len(requests[0].CoordinatorContract))
}
// NOTE: this is incorrect if batch contains requests destined for different contracts (unlikely)
// it will be fixed when we get rid of batching
destinationContract.SetBytes(requests[0].CoordinatorContract)
if destinationContract == (common.Address{}) {
return errors.New("FunctionsContractTransmitter: destination coordinator contract is zero")
}
// Sanity check - every report should contain requests with the same coordinator contract.
for _, req := range requests[1:] {
if !bytes.Equal(req.CoordinatorContract, destinationContract.Bytes()) {
oc.lggr.Errorw("FunctionsContractTransmitter: non-uniform coordinator addresses in a batch - still sending to a single destination",
"requestID", hex.EncodeToString(req.RequestID),
"destinationContract", destinationContract,
"requestCoordinator", hex.EncodeToString(req.CoordinatorContract),
)
}
}
oc.lggr.Debugw("FunctionsContractTransmitter: ready", "nRequests", len(requests), "coordinatorContract", destinationContract.Hex())
} else {
return fmt.Errorf("unsupported contract version: %d", oc.contractVersion)
Expand Down
41 changes: 41 additions & 0 deletions core/services/relay/evm/functions/contract_transmitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,44 @@ func TestContractTransmitter_Transmit_V1(t *testing.T) {
}
require.Error(t, ot.Transmit(testutils.Context(t), ocrtypes.ReportContext{}, reportBytes, signatures))
}

func TestContractTransmitter_Transmit_V1_CoordinatorMismatch(t *testing.T) {
t.Parallel()

contractVersion := uint32(1)
configuredDestAddress, coordinatorAddress1, coordinatorAddress2 := testutils.NewAddress(), testutils.NewAddress(), testutils.NewAddress()
lggr := logger.TestLogger(t)
c := evmclimocks.NewClient(t)
lp := lpmocks.NewLogPoller(t)
contractABI, _ := abi.JSON(strings.NewReader(ocr2aggregator.OCR2AggregatorABI))
lp.On("RegisterFilter", mock.Anything).Return(nil)

ocrTransmitter := mockTransmitter{}
ot, err := functions.NewFunctionsContractTransmitter(c, contractABI, &ocrTransmitter, lp, lggr, func(b []byte) (*txmgr.TxMeta, error) {
return &txmgr.TxMeta{}, nil
}, contractVersion)
require.NoError(t, err)
require.NoError(t, ot.UpdateRoutes(configuredDestAddress, configuredDestAddress))

reqId1, err := hex.DecodeString("110102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f")
require.NoError(t, err)
reqId2, err := hex.DecodeString("220102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f")
require.NoError(t, err)
processedRequests := []*encoding.ProcessedRequest{
{
RequestID: reqId1,
CoordinatorContract: coordinatorAddress1.Bytes(),
},
{
RequestID: reqId2,
CoordinatorContract: coordinatorAddress2.Bytes(),
},
}
codec, err := encoding.NewReportCodec(contractVersion)
require.NoError(t, err)
reportBytes, err := codec.EncodeReport(processedRequests)
require.NoError(t, err)

require.NoError(t, ot.Transmit(testutils.Context(t), ocrtypes.ReportContext{}, reportBytes, []ocrtypes.AttributedOnchainSignature{}))
require.Equal(t, coordinatorAddress1, ocrTransmitter.toAddress)
}

0 comments on commit 4a2fc41

Please sign in to comment.