Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add discriminator to the account and allow encoding/decoding of non-account types #672

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReader
return err
}

idlCodec, err := codec.NewIDLCodec(idl, config.BuilderForEncoding(method.Encoding))
idlCodec, err := codec.NewIDLAccountCodec(idl, config.BuilderForEncoding(method.Encoding))
if err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/solana/chainreader/chain_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) {
t.FailNow()
}

entry, err := codec.NewIDLCodec(idl, binary.LittleEndian())
entry, err := codec.NewIDLAccountCodec(idl, binary.LittleEndian())
if err != nil {
t.Logf("failed to create new codec from test IDL: %s", err.Error())
t.FailNow()
Expand Down Expand Up @@ -328,6 +328,7 @@ type mockedRPCCall struct {
delay time.Duration
}

// TODO BCI-3156 use a localnet for testing instead of a mock.
type mockedRPCClient struct {
mu sync.Mutex
responseByAddress map[string]mockedRPCCall
Expand Down Expand Up @@ -722,7 +723,7 @@ func makeTestCodec(t *testing.T, rawIDL string, encoding config.EncodingType) ty
t.FailNow()
}

testCodec, err := codec.NewIDLCodec(idl, config.BuilderForEncoding(encoding))
testCodec, err := codec.NewIDLAccountCodec(idl, config.BuilderForEncoding(encoding))
if err != nil {
t.Logf("failed to create new codec from test IDL: %s", err.Error())
t.FailNow()
Expand Down
71 changes: 71 additions & 0 deletions pkg/solana/codec/discriminator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package codec

import (
"bytes"
"crypto/sha256"
"fmt"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/codec/encodings"
"github.com/smartcontractkit/chainlink-common/pkg/types"
)

const discriminatorLength = 8

func NewDiscriminator(name string) encodings.TypeCodec {
sum := sha256.Sum256([]byte("account:" + name))
return &discriminator{hashPrefix: sum[:discriminatorLength]}
}

type discriminator struct {
hashPrefix []byte
}

func (d discriminator) Encode(value any, into []byte) ([]byte, error) {
if value == nil {
return append(into, d.hashPrefix...), nil
}

raw, ok := value.(*[]byte)
if !ok {
return nil, fmt.Errorf("%w: value must be a byte slice got %T", types.ErrInvalidType, value)
}

// inject if not specified
if raw == nil {
return append(into, d.hashPrefix...), nil
}

// Not sure if we should really be encoding accounts...
if !bytes.Equal(*raw, d.hashPrefix) {
return nil, fmt.Errorf("%w: invalid discriminator expected %x got %x", types.ErrInvalidType, d.hashPrefix, raw)
}

return append(into, *raw...), nil
}

func (d discriminator) Decode(encoded []byte) (any, []byte, error) {
raw, remaining, err := encodings.SafeDecode(encoded, discriminatorLength, func(raw []byte) []byte { return raw })
if err != nil {
return nil, nil, err
}

if !bytes.Equal(raw, d.hashPrefix) {
return nil, nil, fmt.Errorf("%w: invalid discriminator expected %x got %x", types.ErrInvalidEncoding, d.hashPrefix, raw)
}

return &raw, remaining, nil
}

func (d discriminator) GetType() reflect.Type {
// Pointer type so that nil can inject values and so that the NamedCodec won't wrap with no-nil pointer.
return reflect.TypeOf(&[]byte{})
}

func (d discriminator) Size(_ int) (int, error) {
return discriminatorLength, nil
}

func (d discriminator) FixedSize() (int, error) {
return discriminatorLength, nil
}
83 changes: 83 additions & 0 deletions pkg/solana/codec/discriminator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package codec_test

import (
"crypto/sha256"
"errors"
"reflect"
"testing"

"github.com/smartcontractkit/chainlink-common/pkg/types"
"github.com/stretchr/testify/require"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec"
)

func TestDiscriminator(t *testing.T) {
t.Run("encode and decode return the discriminator", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
encoded, err := c.Encode(&expected, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
actual, remaining, err := c.Decode(encoded)
require.NoError(t, err)
require.Equal(t, &expected, actual)
require.Len(t, remaining, 0)
})

t.Run("encode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
_, err := c.Encode(&[]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("encode injects the discriminator if it's not provided", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
encoded, err := c.Encode(nil, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
encoded, err = c.Encode((*[]byte)(nil), nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
})

t.Run("decode returns an error if the encoded value is too short", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("decode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("encode returns an error if the value is not a byte slice", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
_, err := c.Encode(42, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("GetType returns the type of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
require.Equal(t, reflect.TypeOf(&[]byte{}), c.GetType())
})

t.Run("Size returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
size, err := c.Size(0)
require.NoError(t, err)
require.Equal(t, 8, size)
})

t.Run("FixedSize returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
size, err := c.FixedSize()
require.NoError(t, err)
require.Equal(t, 8, size)
})
}
41 changes: 30 additions & 11 deletions pkg/solana/codec/solana.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ func NewNamedModifierCodec(original types.RemoteCodec, itemType string, modifier
return modCodec, err
}

// NewIDLCodec is for Anchor custom types
func NewIDLCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) {
accounts := make(encodings.LenientCodecFromTypeCodec)
// NewIDLAccountCodec is for Anchor custom types
func NewIDLAccountCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) {
return newIDLCoded(idl, builder, idl.Accounts, true)
}

func NewIDLDefinedTypesCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) {
return newIDLCoded(idl, builder, idl.Types, false)
}

func newIDLCoded(
idl IDL, builder encodings.Builder, from IdlTypeDefSlice, includeDiscriminator bool) (types.RemoteCodec, error) {
typeCodecs := make(encodings.LenientCodecFromTypeCodec)

refs := &codecRefs{
builder: builder,
Expand All @@ -71,22 +80,22 @@ func NewIDLCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error)
dependencies: make(map[string][]string),
}

for _, account := range idl.Accounts {
for _, def := range from {
var (
name string
accCodec encodings.TypeCodec
err error
)

name, accCodec, err = createNamedCodec(account, refs)
name, accCodec, err = createNamedCodec(def, refs, includeDiscriminator)
if err != nil {
return nil, err
}

accounts[name] = accCodec
typeCodecs[name] = accCodec
}

return accounts, nil
return typeCodecs, nil
}

type codecRefs struct {
Expand All @@ -99,13 +108,14 @@ type codecRefs struct {
func createNamedCodec(
def IdlTypeDef,
refs *codecRefs,
includeDiscriminator bool,
) (string, encodings.TypeCodec, error) {
caser := cases.Title(language.English)
name := def.Name

switch def.Type.Kind {
case IdlTypeDefTyKindStruct:
return asStruct(def, refs, name, caser)
return asStruct(def, refs, name, caser, includeDiscriminator)
case IdlTypeDefTyKindEnum:
variants := def.Type.Variants
if !variants.IsAllUint8() {
Expand All @@ -123,8 +133,17 @@ func asStruct(
refs *codecRefs,
name string, // name is the struct name and can be used in dependency checks
caser cases.Caser,
includeDiscriminator bool,
) (string, encodings.TypeCodec, error) {
named := make([]encodings.NamedTypeCodec, len(*def.Type.Fields))
desLen := 0
if includeDiscriminator {
desLen = 1
}
named := make([]encodings.NamedTypeCodec, len(*def.Type.Fields)+desLen)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a nil pointer concern for dereferencing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add it in this PR, but I don't think so. The call only gets made if the type is a struct so I expect the fields would be present but empty if there are no fields.

There are a few derefs later on as well.


if includeDiscriminator {
named[0] = encodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name)}
}

for idx, field := range *def.Type.Fields {
fieldName := field.Name
Expand All @@ -134,7 +153,7 @@ func asStruct(
return name, nil, err
}

named[idx] = encodings.NamedTypeCodec{Name: caser.String(fieldName), Codec: typedCodec}
named[idx+desLen] = encodings.NamedTypeCodec{Name: caser.String(fieldName), Codec: typedCodec}
}

structCodec, err := encodings.NewStructCodec(named)
Expand Down Expand Up @@ -188,7 +207,7 @@ func asDefined(parentTypeName string, definedName *IdlTypeDefined, refs *codecRe

saveDependency(refs, parentTypeName, definedName.Defined)

newTypeName, newTypeCodec, err := createNamedCodec(*nextDef, refs)
newTypeName, newTypeCodec, err := createNamedCodec(*nextDef, refs, false)
if err != nil {
return nil, err
}
Expand Down
44 changes: 38 additions & 6 deletions pkg/solana/codec/solana_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@ import (
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec/testutils"
)

func TestNewIDLCodec(t *testing.T) {
func TestNewIDLAccountCodec(t *testing.T) {
/// TODO BCI-3155 this should run the codec interface tests
t.Parallel()

ctx := tests.Context(t)
_, _, entry := newTestIDLAndCodec(t)
_, _, entry := newTestIDLAndCodec(t, true)

expected := testutils.DefaultTestStruct
bts, err := entry.Encode(ctx, expected, testutils.TestStructWithNestedStruct)

// length of fields + discriminator
require.Equal(t, 262, len(bts))

require.NoError(t, err)

var decoded testutils.StructWithNestedStruct
Expand All @@ -35,11 +39,32 @@ func TestNewIDLCodec(t *testing.T) {
require.Equal(t, expected, decoded)
}

func TestNewIDLDefinedTypesCodecCodec(t *testing.T) {
/// TODO BCI-3155 this should run the codec interface tests
t.Parallel()

ctx := tests.Context(t)
_, _, entry := newTestIDLAndCodec(t, false)

expected := testutils.DefaultTestStruct
bts, err := entry.Encode(ctx, expected, testutils.TestStructWithNestedStructType)

// length of fields without a discriminator
require.Equal(t, 254, len(bts))

require.NoError(t, err)

var decoded testutils.StructWithNestedStruct

require.NoError(t, entry.Decode(ctx, bts, &decoded, testutils.TestStructWithNestedStructType))
require.Equal(t, expected, decoded)
}

func TestNewIDLCodec_WithModifiers(t *testing.T) {
t.Parallel()

ctx := tests.Context(t)
_, _, idlCodec := newTestIDLAndCodec(t)
_, _, idlCodec := newTestIDLAndCodec(t, true)
modConfig := codeccommon.ModifiersConfig{
&codeccommon.RenameModifierConfig{Fields: map[string]string{"Value": "V"}},
}
Expand Down Expand Up @@ -113,12 +138,12 @@ func TestNewIDLCodec_CircularDependency(t *testing.T) {
t.FailNow()
}

_, err := codec.NewIDLCodec(idl, binary.LittleEndian())
_, err := codec.NewIDLAccountCodec(idl, binary.LittleEndian())

assert.ErrorIs(t, err, types.ErrInvalidConfig)
}

func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) {
func newTestIDLAndCodec(t *testing.T, account bool) (string, codec.IDL, types.RemoteCodec) {
t.Helper()

var idl codec.IDL
Expand All @@ -127,7 +152,14 @@ func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) {
t.FailNow()
}

entry, err := codec.NewIDLCodec(idl, binary.LittleEndian())
var entry types.RemoteCodec
var err error
if account {
entry, err = codec.NewIDLAccountCodec(idl, binary.LittleEndian())
} else {
entry, err = codec.NewIDLDefinedTypesCodec(idl, binary.LittleEndian())
}

if err != nil {
t.Logf("failed to create new codec from test IDL: %s", err.Error())
t.FailNow()
Expand Down
Loading
Loading