From d822ac3a54a82f1b356327f69fb02c4320dbe152 Mon Sep 17 00:00:00 2001 From: wwestgarth Date: Tue, 23 Jul 2024 17:16:34 +0100 Subject: [PATCH] fix: add nil check on AMM range partition for case where aggressive order is a market order --- core/execution/amm/engine.go | 6 ++++-- core/execution/amm/engine_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/core/execution/amm/engine.go b/core/execution/amm/engine.go index a97008aa14..5dbc803828 100644 --- a/core/execution/amm/engine.go +++ b/core/execution/amm/engine.go @@ -474,7 +474,9 @@ func (e *Engine) submit(active []*Pool, agg *types.Order, inner, outer *num.Uint } // partition takes the given price range and returns which pools have volume in that region, and -// divides that range into sub-levels where AMM boundaries end. +// divides that range into sub-levels where AMM boundaries end. Note that `outer` can be nil for the case +// where the incoming order is a market order (so we have no bound on the price), and we've already consumed +// all volume on the orderbook. func (e *Engine) partition(agg *types.Order, inner, outer *num.Uint) ([]*Pool, []*num.Uint) { active := []*Pool{} bounds := map[string]*num.Uint{} @@ -503,7 +505,7 @@ func (e *Engine) partition(agg *types.Order, inner, outer *num.Uint) ([]*Pool, [ // This is because to get the BUY volume an AMM has at price P, we need to calculate the difference // in its position between prices P -> P + 1. For SELL volume its the other way around and we // need the difference in position from P - 1 -> P. - if inner.EQ(outer) { + if inner != nil && outer != nil && inner.EQ(outer) { if agg.Side == types.SideSell { outer = num.UintZero().Add(outer, e.oneTick) } else { diff --git a/core/execution/amm/engine_test.go b/core/execution/amm/engine_test.go index 6f616d8569..77c0dd7b42 100644 --- a/core/execution/amm/engine_test.go +++ b/core/execution/amm/engine_test.go @@ -51,6 +51,7 @@ func TestAMMTrading(t *testing.T) { t.Run("test basic submit order", testBasicSubmitOrder) t.Run("test submit order at best price", testSubmitOrderAtBestPrice) t.Run("test submit market order", testSubmitMarketOrder) + t.Run("test submit market order unbounded", testSubmitMarketOrderUnbounded) t.Run("test submit order pro rata", testSubmitOrderProRata) t.Run("test best prices and volume", testBestPricesAndVolume) @@ -269,6 +270,31 @@ func testSubmitMarketOrder(t *testing.T) { assert.Equal(t, 126420, int(orders[0].Size)) } +func testSubmitMarketOrderUnbounded(t *testing.T) { + tst := getTestEngine(t) + + party, subAccount := getParty(t, tst) + submit := getPoolSubmission(t, party, tst.marketID) + + expectSubaccountCreation(t, tst, party, subAccount) + whenAMMIsSubmitted(t, tst, submit) + + // now submit an order against it + agg := &types.Order{ + Size: 1000000, + Remaining: 1000000, + Side: types.SideSell, + Price: num.NewUint(0), + Type: types.OrderTypeMarket, + } + + ensurePosition(t, tst.pos, 0, num.NewUint(0)) + orders := tst.engine.SubmitOrder(agg, num.NewUint(1980), nil) + require.Len(t, orders, 1) + assert.Equal(t, "1960", orders[0].Price.String()) + assert.Equal(t, 1000000, int(orders[0].Size)) +} + func testSubmitOrderProRata(t *testing.T) { tst := getTestEngine(t)