diff --git a/.mockery.yaml b/.mockery.yaml index 1df96bfec..347d69c58 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -37,6 +37,7 @@ packages: config: filename: simple_keystore.go case: underscore + TxManager: github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller: interfaces: RPCClient: diff --git a/contracts/artifacts/localnet/write_test-keypair.json b/contracts/artifacts/localnet/write_test-keypair.json index c4e6e125c..dfb18e9c4 100644 --- a/contracts/artifacts/localnet/write_test-keypair.json +++ b/contracts/artifacts/localnet/write_test-keypair.json @@ -1 +1,6 @@ -[26,39,164,161,246,97,149,0,58,187,146,162,53,35,107,2,117,242,83,171,48,7,63,240,69,221,239,45,97,55,112,106,192,228,214,205,123,71,58,23,62,229,166,213,149,122,96,145,35,150,16,156,247,199,242,108,173,80,62,231,39,196,27,192] \ No newline at end of file +[ + 26, 39, 164, 161, 246, 97, 149, 0, 58, 187, 146, 162, 53, 35, 107, 2, 117, + 242, 83, 171, 48, 7, 63, 240, 69, 221, 239, 45, 97, 55, 112, 106, 192, 228, + 214, 205, 123, 71, 58, 23, 62, 229, 166, 213, 149, 122, 96, 145, 35, 150, 16, + 156, 247, 199, 242, 108, 173, 80, 62, 231, 39, 196, 27, 192 +] diff --git a/contracts/programs/write_test/src/lib.rs b/contracts/programs/write_test/src/lib.rs index 4078bca4d..8d8fa3cac 100644 --- a/contracts/programs/write_test/src/lib.rs +++ b/contracts/programs/write_test/src/lib.rs @@ -12,10 +12,9 @@ pub mod write_test { data.administrator = ctx.accounts.admin.key(); data.pending_administrator = Pubkey::default(); data.lookup_table = lookup_table; - + Ok(()) } - } #[derive(Accounts)] diff --git a/go.mod b/go.mod index 43e82aaa0..f9363c1cc 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/smartcontractkit/chainlink-common v0.3.1-0.20241127162636-07aa781ee1f4 github.com/smartcontractkit/libocr v0.0.0-20241007185508-adbe57025f12 github.com/stretchr/testify v1.9.0 - github.com/test-go/testify v1.1.4 go.uber.org/zap v1.27.0 golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 golang.org/x/sync v0.8.0 diff --git a/pkg/solana/chain.go b/pkg/solana/chain.go index 630248aff..f8c49a4cc 100644 --- a/pkg/solana/chain.go +++ b/pkg/solana/chain.go @@ -30,6 +30,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" "github.com/smartcontractkit/chainlink-solana/pkg/solana/monitor" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) type Chain interface { @@ -576,12 +577,12 @@ func (c *chain) sendTx(ctx context.Context, from, to string, amount *big.Int, ba chainTxm := c.TxManager() err = chainTxm.Enqueue(ctx, "", tx, nil, - txm.SetComputeUnitLimit(500), // reduce from default 200K limit - should only take 450 compute units + txmutils.SetComputeUnitLimit(500), // reduce from default 200K limit - should only take 450 compute units // no fee bumping and no additional fee - makes validating balance accurate - txm.SetComputeUnitPriceMax(0), - txm.SetComputeUnitPriceMin(0), - txm.SetBaseComputeUnitPrice(0), - txm.SetFeeBumpPeriod(0), + txmutils.SetComputeUnitPriceMax(0), + txmutils.SetComputeUnitPriceMin(0), + txmutils.SetBaseComputeUnitPrice(0), + txmutils.SetFeeBumpPeriod(0), ) if err != nil { return fmt.Errorf("transaction failed: %w", err) diff --git a/pkg/solana/chainwriter/ccip_example_config.go b/pkg/solana/chainwriter/ccip_example_config.go index bd5087af8..89038fd6a 100644 --- a/pkg/solana/chainwriter/ccip_example_config.go +++ b/pkg/solana/chainwriter/ccip_example_config.go @@ -79,7 +79,7 @@ func TestConfig() { // 3. Lookup Table content - Get all the accounts from a lookup table // 4. PDA Account Lookup - Based on another account and a seed/s // Nested PDA Account with seeds from: - // -> input paramters + // -> input parameters // -> constant // PDALookups can resolve to multiple addresses if: // A) The PublicKey lookup resolves to multiple addresses (i.e. multiple token addresses) @@ -102,8 +102,8 @@ func TestConfig() { }, // Lookup Table content - Get the accounts from the derived lookup table above AccountsFromLookupTable{ - LookupTablesName: "RegistryTokenState", - IncludeIndexes: []int{}, // If left empty, all addresses will be included. Otherwise, only the specified indexes will be included. + LookupTableName: "RegistryTokenState", + IncludeIndexes: []int{}, // If left empty, all addresses will be included. Otherwise, only the specified indexes will be included. }, // Account Lookup - Based on data from input parameters // In this case, the user wants to add the destination token addresses to the transaction. diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go index 608f3c610..4fcc5caa0 100644 --- a/pkg/solana/chainwriter/chain_writer.go +++ b/pkg/solana/chainwriter/chain_writer.go @@ -21,12 +21,13 @@ import ( type SolanaChainWriterService struct { reader client.Reader - txm txm.Txm + txm txm.TxManager ge fees.Estimator config ChainWriterConfig codecs map[string]types.Codec } +// nolint // ignoring naming suggestion type ChainWriterConfig struct { Programs map[string]ProgramConfig } @@ -46,7 +47,7 @@ type MethodConfig struct { DebugIDLocation string } -func NewSolanaChainWriterService(reader client.Reader, txm txm.Txm, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) { +func NewSolanaChainWriterService(reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) { codecs, err := parseIDLCodecs(config) if err != nil { return nil, fmt.Errorf("failed to parse IDL codecs: %w", err) @@ -68,7 +69,7 @@ func parseIDLCodecs(config ChainWriterConfig) (map[string]types.Codec, error) { if err := json.Unmarshal([]byte(programConfig.IDL), &idl); err != nil { return nil, fmt.Errorf("failed to unmarshal IDL: %w", err) } - idlCodec, err := codec.NewIDLAccountCodec(idl, binary.LittleEndian()) + idlCodec, err := codec.NewIDLInstructionsCodec(idl, binary.LittleEndian()) if err != nil { return nil, fmt.Errorf("failed to create codec from IDL: %w", err) } @@ -79,7 +80,7 @@ func parseIDLCodecs(config ChainWriterConfig) (map[string]types.Codec, error) { return nil, fmt.Errorf("failed to create input modifications: %w", err) } // add mods to codec - idlCodec, err = codec.NewNamedModifierCodec(idlCodec, WrapItemType(program, method, true), modConfig) + idlCodec, err = codec.NewNamedModifierCodec(idlCodec, method, modConfig) if err != nil { return nil, fmt.Errorf("failed to create named codec: %w", err) } @@ -90,14 +91,6 @@ func parseIDLCodecs(config ChainWriterConfig) (map[string]types.Codec, error) { return codecs, nil } -func WrapItemType(programName, itemType string, isParams bool) string { - if isParams { - return fmt.Sprintf("params.%s.%s", programName, itemType) - } - - return fmt.Sprintf("return.%s.%s", programName, itemType) -} - /* GetAddresses resolves account addresses from various `Lookup` configurations to build the required `solana.AccountMeta` list for Solana transactions. It handles constant addresses, dynamic lookups, program-derived addresses (PDAs), and lookup tables. @@ -161,7 +154,7 @@ func (s *SolanaChainWriterService) FilterLookupTableAddresses( for innerIdentifier, metas := range innerMap { tableKey, err := solana.PublicKeyFromBase58(innerIdentifier) if err != nil { - fmt.Errorf("error parsing lookup table key: %w", err) + continue } // Collect public keys that are actually used @@ -198,18 +191,31 @@ func (s *SolanaChainWriterService) FilterLookupTableAddresses( } func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contractName, method string, args any, transactionID string, toAddress string, meta *types.TxMeta, value *big.Int) error { - programConfig := s.config.Programs[contractName] - methodConfig := programConfig.Methods[method] + programConfig, exists := s.config.Programs[contractName] + if !exists { + return fmt.Errorf("failed to find program config for contract name: %s", contractName) + } + methodConfig, exists := programConfig.Methods[method] + if !exists { + return fmt.Errorf("failed to find method config for method: %s", method) + } // Configure debug ID debugID := "" if methodConfig.DebugIDLocation != "" { - debugID, err := GetDebugIDAtLocation(args, methodConfig.DebugIDLocation) + var err error + debugID, err = GetDebugIDAtLocation(args, methodConfig.DebugIDLocation) if err != nil { return errorWithDebugID(fmt.Errorf("error getting debug ID from input args: %w", err), debugID) } } + codec := s.codecs[contractName] + encodedPayload, err := codec.Encode(ctx, args, method) + if err != nil { + return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID) + } + // Fetch derived and static table maps derivedTableMap, staticTableMap, err := s.ResolveLookupTables(ctx, args, methodConfig.LookupTables) if err != nil { @@ -232,7 +238,7 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra } // Prepare transaction - programId, err := solana.PublicKeyFromBase58(contractName) + programID, err := solana.PublicKeyFromBase58(contractName) if err != nil { return errorWithDebugID(fmt.Errorf("error parsing program ID: %w", err), debugID) } @@ -242,15 +248,9 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID) } - codec := s.codecs[contractName] - encodedPayload, err := codec.Encode(ctx, args, WrapItemType(contractName, method, true)) - if err != nil { - return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID) - } - tx, err := solana.NewTransaction( []solana.Instruction{ - solana.NewInstruction(programId, accounts, encodedPayload), + solana.NewInstruction(programID, accounts, encodedPayload), }, blockhash.Value.Blockhash, solana.TransactionPayer(feePayer), @@ -269,13 +269,13 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra } var ( - _ services.Service = &SolanaChainWriterService{} - _ types.ChainWriter = &SolanaChainWriterService{} + _ services.Service = &SolanaChainWriterService{} + _ types.ContractWriter = &SolanaChainWriterService{} ) // GetTransactionStatus returns the current status of a transaction in the underlying chain's TXM. func (s *SolanaChainWriterService) GetTransactionStatus(ctx context.Context, transactionID string) (types.TransactionStatus, error) { - return types.Unknown, nil + return s.txm.GetTransactionStatus(ctx, transactionID) } // GetFeeComponents retrieves the associated gas costs for executing a transaction. @@ -286,7 +286,7 @@ func (s *SolanaChainWriterService) GetFeeComponents(ctx context.Context) (*types fee := s.ge.BaseComputeUnitPrice() return &types.ChainFeeComponents{ - ExecutionFee: big.NewInt(int64(fee)), + ExecutionFee: new(big.Int).SetUint64(fee), DataAvailabilityFee: nil, }, nil } diff --git a/pkg/solana/chainwriter/chain_writer_test.go b/pkg/solana/chainwriter/chain_writer_test.go new file mode 100644 index 000000000..d931fb6d8 --- /dev/null +++ b/pkg/solana/chainwriter/chain_writer_test.go @@ -0,0 +1,690 @@ +package chainwriter_test + +import ( + "bytes" + "errors" + "math/big" + "reflect" + "testing" + + ag_binary "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" + addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table" + "github.com/gagliardetto/solana-go/rpc" + "github.com/google/uuid" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter" + clientmocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks" + feemocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees/mocks" + txmMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" +) + +var writeTestIdlJSON = `{"version": "0.1.0","name": "write_test","instructions": [{"name": "initialize","accounts": [{"name": "dataAccount","isMut": true,"isSigner": false,"docs": ["PDA account, derived from seeds and created by the System Program in this instruction"]},{"name": "admin","isMut": true,"isSigner": true,"docs": ["Admin account that pays for PDA creation and signs the transaction"]},{"name": "systemProgram","isMut": false,"isSigner": false,"docs": ["System Program is required for PDA creation"]}],"args": [{"name": "lookupTable","type": "publicKey"}]}],"accounts": [{"name": "DataAccount","type": {"kind": "struct","fields": [{"name": "version","type": "u8"},{"name": "administrator","type": "publicKey"},{"name": "pendingAdministrator","type": "publicKey"},{"name": "lookupTable","type": "publicKey"}]}}]}` + +func TestChainWriter_GetAddresses(t *testing.T) { + ctx := tests.Context(t) + + // mock client + rw := clientmocks.NewReaderWriter(t) + // mock estimator + ge := feemocks.NewEstimator(t) + // mock txm + txm := txmMocks.NewTxManager(t) + + // initialize chain writer + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + + // expected account meta for constant account + constantAccountMeta := &solana.AccountMeta{ + IsSigner: true, + IsWritable: true, + } + + // expected account meta for account lookup + accountLookupMeta := &solana.AccountMeta{ + IsSigner: true, + IsWritable: false, + } + + // setup pda account address + seed1 := []byte("seed1") + pda1 := mustFindPdaProgramAddress(t, [][]byte{seed1}, solana.SystemProgramID) + // expected account meta for pda lookup + pdaLookupMeta := &solana.AccountMeta{ + PublicKey: pda1, + IsSigner: false, + IsWritable: false, + } + + // setup pda account with inner field lookup + programID := chainwriter.GetRandomPubKey(t) + seed2 := []byte("seed2") + pda2 := mustFindPdaProgramAddress(t, [][]byte{seed2}, programID) + // mock data account response from program + lookupTablePubkey := mockDataAccountLookupTable(t, rw, pda2) + // mock fetch lookup table addresses call + storedPubKeys := chainwriter.CreateTestPubKeys(t, 3) + mockFetchLookupTableAddresses(t, rw, lookupTablePubkey, storedPubKeys) + // expected account meta for derived table lookup + derivedTablePdaLookupMeta := &solana.AccountMeta{ + IsSigner: false, + IsWritable: true, + } + + lookupTableConfig := chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed2 for PDA lookup + chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}, + }, + IsSigner: derivedTablePdaLookupMeta.IsSigner, + IsWritable: derivedTablePdaLookupMeta.IsWritable, + InternalField: chainwriter.InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + }, + StaticLookupTables: nil, + } + + t.Run("resolve addresses from different types of lookups", func(t *testing.T) { + constantAccountMeta.PublicKey = chainwriter.GetRandomPubKey(t) + accountLookupMeta.PublicKey = chainwriter.GetRandomPubKey(t) + // correlates to DerivedTable index in account lookup config + derivedTablePdaLookupMeta.PublicKey = storedPubKeys[0] + + args := map[string]interface{}{ + "lookup_table": accountLookupMeta.PublicKey.Bytes(), + "seed1": seed1, + "seed2": seed2, + } + + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountConstant{ + Name: "Constant", + Address: constantAccountMeta.PublicKey.String(), + IsSigner: constantAccountMeta.IsSigner, + IsWritable: constantAccountMeta.IsWritable, + }, + chainwriter.AccountLookup{ + Name: "LookupTable", + Location: "lookup_table", + IsSigner: accountLookupMeta.IsSigner, + IsWritable: accountLookupMeta.IsWritable, + }, + chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed1 for PDA lookup + chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}, + }, + IsSigner: pdaLookupMeta.IsSigner, + IsWritable: pdaLookupMeta.IsWritable, + // Just get the address of the account, nothing internal. + InternalField: chainwriter.InternalField{}, + }, + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + IncludeIndexes: []int{0}, + }, + } + + // Fetch derived table map + derivedTableMap, _, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + // account metas should be returned in the same order as the provided account lookup configs + require.Len(t, accounts, 4) + + // Validate account constant + require.Equal(t, constantAccountMeta.PublicKey, accounts[0].PublicKey) + require.Equal(t, constantAccountMeta.IsSigner, accounts[0].IsSigner) + require.Equal(t, constantAccountMeta.IsWritable, accounts[0].IsWritable) + + // Validate account lookup + require.Equal(t, accountLookupMeta.PublicKey, accounts[1].PublicKey) + require.Equal(t, accountLookupMeta.IsSigner, accounts[1].IsSigner) + require.Equal(t, accountLookupMeta.IsWritable, accounts[1].IsWritable) + + // Validate pda lookup + require.Equal(t, pdaLookupMeta.PublicKey, accounts[2].PublicKey) + require.Equal(t, pdaLookupMeta.IsSigner, accounts[2].IsSigner) + require.Equal(t, pdaLookupMeta.IsWritable, accounts[2].IsWritable) + + // Validate pda lookup with inner field from derived table + require.Equal(t, derivedTablePdaLookupMeta.PublicKey, accounts[3].PublicKey) + require.Equal(t, derivedTablePdaLookupMeta.IsSigner, accounts[3].IsSigner) + require.Equal(t, derivedTablePdaLookupMeta.IsWritable, accounts[3].IsWritable) + }) + + t.Run("resolve addresses for multiple indices from derived lookup table", func(t *testing.T) { + args := map[string]interface{}{ + "seed2": seed2, + } + + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + IncludeIndexes: []int{0, 2}, + }, + } + + // Fetch derived table map + derivedTableMap, _, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + require.Len(t, accounts, 2) + require.Equal(t, storedPubKeys[0], accounts[0].PublicKey) + require.Equal(t, storedPubKeys[2], accounts[1].PublicKey) + }) + + t.Run("resolve all addresses from derived lookup table if indices not specified", func(t *testing.T) { + args := map[string]interface{}{ + "seed2": seed2, + } + + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + }, + } + + // Fetch derived table map + derivedTableMap, _, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + require.Len(t, accounts, 3) + for i, storedPubkey := range storedPubKeys { + require.Equal(t, storedPubkey, accounts[i].PublicKey) + } + }) +} + +func TestChainWriter_FilterLookupTableAddresses(t *testing.T) { + ctx := tests.Context(t) + + // mock client + rw := clientmocks.NewReaderWriter(t) + // mock estimator + ge := feemocks.NewEstimator(t) + // mock txm + txm := txmMocks.NewTxManager(t) + + // initialize chain writer + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + + programID := chainwriter.GetRandomPubKey(t) + seed1 := []byte("seed1") + pda1 := mustFindPdaProgramAddress(t, [][]byte{seed1}, programID) + // mock data account response from program + lookupTablePubkey := mockDataAccountLookupTable(t, rw, pda1) + // mock fetch lookup table addresses call + storedPubKey := chainwriter.GetRandomPubKey(t) + mockFetchLookupTableAddresses(t, rw, lookupTablePubkey, []solana.PublicKey{storedPubKey}) + + unusedProgramID := chainwriter.GetRandomPubKey(t) + seed2 := []byte("seed2") + unusedPda := mustFindPdaProgramAddress(t, [][]byte{seed2}, unusedProgramID) + // mock data account response from program + unusedLookupTable := mockDataAccountLookupTable(t, rw, unusedPda) + // mock fetch lookup table addresses call + unusedKeys := chainwriter.GetRandomPubKey(t) + mockFetchLookupTableAddresses(t, rw, unusedLookupTable, []solana.PublicKey{unusedKeys}) + + // mock static lookup table calls + staticLookupTablePubkey1 := chainwriter.GetRandomPubKey(t) + mockFetchLookupTableAddresses(t, rw, staticLookupTablePubkey1, chainwriter.CreateTestPubKeys(t, 2)) + staticLookupTablePubkey2 := chainwriter.GetRandomPubKey(t) + mockFetchLookupTableAddresses(t, rw, staticLookupTablePubkey2, chainwriter.CreateTestPubKeys(t, 2)) + + lookupTableConfig := chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed2 for PDA lookup + chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}, + }, + IsSigner: true, + IsWritable: true, + InternalField: chainwriter.InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + { + Name: "MiscDerivedTable", + Accounts: chainwriter.PDALookups{ + Name: "MiscPDA", + PublicKey: chainwriter.AccountConstant{Name: "UnusedAccount", Address: unusedProgramID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed2 for PDA lookup + chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}, + }, + IsSigner: true, + IsWritable: true, + InternalField: chainwriter.InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + }, + StaticLookupTables: []string{staticLookupTablePubkey1.String(), staticLookupTablePubkey2.String()}, + } + + args := map[string]interface{}{ + "seed1": seed1, + "seed2": seed2, + } + + t.Run("returns filtered map with only relevant addresses required by account lookup config", func(t *testing.T) { + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + IncludeIndexes: []int{0}, + }, + } + + // Fetch derived table map + derivedTableMap, staticTableMap, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + // Filter the lookup table addresses based on which accounts are actually used + filteredLookupTableMap := cw.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) + + // Filter map should only contain the address for the DerivedTable lookup defined in the account lookup config + require.Len(t, filteredLookupTableMap, len(accounts)) + entry, exists := filteredLookupTableMap[lookupTablePubkey] + require.True(t, exists) + require.Len(t, entry, 1) + require.Equal(t, storedPubKey, entry[0]) + }) + + t.Run("returns empty map if empty account lookup config provided", func(t *testing.T) { + accountLookupConfig := []chainwriter.Lookup{} + + // Fetch derived table map + derivedTableMap, staticTableMap, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + // Filter the lookup table addresses based on which accounts are actually used + filteredLookupTableMap := cw.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) + require.Empty(t, filteredLookupTableMap) + }) + + t.Run("returns empty map if only constant account lookup required", func(t *testing.T) { + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountConstant{ + Name: "Constant", + Address: chainwriter.GetRandomPubKey(t).String(), + IsSigner: false, + IsWritable: false, + }, + } + + // Fetch derived table map + derivedTableMap, staticTableMap, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + // Filter the lookup table addresses based on which accounts are actually used + filteredLookupTableMap := cw.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) + require.Empty(t, filteredLookupTableMap) + }) +} + +func TestChainWriter_SubmitTransaction(t *testing.T) { + t.Parallel() + + ctx := tests.Context(t) + // mock client + rw := clientmocks.NewReaderWriter(t) + // mock estimator + ge := feemocks.NewEstimator(t) + // mock txm + txm := txmMocks.NewTxManager(t) + + // setup admin key + adminPk, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + admin := adminPk.PublicKey() + + account1 := chainwriter.GetRandomPubKey(t) + account2 := chainwriter.GetRandomPubKey(t) + + seed1 := []byte("seed1") + account3 := mustFindPdaProgramAddress(t, [][]byte{seed1}, solana.SystemProgramID) + + // create lookup table addresses + seed2 := []byte("seed2") + programID := chainwriter.GetRandomPubKey(t) + derivedTablePda := mustFindPdaProgramAddress(t, [][]byte{seed2}, programID) + // mock data account response from program + derivedLookupTablePubkey := mockDataAccountLookupTable(t, rw, derivedTablePda) + // mock fetch lookup table addresses call + derivedLookupKeys := chainwriter.CreateTestPubKeys(t, 1) + mockFetchLookupTableAddresses(t, rw, derivedLookupTablePubkey, derivedLookupKeys) + + // mock static lookup table call + staticLookupTablePubkey := chainwriter.GetRandomPubKey(t) + staticLookupKeys := chainwriter.CreateTestPubKeys(t, 2) + mockFetchLookupTableAddresses(t, rw, staticLookupTablePubkey, staticLookupKeys) + + cwConfig := chainwriter.ChainWriterConfig{ + Programs: map[string]chainwriter.ProgramConfig{ + "39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU": { + Methods: map[string]chainwriter.MethodConfig{ + "initialize": { + FromAddress: admin.String(), + ChainSpecificName: "initialize", + LookupTables: chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed2 for PDA lookup + chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}, + }, + IsSigner: false, + IsWritable: false, + InternalField: chainwriter.InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + }, + StaticLookupTables: []string{staticLookupTablePubkey.String()}, + }, + Accounts: []chainwriter.Lookup{ + chainwriter.AccountConstant{ + Name: "Constant", + Address: account1.String(), + IsSigner: false, + IsWritable: false, + }, + chainwriter.AccountLookup{ + Name: "LookupTable", + Location: "lookup_table", + IsSigner: false, + IsWritable: false, + }, + chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed1 for PDA lookup + chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}, + }, + IsSigner: false, + IsWritable: false, + // Just get the address of the account, nothing internal. + InternalField: chainwriter.InternalField{}, + }, + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + IncludeIndexes: []int{0}, + }, + }, + }, + }, + IDL: writeTestIdlJSON, + }, + }, + } + + // initialize chain writer + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, cwConfig) + require.NoError(t, err) + + t.Run("fails with invalid ABI", func(t *testing.T) { + invalidCWConfig := chainwriter.ChainWriterConfig{ + Programs: map[string]chainwriter.ProgramConfig{ + "write_test": { + Methods: map[string]chainwriter.MethodConfig{ + "invalid": { + ChainSpecificName: "invalid", + }, + }, + IDL: "", + }, + }, + } + + _, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, invalidCWConfig) + require.Error(t, err) + }) + + t.Run("fails to encode payload if args with missing values provided", func(t *testing.T) { + txID := uuid.NewString() + args := map[string]interface{}{} + submitErr := cw.SubmitTransaction(ctx, "39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU", "initialize", args, txID, programID.String(), nil, nil) + require.Error(t, submitErr) + }) + + t.Run("fails if invalid contract name provided", func(t *testing.T) { + txID := uuid.NewString() + args := map[string]interface{}{} + submitErr := cw.SubmitTransaction(ctx, "write_test", "initialize", args, txID, programID.String(), nil, nil) + require.Error(t, submitErr) + }) + + t.Run("fails if invalid method provided", func(t *testing.T) { + txID := uuid.NewString() + args := map[string]interface{}{} + submitErr := cw.SubmitTransaction(ctx, "39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU", "badMethod", args, txID, programID.String(), nil, nil) + require.Error(t, submitErr) + }) + + t.Run("submits transaction successfully", func(t *testing.T) { + recentBlockHash := solana.Hash{} + rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once() + txID := uuid.NewString() + configProgramID := solana.MustPublicKeyFromBase58("39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU") + + txm.On("Enqueue", mock.Anything, account1.String(), mock.MatchedBy(func(tx *solana.Transaction) bool { + // match transaction fields to ensure it was built as expected + require.Equal(t, recentBlockHash, tx.Message.RecentBlockhash) + require.Len(t, tx.Message.Instructions, 1) + require.Len(t, tx.Message.AccountKeys, 5) // fee payer + derived accounts + require.Equal(t, admin, tx.Message.AccountKeys[0]) // fee payer + require.Equal(t, account1, tx.Message.AccountKeys[1]) // account constant + require.Equal(t, account2, tx.Message.AccountKeys[2]) // account lookup + require.Equal(t, account3, tx.Message.AccountKeys[3]) // pda lookup + require.Equal(t, configProgramID, tx.Message.AccountKeys[4]) // instruction program ID + require.Len(t, tx.Message.AddressTableLookups, 1) // address table look contains entry + require.Equal(t, derivedLookupTablePubkey, tx.Message.AddressTableLookups[0].AccountKey) // address table + return true + }), &txID).Return(nil).Once() + + args := map[string]interface{}{ + "lookupTable": chainwriter.GetRandomPubKey(t).Bytes(), + "lookup_table": account2.Bytes(), + "seed1": seed1, + "seed2": seed2, + } + submitErr := cw.SubmitTransaction(ctx, "39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU", "initialize", args, txID, programID.String(), nil, nil) + require.NoError(t, submitErr) + }) +} + +func TestChainWriter_GetTransactionStatus(t *testing.T) { + t.Parallel() + + ctx := tests.Context(t) + rw := clientmocks.NewReaderWriter(t) + ge := feemocks.NewEstimator(t) + + // mock txm + txm := txmMocks.NewTxManager(t) + + // initialize chain writer + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + + t.Run("returns unknown with error if ID not found", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Unknown, errors.New("tx not found")).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.Error(t, err) + require.Equal(t, types.Unknown, status) + }) + + t.Run("returns pending when transaction is pending", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Pending, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Pending, status) + }) + + t.Run("returns unconfirmed when transaction is unconfirmed", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Unconfirmed, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Unconfirmed, status) + }) + + t.Run("returns finalized when transaction is finalized", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Finalized, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + }) + + t.Run("returns failed when transaction error classfied as failed", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Failed, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) + }) + + t.Run("returns fatal when transaction error classfied as fatal", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Fatal, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Fatal, status) + }) +} + +func TestChainWriter_GetFeeComponents(t *testing.T) { + t.Parallel() + + ctx := tests.Context(t) + rw := clientmocks.NewReaderWriter(t) + ge := feemocks.NewEstimator(t) + ge.On("BaseComputeUnitPrice").Return(uint64(100)) + + // mock txm + txm := txmMocks.NewTxManager(t) + + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + + t.Run("returns valid compute unit price", func(t *testing.T) { + feeComponents, err := cw.GetFeeComponents(ctx) + require.NoError(t, err) + require.Equal(t, big.NewInt(100), feeComponents.ExecutionFee) + require.Nil(t, feeComponents.DataAvailabilityFee) // always nil for Solana + }) + + t.Run("fails if gas estimator not set", func(t *testing.T) { + cwNoEstimator, err := chainwriter.NewSolanaChainWriterService(rw, txm, nil, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + _, err = cwNoEstimator.GetFeeComponents(ctx) + require.Error(t, err) + }) +} + +func mustBorshEncodeStruct(t *testing.T, data interface{}) []byte { + buf := new(bytes.Buffer) + err := ag_binary.NewBorshEncoder(buf).Encode(data) + require.NoError(t, err) + return buf.Bytes() +} + +func mustFindPdaProgramAddress(t *testing.T, seeds [][]byte, programID solana.PublicKey) solana.PublicKey { + pda, _, err := solana.FindProgramAddress(seeds, programID) + require.NoError(t, err) + return pda +} + +func mockDataAccountLookupTable(t *testing.T, rw *clientmocks.ReaderWriter, pda solana.PublicKey) solana.PublicKey { + lookupTablePubkey := chainwriter.GetRandomPubKey(t) + dataAccount := DataAccount{ + Discriminator: [8]byte{}, + Version: 1, + Administrator: chainwriter.GetRandomPubKey(t), + PendingAdministrator: chainwriter.GetRandomPubKey(t), + LookupTable: lookupTablePubkey, + } + dataAccountBytes := mustBorshEncodeStruct(t, dataAccount) + rw.On("GetAccountInfoWithOpts", mock.Anything, pda, mock.Anything).Return(&rpc.GetAccountInfoResult{ + RPCContext: rpc.RPCContext{}, + Value: &rpc.Account{Data: rpc.DataBytesOrJSONFromBytes(dataAccountBytes)}, + }, nil) + return lookupTablePubkey +} + +func mockFetchLookupTableAddresses(t *testing.T, rw *clientmocks.ReaderWriter, lookupTablePubkey solana.PublicKey, storedPubkeys []solana.PublicKey) { + var lookupTablePubkeySlice solana.PublicKeySlice + lookupTablePubkeySlice.Append(storedPubkeys...) + lookupTableState := addresslookuptable.AddressLookupTableState{ + Addresses: lookupTablePubkeySlice, + } + lookupTableStateBytes := mustBorshEncodeStruct(t, lookupTableState) + rw.On("GetAccountInfoWithOpts", mock.Anything, lookupTablePubkey, mock.Anything).Return(&rpc.GetAccountInfoResult{ + RPCContext: rpc.RPCContext{}, + Value: &rpc.Account{Data: rpc.DataBytesOrJSONFromBytes(lookupTableStateBytes)}, + }, nil) +} diff --git a/pkg/solana/chainwriter/helpers.go b/pkg/solana/chainwriter/helpers.go index 4d5d00600..bc256c60a 100644 --- a/pkg/solana/chainwriter/helpers.go +++ b/pkg/solana/chainwriter/helpers.go @@ -1,12 +1,19 @@ package chainwriter import ( + "context" + "crypto/sha256" "errors" "fmt" "reflect" "strings" + "testing" "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/utils" ) // GetValuesAtLocation parses through nested types and arrays to find all locations of values @@ -54,11 +61,11 @@ func GetValueAtLocation(args any, location string) ([][]byte, error) { var values [][]byte for _, value := range valueList { - if byteArray, ok := value.([]byte); ok { - values = append(values, byteArray) - } else { + byteArray, ok := value.([]byte) + if !ok { return nil, fmt.Errorf("invalid value format at path: %s", location) } + values = append(values, byteArray) } return values, nil @@ -120,3 +127,76 @@ func traversePath(data any, path []string) ([]any, error) { return nil, errors.New("unexpected type encountered at path: " + path[0]) } } + +func InitializeDataAccount( + ctx context.Context, + t *testing.T, + client *rpc.Client, + programID solana.PublicKey, + admin solana.PrivateKey, + lookupTable solana.PublicKey, +) { + pda, _, err := solana.FindProgramAddress([][]byte{[]byte("data")}, programID) + require.NoError(t, err) + + discriminator := GetDiscriminator("initialize") + + instructionData := append(discriminator[:], lookupTable.Bytes()...) + + instruction := solana.NewInstruction( + programID, + solana.AccountMetaSlice{ + solana.Meta(pda).WRITE(), + solana.Meta(admin.PublicKey()).SIGNER().WRITE(), + solana.Meta(solana.SystemProgramID), + }, + instructionData, + ) + + // Send and confirm the transaction + utils.SendAndConfirm(ctx, t, client, []solana.Instruction{instruction}, admin, rpc.CommitmentFinalized) +} + +func GetDiscriminator(instruction string) [8]byte { + fullHash := sha256.Sum256([]byte("global:" + instruction)) + var discriminator [8]byte + copy(discriminator[:], fullHash[:8]) + return discriminator +} + +func GetRandomPubKey(t *testing.T) solana.PublicKey { + privKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + return privKey.PublicKey() +} + +func CreateTestPubKeys(t *testing.T, num int) solana.PublicKeySlice { + addresses := make([]solana.PublicKey, num) + for i := 0; i < num; i++ { + addresses[i] = GetRandomPubKey(t) + } + return addresses +} + +func CreateTestLookupTable(ctx context.Context, t *testing.T, c *rpc.Client, sender solana.PrivateKey, addresses []solana.PublicKey) solana.PublicKey { + // Create lookup tables + slot, serr := c.GetSlot(ctx, rpc.CommitmentFinalized) + require.NoError(t, serr) + table, instruction, ierr := utils.NewCreateLookupTableInstruction( + sender.PublicKey(), + sender.PublicKey(), + slot, + ) + require.NoError(t, ierr) + utils.SendAndConfirm(ctx, t, c, []solana.Instruction{instruction}, sender, rpc.CommitmentConfirmed) + + // add entries to lookup table + utils.SendAndConfirm(ctx, t, c, []solana.Instruction{ + utils.NewExtendLookupTableInstruction( + table, sender.PublicKey(), sender.PublicKey(), + addresses, + ), + }, sender, rpc.CommitmentConfirmed) + + return table +} diff --git a/pkg/solana/chainwriter/lookups.go b/pkg/solana/chainwriter/lookups.go index 1aa9ae92d..1947b060d 100644 --- a/pkg/solana/chainwriter/lookups.go +++ b/pkg/solana/chainwriter/lookups.go @@ -9,6 +9,7 @@ import ( "github.com/gagliardetto/solana-go" addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table" "github.com/gagliardetto/solana-go/rpc" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" ) @@ -69,8 +70,8 @@ type DerivedLookupTable struct { // AccountsFromLookupTable extracts accounts from a lookup table that was previously read and stored in memory. type AccountsFromLookupTable struct { - LookupTablesName string - IncludeIndexes []int + LookupTableName string + IncludeIndexes []int } func (ac AccountConstant) Resolve(_ context.Context, _ any, _ map[string]map[string][]*solana.AccountMeta, _ client.Reader) ([]*solana.AccountMeta, error) { @@ -106,9 +107,9 @@ func (al AccountLookup) Resolve(_ context.Context, args any, _ map[string]map[st func (alt AccountsFromLookupTable) Resolve(_ context.Context, _ any, derivedTableMap map[string]map[string][]*solana.AccountMeta, _ client.Reader) ([]*solana.AccountMeta, error) { // Fetch the inner map for the specified lookup table name - innerMap, ok := derivedTableMap[alt.LookupTablesName] + innerMap, ok := derivedTableMap[alt.LookupTableName] if !ok { - return nil, fmt.Errorf("lookup table not found: %s", alt.LookupTablesName) + return nil, fmt.Errorf("lookup table not found: %s", alt.LookupTableName) } var result []*solana.AccountMeta @@ -125,7 +126,7 @@ func (alt AccountsFromLookupTable) Resolve(_ context.Context, _ any, derivedTabl for publicKey, metas := range innerMap { for _, index := range alt.IncludeIndexes { if index < 0 || index >= len(metas) { - return nil, fmt.Errorf("invalid index %d for account %s in lookup table %s", index, publicKey, alt.LookupTablesName) + return nil, fmt.Errorf("invalid index %d for account %s in lookup table %s", index, publicKey, alt.LookupTableName) } result = append(result, metas[index]) } @@ -161,6 +162,7 @@ func (pda PDALookups) Resolve(ctx context.Context, args any, derivedTableMap map Encoding: "base64", Commitment: rpc.CommitmentFinalized, }) + fmt.Printf("Accounts Info: %+v", accountInfo) if err != nil || accountInfo == nil || accountInfo.Value == nil { return nil, fmt.Errorf("error fetching account info for PDA account: %s, error: %w", accountMeta.PublicKey.String(), err) @@ -325,8 +327,8 @@ func (s *SolanaChainWriterService) LoadTable(ctx context.Context, args any, rlt for _, addr := range addresses { resultMap[rlt.Name][addressMeta.PublicKey.String()] = append(resultMap[rlt.Name][addressMeta.PublicKey.String()], &solana.AccountMeta{ PublicKey: addr, - IsSigner: false, - IsWritable: false, + IsSigner: addressMeta.IsSigner, + IsWritable: addressMeta.IsWritable, }) } diff --git a/pkg/solana/chainwriter/lookups_test.go b/pkg/solana/chainwriter/lookups_test.go index 2a75814bf..53972feac 100644 --- a/pkg/solana/chainwriter/lookups_test.go +++ b/pkg/solana/chainwriter/lookups_test.go @@ -2,24 +2,24 @@ package chainwriter_test import ( "context" - "crypto/sha256" "reflect" "testing" "time" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" + "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + commonutils "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" "github.com/smartcontractkit/chainlink-solana/pkg/solana/utils" - - commonutils "github.com/smartcontractkit/chainlink-common/pkg/utils" - "github.com/test-go/testify/require" ) type TestArgs struct { @@ -40,7 +40,7 @@ type DataAccount struct { func TestAccountContant(t *testing.T) { t.Run("AccountConstant resolves valid address", func(t *testing.T) { - expectedAddr := getRandomPubKey(t) + expectedAddr := chainwriter.GetRandomPubKey(t) expectedMeta := []*solana.AccountMeta{ { PublicKey: expectedAddr, @@ -54,14 +54,15 @@ func TestAccountContant(t *testing.T) { IsSigner: true, IsWritable: true, } - result, err := constantConfig.Resolve(nil, nil, nil, nil) + result, err := constantConfig.Resolve(tests.Context(t), nil, nil, nil) require.NoError(t, err) require.Equal(t, expectedMeta, result) }) } func TestAccountLookups(t *testing.T) { + ctx := tests.Context(t) t.Run("AccountLookup resolves valid address with just one address", func(t *testing.T) { - expectedAddr := getRandomPubKey(t) + expectedAddr := chainwriter.GetRandomPubKey(t) testArgs := TestArgs{ Inner: []InnerArgs{ {Address: expectedAddr.Bytes()}, @@ -81,14 +82,14 @@ func TestAccountLookups(t *testing.T) { IsSigner: true, IsWritable: true, } - result, err := lookupConfig.Resolve(nil, testArgs, nil, nil) + result, err := lookupConfig.Resolve(ctx, testArgs, nil, nil) require.NoError(t, err) require.Equal(t, expectedMeta, result) }) t.Run("AccountLookup resolves valid address with just multiple addresses", func(t *testing.T) { - expectedAddr1 := getRandomPubKey(t) - expectedAddr2 := getRandomPubKey(t) + expectedAddr1 := chainwriter.GetRandomPubKey(t) + expectedAddr2 := chainwriter.GetRandomPubKey(t) testArgs := TestArgs{ Inner: []InnerArgs{ @@ -115,7 +116,7 @@ func TestAccountLookups(t *testing.T) { IsSigner: true, IsWritable: true, } - result, err := lookupConfig.Resolve(nil, testArgs, nil, nil) + result, err := lookupConfig.Resolve(ctx, testArgs, nil, nil) require.NoError(t, err) for i, meta := range result { require.Equal(t, expectedMeta[i], meta) @@ -123,7 +124,7 @@ func TestAccountLookups(t *testing.T) { }) t.Run("AccountLookup fails when address isn't in args", func(t *testing.T) { - expectedAddr := getRandomPubKey(t) + expectedAddr := chainwriter.GetRandomPubKey(t) testArgs := TestArgs{ Inner: []InnerArgs{ @@ -136,7 +137,7 @@ func TestAccountLookups(t *testing.T) { IsSigner: true, IsWritable: true, } - _, err := lookupConfig.Resolve(nil, testArgs, nil, nil) + _, err := lookupConfig.Resolve(ctx, testArgs, nil, nil) require.Error(t, err) }) } @@ -145,7 +146,7 @@ func TestPDALookups(t *testing.T) { programID := solana.SystemProgramID t.Run("PDALookup resolves valid PDA with constant address seeds", func(t *testing.T) { - seed := getRandomPubKey(t) + seed := chainwriter.GetRandomPubKey(t) pda, _, err := solana.FindProgramAddress([][]byte{seed.Bytes()}, programID) require.NoError(t, err) @@ -211,8 +212,6 @@ func TestPDALookups(t *testing.T) { }) t.Run("PDALookup fails with missing seeds", func(t *testing.T) { - programID := solana.SystemProgramID - pdaLookup := chainwriter.PDALookups{ Name: "TestPDA", PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()}, @@ -234,8 +233,8 @@ func TestPDALookups(t *testing.T) { }) t.Run("PDALookup resolves valid PDA with address lookup seeds", func(t *testing.T) { - seed1 := getRandomPubKey(t) - seed2 := getRandomPubKey(t) + seed1 := chainwriter.GetRandomPubKey(t) + seed2 := chainwriter.GetRandomPubKey(t) pda, _, err := solana.FindProgramAddress([][]byte{seed1.Bytes(), seed2.Bytes()}, programID) require.NoError(t, err) @@ -292,22 +291,22 @@ func TestLookupTables(t *testing.T) { txm := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr) - cw, err := chainwriter.NewSolanaChainWriterService(solanaClient, *txm, nil, chainwriter.ChainWriterConfig{}) + cw, err := chainwriter.NewSolanaChainWriterService(solanaClient, txm, nil, chainwriter.ChainWriterConfig{}) t.Run("StaticLookup table resolves properly", func(t *testing.T) { - pubKeys := createTestPubKeys(t, 8) - table := CreateTestLookupTable(ctx, t, rpcClient, sender, pubKeys) + pubKeys := chainwriter.CreateTestPubKeys(t, 8) + table := chainwriter.CreateTestLookupTable(ctx, t, rpcClient, sender, pubKeys) lookupConfig := chainwriter.LookupTables{ DerivedLookupTables: nil, StaticLookupTables: []string{table.String()}, } - _, staticTableMap, err := cw.ResolveLookupTables(ctx, nil, lookupConfig) - require.NoError(t, err) + _, staticTableMap, resolveErr := cw.ResolveLookupTables(ctx, nil, lookupConfig) + require.NoError(t, resolveErr) require.Equal(t, pubKeys, staticTableMap[table]) }) t.Run("Derived lookup table resolves properly with constant address", func(t *testing.T) { - pubKeys := createTestPubKeys(t, 8) - table := CreateTestLookupTable(ctx, t, rpcClient, sender, pubKeys) + pubKeys := chainwriter.CreateTestPubKeys(t, 8) + table := chainwriter.CreateTestLookupTable(ctx, t, rpcClient, sender, pubKeys) lookupConfig := chainwriter.LookupTables{ DerivedLookupTables: []chainwriter.DerivedLookupTable{ { @@ -322,8 +321,8 @@ func TestLookupTables(t *testing.T) { }, StaticLookupTables: nil, } - derivedTableMap, _, err := cw.ResolveLookupTables(ctx, nil, lookupConfig) - require.NoError(t, err) + derivedTableMap, _, resolveErr := cw.ResolveLookupTables(ctx, nil, lookupConfig) + require.NoError(t, resolveErr) addresses, ok := derivedTableMap["DerivedTable"][table.String()] require.True(t, ok) @@ -333,7 +332,7 @@ func TestLookupTables(t *testing.T) { }) t.Run("Derived lookup table fails with invalid address", func(t *testing.T) { - invalidTable := getRandomPubKey(t) + invalidTable := chainwriter.GetRandomPubKey(t) lookupConfig := chainwriter.LookupTables{ DerivedLookupTables: []chainwriter.DerivedLookupTable{ @@ -356,7 +355,7 @@ func TestLookupTables(t *testing.T) { }) t.Run("Static lookup table fails with invalid address", func(t *testing.T) { - invalidTable := getRandomPubKey(t) + invalidTable := chainwriter.GetRandomPubKey(t) lookupConfig := chainwriter.LookupTables{ DerivedLookupTables: nil, @@ -369,8 +368,8 @@ func TestLookupTables(t *testing.T) { }) t.Run("Derived lookup table resolves properly with account lookup address", func(t *testing.T) { - pubKeys := createTestPubKeys(t, 8) - table := CreateTestLookupTable(ctx, t, rpcClient, sender, pubKeys) + pubKeys := chainwriter.CreateTestPubKeys(t, 8) + table := chainwriter.CreateTestLookupTable(ctx, t, rpcClient, sender, pubKeys) lookupConfig := chainwriter.LookupTables{ DerivedLookupTables: []chainwriter.DerivedLookupTable{ { @@ -405,10 +404,10 @@ func TestLookupTables(t *testing.T) { // Deployed write_test contract programID := solana.MustPublicKeyFromBase58("39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU") - lookupKeys := createTestPubKeys(t, 5) - lookupTable := CreateTestLookupTable(ctx, t, rpcClient, sender, lookupKeys) + lookupKeys := chainwriter.CreateTestPubKeys(t, 5) + lookupTable := chainwriter.CreateTestLookupTable(ctx, t, rpcClient, sender, lookupKeys) - InitializeDataAccount(ctx, t, rpcClient, programID, sender, lookupTable) + chainwriter.InitializeDataAccount(ctx, t, rpcClient, programID, sender, lookupTable) args := map[string]interface{}{ "seed1": []byte("data"), @@ -446,76 +445,3 @@ func TestLookupTables(t *testing.T) { } }) } - -func InitializeDataAccount( - ctx context.Context, - t *testing.T, - client *rpc.Client, - programID solana.PublicKey, - admin solana.PrivateKey, - lookupTable solana.PublicKey, -) { - pda, _, err := solana.FindProgramAddress([][]byte{[]byte("data")}, programID) - require.NoError(t, err) - - discriminator := getDiscriminator("initialize") - - instructionData := append(discriminator[:], lookupTable.Bytes()...) - - instruction := solana.NewInstruction( - programID, - solana.AccountMetaSlice{ - solana.Meta(pda).WRITE(), - solana.Meta(admin.PublicKey()).SIGNER().WRITE(), - solana.Meta(solana.SystemProgramID), - }, - instructionData, - ) - - // Send and confirm the transaction - utils.SendAndConfirm(ctx, t, client, []solana.Instruction{instruction}, admin, rpc.CommitmentFinalized) -} - -func getDiscriminator(instruction string) [8]byte { - fullHash := sha256.Sum256([]byte("global:" + instruction)) - var discriminator [8]byte - copy(discriminator[:], fullHash[:8]) - return discriminator -} - -func getRandomPubKey(t *testing.T) solana.PublicKey { - privKey, err := solana.NewRandomPrivateKey() - require.NoError(t, err) - return privKey.PublicKey() -} - -func createTestPubKeys(t *testing.T, num int) solana.PublicKeySlice { - addresses := make([]solana.PublicKey, num) - for i := 0; i < num; i++ { - addresses[i] = getRandomPubKey(t) - } - return addresses -} - -func CreateTestLookupTable(ctx context.Context, t *testing.T, c *rpc.Client, sender solana.PrivateKey, addresses []solana.PublicKey) solana.PublicKey { - // Create lookup tables - slot, serr := c.GetSlot(ctx, rpc.CommitmentFinalized) - require.NoError(t, serr) - table, instruction, ierr := utils.NewCreateLookupTableInstruction( - sender.PublicKey(), - sender.PublicKey(), - slot, - ) - require.NoError(t, ierr) - utils.SendAndConfirm(ctx, t, c, []solana.Instruction{instruction}, sender, rpc.CommitmentConfirmed) - - // add entries to lookup table - utils.SendAndConfirm(ctx, t, c, []solana.Instruction{ - utils.NewExtendLookupTableInstruction( - table, sender.PublicKey(), sender.PublicKey(), - addresses, - ), - }, sender, rpc.CommitmentConfirmed) - - return table -} diff --git a/pkg/solana/relay.go b/pkg/solana/relay.go index 1f2fbdffd..f925434d2 100644 --- a/pkg/solana/relay.go +++ b/pkg/solana/relay.go @@ -19,12 +19,13 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) var _ TxManager = (*txm.Txm)(nil) type TxManager interface { - Enqueue(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, txCfgs ...txm.SetTxConfig) error + Enqueue(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, txCfgs ...txmutils.SetTxConfig) error } var _ relaytypes.Relayer = &Relayer{} //nolint:staticcheck diff --git a/pkg/solana/transmitter_test.go b/pkg/solana/transmitter_test.go index 1d058d36a..b4372515a 100644 --- a/pkg/solana/transmitter_test.go +++ b/pkg/solana/transmitter_test.go @@ -17,7 +17,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" clientmocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks" "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) // custom mock txm instead of mockery generated because SetTxConfig causes circular imports @@ -27,7 +27,7 @@ type verifyTxSize struct { s *solana.PrivateKey } -func (txm verifyTxSize) Enqueue(_ context.Context, _ string, tx *solana.Transaction, txID *string, _ ...txm.SetTxConfig) error { +func (txm verifyTxSize) Enqueue(_ context.Context, _ string, tx *solana.Transaction, txID *string, _ ...txmutils.SetTxConfig) error { // additional components that transaction manager adds to the transaction require.NoError(txm.t, fees.SetComputeUnitPrice(tx, 0)) require.NoError(txm.t, fees.SetComputeUnitLimit(tx, 0)) diff --git a/pkg/solana/txm/mocks/tx_manager.go b/pkg/solana/txm/mocks/tx_manager.go new file mode 100644 index 000000000..50806a4da --- /dev/null +++ b/pkg/solana/txm/mocks/tx_manager.go @@ -0,0 +1,390 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + solana "github.com/gagliardetto/solana-go" + mock "github.com/stretchr/testify/mock" + + types "github.com/smartcontractkit/chainlink-common/pkg/types" + + utils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" +) + +// TxManager is an autogenerated mock type for the TxManager type +type TxManager struct { + mock.Mock +} + +type TxManager_Expecter struct { + mock *mock.Mock +} + +func (_m *TxManager) EXPECT() *TxManager_Expecter { + return &TxManager_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *TxManager) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TxManager_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type TxManager_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *TxManager_Expecter) Close() *TxManager_Close_Call { + return &TxManager_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *TxManager_Close_Call) Run(run func()) *TxManager_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TxManager_Close_Call) Return(_a0 error) *TxManager_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Close_Call) RunAndReturn(run func() error) *TxManager_Close_Call { + _c.Call.Return(run) + return _c +} + +// Enqueue provides a mock function with given fields: ctx, accountID, tx, txID, txCfgs +func (_m *TxManager) Enqueue(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, txCfgs ...utils.SetTxConfig) error { + _va := make([]interface{}, len(txCfgs)) + for _i := range txCfgs { + _va[_i] = txCfgs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, accountID, tx, txID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Enqueue") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, *solana.Transaction, *string, ...utils.SetTxConfig) error); ok { + r0 = rf(ctx, accountID, tx, txID, txCfgs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TxManager_Enqueue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Enqueue' +type TxManager_Enqueue_Call struct { + *mock.Call +} + +// Enqueue is a helper method to define mock.On call +// - ctx context.Context +// - accountID string +// - tx *solana.Transaction +// - txID *string +// - txCfgs ...utils.SetTxConfig +func (_e *TxManager_Expecter) Enqueue(ctx interface{}, accountID interface{}, tx interface{}, txID interface{}, txCfgs ...interface{}) *TxManager_Enqueue_Call { + return &TxManager_Enqueue_Call{Call: _e.mock.On("Enqueue", + append([]interface{}{ctx, accountID, tx, txID}, txCfgs...)...)} +} + +func (_c *TxManager_Enqueue_Call) Run(run func(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, txCfgs ...utils.SetTxConfig)) *TxManager_Enqueue_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]utils.SetTxConfig, len(args)-4) + for i, a := range args[4:] { + if a != nil { + variadicArgs[i] = a.(utils.SetTxConfig) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(*solana.Transaction), args[3].(*string), variadicArgs...) + }) + return _c +} + +func (_c *TxManager_Enqueue_Call) Return(_a0 error) *TxManager_Enqueue_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Enqueue_Call) RunAndReturn(run func(context.Context, string, *solana.Transaction, *string, ...utils.SetTxConfig) error) *TxManager_Enqueue_Call { + _c.Call.Return(run) + return _c +} + +// GetTransactionStatus provides a mock function with given fields: ctx, transactionID +func (_m *TxManager) GetTransactionStatus(ctx context.Context, transactionID string) (types.TransactionStatus, error) { + ret := _m.Called(ctx, transactionID) + + if len(ret) == 0 { + panic("no return value specified for GetTransactionStatus") + } + + var r0 types.TransactionStatus + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.TransactionStatus, error)); ok { + return rf(ctx, transactionID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.TransactionStatus); ok { + r0 = rf(ctx, transactionID) + } else { + r0 = ret.Get(0).(types.TransactionStatus) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, transactionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// TxManager_GetTransactionStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTransactionStatus' +type TxManager_GetTransactionStatus_Call struct { + *mock.Call +} + +// GetTransactionStatus is a helper method to define mock.On call +// - ctx context.Context +// - transactionID string +func (_e *TxManager_Expecter) GetTransactionStatus(ctx interface{}, transactionID interface{}) *TxManager_GetTransactionStatus_Call { + return &TxManager_GetTransactionStatus_Call{Call: _e.mock.On("GetTransactionStatus", ctx, transactionID)} +} + +func (_c *TxManager_GetTransactionStatus_Call) Run(run func(ctx context.Context, transactionID string)) *TxManager_GetTransactionStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *TxManager_GetTransactionStatus_Call) Return(_a0 types.TransactionStatus, _a1 error) *TxManager_GetTransactionStatus_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *TxManager_GetTransactionStatus_Call) RunAndReturn(run func(context.Context, string) (types.TransactionStatus, error)) *TxManager_GetTransactionStatus_Call { + _c.Call.Return(run) + return _c +} + +// HealthReport provides a mock function with given fields: +func (_m *TxManager) HealthReport() map[string]error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for HealthReport") + } + + var r0 map[string]error + if rf, ok := ret.Get(0).(func() map[string]error); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]error) + } + } + + return r0 +} + +// TxManager_HealthReport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HealthReport' +type TxManager_HealthReport_Call struct { + *mock.Call +} + +// HealthReport is a helper method to define mock.On call +func (_e *TxManager_Expecter) HealthReport() *TxManager_HealthReport_Call { + return &TxManager_HealthReport_Call{Call: _e.mock.On("HealthReport")} +} + +func (_c *TxManager_HealthReport_Call) Run(run func()) *TxManager_HealthReport_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TxManager_HealthReport_Call) Return(_a0 map[string]error) *TxManager_HealthReport_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_HealthReport_Call) RunAndReturn(run func() map[string]error) *TxManager_HealthReport_Call { + _c.Call.Return(run) + return _c +} + +// Name provides a mock function with given fields: +func (_m *TxManager) Name() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Name") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// TxManager_Name_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Name' +type TxManager_Name_Call struct { + *mock.Call +} + +// Name is a helper method to define mock.On call +func (_e *TxManager_Expecter) Name() *TxManager_Name_Call { + return &TxManager_Name_Call{Call: _e.mock.On("Name")} +} + +func (_c *TxManager_Name_Call) Run(run func()) *TxManager_Name_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TxManager_Name_Call) Return(_a0 string) *TxManager_Name_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Name_Call) RunAndReturn(run func() string) *TxManager_Name_Call { + _c.Call.Return(run) + return _c +} + +// Ready provides a mock function with given fields: +func (_m *TxManager) Ready() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Ready") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TxManager_Ready_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ready' +type TxManager_Ready_Call struct { + *mock.Call +} + +// Ready is a helper method to define mock.On call +func (_e *TxManager_Expecter) Ready() *TxManager_Ready_Call { + return &TxManager_Ready_Call{Call: _e.mock.On("Ready")} +} + +func (_c *TxManager_Ready_Call) Run(run func()) *TxManager_Ready_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TxManager_Ready_Call) Return(_a0 error) *TxManager_Ready_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Ready_Call) RunAndReturn(run func() error) *TxManager_Ready_Call { + _c.Call.Return(run) + return _c +} + +// Start provides a mock function with given fields: _a0 +func (_m *TxManager) Start(_a0 context.Context) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Start") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TxManager_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type TxManager_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +// - _a0 context.Context +func (_e *TxManager_Expecter) Start(_a0 interface{}) *TxManager_Start_Call { + return &TxManager_Start_Call{Call: _e.mock.On("Start", _a0)} +} + +func (_c *TxManager_Start_Call) Run(run func(_a0 context.Context)) *TxManager_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *TxManager_Start_Call) Return(_a0 error) *TxManager_Start_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Start_Call) RunAndReturn(run func(context.Context) error) *TxManager_Start_Call { + _c.Call.Return(run) + return _c +} + +// NewTxManager creates a new instance of TxManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTxManager(t interface { + mock.TestingT + Cleanup(func()) +}) *TxManager { + mock := &TxManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/solana/txm/pendingtx.go b/pkg/solana/txm/pendingtx.go index ecae7243b..033b0c16f 100644 --- a/pkg/solana/txm/pendingtx.go +++ b/pkg/solana/txm/pendingtx.go @@ -9,6 +9,8 @@ import ( "github.com/gagliardetto/solana-go" "golang.org/x/exp/maps" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) var ( @@ -37,11 +39,11 @@ type PendingTxContext interface { // OnFinalized marks transaction as Finalized, moves it from the broadcasted or confirmed map to finalized map, removes signatures from signature map to stop confirmation checks OnFinalized(sig solana.Signature, retentionTimeout time.Duration) (string, error) // OnPrebroadcastError adds transaction that has not yet been broadcasted to the finalized/errored map as errored, matches err type using enum - OnPrebroadcastError(id string, retentionTimeout time.Duration, txState TxState, errType TxErrType) error + OnPrebroadcastError(id string, retentionTimeout time.Duration, txState utils.TxState, errType TxErrType) error // OnError marks transaction as errored, matches err type using enum, moves it from the broadcasted or confirmed map to finalized/errored map, removes signatures from signature map to stop confirmation checks - OnError(sig solana.Signature, retentionTimeout time.Duration, txState TxState, errType TxErrType) (string, error) + OnError(sig solana.Signature, retentionTimeout time.Duration, txState utils.TxState, errType TxErrType) (string, error) // GetTxState returns the transaction state for the provided ID if it exists - GetTxState(id string) (TxState, error) + GetTxState(id string) (utils.TxState, error) // TrimFinalizedErroredTxs removes transactions that have reached their retention time TrimFinalizedErroredTxs() int } @@ -49,17 +51,17 @@ type PendingTxContext interface { // finishedTx is used to store info required to track transactions to finality or error type pendingTx struct { tx solana.Transaction - cfg TxConfig + cfg utils.TxConfig signatures []solana.Signature id string createTs time.Time - state TxState + state utils.TxState } // finishedTx is used to store minimal info specifically for finalized or errored transactions for external status checks type finishedTx struct { retentionTs time.Time - state TxState + state utils.TxState } var _ PendingTxContext = &pendingTxContext{} @@ -116,7 +118,7 @@ func (c *pendingTxContext) New(tx pendingTx, sig solana.Signature, cancel contex // add signature to tx tx.signatures = append(tx.signatures, sig) tx.createTs = time.Now() - tx.state = Broadcasted + tx.state = utils.Broadcasted // save to the broadcasted map since transaction was just broadcasted c.broadcastedTxs[tx.id] = tx return "", nil @@ -251,7 +253,7 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { return ErrTransactionNotFound } // Check if tranasction already in processed state - if tx.state == Processed { + if tx.state == utils.Processed { return ErrAlreadyInExpectedState } return nil @@ -271,7 +273,7 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { return id, ErrTransactionNotFound } // update tx state to Processed - tx.state = Processed + tx.state = utils.Processed // save updated tx back to the broadcasted map c.broadcastedTxs[id] = tx return id, nil @@ -286,7 +288,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { return ErrSigDoesNotExist } // Check if transaction already in confirmed state - if tx, exists := c.confirmedTxs[id]; exists && tx.state == Confirmed { + if tx, exists := c.confirmedTxs[id]; exists && tx.state == utils.Confirmed { return ErrAlreadyInExpectedState } // Transactions should only move to confirmed from broadcasted/processed @@ -315,7 +317,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { delete(c.cancelBy, id) } // update tx state to Confirmed - tx.state = Confirmed + tx.state = utils.Confirmed // move tx to confirmed map c.confirmedTxs[id] = tx // remove tx from broadcasted map @@ -379,7 +381,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti return id, nil } finalizedTx := finishedTx{ - state: Finalized, + state: utils.Finalized, retentionTs: time.Now().Add(retentionTimeout), } // move transaction from confirmed to finalized map @@ -388,7 +390,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti }) } -func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState TxState, _ TxErrType) error { +func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState utils.TxState, _ TxErrType) error { // nothing to do if retention timeout is 0 since transaction is not stored yet. if retentionTimeout == 0 { return nil @@ -429,7 +431,7 @@ func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time. return err } -func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.Duration, txState TxState, _ TxErrType) (string, error) { +func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.Duration, txState utils.TxState, _ TxErrType) (string, error) { err := c.withReadLock(func() error { id, sigExists := c.sigToID[sig] if !sigExists { @@ -494,7 +496,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D }) } -func (c *pendingTxContext) GetTxState(id string) (TxState, error) { +func (c *pendingTxContext) GetTxState(id string) (utils.TxState, error) { c.lock.RLock() defer c.lock.RUnlock() if tx, exists := c.broadcastedTxs[id]; exists { @@ -506,7 +508,7 @@ func (c *pendingTxContext) GetTxState(id string) (TxState, error) { if tx, exists := c.finalizedErroredTxs[id]; exists { return tx.state, nil } - return NotFound, fmt.Errorf("failed to find transaction for id: %s", id) + return utils.NotFound, fmt.Errorf("failed to find transaction for id: %s", id) } // TrimFinalizedErroredTxs deletes transactions from the finalized/errored map and the allTxs map after the retention period has passed @@ -617,7 +619,7 @@ func (c *pendingTxContextWithProm) OnFinalized(sig solana.Signature, retentionTi return id, err } -func (c *pendingTxContextWithProm) OnError(sig solana.Signature, retentionTimeout time.Duration, txState TxState, errType TxErrType) (string, error) { +func (c *pendingTxContextWithProm) OnError(sig solana.Signature, retentionTimeout time.Duration, txState utils.TxState, errType TxErrType) (string, error) { id, err := c.pendingTx.OnError(sig, retentionTimeout, txState, errType) // err indicates transaction not found so may already be removed if err == nil { incrementErrorMetrics(errType, c.chainID) @@ -625,7 +627,7 @@ func (c *pendingTxContextWithProm) OnError(sig solana.Signature, retentionTimeou return id, err } -func (c *pendingTxContextWithProm) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState TxState, errType TxErrType) error { +func (c *pendingTxContextWithProm) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState utils.TxState, errType TxErrType) error { err := c.pendingTx.OnPrebroadcastError(id, retentionTimeout, txState, errType) // err indicates transaction not found so may already be removed if err == nil { incrementErrorMetrics(errType, c.chainID) @@ -652,7 +654,7 @@ func incrementErrorMetrics(errType TxErrType, chainID string) { promSolTxmErrorTxs.WithLabelValues(chainID).Inc() } -func (c *pendingTxContextWithProm) GetTxState(id string) (TxState, error) { +func (c *pendingTxContextWithProm) GetTxState(id string) (utils.TxState, error) { return c.pendingTx.GetTxState(id) } diff --git a/pkg/solana/txm/pendingtx_test.go b/pkg/solana/txm/pendingtx_test.go index e7b7fc51e..759f54ca3 100644 --- a/pkg/solana/txm/pendingtx_test.go +++ b/pkg/solana/txm/pendingtx_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) func TestPendingTxContext_add_remove_multiple(t *testing.T) { @@ -90,7 +92,7 @@ func TestPendingTxContext_new(t *testing.T) { require.Equal(t, sig, tx.signatures[0]) // Check status is Broadcasted - require.Equal(t, Broadcasted, tx.state) + require.Equal(t, utils.Broadcasted, tx.state) // Check it does not exist in confirmed map _, exists = txs.confirmedTxs[msg.id] @@ -222,7 +224,7 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { require.Equal(t, sig, tx.signatures[0]) // Check status is Processed - require.Equal(t, Processed, tx.state) + require.Equal(t, utils.Processed, tx.state) // Check it does not exist in confirmed map _, exists = txs.confirmedTxs[msg.id] @@ -293,7 +295,7 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, Errored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -361,7 +363,7 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { require.Equal(t, sig, tx.signatures[0]) // Check status is Confirmed - require.Equal(t, Confirmed, tx.state) + require.Equal(t, utils.Confirmed, tx.state) // Check it does not exist in finalized map _, exists = txs.finalizedErroredTxs[msg.id] @@ -405,7 +407,7 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, Errored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -475,7 +477,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.True(t, exists) // Check status is Finalized - require.Equal(t, Finalized, tx.state) + require.Equal(t, utils.Finalized, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig1] @@ -525,7 +527,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.True(t, exists) // Check status is Finalized - require.Equal(t, Finalized, tx.state) + require.Equal(t, utils.Finalized, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig1] @@ -583,7 +585,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, Errored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -608,7 +610,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, Errored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -625,7 +627,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.True(t, exists) // Check status is Finalized - require.Equal(t, Errored, tx.state) + require.Equal(t, utils.Errored, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig] @@ -646,7 +648,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Transition to errored state - id, err = txs.OnError(sig, retentionTimeout, Errored, 0) + id, err = txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -663,7 +665,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.True(t, exists) // Check status is Finalized - require.Equal(t, Errored, tx.state) + require.Equal(t, utils.Errored, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig] @@ -679,7 +681,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.NoError(t, err) // Transition to fatally errored state - id, err := txs.OnError(sig, retentionTimeout, FatallyErrored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.FatallyErrored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -692,7 +694,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.True(t, exists) // Check status is Errored - require.Equal(t, FatallyErrored, tx.state) + require.Equal(t, utils.FatallyErrored, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig] @@ -713,7 +715,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Transition to errored state - id, err = txs.OnError(sig, 0*time.Second, Errored, 0) + id, err = txs.OnError(sig, 0*time.Second, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -748,7 +750,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Transition back to confirmed state - id, err = txs.OnError(sig, retentionTimeout, Errored, 0) + id, err = txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.Error(t, err) require.Equal(t, "", id) }) @@ -764,7 +766,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} // Transition to errored state - err := txs.OnPrebroadcastError(msg.id, retentionTimeout, Errored, 0) + err := txs.OnPrebroadcastError(msg.id, retentionTimeout, utils.Errored, 0) require.NoError(t, err) // Check it exists in errored map @@ -772,7 +774,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { require.True(t, exists) // Check status is Errored - require.Equal(t, Errored, tx.state) + require.Equal(t, utils.Errored, tx.state) }) t.Run("successfully adds transaction with fatally errored state", func(t *testing.T) { @@ -780,7 +782,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { msg := pendingTx{id: uuid.NewString()} // Transition to fatally errored state - err := txs.OnPrebroadcastError(msg.id, retentionTimeout, FatallyErrored, 0) + err := txs.OnPrebroadcastError(msg.id, retentionTimeout, utils.FatallyErrored, 0) require.NoError(t, err) // Check it exists in errored map @@ -788,7 +790,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { require.True(t, exists) // Check status is Errored - require.Equal(t, FatallyErrored, tx.state) + require.Equal(t, utils.FatallyErrored, tx.state) }) t.Run("fails to add transaction to errored map if id exists in another map already", func(t *testing.T) { @@ -801,7 +803,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { require.NoError(t, err) // Transition to errored state - err = txs.OnPrebroadcastError(msg.id, retentionTimeout, FatallyErrored, 0) + err = txs.OnPrebroadcastError(msg.id, retentionTimeout, utils.FatallyErrored, 0) require.ErrorIs(t, err, ErrIDAlreadyExists) }) @@ -809,11 +811,11 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { txID := uuid.NewString() // Transition to errored state - err := txs.OnPrebroadcastError(txID, retentionTimeout, Errored, 0) + err := txs.OnPrebroadcastError(txID, retentionTimeout, utils.Errored, 0) require.NoError(t, err) // Transition back to errored state - err = txs.OnPrebroadcastError(txID, retentionTimeout, Errored, 0) + err = txs.OnPrebroadcastError(txID, retentionTimeout, utils.Errored, 0) require.ErrorIs(t, err, ErrAlreadyInExpectedState) }) } @@ -867,7 +869,7 @@ func TestPendingTxContext_remove(t *testing.T) { erroredMsg := pendingTx{id: uuid.NewString()} err = txs.New(erroredMsg, erroredSig, cancel) require.NoError(t, err) - id, err = txs.OnError(erroredSig, retentionTimeout, Errored, 0) + id, err = txs.OnError(erroredSig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, erroredMsg.id, id) @@ -1062,7 +1064,7 @@ func TestGetTxState(t *testing.T) { err := txs.New(broadcastedMsg, broadcastedSig, cancel) require.NoError(t, err) - var state TxState + var state utils.TxState // Create new processed transaction processedMsg := pendingTx{id: uuid.NewString()} err = txs.New(processedMsg, processedSig, cancel) @@ -1073,7 +1075,7 @@ func TestGetTxState(t *testing.T) { // Check Processed state is returned state, err = txs.GetTxState(processedMsg.id) require.NoError(t, err) - require.Equal(t, Processed, state) + require.Equal(t, utils.Processed, state) // Create new confirmed transaction confirmedMsg := pendingTx{id: uuid.NewString()} @@ -1085,7 +1087,7 @@ func TestGetTxState(t *testing.T) { // Check Confirmed state is returned state, err = txs.GetTxState(confirmedMsg.id) require.NoError(t, err) - require.Equal(t, Confirmed, state) + require.Equal(t, utils.Confirmed, state) // Create new finalized transaction finalizedMsg := pendingTx{id: uuid.NewString()} @@ -1097,36 +1099,36 @@ func TestGetTxState(t *testing.T) { // Check Finalized state is returned state, err = txs.GetTxState(finalizedMsg.id) require.NoError(t, err) - require.Equal(t, Finalized, state) + require.Equal(t, utils.Finalized, state) // Create new errored transaction erroredMsg := pendingTx{id: uuid.NewString()} err = txs.New(erroredMsg, erroredSig, cancel) require.NoError(t, err) - id, err = txs.OnError(erroredSig, retentionTimeout, Errored, 0) + id, err = txs.OnError(erroredSig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, erroredMsg.id, id) // Check Errored state is returned state, err = txs.GetTxState(erroredMsg.id) require.NoError(t, err) - require.Equal(t, Errored, state) + require.Equal(t, utils.Errored, state) // Create new fatally errored transaction fatallyErroredMsg := pendingTx{id: uuid.NewString()} err = txs.New(fatallyErroredMsg, fatallyErroredSig, cancel) require.NoError(t, err) - id, err = txs.OnError(fatallyErroredSig, retentionTimeout, FatallyErrored, 0) + id, err = txs.OnError(fatallyErroredSig, retentionTimeout, utils.FatallyErrored, 0) require.NoError(t, err) require.Equal(t, fatallyErroredMsg.id, id) // Check Errored state is returned state, err = txs.GetTxState(fatallyErroredMsg.id) require.NoError(t, err) - require.Equal(t, FatallyErrored, state) + require.Equal(t, utils.FatallyErrored, state) // Check NotFound state is returned if unknown id provided state, err = txs.GetTxState("unknown id") require.Error(t, err) - require.Equal(t, NotFound, state) + require.Equal(t, utils.NotFound, state) } func randomSignature(t *testing.T) solana.Signature { diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index 10cc1acd2..f5d3d8705 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -25,6 +25,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) const ( @@ -36,8 +37,6 @@ const ( MaxComputeUnitLimit = 1_400_000 // max compute unit limit a transaction can have ) -var _ services.Service = (*Txm)(nil) - type SimpleKeystore interface { Sign(ctx context.Context, account string, data []byte) (signature []byte, err error) Accounts(ctx context.Context) (accounts []string, err error) @@ -45,6 +44,14 @@ type SimpleKeystore interface { var _ loop.Keystore = (SimpleKeystore)(nil) +type TxManager interface { + services.Service + Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txCfgs ...txmutils.SetTxConfig) error + GetTransactionStatus(ctx context.Context, transactionID string) (commontypes.TransactionStatus, error) +} + +var _ TxManager = (*Txm)(nil) + // Txm manages transactions for the solana blockchain. // simple implementation with no persistently stored txs type Txm struct { @@ -64,19 +71,6 @@ type Txm struct { sendTx func(ctx context.Context, tx *solanaGo.Transaction) (solanaGo.Signature, error) } -type TxConfig struct { - Timeout time.Duration // transaction broadcast timeout - - // compute unit price config - FeeBumpPeriod time.Duration // how often to bump fee - BaseComputeUnitPrice uint64 // starting price - ComputeUnitPriceMin uint64 // min price - ComputeUnitPriceMax uint64 // max price - - EstimateComputeUnitLimit bool // enable compute limit estimations using simulation - ComputeUnitLimit uint32 // compute unit limit -} - // NewTxm creates a txm. Uses simulation so should only be used to send txes to trusted contracts i.e. OCR. func NewTxm(chainID string, client internal.Loader[client.ReaderWriter], sendTx func(ctx context.Context, tx *solanaGo.Transaction) (solanaGo.Signature, error), @@ -240,7 +234,7 @@ func (txm *Txm) sendWithRetry(ctx context.Context, msg pendingTx) (solanaGo.Tran sig, initSendErr := txm.sendTx(ctx, &initTx) if initSendErr != nil { cancel() // cancel context when exiting early - stateTransitionErr := txm.txs.OnPrebroadcastError(msg.id, txm.cfg.TxRetentionTimeout(), Errored, TxFailReject) + stateTransitionErr := txm.txs.OnPrebroadcastError(msg.id, txm.cfg.TxRetentionTimeout(), txmutils.Errored, TxFailReject) return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("tx failed initial transmit: %w", errors.Join(initSendErr, stateTransitionErr)) } @@ -252,7 +246,7 @@ func (txm *Txm) sendWithRetry(ctx context.Context, msg pendingTx) (solanaGo.Tran } // used for tracking rebroadcasting only in SendWithRetry - var sigs signatureList + var sigs txmutils.SignatureList sigs.Allocate() if initSetErr := sigs.Set(0, sig); initSetErr != nil { return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save initial signature in signature list: %w", initSetErr) @@ -402,7 +396,7 @@ func (txm *Txm) confirm() { // process signatures processSigs := func(s []solanaGo.Signature, res []*rpc.SignatureStatusesResult) { // sort signatures and results process successful first - s, res, err := SortSignaturesAndResults(s, res) + s, res, err := txmutils.SortSignaturesAndResults(s, res) if err != nil { txm.lggr.Errorw("sorting error", "error", err) return @@ -418,7 +412,7 @@ func (txm *Txm) confirm() { // check confirm timeout exceeded if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) + id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), txmutils.Errored, TxFailDrop) if err != nil { txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) } else { @@ -454,7 +448,7 @@ func (txm *Txm) confirm() { } // check confirm timeout exceeded if TxConfirmTimeout set if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) + id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), txmutils.Errored, TxFailDrop) if err != nil { txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) } else { @@ -580,7 +574,7 @@ func (txm *Txm) reap() { } // Enqueue enqueues a msg destined for the solana chain. -func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txCfgs ...SetTxConfig) error { +func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txCfgs ...txmutils.SetTxConfig) error { if err := txm.Ready(); err != nil { return fmt.Errorf("error in soltxm.Enqueue: %w", err) } @@ -650,15 +644,15 @@ func (txm *Txm) GetTransactionStatus(ctx context.Context, transactionID string) } switch state { - case Broadcasted: + case txmutils.Broadcasted: return commontypes.Pending, nil - case Processed, Confirmed: + case txmutils.Processed, txmutils.Confirmed: return commontypes.Unconfirmed, nil - case Finalized: + case txmutils.Finalized: return commontypes.Finalized, nil - case Errored: + case txmutils.Errored: return commontypes.Failed, nil - case FatallyErrored: + case txmutils.FatallyErrored: return commontypes.Fatal, nil default: return commontypes.Unknown, fmt.Errorf("found unknown transaction state: %s", state.String()) @@ -821,7 +815,7 @@ func (txm *Txm) ProcessError(sig solanaGo.Signature, resErr interface{}, simulat errType = TxFailSimOther } txm.lggr.Errorw("unrecognized error", logValues...) - return Errored, errType + return txmutils.Errored, errType } } return @@ -843,8 +837,8 @@ func (txm *Txm) Name() string { return txm.lggr.Name() } func (txm *Txm) HealthReport() map[string]error { return map[string]error{txm.Name(): txm.Healthy()} } -func (txm *Txm) defaultTxConfig() TxConfig { - return TxConfig{ +func (txm *Txm) defaultTxConfig() txmutils.TxConfig { + return txmutils.TxConfig{ Timeout: txm.cfg.TxRetryTimeout(), FeeBumpPeriod: txm.cfg.FeeBumpPeriod(), BaseComputeUnitPrice: txm.fee.BaseComputeUnitPrice(), diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index 0054e0a2b..759883287 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -24,6 +24,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -676,7 +677,7 @@ func TestTxm(t *testing.T) { // send tx - with disabled fee bumping testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, SetFeeBumpPeriod(0))) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, txmutils.SetFeeBumpPeriod(0))) wg.Wait() // no transactions stored inflight txs list @@ -728,7 +729,7 @@ func TestTxm(t *testing.T) { // send tx - with disabled fee bumping and disabled compute unit limit testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, SetFeeBumpPeriod(0), SetComputeUnitLimit(0))) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, txmutils.SetFeeBumpPeriod(0), txmutils.SetComputeUnitLimit(0))) wg.Wait() // no transactions stored inflight txs list diff --git a/pkg/solana/txm/utils.go b/pkg/solana/txm/utils/utils.go similarity index 82% rename from pkg/solana/txm/utils.go rename to pkg/solana/txm/utils/utils.go index fef260e3d..7f3ffb9e2 100644 --- a/pkg/solana/txm/utils.go +++ b/pkg/solana/txm/utils/utils.go @@ -1,4 +1,4 @@ -package txm +package utils import ( "errors" @@ -111,39 +111,39 @@ func convertStatus(res *rpc.SignatureStatusesResult) TxState { return NotFound } -type signatureList struct { +type SignatureList struct { sigs []solana.Signature lock sync.RWMutex wg []*sync.WaitGroup } // internal function that should be called using the proper lock -func (s *signatureList) get(index int) (sig solana.Signature, err error) { +func (s *SignatureList) get(index int) (sig solana.Signature, err error) { if index >= len(s.sigs) { return sig, errors.New("invalid index") } return s.sigs[index], nil } -func (s *signatureList) Get(index int) (sig solana.Signature, err error) { +func (s *SignatureList) Get(index int) (sig solana.Signature, err error) { s.lock.RLock() defer s.lock.RUnlock() return s.get(index) } -func (s *signatureList) List() []solana.Signature { +func (s *SignatureList) List() []solana.Signature { s.lock.RLock() defer s.lock.RUnlock() return s.sigs } -func (s *signatureList) Length() int { +func (s *SignatureList) Length() int { s.lock.RLock() defer s.lock.RUnlock() return len(s.sigs) } -func (s *signatureList) Allocate() (index int) { +func (s *SignatureList) Allocate() (index int) { s.lock.Lock() defer s.lock.Unlock() @@ -156,7 +156,7 @@ func (s *signatureList) Allocate() (index int) { return len(s.sigs) - 1 } -func (s *signatureList) Set(index int, sig solana.Signature) error { +func (s *SignatureList) Set(index int, sig solana.Signature) error { s.lock.Lock() defer s.lock.Unlock() @@ -174,7 +174,7 @@ func (s *signatureList) Set(index int, sig solana.Signature) error { return nil } -func (s *signatureList) Wait(index int) { +func (s *SignatureList) Wait(index int) { wg := &sync.WaitGroup{} s.lock.RLock() if index < len(s.wg) { @@ -185,6 +185,19 @@ func (s *signatureList) Wait(index int) { wg.Wait() } +type TxConfig struct { + Timeout time.Duration // transaction broadcast timeout + + // compute unit price config + FeeBumpPeriod time.Duration // how often to bump fee + BaseComputeUnitPrice uint64 // starting price + ComputeUnitPriceMin uint64 // min price + ComputeUnitPriceMax uint64 // max price + + EstimateComputeUnitLimit bool // enable compute limit estimations using simulation + ComputeUnitLimit uint32 // compute unit limit +} + type SetTxConfig func(*TxConfig) func SetTimeout(t time.Duration) SetTxConfig { diff --git a/pkg/solana/txm/utils_test.go b/pkg/solana/txm/utils/utils_test.go similarity index 98% rename from pkg/solana/txm/utils_test.go rename to pkg/solana/txm/utils/utils_test.go index f4ac868ff..676f04202 100644 --- a/pkg/solana/txm/utils_test.go +++ b/pkg/solana/txm/utils/utils_test.go @@ -1,4 +1,4 @@ -package txm +package utils import ( "sync" @@ -42,7 +42,7 @@ func TestSortSignaturesAndResults(t *testing.T) { } func TestSignatureList_AllocateWaitSet(t *testing.T) { - sigs := signatureList{} + sigs := SignatureList{} assert.Equal(t, 0, sigs.Length()) // can't set without pre-allocating diff --git a/pkg/solana/utils/utils.go b/pkg/solana/utils/utils.go index 0c772065b..3353d40b3 100644 --- a/pkg/solana/utils/utils.go +++ b/pkg/solana/utils/utils.go @@ -13,9 +13,10 @@ import ( "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" "github.com/pelletier/go-toml/v2" + "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" - "github.com/test-go/testify/require" ) var ( diff --git a/pkg/solana/utils/utils_test.go b/pkg/solana/utils/utils_test.go index 15a3e47d8..0f41f80c9 100644 --- a/pkg/solana/utils/utils_test.go +++ b/pkg/solana/utils/utils_test.go @@ -3,8 +3,9 @@ package utils_test import ( "testing" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/utils" "github.com/stretchr/testify/assert" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/utils" ) func TestLamportsToSol(t *testing.T) {