Skip to content

Commit

Permalink
Add returnVal any param validation to chainReader GetLatestValue
Browse files Browse the repository at this point in the history
  • Loading branch information
ilija42 committed Oct 24, 2023
1 parent 0e354da commit 1530d3d
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 2 deletions.
9 changes: 7 additions & 2 deletions pkg/loop/internal/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/smartcontractkit/chainlink-relay/pkg/loop/internal/pb"
"github.com/smartcontractkit/chainlink-relay/pkg/types"
"github.com/smartcontractkit/chainlink-relay/pkg/utils"
)

var _ types.ChainReader = (*chainReaderClient)(nil)
Expand All @@ -15,7 +16,11 @@ type chainReaderClient struct {
grpc pb.ChainReaderClient
}

func (c *chainReaderClient) GetLatestValue(ctx context.Context, bc types.BoundContract, method string, params, retVal any) error {
func (c *chainReaderClient) GetLatestValue(ctx context.Context, bc types.BoundContract, method string, params, returnVal any) error {
if err := utils.ValidateStructPtr(returnVal); err != nil {
return err
}

boundContract := pb.BoundContract{Name: bc.Name, Address: bc.Address, Pending: bc.Pending}
jsonParams, err := json.Marshal(params)
if err != nil {
Expand All @@ -26,7 +31,7 @@ func (c *chainReaderClient) GetLatestValue(ctx context.Context, bc types.BoundCo
if err != nil {
return err
}
return json.Unmarshal(reply.RetVal, &retVal)
return json.Unmarshal(reply.RetVal, &returnVal)
}

var _ pb.ChainReaderServer = (*chainReaderServer)(nil)
Expand Down
5 changes: 5 additions & 0 deletions pkg/loop/internal/test/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"

"github.com/smartcontractkit/chainlink-relay/pkg/types"
"github.com/smartcontractkit/chainlink-relay/pkg/utils"
)

type staticConfigProvider struct{}
Expand Down Expand Up @@ -93,6 +94,10 @@ func (s staticContractTransmitter) FromAccount() (libocr.Account, error) {
type staticChainReader struct{}

func (c staticChainReader) GetLatestValue(ctx context.Context, bc types.BoundContract, method string, params, returnVal any) error {
if err := utils.ValidateStructPtr(returnVal); err != nil {
return err
}

if !assert.ObjectsAreEqual(bc, boundContract) {
return fmt.Errorf("expected report context %v but got %v", boundContract, bc)
}
Expand Down
17 changes: 17 additions & 0 deletions pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package utils

import (
"context"
"fmt"
"math"
mrand "math/rand"
"reflect"
"time"
)

Expand Down Expand Up @@ -38,6 +40,21 @@ func ContextFromChan(chStop <-chan struct{}) (context.Context, context.CancelFun
return ctx, cancel
}

// ValidateStructPtr checks if the input value is a pointer to a non-nil struct.
func ValidateStructPtr(val any) error {
rv := reflect.ValueOf(val)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return fmt.Errorf("value of type %T is not a non-nil struct pointer", val)
}

rv = rv.Elem()
if !rv.IsValid() || rv.Kind() != reflect.Struct {
return fmt.Errorf("dereferenced value of type %T is not a valid struct", val)
}

return nil
}

// ContextWithDeadlineFn returns a copy of the parent context with the deadline modified by deadlineFn.
// deadlineFn will only be called if the parent has a deadline.
// The new deadline must be sooner than the old to have an effect.
Expand Down
61 changes: 61 additions & 0 deletions pkg/utils/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package utils_test

import (
"errors"
"testing"

"github.com/stretchr/testify/assert"

"github.com/smartcontractkit/chainlink-relay/pkg/utils"
)

type ExampleStruct struct {
Name string
}

func TestValidateStructPtr(t *testing.T) {
testNum := 42
testCases := []struct {
name string
input interface{}
expect error
}{
{
name: "struct pointer",
input: &ExampleStruct{"Bob"},
expect: nil,
},
{
name: "struct",
input: ExampleStruct{"Alice"},
expect: errors.New("value of type utils_test.ExampleStruct is not a non-nil struct pointer"),
},
{
name: "nil struct pointer",
input: (*ExampleStruct)(nil),
expect: errors.New("value of type *utils_test.ExampleStruct is not a non-nil struct pointer"),
},
{
name: "nil",
input: nil,
expect: errors.New("value of type <nil> is not a non-nil struct pointer"),
},
{
name: "non struct pointer",
input: &testNum,
expect: errors.New("dereferenced value of type *int is not a valid struct"),
},
{
name: "non struct type",
input: testNum,
expect: errors.New("value of type int is not a non-nil struct pointer"),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := utils.ValidateStructPtr(tc.input)
assert.Equal(t, tc.expect, result)
})
}
}

0 comments on commit 1530d3d

Please sign in to comment.