Skip to content

Commit

Permalink
Undo removing waitUntilAccountSequence
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya1702 committed Sep 12, 2023
1 parent c257fff commit 61de345
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
47 changes: 47 additions & 0 deletions services/horizon/internal/txsub/system.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ func (sys *System) Submit(
return
}

if err = sys.waitUntilAccountSequence(ctx, db, sourceAddress, uint64(envelope.SeqNum())); err != nil {
sys.finish(ctx, hash, resultCh, Result{Err: err})
return
}

// If error is txBAD_SEQ, check for the result again
tx, err = txResultByHash(ctx, db, hash)
if err != nil {
Expand All @@ -164,6 +169,48 @@ func (sys *System) Submit(
return
}

// waitUntilAccountSequence blocks until either the context times out or the sequence number of the
// given source account is greater than or equal to `seq`
func (sys *System) waitUntilAccountSequence(ctx context.Context, db HorizonDB, sourceAddress string, seq uint64) error {
timer := time.NewTimer(sys.accountSeqPollInterval)
defer timer.Stop()

for {
sequenceNumbers, err := db.GetSequenceNumbers(ctx, []string{sourceAddress})
if err != nil {
sys.Log.Ctx(ctx).
WithError(err).
WithField("sourceAddress", sourceAddress).
Warn("cannot fetch sequence number")
} else {
num, ok := sequenceNumbers[sourceAddress]
if !ok {
sys.Log.Ctx(ctx).
WithField("sequenceNumbers", sequenceNumbers).
WithField("sourceAddress", sourceAddress).
Warn("missing sequence number for account")
}
if num >= seq {
return nil
}
}

select {
case <-ctx.Done():
return sys.deriveTxSubError(ctx)
case <-timer.C:
timer.Reset(sys.accountSeqPollInterval)
}
}
}

func (sys *System) deriveTxSubError(ctx context.Context) error {
if ctx.Err() == context.Canceled {
return ErrCanceled
}
return ErrTimeout
}

// Submit submits the provided base64 encoded transaction envelope to the
// network using this submission system.
func (sys *System) submitOnce(ctx context.Context, env string) SubmissionResult {
Expand Down
73 changes: 73 additions & 0 deletions services/horizon/internal/txsub/system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,73 @@ func (suite *SystemTestSuite) TestSubmit_Basic() {
assert.False(suite.T(), suite.submitter.WasSubmittedTo)
}

func (suite *SystemTestSuite) TestTimeoutDuringSequnceLoop() {
var cancel context.CancelFunc
suite.ctx, cancel = context.WithTimeout(suite.ctx, time.Duration(0))
defer cancel()

suite.submitter.R = suite.badSeq
suite.db.On("BeginTx", mock.AnythingOfType("*context.timerCtx"), &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: true,
}).Return(nil).Once()
suite.db.On("Rollback").Return(nil).Once()
suite.db.On("PreFilteredTransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash).
Return(sql.ErrNoRows).Once()
suite.db.On("TransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash).
Return(sql.ErrNoRows).Once()
suite.db.On("NoRows", sql.ErrNoRows).Return(true).Twice()
suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}).
Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil)

r := <-suite.system.Submit(
suite.ctx,
suite.successTx.Transaction.TxEnvelope,
suite.successXDR,
suite.successTx.Transaction.TransactionHash,
)

assert.NotNil(suite.T(), r.Err)
assert.Equal(suite.T(), ErrTimeout, r.Err)
}

func (suite *SystemTestSuite) TestClientDisconnectedDuringSequnceLoop() {
var cancel context.CancelFunc
suite.ctx, cancel = context.WithCancel(suite.ctx)

suite.submitter.R = suite.badSeq
suite.db.On("BeginTx", mock.AnythingOfType("*context.cancelCtx"), &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: true,
}).Return(nil).Once()
suite.db.On("Rollback").Return(nil).Once()
suite.db.On("PreFilteredTransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash).
Return(sql.ErrNoRows).Once()
suite.db.On("TransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash).
Return(sql.ErrNoRows).Once()
suite.db.On("NoRows", sql.ErrNoRows).Return(true).Twice()
suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}).
Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil).
Run(func(args mock.Arguments) {
// simulate client disconnecting while looping on sequnce number check
cancel()
suite.ctx.Deadline()
}).
Once()
suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}).
Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil)

r := <-suite.system.Submit(
suite.ctx,
suite.successTx.Transaction.TxEnvelope,
suite.successXDR,
suite.successTx.Transaction.TransactionHash,
)

assert.NotNil(suite.T(), r.Err)
assert.Equal(suite.T(), ErrCanceled, r.Err)
}

func getMetricValue(metric prometheus.Metric) *dto.Metric {
value := &dto.Metric{}
err := metric.Write(value)
Expand Down Expand Up @@ -198,6 +265,9 @@ func (suite *SystemTestSuite) TestSubmit_BadSeq() {
suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}).
Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil).
Once()
suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}).
Return(map[string]uint64{suite.unmuxedSource.Address(): 1}, nil).
Once()
suite.db.On("PreFilteredTransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash).
Return(sql.ErrNoRows).Twice()
suite.db.On("NoRows", sql.ErrNoRows).Return(true).Once()
Expand Down Expand Up @@ -239,6 +309,9 @@ func (suite *SystemTestSuite) TestSubmit_BadSeqNotFound() {
suite.db.On("NoRows", sql.ErrNoRows).Return(true).Twice()
suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}).
Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil).
Times(3)
suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}).
Return(map[string]uint64{suite.unmuxedSource.Address(): 1}, nil).
Once()

// set poll interval to 1ms so we don't need to wait 3 seconds for the test to complete
Expand Down

0 comments on commit 61de345

Please sign in to comment.