Skip to content

Commit

Permalink
refactor: index submitted txs in tx client and improve nonce manageme…
Browse files Browse the repository at this point in the history
…nt (#3830)

## Overview

Fixes #3899

---------

Co-authored-by: Callum Waters <[email protected]>
Co-authored-by: Rootul P <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2024
1 parent e24c0e8 commit e9278ed
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 121 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ test-race:
# TODO: Remove the -skip flag once the following tests no longer contain data races.
# https://github.com/celestiaorg/celestia-app/issues/1369
@echo "--> Running tests in race mode"
@go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestBlobstreamRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestBlobstreamCLI|TestUpgrade|TestMaliciousTestNode|TestBigBlobSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext|TestBlobstream|TestCLITestSuite|TestLegacyUpgrade|TestSignerTwins|TestConcurrentTxSubmission|TestTxClientTestSuite|Test_testnode"
@go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestBlobstreamRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestBlobstreamCLI|TestUpgrade|TestMaliciousTestNode|TestBigBlobSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext|TestBlobstream|TestCLITestSuite|TestLegacyUpgrade|TestSignerTwins|TestConcurrentTxSubmission|TestTxClientTestSuite|Test_testnode|TestEvictions"
.PHONY: test-race

## test-bench: Run unit tests in bench mode.
Expand Down
50 changes: 50 additions & 0 deletions pkg/user/pruning_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package user

import (
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestPruningInTxTracker(t *testing.T) {
txClient := &TxClient{
txTracker: make(map[string]txInfo),
}
numTransactions := 10

// Add 10 transactions to the tracker that are 10 and 5 minutes old
var txsToBePruned int
var txsNotReadyToBePruned int
for i := 0; i < numTransactions; i++ {
// 5 transactions will be pruned
if i%2 == 0 {
txClient.txTracker["tx"+fmt.Sprint(i)] = txInfo{
signer: "signer" + fmt.Sprint(i),
sequence: uint64(i),
timestamp: time.Now().
Add(-10 * time.Minute),
}
txsToBePruned++
} else {
txClient.txTracker["tx"+fmt.Sprint(i)] = txInfo{
signer: "signer" + fmt.Sprint(i),
sequence: uint64(i),
timestamp: time.Now().
Add(-5 * time.Minute),
}
txsNotReadyToBePruned++
}
}

txTrackerBeforePruning := len(txClient.txTracker)

// All transactions were indexed
require.Equal(t, numTransactions, len(txClient.txTracker))
txClient.pruneTxTracker()
// Prunes the transactions that are 10 minutes old
// 5 transactions will be pruned
require.Equal(t, txsToBePruned, txTrackerBeforePruning-txsToBePruned)
require.Equal(t, len(txClient.txTracker), txsNotReadyToBePruned)
}
198 changes: 100 additions & 98 deletions pkg/user/tx_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"time"

"github.com/celestiaorg/go-square/v2/share"
blobtx "github.com/celestiaorg/go-square/v2/tx"
"github.com/cosmos/cosmos-sdk/client"
nodeservice "github.com/cosmos/cosmos-sdk/client/grpc/node"
"github.com/cosmos/cosmos-sdk/client/grpc/tmservice"
Expand All @@ -27,20 +26,28 @@ import (

"github.com/celestiaorg/celestia-app/v3/app"
"github.com/celestiaorg/celestia-app/v3/app/encoding"
apperrors "github.com/celestiaorg/celestia-app/v3/app/errors"
"github.com/celestiaorg/celestia-app/v3/app/grpc/tx"
"github.com/celestiaorg/celestia-app/v3/pkg/appconsts"
"github.com/celestiaorg/celestia-app/v3/x/blob/types"
"github.com/celestiaorg/celestia-app/v3/x/minfee"
)

const (
DefaultPollTime = 3 * time.Second
DefaultGasMultiplier float64 = 1.1
DefaultPollTime = 3 * time.Second
DefaultGasMultiplier float64 = 1.1
txTrackerPruningInterval = 10 * time.Minute
)

type Option func(client *TxClient)

// txInfo is a struct that holds the sequence and the signer of a transaction
// in the local tx pool.
type txInfo struct {
sequence uint64
signer string
timestamp time.Time
}

// TxResponse is a response from the chain after
// a transaction has been submitted.
type TxResponse struct {
Expand Down Expand Up @@ -137,6 +144,9 @@ type TxClient struct {
defaultGasPrice float64
defaultAccount string
defaultAddress sdktypes.AccAddress
// txTracker maps the tx hash to the Sequence and signer of the transaction
// that was submitted to the chain
txTracker map[string]txInfo
}

// NewTxClient returns a new signer using the provided keyring
Expand Down Expand Up @@ -169,6 +179,7 @@ func NewTxClient(
defaultGasPrice: appconsts.DefaultMinGasPrice,
defaultAccount: records[0].Name,
defaultAddress: addr,
txTracker: make(map[string]txInfo),
}

for _, opt := range options {
Expand Down Expand Up @@ -302,6 +313,12 @@ func (client *TxClient) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts
func (client *TxClient) BroadcastTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (*sdktypes.TxResponse, error) {
client.mtx.Lock()
defer client.mtx.Unlock()

// prune transactions that are older than 10 minutes
// pruning has to be done in broadcast, since users
// might not always call ConfirmTx().
client.pruneTxTracker()

account, err := client.getAccountNameFromMsgs(msgs)
if err != nil {
return nil, err
Expand Down Expand Up @@ -368,23 +385,20 @@ func (client *TxClient) broadcastTx(ctx context.Context, txBytes []byte, signer
return nil, err
}
if resp.TxResponse.Code != abci.CodeTypeOK {
if apperrors.IsNonceMismatchCode(resp.TxResponse.Code) {
// query the account to update the sequence number on-chain for the account
_, seqNum, err := QueryAccount(ctx, client.grpc, client.registry, client.signer.accounts[signer].address)
if err != nil {
return nil, fmt.Errorf("querying account for new sequence number: %w\noriginal tx response: %s", err, resp.TxResponse.RawLog)
}
if err := client.signer.SetSequence(signer, seqNum); err != nil {
return nil, fmt.Errorf("setting sequence: %w", err)
}
return client.retryBroadcastingTx(ctx, txBytes)
}
broadcastTxErr := &BroadcastTxError{
TxHash: resp.TxResponse.TxHash,
Code: resp.TxResponse.Code,
ErrorLog: resp.TxResponse.RawLog,
}
return resp.TxResponse, broadcastTxErr
return nil, broadcastTxErr
}

// save the sequence and signer of the transaction in the local txTracker
// before the sequence is incremented
client.txTracker[resp.TxResponse.TxHash] = txInfo{
sequence: client.signer.accounts[signer].Sequence(),
signer: signer,
timestamp: time.Now(),
}

// after the transaction has been submitted, we can increment the
Expand All @@ -395,62 +409,13 @@ func (client *TxClient) broadcastTx(ctx context.Context, txBytes []byte, signer
return resp.TxResponse, nil
}

// retryBroadcastingTx creates a new transaction by copying over an existing transaction but creates a new signature with the
// new sequence number. It then calls `broadcastTx` and attempts to submit the transaction
func (client *TxClient) retryBroadcastingTx(ctx context.Context, txBytes []byte) (*sdktypes.TxResponse, error) {
blobTx, isBlobTx, err := blobtx.UnmarshalBlobTx(txBytes)
if isBlobTx {
// only check the error if the bytes are supposed to be of type blob tx
if err != nil {
return nil, err
// pruneTxTracker removes transactions from the local tx tracker that are older than 10 minutes
func (client *TxClient) pruneTxTracker() {
for hash, txInfo := range client.txTracker {
if time.Since(txInfo.timestamp) >= txTrackerPruningInterval {
delete(client.txTracker, hash)
}
txBytes = blobTx.Tx
}
tx, err := client.signer.DecodeTx(txBytes)
if err != nil {
return nil, err
}

opts := make([]TxOption, 0)
if granter := tx.FeeGranter(); granter != nil {
opts = append(opts, SetFeeGranter(granter))
}
if payer := tx.FeePayer(); payer != nil {
opts = append(opts, SetFeePayer(payer))
}
if memo := tx.GetMemo(); memo != "" {
opts = append(opts, SetMemo(memo))
}
if fee := tx.GetFee(); fee != nil {
opts = append(opts, SetFee(fee.AmountOf(appconsts.BondDenom).Uint64()))
}
if gas := tx.GetGas(); gas > 0 {
opts = append(opts, SetGasLimit(gas))
}

txBuilder, err := client.signer.txBuilder(tx.GetMsgs(), opts...)
if err != nil {
return nil, err
}
signer, _, err := client.signer.signTransaction(txBuilder)
if err != nil {
return nil, fmt.Errorf("resigning transaction: %w", err)
}

newTxBytes, err := client.signer.EncodeTx(txBuilder.GetTx())
if err != nil {
return nil, err
}

// rewrap the blob tx if it was originally a blob tx
if isBlobTx {
newTxBytes, err = blobtx.MarshalBlobTx(newTxBytes, blobTx.Blobs...)
if err != nil {
return nil, err
}
}

return client.broadcastTx(ctx, newTxBytes, signer)
}

// ConfirmTx periodically pings the provided node for the commitment of a transaction by its
Expand All @@ -468,40 +433,68 @@ func (client *TxClient) ConfirmTx(ctx context.Context, txHash string) (*TxRespon
return nil, err
}

if resp != nil {
switch resp.Status {
case core.TxStatusPending:
// Continue polling if the transaction is still pending
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-pollTicker.C:
continue
}
case core.TxStatusCommitted:
txResponse := &TxResponse{
Height: resp.Height,
TxHash: txHash,
Code: resp.ExecutionCode,
}
if resp.ExecutionCode != abci.CodeTypeOK {
executionErr := &ExecutionError{
TxHash: txHash,
Code: resp.ExecutionCode,
ErrorLog: resp.Error,
}
return nil, executionErr
switch resp.Status {
case core.TxStatusPending:
// Continue polling if the transaction is still pending
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-pollTicker.C:
continue
}
case core.TxStatusCommitted:
txResponse := &TxResponse{
Height: resp.Height,
TxHash: txHash,
Code: resp.ExecutionCode,
}
if resp.ExecutionCode != abci.CodeTypeOK {
executionErr := &ExecutionError{
TxHash: txHash,
Code: resp.ExecutionCode,
ErrorLog: resp.Error,
}
return txResponse, nil
case core.TxStatusEvicted:
return nil, fmt.Errorf("tx was evicted from the mempool")
default:
return nil, fmt.Errorf("unknown tx: %s", txHash)
client.deleteFromTxTracker(txHash)
return nil, executionErr
}
client.deleteFromTxTracker(txHash)
return txResponse, nil
case core.TxStatusEvicted:
return nil, client.handleEvictions(txHash)
default:
client.deleteFromTxTracker(txHash)
return nil, fmt.Errorf("transaction with hash %s not found; it was likely rejected", txHash)
}
}
}

// handleEvictions handles the scenario where a transaction is evicted from the mempool.
// It removes the evicted transaction from the local tx tracker without incrementing
// the signer's sequence.
func (client *TxClient) handleEvictions(txHash string) error {
client.mtx.Lock()
defer client.mtx.Unlock()
// Get transaction from the local tx tracker
txInfo, exists := client.txTracker[txHash]
if !exists {
return fmt.Errorf("tx: %s not found in tx client txTracker; likely failed during broadcast", txHash)
}
// The sequence should be rolled back to the sequence of the transaction that was evicted to be
// ready for resubmission. All transactions with a later nonce will be kicked by the nodes tx pool.
if err := client.signer.SetSequence(txInfo.signer, txInfo.sequence); err != nil {
return fmt.Errorf("setting sequence: %w", err)
}
delete(client.txTracker, txHash)
return fmt.Errorf("tx was evicted from the mempool")
}

// deleteFromTxTracker safely deletes a transaction from the local tx tracker.
func (client *TxClient) deleteFromTxTracker(txHash string) {
client.mtx.Lock()
defer client.mtx.Unlock()
delete(client.txTracker, txHash)
}

// EstimateGas simulates the transaction, calculating the amount of gas that was consumed during execution. The final
// result will be multiplied by gasMultiplier(that is set in TxClient)
func (client *TxClient) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (uint64, error) {
Expand Down Expand Up @@ -576,6 +569,7 @@ func (client *TxClient) checkAccountLoaded(ctx context.Context, account string)
if err != nil {
return fmt.Errorf("retrieving address from keyring: %w", err)
}
// FIXME: have a less trusting way of getting the account number and sequence
accNum, sequence, err := QueryAccount(ctx, client.grpc, client.registry, addr)
if err != nil {
return fmt.Errorf("querying account %s: %w", account, err)
Expand Down Expand Up @@ -604,6 +598,14 @@ func (client *TxClient) getAccountNameFromMsgs(msgs []sdktypes.Msg) (string, err
return record.Name, nil
}

// GetTxFromTxTracker gets transaction info from the tx client's local tx tracker by its hash
func (client *TxClient) GetTxFromTxTracker(hash string) (sequence uint64, signer string, exists bool) {
client.mtx.Lock()
defer client.mtx.Unlock()
txInfo, exists := client.txTracker[hash]
return txInfo.sequence, txInfo.signer, exists
}

// Signer exposes the tx clients underlying signer
func (client *TxClient) Signer() *Signer {
return client.signer
Expand Down
Loading

0 comments on commit e9278ed

Please sign in to comment.