diff --git a/services/horizon/internal/txsub/system.go b/services/horizon/internal/txsub/system.go index 2e61eeafc7..8712e5e418 100644 --- a/services/horizon/internal/txsub/system.go +++ b/services/horizon/internal/txsub/system.go @@ -147,6 +147,11 @@ func (sys *System) Submit( return } + if err = sys.waitUntilAccountSequence(ctx, db, sourceAddress, uint64(envelope.SeqNum())); err != nil { + sys.finish(ctx, hash, resultCh, Result{Err: err}) + return + } + // If error is txBAD_SEQ, check for the result again tx, err = txResultByHash(ctx, db, hash) if err != nil { @@ -164,6 +169,48 @@ func (sys *System) Submit( return } +// waitUntilAccountSequence blocks until either the context times out or the sequence number of the +// given source account is greater than or equal to `seq` +func (sys *System) waitUntilAccountSequence(ctx context.Context, db HorizonDB, sourceAddress string, seq uint64) error { + timer := time.NewTimer(sys.accountSeqPollInterval) + defer timer.Stop() + + for { + sequenceNumbers, err := db.GetSequenceNumbers(ctx, []string{sourceAddress}) + if err != nil { + sys.Log.Ctx(ctx). + WithError(err). + WithField("sourceAddress", sourceAddress). + Warn("cannot fetch sequence number") + } else { + num, ok := sequenceNumbers[sourceAddress] + if !ok { + sys.Log.Ctx(ctx). + WithField("sequenceNumbers", sequenceNumbers). + WithField("sourceAddress", sourceAddress). + Warn("missing sequence number for account") + } + if num >= seq { + return nil + } + } + + select { + case <-ctx.Done(): + return sys.deriveTxSubError(ctx) + case <-timer.C: + timer.Reset(sys.accountSeqPollInterval) + } + } +} + +func (sys *System) deriveTxSubError(ctx context.Context) error { + if ctx.Err() == context.Canceled { + return ErrCanceled + } + return ErrTimeout +} + // Submit submits the provided base64 encoded transaction envelope to the // network using this submission system. func (sys *System) submitOnce(ctx context.Context, env string) SubmissionResult { diff --git a/services/horizon/internal/txsub/system_test.go b/services/horizon/internal/txsub/system_test.go index 02c8f2c063..e6bffc5987 100644 --- a/services/horizon/internal/txsub/system_test.go +++ b/services/horizon/internal/txsub/system_test.go @@ -146,6 +146,73 @@ func (suite *SystemTestSuite) TestSubmit_Basic() { assert.False(suite.T(), suite.submitter.WasSubmittedTo) } +func (suite *SystemTestSuite) TestTimeoutDuringSequnceLoop() { + var cancel context.CancelFunc + suite.ctx, cancel = context.WithTimeout(suite.ctx, time.Duration(0)) + defer cancel() + + suite.submitter.R = suite.badSeq + suite.db.On("BeginTx", mock.AnythingOfType("*context.timerCtx"), &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }).Return(nil).Once() + suite.db.On("Rollback").Return(nil).Once() + suite.db.On("PreFilteredTransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash). + Return(sql.ErrNoRows).Once() + suite.db.On("TransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash). + Return(sql.ErrNoRows).Once() + suite.db.On("NoRows", sql.ErrNoRows).Return(true).Twice() + suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). + Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil) + + r := <-suite.system.Submit( + suite.ctx, + suite.successTx.Transaction.TxEnvelope, + suite.successXDR, + suite.successTx.Transaction.TransactionHash, + ) + + assert.NotNil(suite.T(), r.Err) + assert.Equal(suite.T(), ErrTimeout, r.Err) +} + +func (suite *SystemTestSuite) TestClientDisconnectedDuringSequnceLoop() { + var cancel context.CancelFunc + suite.ctx, cancel = context.WithCancel(suite.ctx) + + suite.submitter.R = suite.badSeq + suite.db.On("BeginTx", mock.AnythingOfType("*context.cancelCtx"), &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }).Return(nil).Once() + suite.db.On("Rollback").Return(nil).Once() + suite.db.On("PreFilteredTransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash). + Return(sql.ErrNoRows).Once() + suite.db.On("TransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash). + Return(sql.ErrNoRows).Once() + suite.db.On("NoRows", sql.ErrNoRows).Return(true).Twice() + suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). + Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil). + Run(func(args mock.Arguments) { + // simulate client disconnecting while looping on sequnce number check + cancel() + suite.ctx.Deadline() + }). + Once() + suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). + Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil) + + r := <-suite.system.Submit( + suite.ctx, + suite.successTx.Transaction.TxEnvelope, + suite.successXDR, + suite.successTx.Transaction.TransactionHash, + ) + + assert.NotNil(suite.T(), r.Err) + assert.Equal(suite.T(), ErrCanceled, r.Err) +} + func getMetricValue(metric prometheus.Metric) *dto.Metric { value := &dto.Metric{} err := metric.Write(value) @@ -198,6 +265,9 @@ func (suite *SystemTestSuite) TestSubmit_BadSeq() { suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil). Once() + suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). + Return(map[string]uint64{suite.unmuxedSource.Address(): 1}, nil). + Once() suite.db.On("PreFilteredTransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash). Return(sql.ErrNoRows).Twice() suite.db.On("NoRows", sql.ErrNoRows).Return(true).Once() @@ -239,6 +309,9 @@ func (suite *SystemTestSuite) TestSubmit_BadSeqNotFound() { suite.db.On("NoRows", sql.ErrNoRows).Return(true).Twice() suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil). + Times(3) + suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). + Return(map[string]uint64{suite.unmuxedSource.Address(): 1}, nil). Once() // set poll interval to 1ms so we don't need to wait 3 seconds for the test to complete