Skip to content

Commit

Permalink
Updated lookup tests and helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan committed Nov 25, 2024
1 parent b0e66c2 commit ad38c25
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 34 deletions.
22 changes: 16 additions & 6 deletions pkg/solana/chainwriter/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (
"github.com/gagliardetto/solana-go"
)

// GetAddressAtLocation parses through nested types and arrays to find all address locations.
func GetAddressAtLocation(args any, location string, debugID string) ([]solana.PublicKey, error) {
var addresses []solana.PublicKey
// GetValuesAtLocation parses through nested types and arrays to find all locations of values
func GetValuesAtLocation(args any, location string, debugID string) ([][]byte, error) {
var vals [][]byte

path := strings.Split(location, ".")

Expand All @@ -22,13 +22,15 @@ func GetAddressAtLocation(args any, location string, debugID string) ([]solana.P

for _, value := range addressList {
if byteArray, ok := value.([]byte); ok {
addresses = append(addresses, solana.PublicKeyFromBytes(byteArray))
vals = append(vals, byteArray)
} else if address, ok := value.(solana.PublicKey); ok {
vals = append(vals, address.Bytes())
} else {
return nil, errorWithDebugID(fmt.Errorf("invalid address format at path: %s", location), debugID)
return nil, errorWithDebugID(fmt.Errorf("invalid value format at path: %s", location), debugID)
}
}

return addresses, nil
return vals, nil
}

func GetDebugIDAtLocation(args any, location string) (string, error) {
Expand Down Expand Up @@ -83,6 +85,7 @@ func traversePath(data any, path []string) ([]any, error) {
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
fmt.Printf("Current path: %v, Current value type: %v\n", path, val.Kind())

switch val.Kind() {
case reflect.Struct:
Expand All @@ -105,6 +108,13 @@ func traversePath(data any, path []string) ([]any, error) {
}
return nil, errors.New("no matching field found in array")

case reflect.Map:
key := reflect.ValueOf(path[0])
value := val.MapIndex(key)
if !value.IsValid() {
return nil, errors.New("key not found: " + path[0])
}
return traversePath(value.Interface(), path[1:])
default:
if len(path) == 1 && val.Kind() == reflect.Slice && val.Type().Elem().Kind() == reflect.Uint8 {
return []any{val.Interface()}, nil
Expand Down
32 changes: 21 additions & 11 deletions pkg/solana/chainwriter/lookups.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ func (ac AccountConstant) Resolve(_ context.Context, _ any, _ map[string]map[str
}

func (al AccountLookup) Resolve(_ context.Context, args any, _ map[string]map[string][]*solana.AccountMeta, debugID string) ([]*solana.AccountMeta, error) {
derivedAddresses, err := GetAddressAtLocation(args, al.Location, debugID)
derivedValues, err := GetValuesAtLocation(args, al.Location, debugID)
if err != nil {
return nil, errorWithDebugID(fmt.Errorf("error getting account from lookup: %w", err), debugID)
}

var metas []*solana.AccountMeta
for _, address := range derivedAddresses {
for _, address := range derivedValues {
metas = append(metas, &solana.AccountMeta{
PublicKey: address,
PublicKey: solana.PublicKeyFromBytes(address),
IsSigner: al.IsSigner,
IsWritable: al.IsWritable,
})
Expand Down Expand Up @@ -146,16 +146,26 @@ func getSeedBytes(ctx context.Context, lookup PDALookups, args any, derivedTable

// Process AddressSeeds first (e.g., public keys)
for _, seed := range lookup.Seeds {
// Get the address(es) at the seed location
seedAddresses, err := GetAddresses(ctx, args, []Lookup{seed}, derivedTableMap, debugID)
if err != nil {
return nil, errorWithDebugID(fmt.Errorf("error getting address seed: %w", err), debugID)
}
if lookupSeed, ok := seed.(AccountLookup); ok {
// Get the values at the seed location
bytes, err := GetValuesAtLocation(args, lookupSeed.Location, debugID)
if err != nil {
return nil, errorWithDebugID(fmt.Errorf("error getting address seed: %w", err), debugID)
}
seedBytes = append(seedBytes, bytes...)
} else {
// Get the address(es) at the seed location
seedAddresses, err := GetAddresses(ctx, args, []Lookup{seed}, derivedTableMap, debugID)
if err != nil {
return nil, errorWithDebugID(fmt.Errorf("error getting address seed: %w", err), debugID)
}

// Add each address seed as bytes
for _, address := range seedAddresses {
seedBytes = append(seedBytes, address.PublicKey.Bytes())
// Add each address seed as bytes
for _, address := range seedAddresses {
seedBytes = append(seedBytes, address.PublicKey.Bytes())
}
}

}

return seedBytes, nil
Expand Down
128 changes: 111 additions & 17 deletions pkg/solana/chainwriter/lookups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ type InnerArgs struct {
}

func TestAccountContant(t *testing.T) {

t.Run("AccountConstant resolves valid address", func(t *testing.T) {
expectedAddr := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6M"
expectedMeta := []*solana.AccountMeta{
Expand Down Expand Up @@ -131,22 +130,117 @@ func TestAccountLookups(t *testing.T) {
}

func TestPDALookups(t *testing.T) {
// TODO: May require deploying a program to test
// t.Run("PDALookup resolves valid address", func(t *testing.T) {
// expectedAddr := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6M"
// expectedMeta := []*solana.AccountMeta{
// {
// PublicKey: solana.MustPublicKeyFromBase58(expectedAddr),
// IsSigner: true,
// IsWritable: true,
// },
// }
// lookupConfig := chainwriter.PDALookups{
// Name: "TestAccount",
// PublicKey:
// }

// })
programID := solana.SystemProgramID

t.Run("PDALookup resolves valid PDA with constant address seeds", func(t *testing.T) {
privKey, err := solana.NewRandomPrivateKey()
require.NoError(t, err)
seed := privKey.PublicKey()

pda, _, err := solana.FindProgramAddress([][]byte{seed.Bytes()}, programID)
require.NoError(t, err)

expectedMeta := []*solana.AccountMeta{
{
PublicKey: pda,
IsSigner: false,
IsWritable: true,
},
}

pdaLookup := chainwriter.PDALookups{
Name: "TestPDA",
PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()},
Seeds: []chainwriter.Lookup{
chainwriter.AccountConstant{Name: "seed", Address: seed.String()},
},
IsSigner: false,
IsWritable: true,
}

ctx := context.Background()
result, err := pdaLookup.Resolve(ctx, nil, nil, "")
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
t.Run("PDALookup resolves valid PDA with non-address lookup seeds", func(t *testing.T) {
seed1 := []byte("test_seed")
seed2 := []byte("another_seed")

pda, _, err := solana.FindProgramAddress([][]byte{seed1, seed2}, programID)
require.NoError(t, err)

expectedMeta := []*solana.AccountMeta{
{
PublicKey: pda,
IsSigner: false,
IsWritable: true,
},
}

pdaLookup := chainwriter.PDALookups{
Name: "TestPDA",
PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()},
Seeds: []chainwriter.Lookup{
chainwriter.AccountLookup{Name: "seed1", Location: "test_seed"},
chainwriter.AccountLookup{Name: "seed2", Location: "another_seed"},
},
IsSigner: false,
IsWritable: true,
}

ctx := context.Background()
args := map[string]interface{}{
"test_seed": seed1,
"another_seed": seed2,
}

result, err := pdaLookup.Resolve(ctx, args, nil, "")
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})

t.Run("PDALookup resolves valid PDA with address lookup seeds", func(t *testing.T) {
privKey1, err := solana.NewRandomPrivateKey()
require.NoError(t, err)
seed1 := privKey1.PublicKey()

privKey2, err := solana.NewRandomPrivateKey()
require.NoError(t, err)
seed2 := privKey2.PublicKey()

pda, _, err := solana.FindProgramAddress([][]byte{seed1.Bytes(), seed2.Bytes()}, programID)
require.NoError(t, err)

expectedMeta := []*solana.AccountMeta{
{
PublicKey: pda,
IsSigner: false,
IsWritable: true,
},
}

pdaLookup := chainwriter.PDALookups{
Name: "TestPDA",
PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()},
Seeds: []chainwriter.Lookup{
chainwriter.AccountLookup{Name: "seed1", Location: "test_seed"},
chainwriter.AccountLookup{Name: "seed2", Location: "another_seed"},
},
IsSigner: false,
IsWritable: true,
}

ctx := context.Background()
args := map[string]interface{}{
"test_seed": seed1,
"another_seed": seed2,
}

result, err := pdaLookup.Resolve(ctx, args, nil, "")
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
}

func TestLookupTables(t *testing.T) {
Expand Down

0 comments on commit ad38c25

Please sign in to comment.