From 4118c49d8666e701af298423f93c77f5acf60be0 Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Thu, 4 Apr 2024 14:55:21 +0200 Subject: [PATCH] feat: nonce handling with signer (backport #3196) (#3213) Closes: #1910 This covers most cases by serializing the actual broadcasts to the consensus node and enabling resubmissions in the case that there is a sequence mismatch. This covers most fail cases with the possible exception of proposal nodes receiving the transactions in the reverse order to the initial nodes that the user broadcasted to There are also some interesting side affects that need to be handled when an existing accepted transaction is later kicked out of the mempool via CheckTx but overall I think this is a huge improvement for the UX of users
This is an automatic backport of pull request #3196 done by [Mergify](https://mergify.com). Co-authored-by: Callum Waters --- Makefile | 2 +- app/errors/nonce_mismatch.go | 20 ++- app/test/priority_test.go | 57 ++++--- pkg/user/e2e_test.go | 85 ++++++++++ pkg/user/signer.go | 321 ++++++++++++++++++++++++++++------- pkg/user/signer_test.go | 39 ++--- test/util/direct_tx_gen.go | 20 +-- 7 files changed, 417 insertions(+), 127 deletions(-) create mode 100644 pkg/user/e2e_test.go diff --git a/Makefile b/Makefile index cfdb71adf4..230d47f200 100644 --- a/Makefile +++ b/Makefile @@ -120,7 +120,7 @@ test-short: ## test-race: Run unit tests in race mode. test-race: @echo "--> Running tests in race mode" - @go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestQGBRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestQGBCLI|TestUpgrade|TestMaliciousTestNode|TestMaxTotalBlobSizeSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext" + @go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestQGBRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestQGBCLI|TestUpgrade|TestMaliciousTestNode|TestMaxTotalBlobSizeSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext|TestConcurrentTxSubmission" .PHONY: test-race ## test-bench: Run unit tests in bench mode. diff --git a/app/errors/nonce_mismatch.go b/app/errors/nonce_mismatch.go index 2726d61060..8209aac8b7 100644 --- a/app/errors/nonce_mismatch.go +++ b/app/errors/nonce_mismatch.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strconv" + "strings" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" ) @@ -13,6 +14,11 @@ func IsNonceMismatch(err error) bool { return errors.Is(err, sdkerrors.ErrWrongSequence) } +// IsNonceMismatch checks if the error code matches the sequence mismatch. +func IsNonceMismatchCode(code uint32) bool { + return code == sdkerrors.ErrWrongSequence.ABCICode() +} + // ParseNonceMismatch extracts the expected sequence number from the // ErrWrongSequence error. func ParseNonceMismatch(err error) (uint64, error) { @@ -20,9 +26,19 @@ func ParseNonceMismatch(err error) (uint64, error) { return 0, errors.New("error is not a sequence mismatch") } - numbers := regexpInt.FindAllString(err.Error(), -1) + return ParseExpectedSequence(err.Error()) +} + +// ParseExpectedSequence extracts the expected sequence number from the +// ErrWrongSequence error. +func ParseExpectedSequence(str string) (uint64, error) { + if !strings.HasPrefix(str, "account sequence mismatch") { + return 0, fmt.Errorf("unexpected wrong sequence error: %s", str) + } + + numbers := regexpInt.FindAllString(str, -1) if len(numbers) != 2 { - return 0, fmt.Errorf("unexpected wrong sequence error: %w", err) + return 0, fmt.Errorf("expected two numbers in string, got %d", len(numbers)) } // the first number is the expected sequence number diff --git a/app/test/priority_test.go b/app/test/priority_test.go index 6605cb564d..87639ef180 100644 --- a/app/test/priority_test.go +++ b/app/test/priority_test.go @@ -3,6 +3,7 @@ package app_test import ( "encoding/hex" "sort" + "sync" "testing" "time" @@ -70,43 +71,47 @@ func (s *PriorityTestSuite) TestPriorityByGasPrice() { t := s.T() // quickly submit blobs with a random fee - hashes := make([]string, 0, len(s.signers)) + + hashes := make(chan string, len(s.signers)) + blobSize := uint32(100) + gasLimit := blobtypes.DefaultEstimateGas([]uint32{blobSize}) + wg := &sync.WaitGroup{} for _, signer := range s.signers { - blobSize := uint32(100) - gasLimit := blobtypes.DefaultEstimateGas([]uint32{blobSize}) - gasPrice := s.rand.Float64() - btx, err := signer.CreatePayForBlob( - blobfactory.ManyBlobs( - t, - s.rand, - []namespace.Namespace{namespace.RandomBlobNamespace()}, - []int{100}), - user.SetGasLimitAndFee(gasLimit, gasPrice), - ) - require.NoError(t, err) - resp, err := signer.BroadcastTx(s.cctx.GoContext(), btx) - require.NoError(t, err) - require.Equal(t, abci.CodeTypeOK, resp.Code) - hashes = append(hashes, resp.TxHash) + wg.Add(1) + go func() { + defer wg.Done() + gasPrice := float64(s.rand.Intn(1000)+1) / 1000 + resp, err := signer.SubmitPayForBlob( + s.cctx.GoContext(), + blobfactory.ManyBlobs( + t, + s.rand, + []namespace.Namespace{namespace.RandomBlobNamespace()}, + []int{100}), + user.SetGasLimitAndFee(gasLimit, gasPrice), + ) + require.NoError(t, err) + require.Equal(t, abci.CodeTypeOK, resp.Code, resp.RawLog) + hashes <- resp.TxHash + }() } + wg.Wait() + close(hashes) + err := s.cctx.WaitForNextBlock() require.NoError(t, err) // get the responses for each tx for analysis and sort by height // note: use rpc types because they contain the tx index heightMap := make(map[int64][]*rpctypes.ResultTx) - for _, hash := range hashes { - resp, err := s.signers[0].ConfirmTx(s.cctx.GoContext(), hash) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, abci.CodeTypeOK, resp.Code) + for hash := range hashes { // use the core rpc type because it contains the tx index hash, err := hex.DecodeString(hash) require.NoError(t, err) coreRes, err := s.cctx.Client.Tx(s.cctx.GoContext(), hash, false) require.NoError(t, err) - heightMap[resp.Height] = append(heightMap[resp.Height], coreRes) + heightMap[coreRes.Height] = append(heightMap[coreRes.Height], coreRes) } require.GreaterOrEqual(t, len(heightMap), 1) @@ -123,7 +128,7 @@ func (s *PriorityTestSuite) TestPriorityByGasPrice() { // check that there was at least one block with more than three transactions // in it. This is more of a sanity check than a test. - require.True(t, highestNumOfTxsPerBlock > 3) + require.Greater(t, highestNumOfTxsPerBlock, 3) } func sortByIndex(txs []*rpctypes.ResultTx) []*rpctypes.ResultTx { @@ -135,14 +140,14 @@ func sortByIndex(txs []*rpctypes.ResultTx) []*rpctypes.ResultTx { func isSortedByFee(t *testing.T, ecfg encoding.Config, responses []*rpctypes.ResultTx) bool { for i := 0; i < len(responses)-1; i++ { - if gasPrice(t, ecfg, responses[i]) <= gasPrice(t, ecfg, responses[i+1]) { + if getGasPrice(t, ecfg, responses[i]) <= getGasPrice(t, ecfg, responses[i+1]) { return false } } return true } -func gasPrice(t *testing.T, ecfg encoding.Config, resp *rpctypes.ResultTx) float64 { +func getGasPrice(t *testing.T, ecfg encoding.Config, resp *rpctypes.ResultTx) float64 { sdkTx, err := ecfg.TxConfig.TxDecoder()(resp.Tx) require.NoError(t, err) feeTx := sdkTx.(sdk.FeeTx) diff --git a/pkg/user/e2e_test.go b/pkg/user/e2e_test.go new file mode 100644 index 0000000000..b195d27ed8 --- /dev/null +++ b/pkg/user/e2e_test.go @@ -0,0 +1,85 @@ +package user_test + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/celestiaorg/celestia-app/app" + "github.com/celestiaorg/celestia-app/app/encoding" + "github.com/celestiaorg/celestia-app/pkg/appconsts" + "github.com/celestiaorg/celestia-app/pkg/user" + "github.com/celestiaorg/celestia-app/test/util/blobfactory" + "github.com/celestiaorg/celestia-app/test/util/testnode" + "github.com/stretchr/testify/require" + tmrand "github.com/tendermint/tendermint/libs/rand" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" +) + +func TestConcurrentTxSubmission(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + // Setup network + tmConfig := testnode.DefaultTendermintConfig() + tmConfig.Consensus.TimeoutCommit = 10 * time.Second + ctx, _, _ := testnode.NewNetwork(t, testnode.DefaultConfig().WithTendermintConfig(tmConfig)) + _, err := ctx.WaitForHeight(1) + require.NoError(t, err) + + // Setup signer + signer, err := newSingleSignerFromContext(ctx) + require.NoError(t, err) + + // Pregenerate all the blobs + numTxs := 10 + blobs := blobfactory.ManyRandBlobs(t, tmrand.NewRand(), blobfactory.Repeat(2048, numTxs)...) + + // Prepare transactions + var ( + wg sync.WaitGroup + errCh = make(chan error) + ) + + subCtx, cancel := context.WithCancel(ctx.GoContext()) + defer cancel() + time.AfterFunc(time.Minute, cancel) + for i := 0; i < numTxs; i++ { + wg.Add(1) + go func(b *tmproto.Blob) { + defer wg.Done() + _, err := signer.SubmitPayForBlob(subCtx, []*tmproto.Blob{b}, user.SetGasLimitAndFee(500_000, appconsts.DefaultMinGasPrice)) + if err != nil && !errors.Is(err, context.Canceled) { + // only catch the first error + select { + case errCh <- err: + cancel() + default: + } + } + }(blobs[i]) + } + wg.Wait() + + select { + case err := <-errCh: + require.NoError(t, err) + default: + } +} + +func newSingleSignerFromContext(ctx testnode.Context) (*user.Signer, error) { + encCfg := encoding.MakeConfig(app.ModuleEncodingRegisters...) + record, err := ctx.Keyring.Key("validator") + if err != nil { + return nil, err + } + address, err := record.GetAddress() + if err != nil { + return nil, err + } + return user.SetupSigner(ctx.GoContext(), ctx.Keyring, ctx.GRPCClient, address, encCfg) +} diff --git a/pkg/user/signer.go b/pkg/user/signer.go index 6f3918d18f..0f84556bed 100644 --- a/pkg/user/signer.go +++ b/pkg/user/signer.go @@ -9,16 +9,18 @@ import ( "time" "github.com/celestiaorg/celestia-app/app/encoding" + apperrors "github.com/celestiaorg/celestia-app/app/errors" blob "github.com/celestiaorg/celestia-app/x/blob/types" "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/grpc/tmservice" "github.com/cosmos/cosmos-sdk/crypto/keyring" cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" sdktypes "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/types/tx" + sdktx "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx/signing" authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + abci "github.com/tendermint/tendermint/abci/types" tmproto "github.com/tendermint/tendermint/proto/tendermint/types" tmtypes "github.com/tendermint/tendermint/types" "google.golang.org/grpc" @@ -35,11 +37,18 @@ type Signer struct { pk cryptotypes.PubKey chainID string accountNumber uint64 - pollTime time.Duration - mtx sync.RWMutex - lastSignedSequence uint64 - lastConfirmedSequence uint64 + mtx sync.RWMutex + // how often to poll the network for confirmation of a transaction + pollTime time.Duration + // the signers local view of the sequence number + localSequence uint64 + // the chains last known sequence number + networkSequence uint64 + // lookup map of all pending and yet to be confirmed outbound transactions + outboundSequences map[uint64]struct{} + // a reverse map for confirming which sequence numbers have been committed + reverseTxHashSequenceMap map[string]uint64 } // NewSigner returns a new signer using the provided keyring @@ -64,16 +73,18 @@ func NewSigner( } return &Signer{ - keys: keys, - address: address, - grpc: conn, - enc: enc, - pk: pk, - chainID: chainID, - accountNumber: accountNumber, - lastSignedSequence: sequence, - lastConfirmedSequence: sequence, - pollTime: DefaultPollTime, + keys: keys, + address: address, + grpc: conn, + enc: enc, + pk: pk, + chainID: chainID, + accountNumber: accountNumber, + localSequence: sequence, + networkSequence: sequence, + pollTime: DefaultPollTime, + outboundSequences: make(map[uint64]struct{}), + reverseTxHashSequenceMap: make(map[string]uint64), }, nil } @@ -125,17 +136,14 @@ func SetupSigner( // SubmitTx forms a transaction from the provided messages, signs it, and submits it to the chain. TxOptions // may be provided to set the fee and gas limit. func (s *Signer) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (*sdktypes.TxResponse, error) { - txBytes, err := s.CreateTx(msgs, opts...) + tx, err := s.CreateTx(msgs, opts...) if err != nil { return nil, err } - resp, err := s.BroadcastTx(ctx, txBytes) + resp, err := s.BroadcastTx(ctx, tx) if err != nil { - return nil, err - } - if resp.Code != 0 { - return resp, fmt.Errorf("tx failed with code %d: %s", resp.Code, resp.RawLog) + return resp, err } return s.ConfirmTx(ctx, resp.TxHash) @@ -144,25 +152,35 @@ func (s *Signer) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOp // SubmitPayForBlob forms a transaction from the provided blobs, signs it, and submits it to the chain. // TxOptions may be provided to set the fee and gas limit. func (s *Signer) SubmitPayForBlob(ctx context.Context, blobs []*tmproto.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { - txBytes, err := s.CreatePayForBlob(blobs, opts...) + resp, err := s.broadcastPayForBlob(ctx, blobs, opts...) if err != nil { - return nil, err + return resp, err } - resp, err := s.BroadcastTx(ctx, txBytes) + return s.ConfirmTx(ctx, resp.TxHash) +} + +func (s *Signer) broadcastPayForBlob(ctx context.Context, blobs []*blob.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + txBytes, seqNum, err := s.createPayForBlobs(blobs, opts...) if err != nil { return nil, err } - if resp.Code != 0 { - return resp, fmt.Errorf("tx failed with code %d: %s", resp.Code, resp.RawLog) - } - return s.ConfirmTx(ctx, resp.TxHash) + return s.broadcastTx(ctx, txBytes, seqNum) } // CreateTx forms a transaction from the provided messages and signs it. TxOptions may be optionally // used to set the gas limit and fee. -func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) ([]byte, error) { +func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + + return s.createTx(msgs, opts...) +} + +func (s *Signer) createTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { txBuilder := s.txBuilder(opts...) if err := txBuilder.SetMsgs(msgs...); err != nil { return nil, err @@ -172,71 +190,200 @@ func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) ([]byte, error) return nil, err } - return s.enc.TxEncoder()(txBuilder.GetTx()) + return txBuilder.GetTx(), nil } func (s *Signer) CreatePayForBlob(blobs []*tmproto.Blob, opts ...TxOption) ([]byte, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + blobTx, _, err := s.createPayForBlobs(blobs, opts...) + return blobTx, err +} + +func (s *Signer) createPayForBlobs(blobs []*tmproto.Blob, opts ...TxOption) ([]byte, uint64, error) { msg, err := blob.NewMsgPayForBlobs(s.address.String(), blobs...) if err != nil { - return nil, err + return nil, 0, err } - txBytes, err := s.CreateTx([]sdktypes.Msg{msg}, opts...) + tx, err := s.createTx([]sdktypes.Msg{msg}, opts...) if err != nil { - return nil, err + return nil, 0, err + } + + seqNum, err := getSequenceNumber(tx) + if err != nil { + panic(err) + } + + txBytes, err := s.EncodeTx(tx) + if err != nil { + return nil, 0, err } - return tmtypes.MarshalBlobTx(txBytes, blobs...) + blobTx, err := tmtypes.MarshalBlobTx(txBytes, blobs...) + return blobTx, seqNum, err +} + +func (s *Signer) EncodeTx(tx sdktypes.Tx) ([]byte, error) { + return s.enc.TxEncoder()(tx) +} + +func (s *Signer) DecodeTx(txBytes []byte) (authsigning.Tx, error) { + tx, err := s.enc.TxDecoder()(txBytes) + if err != nil { + return nil, err + } + authTx, ok := tx.(authsigning.Tx) + if !ok { + return nil, errors.New("not an authsigning transaction") + } + return authTx, nil } // BroadcastTx submits the provided transaction bytes to the chain and returns the response. -func (s *Signer) BroadcastTx(ctx context.Context, txBytes []byte) (*sdktypes.TxResponse, error) { - txClient := tx.NewServiceClient(s.grpc) +func (s *Signer) BroadcastTx(ctx context.Context, tx authsigning.Tx) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + txBytes, err := s.EncodeTx(tx) + if err != nil { + return nil, err + } + sequence, err := getSequenceNumber(tx) + if err != nil { + return nil, err + } + return s.broadcastTx(ctx, txBytes, sequence) +} - // TODO (@cmwaters): handle nonce mismatch errors +// CONTRACT: assumes the caller has the lock +func (s *Signer) broadcastTx(ctx context.Context, txBytes []byte, sequence uint64) (*sdktypes.TxResponse, error) { + if _, exists := s.outboundSequences[sequence]; exists { + return s.retryBroadcastingTx(ctx, txBytes, sequence+1) + } + + if sequence < s.networkSequence { + s.localSequence = s.networkSequence + return s.retryBroadcastingTx(ctx, txBytes, s.localSequence) + } + + txClient := sdktx.NewServiceClient(s.grpc) resp, err := txClient.BroadcastTx( ctx, - &tx.BroadcastTxRequest{ - Mode: tx.BroadcastMode_BROADCAST_MODE_SYNC, + &sdktx.BroadcastTxRequest{ + Mode: sdktx.BroadcastMode_BROADCAST_MODE_SYNC, TxBytes: txBytes, }, ) if err != nil { return nil, err } - return resp.TxResponse, nil + if apperrors.IsNonceMismatchCode(resp.TxResponse.Code) { + // extract what the lastCommittedNonce on chain is + nextSequence, err := apperrors.ParseExpectedSequence(resp.TxResponse.RawLog) + if err != nil { + return nil, fmt.Errorf("parsing nonce mismatch upon retry: %w", err) + } + s.networkSequence = nextSequence + s.localSequence = nextSequence + // FIXME: We can't actually resign the transaction. A malicious node + // may manipulate us into signing the same transaction several times + // and then executing them. We need some proof of what the last network + // sequence is rather than relying on an error provided by the node + // return s.retryBroadcastingTx(ctx, txBytes, nextSequence) + // Ref: https://github.com/celestiaorg/celestia-app/issues/3256 + // return s.retryBroadcastingTx(ctx, txBytes, nextSequence) + } else if resp.TxResponse.Code == abci.CodeTypeOK { + s.outboundSequences[sequence] = struct{}{} + s.reverseTxHashSequenceMap[resp.TxResponse.TxHash] = sequence + return resp.TxResponse, nil + } + return resp.TxResponse, fmt.Errorf("tx failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) +} + +// retryBroadcastingTx creates a new transaction by copying over an existing transaction but creates a new signature with the +// new sequence number. It then calls `broadcastTx` and attempts to submit the transaction +func (s *Signer) retryBroadcastingTx(ctx context.Context, txBytes []byte, newSequenceNumber uint64) (*sdktypes.TxResponse, error) { + blobTx, isBlobTx := tmtypes.UnmarshalBlobTx(txBytes) + if isBlobTx { + txBytes = blobTx.Tx + } + tx, err := s.DecodeTx(txBytes) + if err != nil { + return nil, err + } + txBuilder := s.txBuilder() + if err := txBuilder.SetMsgs(tx.GetMsgs()...); err != nil { + return nil, err + } + if granter := tx.FeeGranter(); granter != nil { + txBuilder.SetFeeGranter(granter) + } + if payer := tx.FeePayer(); payer != nil { + txBuilder.SetFeePayer(payer) + } + if memo := tx.GetMemo(); memo != "" { + txBuilder.SetMemo(memo) + } + if fee := tx.GetFee(); fee != nil { + txBuilder.SetFeeAmount(fee) + } + if gas := tx.GetGas(); gas > 0 { + txBuilder.SetGasLimit(gas) + } + + if err := s.signTransaction(txBuilder, newSequenceNumber); err != nil { + return nil, fmt.Errorf("resigning transaction: %w", err) + } + + newTxBytes, err := s.EncodeTx(txBuilder.GetTx()) + if err != nil { + return nil, err + } + + // rewrap the blob tx if it was originally a blob tx + if isBlobTx { + newTxBytes, err = tmtypes.MarshalBlobTx(newTxBytes, blobTx.Blobs...) + if err != nil { + return nil, err + } + } + + return s.broadcastTx(ctx, newTxBytes, newSequenceNumber) } // ConfirmTx periodically pings the provided node for the commitment of a transaction by its // hash. It will continually loop until the context is cancelled, the tx is found or an error // is encountered. func (s *Signer) ConfirmTx(ctx context.Context, txHash string) (*sdktypes.TxResponse, error) { - txClient := tx.NewServiceClient(s.grpc) + txClient := sdktx.NewServiceClient(s.grpc) + + pollTime := s.getPollTime() timer := time.NewTimer(0) defer timer.Stop() + for { select { case <-ctx.Done(): return &sdktypes.TxResponse{}, ctx.Err() case <-timer.C: - resp, err := txClient.GetTx( - ctx, - &tx.GetTxRequest{ - Hash: txHash, - }, - ) + resp, err := txClient.GetTx(ctx, &sdktx.GetTxRequest{Hash: txHash}) if err == nil { if resp.TxResponse.Code != 0 { - return resp.TxResponse, fmt.Errorf("tx failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) + s.updateNetworkSequence(txHash, false) + return resp.TxResponse, fmt.Errorf("tx was included but failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) } + s.updateNetworkSequence(txHash, true) return resp.TxResponse, nil } - + // FIXME: this is a relatively brittle of working out whether to retry or not. The tx might be not found for other + // reasons. It may have been removed from the mempool at a later point. We should build an endpoint that gives the + // signer more information on the status of their transaction and then update the logic here if !strings.Contains(err.Error(), "not found") { return &sdktypes.TxResponse{}, err } - timer.Reset(s.pollTime) + timer.Reset(pollTime) } } } @@ -247,7 +394,7 @@ func (s *Signer) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...T return 0, err } - if err := s.signTransaction(txBuilder, s.Sequence()); err != nil { + if err := s.signTransaction(txBuilder, s.LocalSequence()); err != nil { return 0, err } @@ -256,7 +403,7 @@ func (s *Signer) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...T return 0, err } - resp, err := tx.NewServiceClient(s.grpc).Simulate(ctx, &tx.SimulateRequest{ + resp, err := sdktx.NewServiceClient(s.grpc).Simulate(ctx, &sdktx.SimulateRequest{ TxBytes: txBytes, }) if err != nil { @@ -288,37 +435,67 @@ func (s *Signer) SetPollTime(pollTime time.Duration) { s.pollTime = pollTime } +func (s *Signer) getPollTime() time.Duration { + s.mtx.Lock() + defer s.mtx.Unlock() + return s.pollTime +} + // PubKey returns the public key of the signer func (s *Signer) PubKey() cryptotypes.PubKey { return s.pk } -func (s *Signer) Sequence() uint64 { - s.mtx.Lock() - defer s.mtx.Unlock() - return s.lastSignedSequence -} - // DEPRECATED: use Sequence instead func (s *Signer) GetSequence() uint64 { return s.getAndIncrementSequence() } -// getAndIncrementSequence gets the lastest signed sequnce and increments the local sequence number +// LocalSequence returns the next sequence number of the signers +// locally saved +func (s *Signer) LocalSequence() uint64 { + s.mtx.RLock() + defer s.mtx.RUnlock() + return s.localSequence +} + +func (s *Signer) NetworkSequence() uint64 { + s.mtx.RLock() + defer s.mtx.RUnlock() + return s.networkSequence +} + +// getAndIncrementSequence gets the latest signed sequence and increments the +// local sequence number func (s *Signer) getAndIncrementSequence() uint64 { + defer func() { s.localSequence++ }() + return s.localSequence +} + +// ForceSetSequence manually overrides the current local and network level +// sequence number. Be careful when invoking this as it may cause the +// transactions to reject the sequence if it doesn't match the one in state +func (s *Signer) ForceSetSequence(seq uint64) { s.mtx.Lock() defer s.mtx.Unlock() - defer func() { s.lastSignedSequence++ }() - return s.lastSignedSequence + s.localSequence = seq + s.networkSequence = seq } -// ForceSetSequence manually overrides the current sequence number. Be careful when -// invoking this as it may cause the transactions to reject the sequence if -// it doesn't match the one in state -func (s *Signer) ForceSetSequence(seq uint64) { +// updateNetworkSequence is called once a transaction is confirmed +// and updates the chains last known sequence number +func (s *Signer) updateNetworkSequence(txHash string, success bool) { s.mtx.Lock() defer s.mtx.Unlock() - s.lastSignedSequence = seq + sequence, exists := s.reverseTxHashSequenceMap[txHash] + if !exists { + return + } + if success && sequence >= s.networkSequence { + s.networkSequence = sequence + 1 + } + delete(s.outboundSequences, sequence) + delete(s.reverseTxHashSequenceMap, txHash) } // Keyring exposes the signers underlying keyring @@ -426,3 +603,15 @@ func (s *Signer) getSignatureV2(sequence uint64, signature []byte) signing.Signa } return sigV2 } + +func getSequenceNumber(tx authsigning.Tx) (uint64, error) { + sigs, err := tx.GetSignaturesV2() + if err != nil { + return 0, err + } + if len(sigs) > 1 { + return 0, fmt.Errorf("only a signle signature is supported, got %d", len(sigs)) + } + + return sigs[0].Sequence, nil +} diff --git a/pkg/user/signer_test.go b/pkg/user/signer_test.go index ddea206ea7..ade88b3a00 100644 --- a/pkg/user/signer_test.go +++ b/pkg/user/signer_test.go @@ -14,7 +14,6 @@ import ( "github.com/celestiaorg/celestia-app/test/util/testnode" sdk "github.com/cosmos/cosmos-sdk/types" bank "github.com/cosmos/cosmos-sdk/x/bank/types" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" abci "github.com/tendermint/tendermint/abci/types" @@ -78,28 +77,30 @@ func (s *SignerTestSuite) TestConfirmTx() { gas := user.SetGasLimit(1e6) t.Run("deadline exceeded when the context times out", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), time.Second) defer cancel() _, err := s.signer.ConfirmTx(ctx, "E32BD15CAF57AF15D17B0D63CF4E63A9835DD1CEBB059C335C79586BC3013728") - assert.Error(t, err) - assert.Contains(t, err.Error(), context.DeadlineExceeded.Error()) + require.Error(t, err) + require.Contains(t, err.Error(), context.DeadlineExceeded.Error()) }) t.Run("should error when tx is not found", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), 5*time.Second) defer cancel() _, err := s.signer.ConfirmTx(ctx, "not found tx") - assert.Error(t, err) + require.Error(t, err) }) t.Run("should success when tx is found immediately", func(t *testing.T) { msg := bank.NewMsgSend(s.signer.Address(), testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) resp, err := s.submitTxWithoutConfirm([]sdk.Msg{msg}, fee, gas) - assert.NoError(t, err) - assert.NotNil(t, resp) - resp, err = s.signer.ConfirmTx(s.ctx.GoContext(), resp.TxHash) - assert.NoError(t, err) - assert.Equal(t, abci.CodeTypeOK, resp.Code) + require.NoError(t, err) + require.NotNil(t, resp) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), 30*time.Second) + defer cancel() + resp, err = s.signer.ConfirmTx(ctx, resp.TxHash) + require.NoError(t, err) + require.Equal(t, abci.CodeTypeOK, resp.Code) }) t.Run("should error when tx is found with a non-zero error code", func(t *testing.T) { @@ -107,17 +108,17 @@ func (s *SignerTestSuite) TestConfirmTx() { // Create a msg send with out of balance, ensure this tx fails msg := bank.NewMsgSend(s.signer.Address(), testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 1+balance))) resp, err := s.submitTxWithoutConfirm([]sdk.Msg{msg}, fee, gas) - assert.NoError(t, err) - assert.NotNil(t, resp) + require.NoError(t, err) + require.NotNil(t, resp) resp, err = s.signer.ConfirmTx(s.ctx.GoContext(), resp.TxHash) - assert.Error(t, err) - assert.NotEqual(t, abci.CodeTypeOK, resp.Code) + require.Error(t, err) + require.NotEqual(t, abci.CodeTypeOK, resp.Code) }) } func (s *SignerTestSuite) TestGasEstimation() { msg := bank.NewMsgSend(s.signer.Address(), testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) - gas, err := s.signer.EstimateGas(context.Background(), []sdk.Msg{msg}) + gas, err := s.signer.EstimateGas(s.ctx.GoContext(), []sdk.Msg{msg}) require.NoError(s.T(), err) require.Greater(s.T(), gas, uint64(0)) } @@ -148,13 +149,13 @@ func (s *SignerTestSuite) TestGasConsumption() { // verify that the amount deducted depends on the fee set in the tx. amountDeducted := balanceBefore - balanceAfter - utiaToSend - assert.Equal(t, int64(fee), amountDeducted) + require.Equal(t, int64(fee), amountDeducted) // verify that the amount deducted does not depend on the actual gas used. gasUsedBasedDeduction := resp.GasUsed * gasPrice - assert.NotEqual(t, gasUsedBasedDeduction, amountDeducted) + require.NotEqual(t, gasUsedBasedDeduction, amountDeducted) // The gas used based deduction should be less than the fee because the fee is 1 TIA. - assert.Less(t, gasUsedBasedDeduction, int64(fee)) + require.Less(t, gasUsedBasedDeduction, int64(fee)) } func (s *SignerTestSuite) queryCurrentBalance(t *testing.T) int64 { diff --git a/test/util/direct_tx_gen.go b/test/util/direct_tx_gen.go index 6340ce52f6..0e6d4dcb17 100644 --- a/test/util/direct_tx_gen.go +++ b/test/util/direct_tx_gen.go @@ -111,7 +111,7 @@ func DirectQueryAccount(app *app.App, addr sdk.AccAddress) authtypes.AccountI { // provided configuration. One blob transaction is generated per account // provided. The sequence and account numbers are set manually using the provided values. func RandBlobTxsWithManualSequence( - _ *testing.T, + t *testing.T, _ sdk.TxEncoder, kr keyring.Keyring, size int, @@ -172,25 +172,19 @@ func RandBlobTxsWithManualSequence( } if invalidSignature { invalidSig, err := builder.GetTx().GetSignaturesV2() - if err != nil { - panic(err) - } + require.NoError(t, err) invalidSig[0].Data.(*signing.SingleSignatureData).Signature = []byte("invalid signature") - if err := builder.SetSignatures(invalidSig...); err != nil { - panic(err) - } + err = builder.SetSignatures(invalidSig...) + require.NoError(t, err) stx = builder.GetTx() } rawTx, err := signer.EncodeTx(stx) - if err != nil { - panic(err) - } + require.NoError(t, err) + cTx, err := coretypes.MarshalBlobTx(rawTx, blobs...) - if err != nil { - panic(err) - } + require.NoError(t, err) txs[i] = cTx }