Skip to content

Commit

Permalink
[Functions] Remove V0 support from allowlist reader, report encoding …
Browse files Browse the repository at this point in the history
…and contract transmitter (#11057)
  • Loading branch information
bolekk authored Oct 24, 2023
1 parent 848bfe2 commit a632a91
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 251 deletions.
28 changes: 3 additions & 25 deletions core/services/gateway/handlers/functions/allowlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client"
"github.com/smartcontractkit/chainlink/v2/core/gethwrappers/functions/generated/functions_allow_list"
"github.com/smartcontractkit/chainlink/v2/core/gethwrappers/functions/generated/functions_router"
"github.com/smartcontractkit/chainlink/v2/core/gethwrappers/functions/generated/ocr2dr_oracle"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/job"
"github.com/smartcontractkit/chainlink/v2/core/utils"
Expand Down Expand Up @@ -51,7 +50,6 @@ type onchainAllowlist struct {
config OnchainAllowlistConfig
allowlist atomic.Pointer[map[common.Address]struct{}]
client evmclient.Client
contractV0 *ocr2dr_oracle.OCR2DROracle
contractV1 *functions_router.FunctionsRouter
blockConfirmations *big.Int
lggr logger.Logger
Expand All @@ -66,9 +64,8 @@ func NewOnchainAllowlist(client evmclient.Client, config OnchainAllowlistConfig,
if lggr == nil {
return nil, errors.New("logger is nil")
}
contractV0, err := ocr2dr_oracle.NewOCR2DROracle(config.ContractAddress, client)
if err != nil {
return nil, fmt.Errorf("unexpected error during NewOCR2DROracle: %s", err)
if config.ContractVersion != 1 {
return nil, fmt.Errorf("unsupported contract version %d", config.ContractVersion)
}
contractV1, err := functions_router.NewFunctionsRouter(config.ContractAddress, client)
if err != nil {
Expand All @@ -77,7 +74,6 @@ func NewOnchainAllowlist(client evmclient.Client, config OnchainAllowlistConfig,
allowlist := &onchainAllowlist{
config: config,
client: client,
contractV0: contractV0,
contractV1: contractV1,
blockConfirmations: big.NewInt(int64(config.BlockConfirmations)),
lggr: lggr.Named("OnchainAllowlist"),
Expand Down Expand Up @@ -148,25 +144,7 @@ func (a *onchainAllowlist) UpdateFromContract(ctx context.Context) error {
return errors.New("LatestBlockHeight returned nil")
}
blockNum := big.NewInt(0).Sub(latestBlockHeight, a.blockConfirmations)
if a.config.ContractVersion == 0 {
return a.updateFromContractV0(ctx, blockNum)
} else if a.config.ContractVersion == 1 {
return a.updateFromContractV1(ctx, blockNum)
}
return fmt.Errorf("unknown contract version %d", a.config.ContractVersion)
}

func (a *onchainAllowlist) updateFromContractV0(ctx context.Context, blockNum *big.Int) error {
addrList, err := a.contractV0.GetAuthorizedSenders(&bind.CallOpts{
Pending: false,
BlockNumber: blockNum,
Context: ctx,
})
if err != nil {
return errors.Wrap(err, "error calling GetAuthorizedSenders")
}
a.update(addrList)
return nil
return a.updateFromContractV1(ctx, blockNum)
}

func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *big.Int) error {
Expand Down
56 changes: 33 additions & 23 deletions core/services/gateway/handlers/functions/allowlist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,41 @@ func sampleEncodedAllowlist(t *testing.T) []byte {
func TestAllowlist_UpdateAndCheck(t *testing.T) {
t.Parallel()

for _, version := range []uint32{0, 1} {
client := mocks.NewClient(t)
client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil)
// both contract versions have the same return type
client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(sampleEncodedAllowlist(t), nil)
config := functions.OnchainAllowlistConfig{
ContractVersion: version,
ContractAddress: common.Address{},
BlockConfirmations: 1,
}
allowlist, err := functions.NewOnchainAllowlist(client, config, logger.TestLogger(t))
require.NoError(t, err)
client := mocks.NewClient(t)
client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil)
client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(sampleEncodedAllowlist(t), nil)
config := functions.OnchainAllowlistConfig{
ContractVersion: 1,
ContractAddress: common.Address{},
BlockConfirmations: 1,
}
allowlist, err := functions.NewOnchainAllowlist(client, config, logger.TestLogger(t))
require.NoError(t, err)

err = allowlist.Start(testutils.Context(t))
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, allowlist.Close())
})

require.NoError(t, allowlist.UpdateFromContract(testutils.Context(t)))
require.False(t, allowlist.Allow(common.Address{}))
require.True(t, allowlist.Allow(common.HexToAddress(addr1)))
require.True(t, allowlist.Allow(common.HexToAddress(addr2)))
require.False(t, allowlist.Allow(common.HexToAddress(addr3)))
}

err = allowlist.Start(testutils.Context(t))
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, allowlist.Close())
})
func TestAllowlist_UnsupportedVersion(t *testing.T) {
t.Parallel()

require.NoError(t, allowlist.UpdateFromContract(testutils.Context(t)))
require.False(t, allowlist.Allow(common.Address{}))
require.True(t, allowlist.Allow(common.HexToAddress(addr1)))
require.True(t, allowlist.Allow(common.HexToAddress(addr2)))
require.False(t, allowlist.Allow(common.HexToAddress(addr3)))
client := mocks.NewClient(t)
config := functions.OnchainAllowlistConfig{
ContractVersion: 0,
ContractAddress: common.Address{},
BlockConfirmations: 1,
}
_, err := functions.NewOnchainAllowlist(client, config, logger.TestLogger(t))
require.Error(t, err)
}

func TestAllowlist_UpdatePeriodically(t *testing.T) {
Expand All @@ -77,7 +87,7 @@ func TestAllowlist_UpdatePeriodically(t *testing.T) {
}).Return(sampleEncodedAllowlist(t), nil)
config := functions.OnchainAllowlistConfig{
ContractAddress: common.Address{},
ContractVersion: 0,
ContractVersion: 1,
BlockConfirmations: 1,
UpdateFrequencySec: 1,
UpdateTimeoutSec: 1,
Expand Down
85 changes: 0 additions & 85 deletions core/services/ocr2/plugins/functions/encoding/abi_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,12 @@ type ReportCodec interface {
DecodeReport(raw []byte) ([]*ProcessedRequest, error)
}

type reportCodecV0 struct {
reportTypes abi.Arguments
}

type reportCodecV1 struct {
reportTypes abi.Arguments
}

func NewReportCodec(contractVersion uint32) (ReportCodec, error) {
switch contractVersion {
case 0: // deprecated
reportTypes, err := getReportTypesV0()
if err != nil {
return nil, err
}
return &reportCodecV0{reportTypes: reportTypes}, nil
case 1:
reportTypes, err := getReportTypesV1()
if err != nil {
Expand All @@ -48,81 +38,6 @@ func SliceToByte32(slice []byte) ([32]byte, error) {
return res, nil
}

func getReportTypesV0() (abi.Arguments, error) {
bytes32ArrType, err := abi.NewType("bytes32[]", "", []abi.ArgumentMarshaling{})
if err != nil {
return nil, fmt.Errorf("unable to create an ABI type object for bytes32[]")
}
bytesArrType, err := abi.NewType("bytes[]", "", []abi.ArgumentMarshaling{})
if err != nil {
return nil, fmt.Errorf("unable to create an ABI type object for bytes[]")
}
return abi.Arguments([]abi.Argument{
{Name: "ids", Type: bytes32ArrType},
{Name: "results", Type: bytesArrType},
{Name: "errors", Type: bytesArrType},
}), nil
}

func (c *reportCodecV0) EncodeReport(requests []*ProcessedRequest) ([]byte, error) {
size := len(requests)
if size == 0 {
return []byte{}, nil
}
ids := make([][32]byte, size)
results := make([][]byte, size)
errors := make([][]byte, size)
for i := 0; i < size; i++ {
var err error
ids[i], err = SliceToByte32(requests[i].RequestID)
if err != nil {
return nil, err
}
results[i] = requests[i].Result
errors[i] = requests[i].Error
}
return c.reportTypes.Pack(ids, results, errors)
}

func (c *reportCodecV0) DecodeReport(raw []byte) ([]*ProcessedRequest, error) {
reportElems := map[string]interface{}{}
if err := c.reportTypes.UnpackIntoMap(reportElems, raw); err != nil {
return nil, errors.WithMessage(err, "unable to unpack elements from raw report")
}

idsIface, idsOK := reportElems["ids"]
resultsIface, resultsOK := reportElems["results"]
errorsIface, errorsOK := reportElems["errors"]
if !idsOK || !resultsOK || !errorsOK {
return nil, fmt.Errorf("missing arrays in raw report, ids: %v, results: %v, errors: %v", idsOK, resultsOK, errorsOK)
}

ids, idsOK := idsIface.([][32]byte)
results, resultsOK := resultsIface.([][]byte)
errors, errorsOK := errorsIface.([][]byte)
if !idsOK || !resultsOK || !errorsOK {
return nil, fmt.Errorf("unable to cast part of raw report into array type, ids: %v, results: %v, errors: %v", idsOK, resultsOK, errorsOK)
}

size := len(ids)
if len(results) != size || len(errors) != size {
return nil, fmt.Errorf("unequal sizes of arrays parsed from raw report, ids: %v, results: %v, errors: %v", len(ids), len(results), len(errors))
}
if size == 0 {
return []*ProcessedRequest{}, nil
}

decoded := make([]*ProcessedRequest, size)
for i := 0; i < size; i++ {
decoded[i] = &ProcessedRequest{
RequestID: ids[i][:],
Result: results[i],
Error: errors[i],
}
}
return decoded, nil
}

func getReportTypesV1() (abi.Arguments, error) {
bytes32ArrType, err := abi.NewType("bytes32[]", "", []abi.ArgumentMarshaling{})
if err != nil {
Expand Down
31 changes: 0 additions & 31 deletions core/services/ocr2/plugins/functions/encoding/abi_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,6 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/encoding"
)

func TestABICodec_EncodeDecodeV0Success(t *testing.T) {
t.Parallel()
codec, err := encoding.NewReportCodec(0)
require.NoError(t, err)

var report = []*encoding.ProcessedRequest{
{
RequestID: []byte(fmt.Sprintf("%032d", 123)),
Result: []byte("abcd"),
Error: []byte("err string"),
},
{
RequestID: []byte(fmt.Sprintf("%032d", 4321)),
Result: []byte("0xababababab"),
Error: []byte(""),
},
}

encoded, err := codec.EncodeReport(report)
require.NoError(t, err)
decoded, err := codec.DecodeReport(encoded)
require.NoError(t, err)

require.Equal(t, len(report), len(decoded))
for i := 0; i < len(report); i++ {
require.Equal(t, report[i].RequestID, decoded[i].RequestID, "RequestIDs not equal at index %d", i)
require.Equal(t, report[i].Result, decoded[i].Result, "Results not equal at index %d", i)
require.Equal(t, report[i].Error, decoded[i].Error, "Errors not equal at index %d", i)
}
}

func TestABICodec_EncodeDecodeV1Success(t *testing.T) {
t.Parallel()
codec, err := encoding.NewReportCodec(1)
Expand Down
60 changes: 27 additions & 33 deletions core/services/ocr2/plugins/functions/reporting.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,15 @@ func (r *functionsReporting) Query(ctx context.Context, ts types.ReportTimestamp
var reportCoordinator *common.Address
for _, result := range results {
result := result
if r.contractVersion == 1 {
reportCoordinator, err = ShouldIncludeCoordinator(result.CoordinatorContractAddress, reportCoordinator)
if err != nil {
r.logger.Debug("FunctionsReporting Query: skipping request with mismatched coordinator contract address", commontypes.LogFields{
"requestID": formatRequestId(result.RequestID[:]),
"requestCoordinator": result.CoordinatorContractAddress,
"reportCoordinator": reportCoordinator,
"error": err,
})
continue
}
reportCoordinator, err = ShouldIncludeCoordinator(result.CoordinatorContractAddress, reportCoordinator)
if err != nil {
r.logger.Debug("FunctionsReporting Query: skipping request with mismatched coordinator contract address", commontypes.LogFields{
"requestID": formatRequestId(result.RequestID[:]),
"requestCoordinator": result.CoordinatorContractAddress,
"reportCoordinator": reportCoordinator,
"error": err,
})
continue
}
queryProto.RequestIDs = append(queryProto.RequestIDs, result.RequestID[:])
idStrs = append(idStrs, formatRequestId(result.RequestID[:]))
Expand Down Expand Up @@ -236,16 +234,14 @@ func (r *functionsReporting) Observation(ctx context.Context, ts types.ReportTim
Error: localResult.Error,
OnchainMetadata: localResult.OnchainMetadata,
}
if r.contractVersion == 1 {
if localResult.CallbackGasLimit == nil || localResult.CoordinatorContractAddress == nil {
r.logger.Error("FunctionsReporting Observation missing required v1 fields", commontypes.LogFields{
"requestID": formatRequestId(id[:]),
})
continue
}
resultProto.CallbackGasLimit = *localResult.CallbackGasLimit
resultProto.CoordinatorContract = localResult.CoordinatorContractAddress[:]
if localResult.CallbackGasLimit == nil || localResult.CoordinatorContractAddress == nil {
r.logger.Error("FunctionsReporting Observation missing required v1 fields", commontypes.LogFields{
"requestID": formatRequestId(id[:]),
})
continue
}
resultProto.CallbackGasLimit = *localResult.CallbackGasLimit
resultProto.CoordinatorContract = localResult.CoordinatorContractAddress[:]
observationProto.ProcessedRequests = append(observationProto.ProcessedRequests, &resultProto)
idStrs = append(idStrs, formatRequestId(localResult.RequestID[:]))
}
Expand Down Expand Up @@ -367,19 +363,17 @@ func (r *functionsReporting) Report(ctx context.Context, ts types.ReportTimestam
"requestID": reqId,
"nObservations": len(observations),
})
if r.contractVersion == 1 {
var requestCoordinator common.Address
requestCoordinator.SetBytes(aggregated.CoordinatorContract)
reportCoordinator, err = ShouldIncludeCoordinator(&requestCoordinator, reportCoordinator)
if err != nil {
r.logger.Error("FunctionsReporting Report: skipping request with mismatched coordinator contract address", commontypes.LogFields{
"requestID": reqId,
"requestCoordinator": requestCoordinator,
"reportCoordinator": reportCoordinator,
"error": err,
})
continue
}
var requestCoordinator common.Address
requestCoordinator.SetBytes(aggregated.CoordinatorContract)
reportCoordinator, err = ShouldIncludeCoordinator(&requestCoordinator, reportCoordinator)
if err != nil {
r.logger.Error("FunctionsReporting Report: skipping request with mismatched coordinator contract address", commontypes.LogFields{
"requestID": reqId,
"requestCoordinator": requestCoordinator,
"reportCoordinator": reportCoordinator,
"error": err,
})
continue
}
allAggregated = append(allAggregated, aggregated)
allIdStrs = append(allIdStrs, reqId)
Expand Down
Loading

0 comments on commit a632a91

Please sign in to comment.