diff --git a/pkg/solana/chainreader/chain_reader.go b/pkg/solana/chainreader/chain_reader.go index b52c72bb5..1332743fa 100644 --- a/pkg/solana/chainreader/chain_reader.go +++ b/pkg/solana/chainreader/chain_reader.go @@ -187,7 +187,7 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReader return err } - idlCodec, err := codec.NewIDLCodec(idl, config.BuilderForEncoding(method.Encoding)) + idlCodec, err := codec.NewIDLAccountCodec(idl, config.BuilderForEncoding(method.Encoding)) if err != nil { return err } diff --git a/pkg/solana/chainreader/chain_reader_test.go b/pkg/solana/chainreader/chain_reader_test.go index 7df246214..d1387af00 100644 --- a/pkg/solana/chainreader/chain_reader_test.go +++ b/pkg/solana/chainreader/chain_reader_test.go @@ -271,7 +271,7 @@ func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) { t.FailNow() } - entry, err := codec.NewIDLCodec(idl, binary.LittleEndian()) + entry, err := codec.NewIDLAccountCodec(idl, binary.LittleEndian()) if err != nil { t.Logf("failed to create new codec from test IDL: %s", err.Error()) t.FailNow() @@ -722,7 +722,7 @@ func makeTestCodec(t *testing.T, rawIDL string, encoding config.EncodingType) ty t.FailNow() } - testCodec, err := codec.NewIDLCodec(idl, config.BuilderForEncoding(encoding)) + testCodec, err := codec.NewIDLAccountCodec(idl, config.BuilderForEncoding(encoding)) if err != nil { t.Logf("failed to create new codec from test IDL: %s", err.Error()) t.FailNow() diff --git a/pkg/solana/codec/discriminator.go b/pkg/solana/codec/discriminator.go new file mode 100644 index 000000000..f712a3f68 --- /dev/null +++ b/pkg/solana/codec/discriminator.go @@ -0,0 +1,71 @@ +package codec + +import ( + "bytes" + "crypto/sha256" + "fmt" + "reflect" + + "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" + "github.com/smartcontractkit/chainlink-common/pkg/types" +) + +const discriminatorLength = 8 + +func NewDiscriminator(name string) encodings.TypeCodec { + sum := sha256.Sum256([]byte("account:" + name)) + return &discriminator{hashPrefix: sum[:discriminatorLength]} +} + +type discriminator struct { + hashPrefix []byte +} + +func (d discriminator) Encode(value any, into []byte) ([]byte, error) { + if value == nil { + return append(into, d.hashPrefix...), nil + } + + raw, ok := value.(*[]byte) + if !ok { + return nil, fmt.Errorf("%w: value must be a byte slice got %T", types.ErrInvalidType, value) + } + + // inject if not specified + if raw == nil { + return append(into, d.hashPrefix...), nil + } + + // Not sure if we should really be encoding accounts... + if !bytes.Equal(*raw, d.hashPrefix) { + return nil, fmt.Errorf("%w: invalid discriminator expected %x got %x", types.ErrInvalidType, d.hashPrefix, raw) + } + + return append(into, *raw...), nil +} + +func (d discriminator) Decode(encoded []byte) (any, []byte, error) { + raw, remaining, err := encodings.SafeDecode(encoded, discriminatorLength, func(raw []byte) []byte { return raw }) + if err != nil { + return nil, nil, err + } + + if !bytes.Equal(raw, d.hashPrefix) { + return nil, nil, fmt.Errorf("%w: invalid discriminator expected %x got %x", types.ErrInvalidEncoding, d.hashPrefix, raw) + } + + return &raw, remaining, nil +} + +func (d discriminator) GetType() reflect.Type { + // Pointer type so that nil can inject values and so that the NamedCodec won't wrap with no-nil pointer. + return reflect.TypeOf(&[]byte{}) +} + +func (d discriminator) Size(_ int) (int, error) { + return discriminatorLength, nil +} + +func (d discriminator) FixedSize() (int, error) { + return discriminatorLength, nil +} diff --git a/pkg/solana/codec/discriminator_test.go b/pkg/solana/codec/discriminator_test.go new file mode 100644 index 000000000..ebd72a60a --- /dev/null +++ b/pkg/solana/codec/discriminator_test.go @@ -0,0 +1,83 @@ +package codec_test + +import ( + "crypto/sha256" + "errors" + "reflect" + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" +) + +func TestDiscriminator(t *testing.T) { + t.Run("encode and decode return the discriminator", func(t *testing.T) { + tmp := sha256.Sum256([]byte("account:Foo")) + expected := tmp[:8] + c := codec.NewDiscriminator("Foo") + encoded, err := c.Encode(&expected, nil) + require.NoError(t, err) + require.Equal(t, expected, encoded) + actual, remaining, err := c.Decode(encoded) + require.NoError(t, err) + require.Equal(t, &expected, actual) + require.Len(t, remaining, 0) + }) + + t.Run("encode returns an error if the discriminator is invalid", func(t *testing.T) { + c := codec.NewDiscriminator("Foo") + _, err := c.Encode(&[]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil) + require.True(t, errors.Is(err, types.ErrInvalidType)) + }) + + t.Run("encode injects the discriminator if it's not provided", func(t *testing.T) { + tmp := sha256.Sum256([]byte("account:Foo")) + expected := tmp[:8] + c := codec.NewDiscriminator("Foo") + encoded, err := c.Encode(nil, nil) + require.NoError(t, err) + require.Equal(t, expected, encoded) + encoded, err = c.Encode((*[]byte)(nil), nil) + require.NoError(t, err) + require.Equal(t, expected, encoded) + }) + + t.Run("decode returns an error if the encoded value is too short", func(t *testing.T) { + c := codec.NewDiscriminator("Foo") + _, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) + require.True(t, errors.Is(err, types.ErrInvalidEncoding)) + }) + + t.Run("decode returns an error if the discriminator is invalid", func(t *testing.T) { + c := codec.NewDiscriminator("Foo") + _, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}) + require.True(t, errors.Is(err, types.ErrInvalidEncoding)) + }) + + t.Run("encode returns an error if the value is not a byte slice", func(t *testing.T) { + c := codec.NewDiscriminator("Foo") + _, err := c.Encode(42, nil) + require.True(t, errors.Is(err, types.ErrInvalidType)) + }) + + t.Run("GetType returns the type of the discriminator", func(t *testing.T) { + c := codec.NewDiscriminator("Foo") + require.Equal(t, reflect.TypeOf(&[]byte{}), c.GetType()) + }) + + t.Run("Size returns the length of the discriminator", func(t *testing.T) { + c := codec.NewDiscriminator("Foo") + size, err := c.Size(0) + require.NoError(t, err) + require.Equal(t, 8, size) + }) + + t.Run("FixedSize returns the length of the discriminator", func(t *testing.T) { + c := codec.NewDiscriminator("Foo") + size, err := c.FixedSize() + require.NoError(t, err) + require.Equal(t, 8, size) + }) +} diff --git a/pkg/solana/codec/solana.go b/pkg/solana/codec/solana.go index fc28beb65..3675d954e 100644 --- a/pkg/solana/codec/solana.go +++ b/pkg/solana/codec/solana.go @@ -60,9 +60,18 @@ func NewNamedModifierCodec(original types.RemoteCodec, itemType string, modifier return modCodec, err } -// NewIDLCodec is for Anchor custom types -func NewIDLCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) { - accounts := make(encodings.LenientCodecFromTypeCodec) +// NewIDLAccountCodec is for Anchor custom types +func NewIDLAccountCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) { + return newIdlCoded(idl, builder, idl.Accounts, true) +} + +func NewIDLDefinedTypesCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) { + return newIdlCoded(idl, builder, idl.Types, false) +} + +func newIdlCoded( + idl IDL, builder encodings.Builder, from IdlTypeDefSlice, includeDiscriminator bool) (types.RemoteCodec, error) { + typeCodecs := make(encodings.LenientCodecFromTypeCodec) refs := &codecRefs{ builder: builder, @@ -71,22 +80,22 @@ func NewIDLCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) dependencies: make(map[string][]string), } - for _, account := range idl.Accounts { + for _, def := range from { var ( name string accCodec encodings.TypeCodec err error ) - name, accCodec, err = createNamedCodec(account, refs) + name, accCodec, err = createNamedCodec(def, refs, includeDiscriminator) if err != nil { return nil, err } - accounts[name] = accCodec + typeCodecs[name] = accCodec } - return accounts, nil + return typeCodecs, nil } type codecRefs struct { @@ -99,13 +108,14 @@ type codecRefs struct { func createNamedCodec( def IdlTypeDef, refs *codecRefs, + includeDiscriminator bool, ) (string, encodings.TypeCodec, error) { caser := cases.Title(language.English) name := def.Name switch def.Type.Kind { case IdlTypeDefTyKindStruct: - return asStruct(def, refs, name, caser) + return asStruct(def, refs, name, caser, includeDiscriminator) case IdlTypeDefTyKindEnum: variants := def.Type.Variants if !variants.IsAllUint8() { @@ -123,8 +133,17 @@ func asStruct( refs *codecRefs, name string, // name is the struct name and can be used in dependency checks caser cases.Caser, + includeDiscriminator bool, ) (string, encodings.TypeCodec, error) { - named := make([]encodings.NamedTypeCodec, len(*def.Type.Fields)) + desLen := 0 + if includeDiscriminator { + desLen = 1 + } + named := make([]encodings.NamedTypeCodec, len(*def.Type.Fields)+desLen) + + if includeDiscriminator { + named[0] = encodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name)} + } for idx, field := range *def.Type.Fields { fieldName := field.Name @@ -134,7 +153,7 @@ func asStruct( return name, nil, err } - named[idx] = encodings.NamedTypeCodec{Name: caser.String(fieldName), Codec: typedCodec} + named[idx+desLen] = encodings.NamedTypeCodec{Name: caser.String(fieldName), Codec: typedCodec} } structCodec, err := encodings.NewStructCodec(named) @@ -188,7 +207,7 @@ func asDefined(parentTypeName string, definedName *IdlTypeDefined, refs *codecRe saveDependency(refs, parentTypeName, definedName.Defined) - newTypeName, newTypeCodec, err := createNamedCodec(*nextDef, refs) + newTypeName, newTypeCodec, err := createNamedCodec(*nextDef, refs, false) if err != nil { return nil, err } diff --git a/pkg/solana/codec/solana_test.go b/pkg/solana/codec/solana_test.go index 403abed02..c10482eed 100644 --- a/pkg/solana/codec/solana_test.go +++ b/pkg/solana/codec/solana_test.go @@ -18,15 +18,19 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec/testutils" ) -func TestNewIDLCodec(t *testing.T) { +func TestNewIDLAccountCodec(t *testing.T) { + /// TODO this should run the codec interface tests t.Parallel() ctx := tests.Context(t) - _, _, entry := newTestIDLAndCodec(t) + _, _, entry := newTestIDLAndCodec(t, true) expected := testutils.DefaultTestStruct bts, err := entry.Encode(ctx, expected, testutils.TestStructWithNestedStruct) + // length of fields + discriminator + require.Equal(t, 262, len(bts)) + require.NoError(t, err) var decoded testutils.StructWithNestedStruct @@ -35,11 +39,32 @@ func TestNewIDLCodec(t *testing.T) { require.Equal(t, expected, decoded) } +func TestNewIDLDefinedTypesCodecCodec(t *testing.T) { + /// TODO this should run the codec interface tests + t.Parallel() + + ctx := tests.Context(t) + _, _, entry := newTestIDLAndCodec(t, false) + + expected := testutils.DefaultTestStruct + bts, err := entry.Encode(ctx, expected, testutils.TestStructWithNestedStructType) + + // length of fields without a discriminator + require.Equal(t, 254, len(bts)) + + require.NoError(t, err) + + var decoded testutils.StructWithNestedStruct + + require.NoError(t, entry.Decode(ctx, bts, &decoded, testutils.TestStructWithNestedStructType)) + require.Equal(t, expected, decoded) +} + func TestNewIDLCodec_WithModifiers(t *testing.T) { t.Parallel() ctx := tests.Context(t) - _, _, idlCodec := newTestIDLAndCodec(t) + _, _, idlCodec := newTestIDLAndCodec(t, true) modConfig := codeccommon.ModifiersConfig{ &codeccommon.RenameModifierConfig{Fields: map[string]string{"Value": "V"}}, } @@ -113,12 +138,12 @@ func TestNewIDLCodec_CircularDependency(t *testing.T) { t.FailNow() } - _, err := codec.NewIDLCodec(idl, binary.LittleEndian()) + _, err := codec.NewIDLAccountCodec(idl, binary.LittleEndian()) assert.ErrorIs(t, err, types.ErrInvalidConfig) } -func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) { +func newTestIDLAndCodec(t *testing.T, account bool) (string, codec.IDL, types.RemoteCodec) { t.Helper() var idl codec.IDL @@ -127,7 +152,14 @@ func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) { t.FailNow() } - entry, err := codec.NewIDLCodec(idl, binary.LittleEndian()) + var entry types.RemoteCodec + var err error + if account { + entry, err = codec.NewIDLAccountCodec(idl, binary.LittleEndian()) + } else { + entry, err = codec.NewIDLDefinedTypesCodec(idl, binary.LittleEndian()) + } + if err != nil { t.Logf("failed to create new codec from test IDL: %s", err.Error()) t.FailNow() diff --git a/pkg/solana/codec/testutils/testIDL.json b/pkg/solana/codec/testutils/testIDL.json index d05496ee5..0ab037721 100644 --- a/pkg/solana/codec/testutils/testIDL.json +++ b/pkg/solana/codec/testutils/testIDL.json @@ -77,6 +77,79 @@ } ], "types": [ + { + "name": "StructWithNestedStructType", + "type": { + "kind": "struct", + "fields": [ + { + "name": "value", + "type": "u8" + }, + { + "name": "innerStruct", + "type": { + "defined": "ObjectRef1" + } + }, + { + "name": "basicNestedArray", + "type": { + "array": [ + { + "array": [ + "u32", + 3 + ] + }, + 3 + ] + } + }, + { + "name": "option", + "type": { + "option": "string" + } + }, + { + "name": "definedArray", + "type": { + "array": [ + { + "defined": "ObjectRef2" + }, + 2 + ] + } + }, + { + "name": "basicVector", + "type": { + "vec": "string" + } + }, + { + "name": "timeVal", + "type": "unixTimestamp" + }, + { + "name": "durationVal", + "type": "duration" + }, + { + "name": "publicKey", + "type": "publicKey" + }, + { + "name": "enumVal", + "type": { + "defined": "SimpleEnum" + } + } + ] + } + }, { "name": "ObjectRef1", "type": { diff --git a/pkg/solana/codec/testutils/types.go b/pkg/solana/codec/testutils/types.go index 7c20762f7..533e88b0b 100644 --- a/pkg/solana/codec/testutils/types.go +++ b/pkg/solana/codec/testutils/types.go @@ -9,9 +9,10 @@ import ( ) var ( - TestStructWithNestedStruct = "StructWithNestedStruct" - DefaultStringRef = "test string" - DefaultTestStruct = StructWithNestedStruct{ + TestStructWithNestedStruct = "StructWithNestedStruct" + TestStructWithNestedStructType = "StructWithNestedStructType" + DefaultStringRef = "test string" + DefaultTestStruct = StructWithNestedStruct{ Value: 80, InnerStruct: ObjectRef1{ Prop1: 10,