diff --git a/common/types/types.go b/common/types/types.go index fcf37789a..d57225fbc 100644 --- a/common/types/types.go +++ b/common/types/types.go @@ -173,7 +173,6 @@ type MiniHeader struct { type Metadata struct { EthereumChainID int - MaxExpirationTime *big.Int EthRPCRequestsSentInCurrentUTCDay int StartOfCurrentUTCDay time.Time } diff --git a/core/core.go b/core/core.go index e5120f151..2e981299a 100644 --- a/core/core.go +++ b/core/core.go @@ -280,7 +280,7 @@ func newWithPrivateConfig(ctx context.Context, config Config, pConfig privateCon } // Initialize metadata and check stored chain id (if any). - metadata, err := initMetadata(config.EthereumChainID, database) + _, err = initMetadata(config.EthereumChainID, database) if err != nil { return nil, err } @@ -354,7 +354,6 @@ func newWithPrivateConfig(ctx context.Context, config Config, pConfig privateCon ChainID: config.EthereumChainID, ContractAddresses: contractAddresses, MaxOrders: config.MaxOrdersInStorage, - MaxExpirationTime: metadata.MaxExpirationTime, }) if err != nil { return nil, err @@ -462,8 +461,7 @@ func initMetadata(chainID int, database *db.DB) (*types.Metadata, error) { if err == db.ErrNotFound { // No stored metadata found (first startup) metadata = &types.Metadata{ - EthereumChainID: chainID, - MaxExpirationTime: constants.UnlimitedExpirationTime, + EthereumChainID: chainID, } if err := database.SaveMetadata(metadata); err != nil { return nil, err @@ -1014,6 +1012,10 @@ func (app *App) GetStats() (*types.Stats, error) { if err != nil { return nil, err } + maxExpirationTime, err := app.db.GetCurrentMaxExpirationTime() + if err != nil { + return nil, err + } response := &types.Stats{ Version: version, @@ -1027,7 +1029,7 @@ func (app *App) GetStats() (*types.Stats, error) { NumPeers: app.node.GetNumPeers(), NumOrdersIncludingRemoved: numOrdersIncludingRemoved, NumPinnedOrders: numPinnedOrders, - MaxExpirationTime: app.orderWatcher.MaxExpirationTime().String(), + MaxExpirationTime: maxExpirationTime.String(), StartOfCurrentUTCDay: metadata.StartOfCurrentUTCDay, EthRPCRequestsSentInCurrentUTCDay: metadata.EthRPCRequestsSentInCurrentUTCDay, EthRPCRateLimitExpiredRequests: app.ethRPCClient.GetRateLimitDroppedRequests(), diff --git a/db/common.go b/db/common.go index 98ad4243e..775cfa1c2 100644 --- a/db/common.go +++ b/db/common.go @@ -7,6 +7,7 @@ import ( "time" "github.com/0xProject/0x-mesh/common/types" + "github.com/0xProject/0x-mesh/constants" "github.com/0xProject/0x-mesh/ethereum" "github.com/0xProject/0x-mesh/zeroex" "github.com/ethereum/go-ethereum/common" @@ -224,6 +225,37 @@ func (db *DB) GetLatestMiniHeader() (*types.MiniHeader, error) { return latestMiniHeaders[0], nil } +// GetCurrentMaxExpirationTime returns the maximum expiration time for non-pinned orders +// stored in the database. If there are no non-pinned orders in the database, it returns +// constants.UnlimitedExpirationTime. +func (db *DB) GetCurrentMaxExpirationTime() (*big.Int, error) { + // Note(albrow): We don't include pinned orders because they are + // never removed due to exceeding the max expiration time. + ordersWithLongestExpirationTime, err := db.FindOrders(&OrderQuery{ + Filters: []OrderFilter{ + { + Field: OFIsPinned, + Kind: Equal, + Value: false, + }, + }, + Sort: []OrderSort{ + { + Field: OFExpirationTimeSeconds, + Direction: Descending, + }, + }, + Limit: 1, + }) + if err != nil { + return nil, err + } + if len(ordersWithLongestExpirationTime) == 0 { + return constants.UnlimitedExpirationTime, nil + } + return ordersWithLongestExpirationTime[0].ExpirationTimeSeconds, nil +} + func ParseContractAddressesAndTokenIdsFromAssetData(assetDataDecoder *zeroex.AssetDataDecoder, assetData []byte, contractAddresses ethereum.ContractAddresses) ([]*types.SingleAssetData, error) { if len(assetData) == 0 { return []*types.SingleAssetData{}, nil diff --git a/db/db_test.go b/db/db_test.go index e479404d8..e484617f0 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -50,6 +50,102 @@ func TestAddOrders(t *testing.T) { } } +func TestAddOrdersMaxExpirationTime(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + opts := TestOptions() + opts.MaxOrders = 10 + db, err := New(ctx, opts) + require.NoError(t, err) + + // Create the max number of orders with increasing expiration time + // 0, 1, 2, etc. + originalOrders := []*types.OrderWithMetadata{} + for i := 0; i < opts.MaxOrders; i++ { + testOrder := newTestOrder() + testOrder.ExpirationTimeSeconds = big.NewInt(int64(i)) + testOrder.IsPinned = false + originalOrders = append(originalOrders, testOrder) + } + + added, removed, err := db.AddOrders(originalOrders) + require.NoError(t, err) + assert.Len(t, removed, 0, "Expected no orders to be removed") + assertOrderSlicesAreUnsortedEqual(t, originalOrders, added) + + // Add two new orders, one with an expiration time too far in the future + // and another with an expiration time soon enough to replace an existing + // order. + currentMaxExpirationTime := originalOrders[len(originalOrders)-1].ExpirationTimeSeconds + orderWithLongerExpirationTime := newTestOrder() + orderWithLongerExpirationTime.IsPinned = false + orderWithLongerExpirationTime.ExpirationTimeSeconds = big.NewInt(0).Add(currentMaxExpirationTime, big.NewInt(1)) + orderWithShorterExpirationTime := newTestOrder() + orderWithShorterExpirationTime.IsPinned = false + orderWithShorterExpirationTime.ExpirationTimeSeconds = big.NewInt(0).Add(currentMaxExpirationTime, big.NewInt(-1)) + newOrders := []*types.OrderWithMetadata{orderWithLongerExpirationTime, orderWithShorterExpirationTime} + added, removed, err = db.AddOrders(newOrders) + require.NoError(t, err) + assertOrderSlicesAreUnsortedEqual(t, []*types.OrderWithMetadata{orderWithShorterExpirationTime}, added) + assertOrderSlicesAreUnsortedEqual(t, []*types.OrderWithMetadata{originalOrders[len(originalOrders)-1]}, removed) + + // Check the remaining orders in the database to make sure they are what we expect. + expectedStoredOrders := make([]*types.OrderWithMetadata, len(originalOrders)) + copy(expectedStoredOrders, originalOrders) + expectedStoredOrders[len(expectedStoredOrders)-1] = orderWithShorterExpirationTime + actualStoredOrders, err := db.FindOrders(nil) + assertOrderSlicesAreUnsortedEqual(t, expectedStoredOrders, actualStoredOrders) + + // Add some pinned orders. Pinned orders should replace non-pinned orders, even if + // they have a later expiration time. + pinnedOrders := []*types.OrderWithMetadata{} + for i := 0; i < opts.MaxOrders; i++ { + testOrder := newTestOrder() + testOrder.ExpirationTimeSeconds = big.NewInt(int64(i * 10)) + testOrder.IsPinned = true + pinnedOrders = append(pinnedOrders, testOrder) + } + added, removed, err = db.AddOrders(pinnedOrders) + require.NoError(t, err) + assert.Len(t, removed, 10, "expected all non-pinned orders to be removed") + assertOrderSlicesAreUnsortedEqual(t, pinnedOrders, added) + + // Add two new pinned orders, one with an expiration time too far in the future + // and another with an expiration time soon enough to replace an existing + // order. Then check that new pinned orders do replace existing pinned orders with + // longer expiration times. + currentMaxExpirationTime = pinnedOrders[len(pinnedOrders)-1].ExpirationTimeSeconds + pinnedOrderWithLongerExpirationTime := newTestOrder() + pinnedOrderWithLongerExpirationTime.IsPinned = true + pinnedOrderWithLongerExpirationTime.ExpirationTimeSeconds = big.NewInt(0).Add(currentMaxExpirationTime, big.NewInt(1)) + pinnedOrderWithShorterExpirationTime := newTestOrder() + pinnedOrderWithShorterExpirationTime.IsPinned = true + pinnedOrderWithShorterExpirationTime.ExpirationTimeSeconds = big.NewInt(0).Add(currentMaxExpirationTime, big.NewInt(-1)) + newPinnedOrders := []*types.OrderWithMetadata{pinnedOrderWithLongerExpirationTime, pinnedOrderWithShorterExpirationTime} + added, removed, err = db.AddOrders(newPinnedOrders) + require.NoError(t, err) + assertOrderSlicesAreUnsortedEqual(t, []*types.OrderWithMetadata{pinnedOrderWithShorterExpirationTime}, added) + assertOrderSlicesAreUnsortedEqual(t, []*types.OrderWithMetadata{pinnedOrders[len(pinnedOrders)-1]}, removed) + + // Check the remaining orders in the database to make sure they are what we expect. + expectedStoredOrders = make([]*types.OrderWithMetadata, len(pinnedOrders)) + copy(expectedStoredOrders, pinnedOrders) + expectedStoredOrders[len(expectedStoredOrders)-1] = pinnedOrderWithShorterExpirationTime + actualStoredOrders, err = db.FindOrders(nil) + assertOrderSlicesAreUnsortedEqual(t, expectedStoredOrders, actualStoredOrders) + + // Try to re-add the original (non-pinned) orders. Non-pinned orders should never replace pinned orders. + added, removed, err = db.AddOrders(originalOrders) + require.NoError(t, err) + assert.Len(t, removed, 0, "expected no pinned orders to be removed") + assert.Len(t, added, 0, "expected no non-pinned orders to be added") + + // Check that the orders stored in the database are the same as before (only + // pinned orders with the shortest expiration time) + actualStoredOrders, err = db.FindOrders(nil) + assertOrderSlicesAreUnsortedEqual(t, expectedStoredOrders, actualStoredOrders) +} + func TestGetOrder(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -68,6 +164,39 @@ func TestGetOrder(t *testing.T) { assert.EqualError(t, err, ErrNotFound.Error(), "calling GetOrder with a hash that doesn't exist should return ErrNotFound") } +func TestGetCurrentMaxExpirationTime(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + // Create some non-pinned orders with expiration times 0, 1, 2, etc. + nonPinnedOrders := []*types.OrderWithMetadata{} + for i := 0; i < 5; i++ { + order := newTestOrder() + order.ExpirationTimeSeconds = big.NewInt(int64(i)) + order.IsPinned = false + nonPinnedOrders = append(nonPinnedOrders, order) + } + _, _, err := db.AddOrders(nonPinnedOrders) + require.NoError(t, err) + + // Create some pinned orders with expiration times 0, 2, 4, etc. + pinnedOrders := []*types.OrderWithMetadata{} + for i := 0; i < 5; i++ { + order := newTestOrder() + order.ExpirationTimeSeconds = big.NewInt(int64(i * 2)) + order.IsPinned = true + pinnedOrders = append(pinnedOrders, order) + } + _, _, err = db.AddOrders(pinnedOrders) + require.NoError(t, err) + + expectedMaxExpirationTime := nonPinnedOrders[len(nonPinnedOrders)-1].ExpirationTimeSeconds + actualMaxExpirationTime, err := db.GetCurrentMaxExpirationTime() + require.NoError(t, err) + assert.Equal(t, expectedMaxExpirationTime, actualMaxExpirationTime) +} + func TestUpdateOrder(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -942,15 +1071,15 @@ func TestUpdateMetadata(t *testing.T) { err = db.SaveMetadata(originalMetadata) require.NoError(t, err) - updatedMaxExpirationTime := originalMetadata.MaxExpirationTime.Add(originalMetadata.MaxExpirationTime, big.NewInt(500)) + updatedETHRPCRequests := originalMetadata.EthRPCRequestsSentInCurrentUTCDay + 200 err = db.UpdateMetadata(func(existingMetadata *types.Metadata) *types.Metadata { updatedMetadata := existingMetadata - updatedMetadata.MaxExpirationTime = updatedMaxExpirationTime + updatedMetadata.EthRPCRequestsSentInCurrentUTCDay = updatedETHRPCRequests return updatedMetadata }) expectedMetadata := originalMetadata - expectedMetadata.MaxExpirationTime = updatedMaxExpirationTime + expectedMetadata.EthRPCRequestsSentInCurrentUTCDay = updatedETHRPCRequests foundMetadata, err := db.GetMetadata() require.NoError(t, err) assertMetadatasAreEqual(t, expectedMetadata, foundMetadata) @@ -1106,7 +1235,6 @@ func newTestEventLogs() []ethtypes.Log { func newTestMetadata() *types.Metadata { return &types.Metadata{ EthereumChainID: 42, - MaxExpirationTime: big.NewInt(12345), EthRPCRequestsSentInCurrentUTCDay: 1337, StartOfCurrentUTCDay: time.Date(1992, time.September, 29, 8, 0, 0, 0, time.UTC), } diff --git a/db/dexietypes/dexietypes.go b/db/dexietypes/dexietypes.go index 587a1dff5..00aa11e6f 100644 --- a/db/dexietypes/dexietypes.go +++ b/db/dexietypes/dexietypes.go @@ -144,6 +144,7 @@ type Order struct { FillableTakerAssetAmount *SortedBigInt `json:"fillableTakerAssetAmount"` IsRemoved uint8 `json:"isRemoved"` IsPinned uint8 `json:"isPinned"` + IsNotPinned uint8 `json:"isNotPinned"` // Used in a compound index in queries related to max expiration time. ParsedMakerAssetData string `json:"parsedMakerAssetData"` ParsedMakerFeeAssetData string `json:"parsedMakerFeeAssetData"` } @@ -157,10 +158,9 @@ type MiniHeader struct { } type Metadata struct { - EthereumChainID int `json:"ethereumChainID"` - MaxExpirationTime *SortedBigInt `json:"maxExpirationTime"` - EthRPCRequestsSentInCurrentUTCDay int `json:"ethRPCRequestsSentInCurrentUTCDay"` - StartOfCurrentUTCDay time.Time `json:"startOfCurrentUTCDay"` + EthereumChainID int `json:"ethereumChainID"` + EthRPCRequestsSentInCurrentUTCDay int `json:"ethRPCRequestsSentInCurrentUTCDay"` + StartOfCurrentUTCDay time.Time `json:"startOfCurrentUTCDay"` } func OrderToCommonType(order *Order) *types.OrderWithMetadata { @@ -222,6 +222,7 @@ func OrderFromCommonType(order *types.OrderWithMetadata) *Order { FillableTakerAssetAmount: NewSortedBigInt(order.FillableTakerAssetAmount), IsRemoved: BoolToUint8(order.IsRemoved), IsPinned: BoolToUint8(order.IsPinned), + IsNotPinned: BoolToUint8(!order.IsPinned), ParsedMakerAssetData: ParsedAssetDataFromCommonType(order.ParsedMakerAssetData), ParsedMakerFeeAssetData: ParsedAssetDataFromCommonType(order.ParsedMakerFeeAssetData), } @@ -352,7 +353,6 @@ func MetadataToCommonType(metadata *Metadata) *types.Metadata { } return &types.Metadata{ EthereumChainID: metadata.EthereumChainID, - MaxExpirationTime: metadata.MaxExpirationTime.Int, EthRPCRequestsSentInCurrentUTCDay: metadata.EthRPCRequestsSentInCurrentUTCDay, StartOfCurrentUTCDay: metadata.StartOfCurrentUTCDay, } @@ -364,7 +364,6 @@ func MetadataFromCommonType(metadata *types.Metadata) *Metadata { } return &Metadata{ EthereumChainID: metadata.EthereumChainID, - MaxExpirationTime: NewSortedBigInt(metadata.MaxExpirationTime), EthRPCRequestsSentInCurrentUTCDay: metadata.EthRPCRequestsSentInCurrentUTCDay, StartOfCurrentUTCDay: metadata.StartOfCurrentUTCDay, } diff --git a/db/sql_implementation.go b/db/sql_implementation.go index 8e74e9158..eb03a5c9c 100644 --- a/db/sql_implementation.go +++ b/db/sql_implementation.go @@ -7,6 +7,7 @@ import ( "database/sql" "errors" "fmt" + "math" "math/big" "os" "path/filepath" @@ -21,6 +22,9 @@ import ( _ "github.com/mattn/go-sqlite3" ) +// largeLimit is used as a workaround due to the fact that SQL does not allow limit without offset. +const largeLimit = math.MaxInt64 + var _ Database = (*DB)(nil) // DB instantiates the DB connection and creates all the collections used by the application @@ -128,7 +132,6 @@ CREATE TABLE IF NOT EXISTS miniHeaders ( CREATE TABLE IF NOT EXISTS metadata ( ethereumChainID BIGINT NOT NULL, - maxExpirationTime TEXT NOT NULL, ethRPCRequestsSentInCurrentUTCDay BIGINT NOT NULL, startOfCurrentUTCDay DATETIME NOT NULL ); @@ -232,19 +235,16 @@ const insertMiniHeaderQuery = `INSERT INTO miniHeaders ( const insertMetadataQuery = `INSERT INTO metadata ( ethereumChainID, - maxExpirationTime, ethRPCRequestsSentInCurrentUTCDay, startOfCurrentUTCDay ) VALUES ( :ethereumChainID, - :maxExpirationTime, :ethRPCRequestsSentInCurrentUTCDay, :startOfCurrentUTCDay )` const updateMetadataQuery = `UPDATE metadata SET ethereumChainID = :ethereumChainID, - maxExpirationTime = :maxExpirationTime, ethRPCRequestsSentInCurrentUTCDay = :ethRPCRequestsSentInCurrentUTCDay, startOfCurrentUTCDay = :startOfCurrentUTCDay ` @@ -258,33 +258,61 @@ func (db *DB) AddOrders(orders []*types.OrderWithMetadata) (added []*types.Order defer func() { err = convertErr(err) }() - txn, err := db.sqldb.BeginTxx(db.ctx, nil) - if err != nil { - return nil, nil, err - } - defer func() { - _ = txn.Rollback() - }() - for _, order := range orders { - result, err := txn.NamedExecContext(db.ctx, insertOrderQuery, sqltypes.OrderFromCommonType(order)) - if err != nil { - return nil, nil, err + addedMap := map[common.Hash]*types.OrderWithMetadata{} + err = db.sqldb.TransactionalContext(db.ctx, nil, func(txn *sqlz.Tx) error { + for _, order := range orders { + result, err := txn.NamedExecContext(db.ctx, insertOrderQuery, sqltypes.OrderFromCommonType(order)) + if err != nil { + return err + } + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + addedMap[order.Hash] = order + } } - affected, err := result.RowsAffected() + + // Remove orders with an expiration time too far in the future. + // HACK(albrow): sqlz doesn't support ORDER BY, LIMIT, and OFFSET + // for DELETE statements. It also doesn't support RETURNING. As a + // workaround, we do a SELECT and DELETE inside a transaction. + // HACK(albrow): SQL doesn't support limit without offset. As a + // workaround, we set the limit to an extremely large number. + removeQuery := txn.Select("*").From("orders"). + OrderBy(sqlz.Desc(string(OFIsPinned)), sqlz.Asc(string(OFExpirationTimeSeconds))). + Limit(largeLimit). + Offset(int64(db.opts.MaxOrders)) + var ordersToRemove []*sqltypes.Order + err = removeQuery.GetAllContext(db.ctx, &ordersToRemove) if err != nil { - return nil, nil, err + return err } - if affected > 0 { - added = append(added, order) + for _, order := range ordersToRemove { + _, err := txn.DeleteFrom("orders").Where(sqlz.Eq(string(OFHash), order.Hash)).ExecContext(db.ctx) + if err != nil { + return err + } + if _, found := addedMap[order.Hash]; found { + // If the order was previously added, remove it from + // the added set and don't add it to the removed set. + delete(addedMap, order.Hash) + } else { + removed = append(removed, sqltypes.OrderToCommonType(order)) + } } - } - if err := txn.Commit(); err != nil { + return nil + }) + if err != nil { return nil, nil, err } + for _, order := range addedMap { + added = append(added, order) + } - // TODO(albrow): Remove orders with longest expiration time. - return added, nil, nil + return added, removed, nil } func (db *DB) GetOrder(hash common.Hash) (order *types.OrderWithMetadata, err error) { @@ -452,40 +480,35 @@ func (db *DB) UpdateOrder(hash common.Hash, updateFunc func(existingOrder *types return errors.New("db.UpdateOrders: updateFunc cannot be nil") } - txn, err := db.sqldb.BeginTxx(db.ctx, nil) - if err != nil { - return err - } - defer func() { - _ = txn.Rollback() - }() - - var existingOrder sqltypes.Order - if err := txn.GetContext(db.ctx, &existingOrder, "SELECT * FROM orders WHERE hash = $1", hash); err != nil { - if err == sql.ErrNoRows { - return ErrNotFound + return db.sqldb.TransactionalContext(db.ctx, nil, func(txn *sqlz.Tx) error { + var existingOrder sqltypes.Order + if err := txn.GetContext(db.ctx, &existingOrder, "SELECT * FROM orders WHERE hash = $1", hash); err != nil { + if err == sql.ErrNoRows { + return ErrNotFound + } + return err } - return err - } - commonOrder := sqltypes.OrderToCommonType(&existingOrder) - commonUpdatedOrder, err := updateFunc(commonOrder) - if err != nil { - return fmt.Errorf("db.UpdateOrders: updateFunc returned error") - } - updatedOrder := sqltypes.OrderFromCommonType(commonUpdatedOrder) - _, err = txn.NamedExecContext(db.ctx, updateOrderQuery, updatedOrder) - if err != nil { - return err - } - return txn.Commit() + commonOrder := sqltypes.OrderToCommonType(&existingOrder) + commonUpdatedOrder, err := updateFunc(commonOrder) + if err != nil { + return fmt.Errorf("db.UpdateOrders: updateFunc returned error") + } + updatedOrder := sqltypes.OrderFromCommonType(commonUpdatedOrder) + _, err = txn.NamedExecContext(db.ctx, updateOrderQuery, updatedOrder) + if err != nil { + return err + } + return nil + }) } func (db *DB) AddMiniHeaders(miniHeaders []*types.MiniHeader) (added []*types.MiniHeader, removed []*types.MiniHeader, err error) { defer func() { err = convertErr(err) }() - var miniHeadersToRemove []*sqltypes.MiniHeader + + addedMap := map[common.Hash]*types.MiniHeader{} err = db.sqldb.TransactionalContext(db.ctx, nil, func(txn *sqlz.Tx) error { for _, miniHeader := range miniHeaders { result, err := txn.NamedExecContext(db.ctx, insertMiniHeaderQuery, sqltypes.MiniHeaderFromCommonType(miniHeader)) @@ -497,15 +520,18 @@ func (db *DB) AddMiniHeaders(miniHeaders []*types.MiniHeader) (added []*types.Mi return err } if affected > 0 { - added = append(added, miniHeader) + addedMap[miniHeader.Hash] = miniHeader } } // HACK(albrow): sqlz doesn't support ORDER BY, LIMIT, and OFFSET // for DELETE statements. It also doesn't support RETURNING. As a // workaround, we do a SELECT and DELETE inside a transaction. - trimQuery := txn.Select("*").From("miniHeaders").OrderBy(sqlz.Desc(string(MFNumber))).Limit(99999999999).Offset(int64(db.opts.MaxMiniHeaders)) - if err := trimQuery.GetAllContext(db.ctx, &miniHeadersToRemove); err != nil { + // HACK(albrow): SQL doesn't support limit without offset. As a + // workaround, we set the limit to an extremely large number. + removeQuery := txn.Select("*").From("miniHeaders").OrderBy(sqlz.Desc(string(MFNumber))).Limit(largeLimit).Offset(int64(db.opts.MaxMiniHeaders)) + var miniHeadersToRemove []*sqltypes.MiniHeader + if err := removeQuery.GetAllContext(db.ctx, &miniHeadersToRemove); err != nil { return err } for _, miniHeader := range miniHeadersToRemove { @@ -513,38 +539,24 @@ func (db *DB) AddMiniHeaders(miniHeaders []*types.MiniHeader) (added []*types.Mi if err != nil { return err } + if _, found := addedMap[miniHeader.Hash]; found { + // If the miniHeader was previously added, remove it from + // the added set and don't add it to the removed set. + delete(addedMap, miniHeader.Hash) + } else { + removed = append(removed, sqltypes.MiniHeaderToCommonType(miniHeader)) + } } return nil }) if err != nil { return nil, nil, err } - - // Because of how the above code is written, a single miniHeader could exist - // in both added and removed sets. We should remove such miniHeaders from both - // sets in this case. - addedMap := map[common.Hash]*types.MiniHeader{} - removedMap := map[common.Hash]*sqltypes.MiniHeader{} - for _, a := range added { - addedMap[a.Hash] = a - } - for _, r := range miniHeadersToRemove { - removedMap[r.Hash] = r - } - dedupedAdded := []*types.MiniHeader{} - dedupedRemoved := []*sqltypes.MiniHeader{} - for _, a := range added { - if _, wasRemoved := removedMap[a.Hash]; !wasRemoved { - dedupedAdded = append(dedupedAdded, a) - } - } - for _, r := range miniHeadersToRemove { - if _, wasAdded := addedMap[r.Hash]; !wasAdded { - dedupedRemoved = append(dedupedRemoved, r) - } + for _, miniHeader := range addedMap { + added = append(added, miniHeader) } - return dedupedAdded, sqltypes.MiniHeadersToCommonType(dedupedRemoved), nil + return added, removed, nil } func (db *DB) GetMiniHeader(hash common.Hash) (miniHeader *types.MiniHeader, err error) { @@ -725,30 +737,25 @@ func (db *DB) UpdateMetadata(updateFunc func(oldmetadata *types.Metadata) (newMe return errors.New("db.UpdateMetadata: updateFunc cannot be nil") } - txn, err := db.sqldb.BeginTxx(db.ctx, nil) - if err != nil { - return err - } - defer func() { - _ = txn.Rollback() - }() + return db.sqldb.TransactionalContext(db.ctx, nil, func(txn *sqlz.Tx) error { + var existingMetadata sqltypes.Metadata + if err := txn.GetContext(db.ctx, &existingMetadata, "SELECT * FROM metadata LIMIT 1"); err != nil { + if err == sql.ErrNoRows { + return ErrNotFound + } + return err + } - var existingMetadata sqltypes.Metadata - if err := txn.GetContext(db.ctx, &existingMetadata, "SELECT * FROM metadata LIMIT 1"); err != nil { - if err == sql.ErrNoRows { - return ErrNotFound + commonMetadata := sqltypes.MetadataToCommonType(&existingMetadata) + commonUpdatedMetadata := updateFunc(commonMetadata) + updatedMetadata := sqltypes.MetadataFromCommonType(commonUpdatedMetadata) + _, err = txn.NamedExecContext(db.ctx, updateMetadataQuery, updatedMetadata) + if err != nil { + return err } - return err - } - commonMetadata := sqltypes.MetadataToCommonType(&existingMetadata) - commonUpdatedMetadata := updateFunc(commonMetadata) - updatedMetadata := sqltypes.MetadataFromCommonType(commonUpdatedMetadata) - _, err = txn.NamedExecContext(db.ctx, updateMetadataQuery, updatedMetadata) - if err != nil { - return err - } - return txn.Commit() + return nil + }) } func convertFilterValue(value interface{}) interface{} { diff --git a/db/sqltypes/sqltypes.go b/db/sqltypes/sqltypes.go index 0f2d02a97..14a655b25 100644 --- a/db/sqltypes/sqltypes.go +++ b/db/sqltypes/sqltypes.go @@ -294,10 +294,9 @@ type MiniHeader struct { } type Metadata struct { - EthereumChainID int `db:"ethereumChainID"` - MaxExpirationTime *SortedBigInt `db:"maxExpirationTime"` - EthRPCRequestsSentInCurrentUTCDay int `db:"ethRPCRequestsSentInCurrentUTCDay"` - StartOfCurrentUTCDay time.Time `db:"startOfCurrentUTCDay"` + EthereumChainID int `db:"ethereumChainID"` + EthRPCRequestsSentInCurrentUTCDay int `db:"ethRPCRequestsSentInCurrentUTCDay"` + StartOfCurrentUTCDay time.Time `db:"startOfCurrentUTCDay"` } func OrderToCommonType(order *Order) *types.OrderWithMetadata { @@ -460,7 +459,6 @@ func MetadataToCommonType(metadata *Metadata) *types.Metadata { } return &types.Metadata{ EthereumChainID: metadata.EthereumChainID, - MaxExpirationTime: metadata.MaxExpirationTime.Int, EthRPCRequestsSentInCurrentUTCDay: metadata.EthRPCRequestsSentInCurrentUTCDay, StartOfCurrentUTCDay: metadata.StartOfCurrentUTCDay, } @@ -472,7 +470,6 @@ func MetadataFromCommonType(metadata *types.Metadata) *Metadata { } return &Metadata{ EthereumChainID: metadata.EthereumChainID, - MaxExpirationTime: NewSortedBigInt(metadata.MaxExpirationTime), EthRPCRequestsSentInCurrentUTCDay: metadata.EthRPCRequestsSentInCurrentUTCDay, StartOfCurrentUTCDay: metadata.StartOfCurrentUTCDay, } diff --git a/ethereum/ratelimit/rate_limiter_test.go b/ethereum/ratelimit/rate_limiter_test.go index 714f7a77f..8aa86637f 100644 --- a/ethereum/ratelimit/rate_limiter_test.go +++ b/ethereum/ratelimit/rate_limiter_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/0xProject/0x-mesh/common/types" - "github.com/0xProject/0x-mesh/constants" "github.com/0xProject/0x-mesh/db" "github.com/benbjohnson/clock" "github.com/stretchr/testify/assert" @@ -77,7 +76,6 @@ func TestScenario2(t *testing.T) { // Set metadata to just short of maximum requests per 24 hours. metadata := &types.Metadata{ EthereumChainID: 1337, - MaxExpirationTime: constants.UnlimitedExpirationTime, StartOfCurrentUTCDay: startOfCurrentUTCDay, EthRPCRequestsSentInCurrentUTCDay: requestsSentInCurrentDay, } @@ -133,7 +131,6 @@ func TestScenario3(t *testing.T) { // non-zero `EthRPCRequestsSentInCurrentUTCDay` metadata := &types.Metadata{ EthereumChainID: 1337, - MaxExpirationTime: constants.UnlimitedExpirationTime, StartOfCurrentUTCDay: yesterdayMidnightUTC, EthRPCRequestsSentInCurrentUTCDay: 5000, } @@ -184,8 +181,7 @@ func TestScenario3(t *testing.T) { func initMetadata(t *testing.T, database *db.DB) { metadata := &types.Metadata{ - EthereumChainID: 1337, - MaxExpirationTime: constants.UnlimitedExpirationTime, + EthereumChainID: 1337, } err := database.SaveMetadata(metadata) require.NoError(t, err) diff --git a/go.sum b/go.sum index 727acc8e5..82dbef071 100644 --- a/go.sum +++ b/go.sum @@ -37,7 +37,9 @@ github.com/albrow/go-envvar v1.1.1-0.20200123010345-a6ece4436cb7 h1:KyGi2bFjYJwa github.com/albrow/go-envvar v1.1.1-0.20200123010345-a6ece4436cb7/go.mod h1:jGxERjkVawmx7yWrFUix71jtSXm1ZtUai96wBHTwkPo= github.com/albrow/stringset v2.1.0+incompatible h1:P90SSV7fle22yLbhDSLRC8Jtec0tCE3A8hJihfxf25E= github.com/albrow/stringset v2.1.0+incompatible/go.mod h1:ltP0XRz96SPEM8ofD1BaE4IpTR2uCGSk6Z2VRfh1Llw= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc h1:cAKDfWh5VpdgMhJosfJnn5/FoN2SRZ4p7fJNX58YPaU= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf h1:qet1QNfXsQxTZqLG4oE62mJzwPIB8+Tee4RNCL9ulrY= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= @@ -705,6 +707,7 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= gopkg.in/DATA-DOG/go-sqlmock.v1 v1.3.0 h1:FVCohIoYO7IJoDDVpV2pdq7SgrMH6wHnuTyrdrxJNoY= gopkg.in/DATA-DOG/go-sqlmock.v1 v1.3.0/go.mod h1:OdE7CF6DbADk7lN8LIKRzRJTTZXIjtWgA5THM5lhBAw= +gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/packages/browser-lite/src/database.ts b/packages/browser-lite/src/database.ts index 7bf9af250..0ff79f174 100644 --- a/packages/browser-lite/src/database.ts +++ b/packages/browser-lite/src/database.ts @@ -75,6 +75,7 @@ export interface Order { lastUpdated: string; isRemoved: number; isPinned: number; + isNotPinned: number; // Used in a compound index in queries related to max expiration time. parsedMakerAssetData: string; parsedMakerFeeAssetData: string; } @@ -115,7 +116,6 @@ export interface AddMiniHeadersResult { export interface Metadata { ethereumChainID: number; - maxExpirationTime: string; ethRPCRequestsSentInCurrentUTCDay: number; startOfCurrentUTCDay: string; } @@ -139,7 +139,7 @@ export function createDatabase(opts: Options): Database { export class Database { private readonly _db: Dexie; - // private readonly _maxOrders: number; + private readonly _maxOrders: number; private readonly _maxMiniHeaders: number; private readonly _orders: Dexie.Table; private readonly _miniHeaders: Dexie.Table; @@ -147,12 +147,12 @@ export class Database { constructor(opts: Options) { this._db = new Dexie(opts.dataSourceName); - // this._maxOrders = opts.maxOrders; + this._maxOrders = opts.maxOrders; this._maxMiniHeaders = opts.maxMiniHeaders; this._db.version(1).stores({ orders: - '&hash,chainId,makerAddress,makerAssetData,makerAssetAmount,makerFee,makerFeeAssetData,takerAddress,takerAssetData,takerFeeAssetData,takerAssetAmount,takerFee,senderAddress,feeRecipientAddress,expirationTimeSeconds,salt,signature,exchangeAddress,fillableTakerAssetAmount,lastUpdated,isRemoved,isPinned,parsedMakerAssetData,parsedMakerFeeAssetData', + '&hash,chainId,makerAddress,makerAssetData,makerAssetAmount,makerFee,makerFeeAssetData,takerAddress,takerAssetData,takerFeeAssetData,takerAssetAmount,takerFee,senderAddress,feeRecipientAddress,expirationTimeSeconds,salt,signature,exchangeAddress,fillableTakerAssetAmount,lastUpdated,isRemoved,isPinned,parsedMakerAssetData,parsedMakerFeeAssetData,[isNotPinned+expirationTimeSeconds]', miniHeaders: '&hash,parent,number,timestamp,logs', metadata: 'ðereumChainID', }); @@ -168,9 +168,9 @@ export class Database { // AddOrders(orders []*types.OrderWithMetadata) (added []*types.OrderWithMetadata, removed []*types.OrderWithMetadata, err error) public async addOrdersAsync(orders: Order[]): Promise { - // TODO(albrow): Remove orders with max expiration time. - const added: Order[] = []; - await this._db.transaction('rw!', this._orders, async () => { + const addedMap = new Map(); + const removed: Order[] = []; + await this._db.transaction('rw', this._orders, async () => { for (const order of orders) { try { await this._orders.add(order); @@ -182,12 +182,29 @@ export class Database { } throw e; } - added.push(order); + addedMap.set(order.hash, order); + } + + // Remove orders with an expiration time too far in the future. + const ordersToRemove = await this._orders + .orderBy('[isNotPinned+expirationTimeSeconds]') + .offset(this._maxOrders) + .toArray(); + for (const order of ordersToRemove) { + await this._orders.delete(order.hash); + if (addedMap.has(order.hash)) { + // If the order was previously added, remove it from + // the added set and don't add it to the removed set. + addedMap.delete(order.hash); + } else { + removed.push(order); + } } }); + return { - added, - removed: [], + added: Array.from(addedMap.values()), + removed, }; } @@ -257,7 +274,7 @@ export class Database { // AddMiniHeaders(miniHeaders []*types.MiniHeader) (added []*types.MiniHeader, removed []*types.MiniHeader, err error) public async addMiniHeadersAsync(miniHeaders: MiniHeader[]): Promise { - const added: MiniHeader[] = []; + const addedMap = new Map(); const removed: MiniHeader[] = []; await this._db.transaction('rw!', this._miniHeaders, async () => { for (const miniHeader of miniHeaders) { @@ -271,20 +288,29 @@ export class Database { } throw e; } - added.push(miniHeader); - const outdatedMiniHeaders = await this._miniHeaders - .orderBy('number') - .offset(this._maxMiniHeaders) - .reverse() - .toArray(); - for (const outdated of outdatedMiniHeaders) { - await this._miniHeaders.delete(outdated.hash); + addedMap.set(miniHeader.hash, miniHeader); + } + + // Remove any outdated miniHeaders. + const outdatedMiniHeaders = await this._miniHeaders + .orderBy('number') + .offset(this._maxMiniHeaders) + .reverse() + .toArray(); + for (const outdated of outdatedMiniHeaders) { + await this._miniHeaders.delete(outdated.hash); + if (addedMap.has(outdated.hash)) { + // If the order was previously added, remove it from + // the added set and don't add it to the removed set. + addedMap.delete(outdated.hash); + } else { removed.push(outdated); } } }); + return { - added, + added: Array.from(addedMap.values()), removed, }; } diff --git a/zeroex/orderwatch/order_watcher.go b/zeroex/orderwatch/order_watcher.go index 3d3c932f5..5057caa25 100644 --- a/zeroex/orderwatch/order_watcher.go +++ b/zeroex/orderwatch/order_watcher.go @@ -17,7 +17,6 @@ import ( "github.com/0xProject/0x-mesh/zeroex" "github.com/0xProject/0x-mesh/zeroex/ordervalidator" "github.com/0xProject/0x-mesh/zeroex/orderwatch/decoder" - "github.com/0xProject/0x-mesh/zeroex/orderwatch/slowcounter" "github.com/ethereum/go-ethereum/common" ethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" @@ -43,35 +42,12 @@ const ( // corresponds to a block depth of ~25. permanentlyDeleteAfter = 5 * time.Minute - // expirationPollingInterval specifies the interval in which the order watcher should check for expired - // orders - expirationPollingInterval = 50 * time.Millisecond - - // maxOrdersTrimRatio affects how many orders are trimmed whenever we reach the - // maximum number of orders. When order storage is full, Watcher will remove - // orders until the total number of remaining orders is equal to - // maxOrdersTrimRatio * maxOrders. - maxOrdersTrimRatio = 0.9 - // defaultMaxOrders is the default max number of orders in storage. defaultMaxOrders = 100000 - // maxExpirationTimeCheckInterval is how often to check whether we can - // increase the max expiration time. - maxExpirationTimeCheckInterval = 30 * time.Second - // maxBlockEventsToHandle is the max number of block events we want to process in a single // call to `handleBlockEvents` maxBlockEventsToHandle = 500 - - // configuration options for the SlowCounter used for increasing max - // expiration time. Effectively, we will increase every 5 minutes as long as - // there is enough space in the database for orders. The first increase will - // be 5 seconds and the amount doubles from there (second increase will be 10 - // seconds, then 20 seconds, then 40, etc.) - slowCounterOffset = 5 // seconds - slowCounterRate = 2.0 - slowCounterInterval = 5 * time.Minute ) var errNoBlocksStored = errors.New("no blocks were stored in the database") @@ -91,8 +67,6 @@ type Watcher struct { orderValidator *ordervalidator.OrderValidator wasStartedOnce bool mu sync.Mutex - maxExpirationTime *big.Int - maxExpirationCounter *slowcounter.SlowCounter maxOrders int handleBlockEventsMu sync.RWMutex // atLeastOneBlockProcessed is closed to signal that the BlockWatcher has processed at least one @@ -109,7 +83,6 @@ type Config struct { ChainID int ContractAddresses ethereum.ContractAddresses MaxOrders int - MaxExpirationTime *big.Int } // New instantiates a new order watcher @@ -124,24 +97,6 @@ func New(config Config) (*Watcher, error) { if config.MaxOrders == 0 { return nil, errors.New("config.MaxOrders is required and cannot be zero") } - if config.MaxExpirationTime == nil { - return nil, errors.New("config.MaxExpirationTime is required and cannot be nil") - } else if big.NewInt(time.Now().Unix()).Cmp(config.MaxExpirationTime) == 1 { - // MaxExpirationTime should never be in the past. - config.MaxExpirationTime = big.NewInt(time.Now().Unix()) - } - - // Configure a SlowCounter to be used for increasing max expiration time. - slowCounterConfig := slowcounter.Config{ - Offset: big.NewInt(slowCounterOffset), - Rate: slowCounterRate, - Interval: slowCounterInterval, - MaxCount: constants.UnlimitedExpirationTime, - } - maxExpirationCounter, err := slowcounter.New(slowCounterConfig, config.MaxExpirationTime) - if err != nil { - return nil, err - } w := &Watcher{ db: config.DB, @@ -151,8 +106,6 @@ func New(config Config) (*Watcher, error) { eventDecoder: decoder, assetDataDecoder: assetDataDecoder, contractAddresses: config.ContractAddresses, - maxExpirationTime: big.NewInt(0).Set(config.MaxExpirationTime), - maxExpirationCounter: maxExpirationCounter, maxOrders: config.MaxOrders, blockEventsChan: make(chan []*blockwatch.Event, 100), atLeastOneBlockProcessed: make(chan struct{}), @@ -194,8 +147,7 @@ func (w *Watcher) Watch(ctx context.Context) error { // A waitgroup lets us wait for all goroutines to exit. wg := &sync.WaitGroup{} - // Start four independent goroutines. The main loop, cleanup loop, removed orders - // checker and max expirationTime checker. Use four separate channels to communicate errors. + // Start some independent goroutines, each with a separate channel for communicating errors. mainLoopErrChan := make(chan error, 1) wg.Add(1) go func() { @@ -208,12 +160,6 @@ func (w *Watcher) Watch(ctx context.Context) error { defer wg.Done() cleanupLoopErrChan <- w.cleanupLoop(innerCtx) }() - maxExpirationTimeLoopErrChan := make(chan error, 1) - wg.Add(1) - go func() { - defer wg.Done() - maxExpirationTimeLoopErrChan <- w.maxExpirationTimeLoop(innerCtx) - }() removedCheckerLoopErrChan := make(chan error, 1) wg.Add(1) go func() { @@ -237,12 +183,6 @@ func (w *Watcher) Watch(ctx context.Context) error { cancel() return err } - case err := <-maxExpirationTimeLoopErrChan: - if err != nil { - logger.WithError(err).Error("error in orderwatcher maxExpirationTimeLoop") - cancel() - return err - } case err := <-removedCheckerLoopErrChan: if err != nil { logger.WithError(err).Error("error in orderwatcher removedCheckerLoop") @@ -323,21 +263,6 @@ func (w *Watcher) cleanupLoop(ctx context.Context) error { } } -func (w *Watcher) maxExpirationTimeLoop(ctx context.Context) error { - ticker := time.NewTicker(maxExpirationTimeCheckInterval) - for { - select { - case <-ctx.Done(): - ticker.Stop() - return nil - case <-ticker.C: - if err := w.increaseMaxExpirationTimeIfPossible(); err != nil { - return err - } - } - } -} - func (w *Watcher) removedCheckerLoop(ctx context.Context) error { for { start := time.Now() @@ -895,10 +820,22 @@ func (w *Watcher) add(orderInfos []*ordervalidator.AcceptedOrderInfo, validation return nil, err } dbOrders = append(dbOrders, dbOrder) + + // We create an ADDED event for all orders in orderInfos. + // Some orders might not actually be added, as a workaround we + // will also emit a STOPPED_WATCHING event in some cases (see + // below) + addedEvent := &zeroex.OrderEvent{ + Timestamp: now, + OrderHash: orderInfo.OrderHash, + SignedOrder: orderInfo.SignedOrder, + FillableTakerAssetAmount: orderInfo.FillableTakerAssetAmount, + EndState: zeroex.ESOrderAdded, + } + orderEvents = append(orderEvents, addedEvent) } - // TODO(albrow): Should AddOrders return the new max expiration time? - // Or is there a better way to do this? + addedMap := map[common.Hash]*types.OrderWithMetadata{} addedOrders, removedOrders, err := w.db.AddOrders(dbOrders) if err != nil { return nil, err @@ -908,14 +845,7 @@ func (w *Watcher) add(orderInfos []*ordervalidator.AcceptedOrderInfo, validation if err != nil { return orderEvents, err } - addedEvent := &zeroex.OrderEvent{ - Timestamp: now, - OrderHash: order.Hash, - SignedOrder: order.SignedOrder(), - FillableTakerAssetAmount: order.FillableTakerAssetAmount, - EndState: zeroex.ESOrderAdded, - } - orderEvents = append(orderEvents, addedEvent) + addedMap[order.Hash] = order } for _, order := range removedOrders { stoppedWatchingEvent := &zeroex.OrderEvent{ @@ -940,8 +870,48 @@ func (w *Watcher) add(orderInfos []*ordervalidator.AcceptedOrderInfo, validation } } - // TODO(albrow): How to handle the edge case of orders that were not - // added due to the max expiration time changing? + // HACK(albrow): We need to handle orders in the orderInfos argument that + // were never added due to the max expiration time effectively changing + // within the database transaction above. In other words, new orders that + // _were_ added can change the effective max expiration time, meaning some + // orders in orderInfos were actually not added. This should not happen + // often. For now, we respond by emitting an ADDED event (above) immediately + // followed by a STOPPED_WATCHING event. If this order was submitted via + // RPC, the RPC client will see a response that indicates the order was + // successfully added, and then it will look like we immediately stopped + // watching it. This is not too far off from what really happened but is + // slightly inefficient. + // + // TODO(albrow): In the future, we should add an additional return value and + // then react to that differently depending on whether the order was + // received via RPC or from a peer. In the former case, we should return an + // RPC error response indicating that the order was not in fact added. In + // the latter case, we should not emit any order events but might potentially + // want to adjust the peer's score. + for _, orderToAdd := range orderInfos { + _, wasAdded := addedMap[orderToAdd.OrderHash] + if !wasAdded { + stoppedWatchingEvent := &zeroex.OrderEvent{ + Timestamp: now, + OrderHash: orderToAdd.OrderHash, + SignedOrder: orderToAdd.SignedOrder, + FillableTakerAssetAmount: orderToAdd.FillableTakerAssetAmount, + EndState: zeroex.ESStoppedWatching, + } + orderEvents = append(orderEvents, stoppedWatchingEvent) + } + } + + if len(removedOrders) > 0 { + newMaxExpirationTime, err := w.db.GetCurrentMaxExpirationTime() + if err != nil { + return nil, err + } + logger.WithFields(logger.Fields{ + "ordersRemoved": len(removedOrders), + "newMaxExpirationTime": newMaxExpirationTime.String(), + }).Debug("removed orders due to exceeding max expiration time") + } return orderEvents, nil } @@ -983,12 +953,6 @@ func (w *Watcher) orderInfoToOrderWithMetadata(orderInfo *ordervalidator.Accepte }, nil } -// MaxExpirationTime returns the current maximum expiration time for incoming -// orders. -func (w *Watcher) MaxExpirationTime() *big.Int { - return w.maxExpirationTime -} - // TODO(albrow): All in-memory state can be removed. func (w *Watcher) setupInMemoryOrderState(order *types.OrderWithMetadata) error { w.eventDecoder.AddKnownExchange(order.ExchangeAddress) @@ -1341,7 +1305,7 @@ func (w *Watcher) generateOrderEventsIfChanged( // ValidateAndStoreValidOrders applies general 0x validation and Mesh-specific validation to // the given orders and if they are valid, adds them to the OrderWatcher func (w *Watcher) ValidateAndStoreValidOrders(ctx context.Context, orders []*zeroex.SignedOrder, pinned bool, chainID int) (*ordervalidator.ValidationResults, error) { - results, validMeshOrders, err := w.meshSpecificOrderValidation(orders, chainID) + results, validMeshOrders, err := w.meshSpecificOrderValidation(orders, chainID, pinned) if err != nil { return nil, err } @@ -1414,9 +1378,36 @@ func (w *Watcher) onchainOrderValidation(ctx context.Context, orders []*zeroex.S return latestMiniHeader, zeroexResults, nil } -func (w *Watcher) meshSpecificOrderValidation(orders []*zeroex.SignedOrder, chainID int) (*ordervalidator.ValidationResults, []*zeroex.SignedOrder, error) { +func (w *Watcher) meshSpecificOrderValidation(orders []*zeroex.SignedOrder, chainID int, pinned bool) (*ordervalidator.ValidationResults, []*zeroex.SignedOrder, error) { results := &ordervalidator.ValidationResults{} validMeshOrders := []*zeroex.SignedOrder{} + + // Calculate max expiration time based on number of orders stored. + // This value is *exclusive*. Any incoming orders with an expiration time + // greater or equal to this will be rejected. + // + // Note(albrow): Technically speaking this is sub-optimal. We are assuming + // that we need to have space in the database for the entire slice of orders, + // but some of them could be invalid and therefore not actually get stored. + // However, the optimal implementation would be less efficient and could + // result in sending more ETH RPC requests than necessary. The edge case + // where potentially valid orders are rejected should be rare in practice, and + // would affect at most len(orders)/2 orders. + maxExpirationTime := constants.UnlimitedExpirationTime + if !pinned { + orderCount, err := w.db.CountOrders(nil) + if err != nil { + return nil, nil, err + } + if orderCount+len(orders) > w.maxOrders { + storedMaxExpirationTime, err := w.db.GetCurrentMaxExpirationTime() + if err != nil { + return nil, nil, err + } + maxExpirationTime = storedMaxExpirationTime + } + } + for _, order := range orders { orderHash, err := order.ComputeOrderHash() if err != nil { @@ -1429,7 +1420,7 @@ func (w *Watcher) meshSpecificOrderValidation(orders []*zeroex.SignedOrder, chai }) continue } - if order.ExpirationTimeSeconds.Cmp(w.MaxExpirationTime()) == 1 { + if !pinned && order.ExpirationTimeSeconds.Cmp(maxExpirationTime) != -1 { results.Rejected = append(results.Rejected, &ordervalidator.RejectedOrderInfo{ OrderHash: orderHash, SignedOrder: order, @@ -1765,36 +1756,6 @@ func (w *Watcher) removeAssetDataAddressFromEventDecoder(assetData []byte) error return nil } -func (w *Watcher) increaseMaxExpirationTimeIfPossible() error { - if orderCount, err := w.db.CountOrders(nil); err != nil { - return err - } else if orderCount < w.maxOrders { - // We have enough space for new orders. Set the new max expiration time to the - // value of slow counter. - newMaxExpiration := w.maxExpirationCounter.Count() - if w.maxExpirationTime.Cmp(newMaxExpiration) != 0 { - logger.WithFields(logger.Fields{ - "oldMaxExpirationTime": w.maxExpirationTime.String(), - "newMaxExpirationTime": fmt.Sprint(newMaxExpiration), - }).Debug("increasing max expiration time") - w.maxExpirationTime.Set(newMaxExpiration) - w.saveMaxExpirationTime(newMaxExpiration) - } - } - - return nil -} - -// saveMaxExpirationTime saves the new max expiration time in the database. -func (w *Watcher) saveMaxExpirationTime(maxExpirationTime *big.Int) { - if err := w.db.UpdateMetadata(func(metadata *types.Metadata) *types.Metadata { - metadata.MaxExpirationTime = maxExpirationTime - return metadata - }); err != nil { - logger.WithError(err).Error("could not update max expiration time in database") - } -} - func (w *Watcher) getBlockchainState(events []*blockwatch.Event) (*big.Int, time.Time) { var latestBlockNumber *big.Int var latestBlockTimestamp time.Time diff --git a/zeroex/orderwatch/order_watcher_test.go b/zeroex/orderwatch/order_watcher_test.go index 2bcb4e221..12ae9ec6c 100644 --- a/zeroex/orderwatch/order_watcher_test.go +++ b/zeroex/orderwatch/order_watcher_test.go @@ -825,7 +825,7 @@ func TestOrderWatcherOrderExpiredThenUnexpired(t *testing.T) { orderopts.ExpirationTimeSeconds(expirationTimeSeconds), ) blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) - watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrder) + watchOrder(ctx, t, orderWatcher, blockwatcher, signedOrder, false) orderEventsChan := make(chan []*zeroex.OrderEvent, 2*orderWatcher.maxOrders) orderWatcher.Subscribe(orderEventsChan) @@ -904,9 +904,7 @@ func TestOrderWatcherOrderExpiredThenUnexpired(t *testing.T) { assert.Equal(t, signedOrder.TakerAssetAmount, newOrders[0].FillableTakerAssetAmount) } -// TODO(albrow): Re-enable this test or move it. func TestOrderWatcherDecreaseExpirationTime(t *testing.T) { - t.Skip("Decreasing expiratin time is not yet implemented") if !serialTestsEnabled { t.Skip("Serial tests (tests which cannot run in parallel) are disabled. You can enable them with the --serial flag") } @@ -916,19 +914,14 @@ func TestOrderWatcherDecreaseExpirationTime(t *testing.T) { defer teardownSubTest(t) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() - database, err := db.New(ctx, db.TestOptions()) - require.NoError(t, err) - - // Store metadata entry in DB - metadata := &types.Metadata{ - EthereumChainID: 1337, - MaxExpirationTime: constants.UnlimitedExpirationTime, - } - err = database.SaveMetadata(metadata) + maxOrders := 10 + dbOpts := db.TestOptions() + dbOpts.MaxOrders = maxOrders + database, err := db.New(ctx, dbOpts) require.NoError(t, err) blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) - orderWatcher.maxOrders = 20 + orderWatcher.maxOrders = maxOrders // Create and watch maxOrders orders. Each order has a different expiration time. optionsForIndex := func(index int) []orderopts.Option { @@ -939,14 +932,14 @@ func TestOrderWatcherDecreaseExpirationTime(t *testing.T) { orderopts.ExpirationTimeSeconds(expirationTimeSeconds), } } - signedOrders := scenario.NewSignedTestOrdersBatch(t, orderWatcher.maxOrders, optionsForIndex) + signedOrders := scenario.NewSignedTestOrdersBatch(t, maxOrders, optionsForIndex) for _, signedOrder := range signedOrders { - watchOrder(ctx, t, orderWatcher, blockWatcher, ethClient, signedOrder) + watchOrder(ctx, t, orderWatcher, blockWatcher, signedOrder, false) } // We don't care about the order events above for the purposes of this test, // so we only subscribe now. - orderEventsChan := make(chan []*zeroex.OrderEvent, 2*orderWatcher.maxOrders) + orderEventsChan := make(chan []*zeroex.OrderEvent, 2*maxOrders) orderWatcher.Subscribe(orderEventsChan) // The next order should cause some orders to be removed and the appropriate @@ -957,32 +950,76 @@ func TestOrderWatcherDecreaseExpirationTime(t *testing.T) { orderopts.SetupMakerState(true), orderopts.ExpirationTimeSeconds(expirationTimeSeconds), ) - watchOrder(ctx, t, orderWatcher, blockWatcher, ethClient, signedOrder) - expectedOrderEvents := int(float64(orderWatcher.maxOrders)*(1-maxOrdersTrimRatio)) + 1 + watchOrder(ctx, t, orderWatcher, blockWatcher, signedOrder, false) + expectedOrderEvents := 2 orderEvents := waitForOrderEvents(t, orderEventsChan, expectedOrderEvents, 4*time.Second) require.Len(t, orderEvents, expectedOrderEvents, "wrong number of order events were fired") - for i, orderEvent := range orderEvents { - // Last event should be ADDED. The other events should be STOPPED_WATCHING. - if i == expectedOrderEvents-1 { - assert.Equal(t, zeroex.ESOrderAdded, orderEvent.EndState, "order event %d had wrong EndState", i) - } else { - // For STOPPED_WATCHING events, we also make sure that the expiration time is after - // the current max expiration time. - assert.Equal(t, zeroex.ESStoppedWatching, orderEvent.EndState, "order event %d had wrong EndState", i) + + storedMaxExpirationTime, err := database.GetCurrentMaxExpirationTime() + require.NoError(t, err) + + // One event should be STOPPED_WATCHING. The other event should be ADDED. + // The order in which the events are emitted is not guaranteed. + numAdded := 0 + numStoppedWatching := 0 + for _, orderEvent := range orderEvents { + switch orderEvent.EndState { + case zeroex.ESOrderAdded: + numAdded += 1 + orderExpirationTime := orderEvent.SignedOrder.ExpirationTimeSeconds + assert.True(t, orderExpirationTime.Cmp(storedMaxExpirationTime) == -1, "ADDED order has an expiration time of %s which is *greater than* the maximum of %s", orderExpirationTime, storedMaxExpirationTime) + case zeroex.ESStoppedWatching: + numStoppedWatching += 1 orderExpirationTime := orderEvent.SignedOrder.ExpirationTimeSeconds - assert.True(t, orderExpirationTime.Cmp(orderWatcher.MaxExpirationTime()) != -1, "remaining order has an expiration time of %s which is *less than* the maximum of %s", orderExpirationTime, orderWatcher.MaxExpirationTime()) + assert.True(t, orderExpirationTime.Cmp(storedMaxExpirationTime) != -1, "STOPPED_WATCHING order has an expiration time of %s which is *less than* the maximum of %s", orderExpirationTime, storedMaxExpirationTime) + default: + t.Errorf("unexpected order event type: %s", orderEvent.EndState) } } + assert.Equal(t, 1, numAdded, "wrong number of ADDED events") + assert.Equal(t, 1, numStoppedWatching, "wrong number of STOPPED_WATCHING events") // Now we check that the correct number of orders remain and that all // remaining orders have an expiration time less than the current max. - expectedRemainingOrders := int(float64(orderWatcher.maxOrders)*maxOrdersTrimRatio) + 1 + expectedRemainingOrders := orderWatcher.maxOrders remainingOrders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, remainingOrders, expectedRemainingOrders) for _, order := range remainingOrders { - assert.True(t, order.ExpirationTimeSeconds.Cmp(orderWatcher.MaxExpirationTime()) == -1, "remaining order has an expiration time of %s which is *greater than* the maximum of %s", order.ExpirationTimeSeconds, orderWatcher.MaxExpirationTime()) + assert.True(t, order.ExpirationTimeSeconds.Cmp(storedMaxExpirationTime) != 1, "remaining order has an expiration time of %s which is *greater than* the maximum of %s", order.ExpirationTimeSeconds, storedMaxExpirationTime) + } + + // Confirm that a pinned order will be accepted even if its expiration + // is greater than the current max. + pinnedOrder := scenario.NewSignedTestOrder(t, + orderopts.SetupMakerState(true), + orderopts.ExpirationTimeSeconds(big.NewInt(0).Add(storedMaxExpirationTime, big.NewInt(10))), + ) + pinnedOrderHash, err := pinnedOrder.ComputeOrderHash() + require.NoError(t, err) + watchOrder(ctx, t, orderWatcher, blockWatcher, pinnedOrder, true) + + expectedOrderEvents = 2 + orderEvents = waitForOrderEvents(t, orderEventsChan, expectedOrderEvents, 4*time.Second) + require.Len(t, orderEvents, expectedOrderEvents, "wrong number of order events were fired") + + // One event should be STOPPED_WATCHING. The other event should be ADDED. + // The order in which the events are emitted is not guaranteed. + numAdded = 0 + numStoppedWatching = 0 + for _, orderEvent := range orderEvents { + switch orderEvent.EndState { + case zeroex.ESOrderAdded: + numAdded += 1 + assert.Equal(t, pinnedOrderHash.Hex(), orderEvent.OrderHash.Hex(), "ADDED event had wrong order hash") + case zeroex.ESStoppedWatching: + numStoppedWatching += 1 + default: + t.Errorf("unexpected order event type: %s", orderEvent.EndState) + } } + assert.Equal(t, 1, numAdded, "wrong number of ADDED events") + assert.Equal(t, 1, numStoppedWatching, "wrong number of STOPPED_WATCHING events") } func TestOrderWatcherBatchEmitsAddedEvents(t *testing.T) { @@ -1052,9 +1089,9 @@ func TestOrderWatcherCleanup(t *testing.T) { orderOptions := scenario.OptionsForAll(orderopts.SetupMakerState(true)) signedOrders := scenario.NewSignedTestOrdersBatch(t, 2, orderOptions) signedOrderOne := signedOrders[0] - watchOrder(ctx, t, orderWatcher, blockWatcher, ethClient, signedOrderOne) + watchOrder(ctx, t, orderWatcher, blockWatcher, signedOrderOne, false) signedOrderTwo := signedOrders[1] - watchOrder(ctx, t, orderWatcher, blockWatcher, ethClient, signedOrderTwo) + watchOrder(ctx, t, orderWatcher, blockWatcher, signedOrderTwo, false) signedOrderOneHash, err := signedOrderTwo.ComputeOrderHash() require.NoError(t, err) @@ -1107,8 +1144,8 @@ func TestOrderWatcherHandleOrderExpirationsExpired(t *testing.T) { signedOrderOne := signedOrders[0] signedOrderTwo := signedOrders[1] blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) - watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrderOne) - watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrderTwo) + watchOrder(ctx, t, orderWatcher, blockwatcher, signedOrderOne, false) + watchOrder(ctx, t, orderWatcher, blockwatcher, signedOrderTwo, false) signedOrderOneHash, err := signedOrderOne.ComputeOrderHash() require.NoError(t, err) @@ -1163,8 +1200,8 @@ func TestOrderWatcherHandleOrderExpirationsUnexpired(t *testing.T) { signedOrderOne := signedOrders[0] signedOrderTwo := signedOrders[1] blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) - watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrderOne) - watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrderTwo) + watchOrder(ctx, t, orderWatcher, blockwatcher, signedOrderOne, false) + watchOrder(ctx, t, orderWatcher, blockwatcher, signedOrderTwo, false) orderEventsChan := make(chan []*zeroex.OrderEvent, 2*orderWatcher.maxOrders) orderWatcher.Subscribe(orderEventsChan) @@ -1248,7 +1285,7 @@ func TestConvertValidationResultsIntoOrderEventsUnexpired(t *testing.T) { orderopts.ExpirationTimeSeconds(expirationTimeSeconds), ) blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) - watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrder) + watchOrder(ctx, t, orderWatcher, blockwatcher, signedOrder, false) orderEventsChan := make(chan []*zeroex.OrderEvent, 2*orderWatcher.maxOrders) orderWatcher.Subscribe(orderEventsChan) @@ -1372,7 +1409,7 @@ func setupOrderWatcherScenario(ctx context.Context, t *testing.T, ethClient *eth blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) // Start watching an order - watchOrder(ctx, t, orderWatcher, blockWatcher, ethClient, signedOrder) + watchOrder(ctx, t, orderWatcher, blockWatcher, signedOrder, false) // Subscribe to OrderWatcher orderEventsChan := make(chan []*zeroex.OrderEvent, 10) @@ -1381,11 +1418,11 @@ func setupOrderWatcherScenario(ctx context.Context, t *testing.T, ethClient *eth return blockWatcher, orderEventsChan } -func watchOrder(ctx context.Context, t *testing.T, orderWatcher *Watcher, blockWatcher *blockwatch.Watcher, ethClient *ethclient.Client, signedOrder *zeroex.SignedOrder) { +func watchOrder(ctx context.Context, t *testing.T, orderWatcher *Watcher, blockWatcher *blockwatch.Watcher, signedOrder *zeroex.SignedOrder, pinned bool) { err := blockWatcher.SyncToLatestBlock() require.NoError(t, err) - validationResults, err := orderWatcher.ValidateAndStoreValidOrders(ctx, []*zeroex.SignedOrder{signedOrder}, false, constants.TestChainID) + validationResults, err := orderWatcher.ValidateAndStoreValidOrders(ctx, []*zeroex.SignedOrder{signedOrder}, pinned, constants.TestChainID) require.NoError(t, err) if len(validationResults.Rejected) != 0 { spew.Dump(validationResults.Rejected) @@ -1413,7 +1450,6 @@ func setupOrderWatcher(ctx context.Context, t *testing.T, ethRPCClient ethrpccli OrderValidator: orderValidator, ChainID: constants.TestChainID, ContractAddresses: ganacheAddresses, - MaxExpirationTime: constants.UnlimitedExpirationTime, MaxOrders: 1000, }) require.NoError(t, err) diff --git a/zeroex/orderwatch/slowcounter/slow_counter.go b/zeroex/orderwatch/slowcounter/slow_counter.go deleted file mode 100644 index 8d86436a5..000000000 --- a/zeroex/orderwatch/slowcounter/slow_counter.go +++ /dev/null @@ -1,114 +0,0 @@ -package slowcounter - -import ( - "errors" - "math/big" - "sync" - "time" -) - -// SlowCounter is an exponentially increasing counter that is slowly incremented -// after a certain time interval, unless reset. It has a few configuration -// options to control the rate of increase. SlowCounter uses the following -// exponential growth formula: -// -// currentCount = startingCount + offset * (rate ^ n) -// -// where n is the number of increments that have occurred. And the number of -// increments is calculated as: -// -// n = math.Floor(time.Since(startTime) / interval) -// -type SlowCounter struct { - mut sync.Mutex - config Config - startingCount *big.Int - // startingTime is the time the counter was started or reset. - startingTime time.Time - // isMax is a boolean cache which is used to prevent any computation from - // occurring if the counter has already hit MaxCount. - isMax bool -} - -// Config is a set of configuration options for SlowCounter. -type Config struct { - // Offset affects how much the count is increased on the first - // increment. - Offset *big.Int - // Rate controls how fast the offset increases after each increment. - Rate float64 - // Interval is the amount of time to wait before each time the counter is - // incremented. - Interval time.Duration - // MaxCount is the maximum value for the counter. After reaching MaxCount, the - // counter will stop incrementing until reset. - MaxCount *big.Int - - // maxCountFloat is MaxCount converted to a big.Float in order to make the - // math easier. - maxCountFloat *big.Float -} - -// New returns a new SlowCounter with the given starting count. -func New(config Config, startingCount *big.Int) (*SlowCounter, error) { - if config.MaxCount == nil { - return nil, errors.New("config.MaxCount cannot be nil") - } else if config.Interval == 0 { - return nil, errors.New("config.Interval cannot be 0") - } - config.maxCountFloat = big.NewFloat(1).SetInt(config.MaxCount) - return &SlowCounter{ - config: config, - startingCount: big.NewInt(0).Set(startingCount), - startingTime: time.Now(), - }, nil -} - -// Count returns the current count. -func (sc *SlowCounter) Count() *big.Int { - sc.mut.Lock() - defer sc.mut.Unlock() - - if sc.isMax { - currentCount := big.NewInt(0).Set(sc.config.MaxCount) - return currentCount - } - - // TODO(albrow): Could be further optimized to reduce memory allocations and - // math/big operations. - // - // currentCount = startingCount + offset * (rate ^ numIncrements) - // - numIncrements := time.Since(sc.startingTime) / sc.config.Interval - if numIncrements == 0 { - currentCount := big.NewInt(0).Set(sc.startingCount) - return currentCount - } - currentCount := big.NewFloat(0).SetInt(sc.startingCount) - offset := big.NewFloat(0).SetInt(sc.config.Offset) - rate := big.NewFloat(sc.config.Rate) - for i := 0; i < int(numIncrements)-1; i++ { - offset.Mul(offset, rate) - } - currentCount.Add(currentCount, offset) - currentCountInt := big.NewInt(0) - currentCount.Int(currentCountInt) - - // If current count exceeds max, return max. - if currentCountInt.Cmp(sc.config.MaxCount) == 1 { - currentCountInt.Set(sc.config.MaxCount) - sc.isMax = true - } - - return currentCountInt -} - -// Reset resets the counter to the given count. -func (sc *SlowCounter) Reset(count *big.Int) { - sc.mut.Lock() - defer sc.mut.Unlock() - - sc.isMax = false - sc.startingCount.Set(count) - sc.startingTime = time.Now() -} diff --git a/zeroex/orderwatch/slowcounter/slow_counter_test.go b/zeroex/orderwatch/slowcounter/slow_counter_test.go deleted file mode 100644 index ac8ad4a00..000000000 --- a/zeroex/orderwatch/slowcounter/slow_counter_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package slowcounter - -import ( - "math/big" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSlowCounter(t *testing.T) { - t.Parallel() - - config := Config{ - Offset: big.NewInt(10), - Rate: 2, - Interval: 250 * time.Millisecond, - MaxCount: big.NewInt(1000), - } - counter, err := New(config, big.NewInt(0)) - require.NoError(t, err) - - { - expectedCount := big.NewInt(0) - actualCount := counter.Count() - assert.Equal(t, expectedCount, actualCount, "wrong count before any increments") - } - - time.Sleep(config.Interval) - - { - expectedCount := big.NewInt(10) - actualCount := counter.Count() - assert.Equal(t, expectedCount, actualCount, "wrong count after 1 increment") - } - - time.Sleep(config.Interval) - - { - expectedCount := big.NewInt(20) - actualCount := counter.Count() - assert.Equal(t, expectedCount, actualCount, "wrong count after 2 increments") - } -} - -func TestSlowCounterReset(t *testing.T) { - t.Parallel() - - config := Config{ - Offset: big.NewInt(10), - Rate: 2, - Interval: 250 * time.Millisecond, - MaxCount: big.NewInt(1000), - } - counter, err := New(config, big.NewInt(20)) - require.NoError(t, err) - - time.Sleep(config.Interval) - - // Reset the counter and check that the count was correctly reset. - counter.Reset(big.NewInt(30)) - { - expectedCount := big.NewInt(30) - actualCount := counter.Count() - assert.Equal(t, expectedCount, actualCount, "wrong count after counter was reset") - } - - time.Sleep(config.Interval) - - // Check the counter was incremented once from the new value after reset. - { - expectedCount := big.NewInt(40) - actualCount := counter.Count() - assert.Equal(t, expectedCount, actualCount, "wrong count after counter was reset and then incremented") - } -} - -func TestSlowCounterMaxCount(t *testing.T) { - t.Parallel() - - config := Config{ - Offset: big.NewInt(10), - Rate: 2, - // Note(albrow): For this test, we're okay with a much faster interval since - // we don't need to be precise. We only need to check that *at least* N - // increments have occurred. It is okay if more than N have occurred. - Interval: 1 * time.Millisecond, - MaxCount: big.NewInt(100), - } - - counter, err := New(config, big.NewInt(0)) - require.NoError(t, err) - - for i := 0; i < 10; i++ { - time.Sleep(config.Interval) - actualCount := counter.Count() - assert.False(t, actualCount.Cmp(config.MaxCount) == 1, "count should never exceed max count") - } -}