diff --git a/pkg/loop/internal/chain_reader.go b/pkg/loop/internal/chain_reader.go index 143a8ed17f..d8c111b595 100644 --- a/pkg/loop/internal/chain_reader.go +++ b/pkg/loop/internal/chain_reader.go @@ -6,6 +6,7 @@ import ( "github.com/smartcontractkit/chainlink-relay/pkg/loop/internal/pb" "github.com/smartcontractkit/chainlink-relay/pkg/types" + "github.com/smartcontractkit/chainlink-relay/pkg/utils" ) var _ types.ChainReader = (*chainReaderClient)(nil) @@ -15,7 +16,11 @@ type chainReaderClient struct { grpc pb.ChainReaderClient } -func (c *chainReaderClient) GetLatestValue(ctx context.Context, bc types.BoundContract, method string, params, retVal any) error { +func (c *chainReaderClient) GetLatestValue(ctx context.Context, bc types.BoundContract, method string, params, returnVal any) error { + if err := utils.ValidateStructPtr(returnVal); err != nil { + return err + } + boundContract := pb.BoundContract{Name: bc.Name, Address: bc.Address, Pending: bc.Pending} jsonParams, err := json.Marshal(params) if err != nil { @@ -26,7 +31,7 @@ func (c *chainReaderClient) GetLatestValue(ctx context.Context, bc types.BoundCo if err != nil { return err } - return json.Unmarshal(reply.RetVal, &retVal) + return json.Unmarshal(reply.RetVal, &returnVal) } var _ pb.ChainReaderServer = (*chainReaderServer)(nil) diff --git a/pkg/loop/internal/test/config.go b/pkg/loop/internal/test/config.go index b21e3558f2..f4b796de44 100644 --- a/pkg/loop/internal/test/config.go +++ b/pkg/loop/internal/test/config.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/smartcontractkit/chainlink-relay/pkg/types" + "github.com/smartcontractkit/chainlink-relay/pkg/utils" ) type staticConfigProvider struct{} @@ -93,6 +94,10 @@ func (s staticContractTransmitter) FromAccount() (libocr.Account, error) { type staticChainReader struct{} func (c staticChainReader) GetLatestValue(ctx context.Context, bc types.BoundContract, method string, params, returnVal any) error { + if err := utils.ValidateStructPtr(returnVal); err != nil { + return err + } + if !assert.ObjectsAreEqual(bc, boundContract) { return fmt.Errorf("expected report context %v but got %v", boundContract, bc) } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index aed521c621..15e76cb418 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -2,8 +2,10 @@ package utils import ( "context" + "fmt" "math" mrand "math/rand" + "reflect" "time" ) @@ -38,6 +40,21 @@ func ContextFromChan(chStop <-chan struct{}) (context.Context, context.CancelFun return ctx, cancel } +// ValidateStructPtr checks if the input value is a pointer to a non-nil struct. +func ValidateStructPtr(val any) error { + rv := reflect.ValueOf(val) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("value of type %T is not a non-nil struct pointer", val) + } + + rv = rv.Elem() + if !rv.IsValid() || rv.Kind() != reflect.Struct { + return fmt.Errorf("dereferenced value of type %T is not a valid struct", val) + } + + return nil +} + // ContextWithDeadlineFn returns a copy of the parent context with the deadline modified by deadlineFn. // deadlineFn will only be called if the parent has a deadline. // The new deadline must be sooner than the old to have an effect. diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go new file mode 100644 index 0000000000..84c24ddda6 --- /dev/null +++ b/pkg/utils/utils_test.go @@ -0,0 +1,61 @@ +package utils_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/smartcontractkit/chainlink-relay/pkg/utils" +) + +type ExampleStruct struct { + Name string +} + +func TestValidateStructPtr(t *testing.T) { + testNum := 42 + testCases := []struct { + name string + input interface{} + expect error + }{ + { + name: "struct pointer", + input: &ExampleStruct{"Bob"}, + expect: nil, + }, + { + name: "struct", + input: ExampleStruct{"Alice"}, + expect: errors.New("value of type utils_test.ExampleStruct is not a non-nil struct pointer"), + }, + { + name: "nil struct pointer", + input: (*ExampleStruct)(nil), + expect: errors.New("value of type *utils_test.ExampleStruct is not a non-nil struct pointer"), + }, + { + name: "nil", + input: nil, + expect: errors.New("value of type is not a non-nil struct pointer"), + }, + { + name: "non struct pointer", + input: &testNum, + expect: errors.New("dereferenced value of type *int is not a valid struct"), + }, + { + name: "non struct type", + input: testNum, + expect: errors.New("value of type int is not a non-nil struct pointer"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := utils.ValidateStructPtr(tc.input) + assert.Equal(t, tc.expect, result) + }) + } +}