Skip to content

Commit

Permalink
Merge pull request #11494 from vegaprotocol/amm-depth-perf
Browse files Browse the repository at this point in the history
feat: amm market depth perf - only calculate levels if they change an…
  • Loading branch information
jeremyletang authored Jul 25, 2024
2 parents c89ec09 + 5264e8f commit 9278fc8
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 11 deletions.
56 changes: 56 additions & 0 deletions datanode/service/market_depth.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,62 @@ type ammCache struct {
ammOrders map[string][]*types.Order // map amm id -> expanded orders, so we can remove them if amended
activeAMMs map[string]entities.AMMPool // map amm id -> amm definition, so we can refresh its expansion
estimatedOrder map[string]struct{} // order-id -> whether it was an estimated order

// the lowest/highest bounds of all AMMs
lowestBound num.Decimal
highestBound num.Decimal

// reference -> calculation levels, if the reference point hasn't changed we can avoid the busy task
// of recalculating them
levels map[string][]*level
}

func (c *ammCache) addAMM(a entities.AMMPool) {
c.activeAMMs[a.AmmPartyID.String()] = a

low := a.ParametersBase
if a.ParametersLowerBound != nil {
low = *a.ParametersLowerBound
}

if c.lowestBound.IsZero() {
c.lowestBound = low
} else {
c.lowestBound = num.MinD(c.lowestBound, low)
}

high := a.ParametersBase
if a.ParametersUpperBound != nil {
high = *a.ParametersUpperBound
}
c.highestBound = num.MaxD(c.highestBound, high)
}

func (c *ammCache) removeAMM(ammParty string) {
delete(c.activeAMMs, ammParty)
delete(c.ammOrders, ammParty)

// now we need to recalculate the lowest/highest

c.lowestBound = num.DecimalZero()
c.highestBound = num.DecimalZero()
for _, a := range c.activeAMMs {
low := a.ParametersBase
if a.ParametersLowerBound != nil {
low = *a.ParametersLowerBound
}
if c.lowestBound.IsZero() {
c.lowestBound = low
} else {
c.lowestBound = num.MinD(c.lowestBound, low)
}

high := a.ParametersBase
if a.ParametersUpperBound != nil {
high = *a.ParametersUpperBound
}
c.lowestBound = num.MaxD(c.highestBound, high)
}
}

type MarketDepth struct {
Expand Down
61 changes: 52 additions & 9 deletions datanode/service/market_depth_amm.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,14 @@ func (m *MarketDepth) getActiveAMMs(ctx context.Context) map[string][]entities.A
return ammByMarket
}

func (m *MarketDepth) getCalculationBounds(reference num.Decimal, priceFactor num.Decimal) []*level {
func (m *MarketDepth) getCalculationBounds(cache *ammCache, reference num.Decimal, priceFactor num.Decimal) []*level {
if levels, ok := cache.levels[reference.String()]; ok {
return levels
}

lowestBound := cache.lowestBound
highestBound := cache.highestBound

// first lets calculate the region we will expand accurately, this will be some percentage either side of the reference price
factor := num.DecimalFromFloat(m.cfg.AmmFullExpansionPercentage).Div(hundred)

Expand Down Expand Up @@ -138,6 +145,34 @@ func (m *MarketDepth) getCalculationBounds(reference num.Decimal, priceFactor nu
estLow := num.UintZero().Sub(accLow, num.Min(accLow, eRange))
estHigh := num.UintZero().Add(accHigh, eRange)

// cap steps to the lowest/highest boundaries of all AMMs
lowD, _ := num.UintFromDecimal(lowestBound)
if accLow.LTE(lowD) {
accLow = lowD.Clone()
estLow = lowD.Clone()
}

highD, _ := num.UintFromDecimal(highestBound)
if accHigh.GTE(highD) {
accHigh = highD.Clone()
estHigh = highD.Clone()
}

// need to find the first n such that
// accLow - (n * eStep) < lowD
// accLow
if estLow.LT(lowD) {
delta, _ := num.UintZero().Delta(accLow, lowD)
delta.Div(delta, eStep)
estLow = num.UintZero().Sub(accLow, delta.Mul(delta, eStep))
}

if estHigh.GT(highD) {
delta, _ := num.UintZero().Delta(accHigh, highD)
delta.Div(delta, eStep)
estHigh = num.UintZero().Add(accHigh, delta.Mul(delta, eStep))
}

levels := []*level{}

// we now have our four prices [estLow, accLow, accHigh, estHigh] where from
Expand All @@ -164,6 +199,10 @@ func (m *MarketDepth) getCalculationBounds(reference num.Decimal, priceFactor nu
price = num.UintZero().Add(price, eStep)
}

cache.levels = map[string][]*level{
reference.String(): levels,
}

return levels
}

Expand Down Expand Up @@ -268,7 +307,6 @@ func (m *MarketDepth) expandByLevels(pool entities.AMMPool, levels []*level, pri
v1 = ammDefn.position
}
}

// calculate the volume
volume := v1.Sub(v2).Abs().IntPart()

Expand Down Expand Up @@ -303,15 +341,15 @@ func (m *MarketDepth) InitialiseAMMs(ctx context.Context) {

// add it to our active list, we want to do this even if we fail to get a reference
for _, a := range amms {
cache.activeAMMs[a.AmmPartyID.String()] = a
cache.addAMM(a)
}

reference, err := m.getReference(ctx, marketID)
if err != nil {
continue
}

levels := m.getCalculationBounds(reference, priceFactor)
levels := m.getCalculationBounds(cache, reference, priceFactor)

for _, amm := range amms {
orders, estimated, err := m.expandByLevels(amm, levels, priceFactor)
Expand Down Expand Up @@ -345,7 +383,12 @@ func (m *MarketDepth) ExpandAMM(ctx context.Context, pool entities.AMMPool, pric
return nil, nil, err
}

levels := m.getCalculationBounds(reference, priceFactor)
cache, err := m.getAMMCache(string(pool.MarketID))
if err != nil {
return nil, nil, err
}

levels := m.getCalculationBounds(cache, reference, priceFactor)

return m.expandByLevels(pool, levels, priceFactor)
}
Expand Down Expand Up @@ -391,11 +434,12 @@ func (m *MarketDepth) refreshAMM(pool entities.AMMPool, depth *entities.MarketDe
}

if pool.Status == entities.AMMStatusCancelled || pool.Status == entities.AMMStatusStopped {
delete(cache.activeAMMs, ammParty)
delete(cache.ammOrders, ammParty)
cache.removeAMM(ammParty)
return
}

cache.addAMM(pool)

// expand it again into new orders and push them into the market depth
orders, estimated, _ := m.ExpandAMM(context.Background(), pool, cache.priceFactor)
for i := range orders {
Expand All @@ -404,9 +448,7 @@ func (m *MarketDepth) refreshAMM(pool entities.AMMPool, depth *entities.MarketDe
cache.estimatedOrder[orders[i].ID] = struct{}{}
}
}

cache.ammOrders[ammParty] = orders
cache.activeAMMs[ammParty] = pool
}

// refreshAMM is used when an AMM has either traded or its definition has changed.
Expand Down Expand Up @@ -484,6 +526,7 @@ func (m *MarketDepth) getAMMCache(marketID string) (*ammCache, error) {
ammOrders: map[string][]*types.Order{},
activeAMMs: map[string]entities.AMMPool{},
estimatedOrder: map[string]struct{}{},
levels: map[string][]*level{},
}
m.ammCache[marketID] = cache

Expand Down
80 changes: 78 additions & 2 deletions datanode/service/market_depth_amm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,82 @@ func TestEstimatedStepOverAMMBound(t *testing.T) {
assert.Equal(t, 3, int(mds.service.GetVolumeAtPrice(marketID, types.SideSell, 2001)))
}

func TestExpansionMuchBiggerThanAMMs(t *testing.T) {
ctx := context.Background()

cfg := service.MarketDepthConfig{
AmmFullExpansionPercentage: 1,
AmmMaxEstimatedSteps: 10,
AmmEstimatedStepPercentage: 5,
}

mds := getServiceWithConfig(t, cfg)
defer mds.ctrl.Finish()

marketID := vgcrypto.RandomHash()

ensureLiveOrders(t, mds, marketID)
ensureDecimalPlaces(t, mds)
mds.pos.EXPECT().GetByMarketAndParty(gomock.Any(), gomock.Any(), gomock.Any()).Return(entities.Position{OpenVolume: 0}, nil)
mds.marketData.EXPECT().GetMarketDataByID(gomock.Any(), gomock.Any()).Times(1).Return(entities.MarketData{MidPrice: num.DecimalFromInt64(2000)}, nil)

// data node is starting from network history, initialise market-depth based on whats aleady there
ensureAMMs(t, mds, marketID)
mds.service.Initialise(ctx)

assert.Equal(t, 465, int(mds.service.GetTotalAMMVolume(marketID)))
assert.Equal(t, 345, int(mds.service.GetAMMVolume(marketID, true)))
assert.Equal(t, 120, int(mds.service.GetAMMVolume(marketID, false)))
assert.Equal(t, 485, int(mds.service.GetTotalVolume(marketID)))

assert.Equal(t, "1999", mds.service.GetBestBidPrice(marketID).String())
assert.Equal(t, "2001", mds.service.GetBestAskPrice(marketID).String())
}

func TestMidPriceMove(t *testing.T) {
ctx := context.Background()

mds := getService(t)
defer mds.ctrl.Finish()

marketID := vgcrypto.RandomHash()

ensureLiveOrders(t, mds, marketID)
ensureDecimalPlaces(t, mds)
mds.pos.EXPECT().GetByMarketAndParty(gomock.Any(), gomock.Any(), gomock.Any()).Return(entities.Position{OpenVolume: 0}, nil)
mds.marketData.EXPECT().GetMarketDataByID(gomock.Any(), gomock.Any()).Times(1).Return(entities.MarketData{MidPrice: num.DecimalFromInt64(2000)}, nil)

// data node is starting from network history, initialise market-depth based on whats aleady there
pool := ensureAMMs(t, mds, marketID)
mds.service.Initialise(ctx)

assert.Equal(t, 240, int(mds.service.GetTotalAMMVolume(marketID)))
assert.Equal(t, 120, int(mds.service.GetAMMVolume(marketID, true)))
assert.Equal(t, 120, int(mds.service.GetAMMVolume(marketID, false)))
assert.Equal(t, 260, int(mds.service.GetTotalVolume(marketID)))

assert.Equal(t, "1999", mds.service.GetBestBidPrice(marketID).String())
assert.Equal(t, "2001", mds.service.GetBestAskPrice(marketID).String())

// now say the mid-price moves a little, we want to check we recalculate the levels properly
mds.pos.EXPECT().GetByMarketAndParty(gomock.Any(), gomock.Any(), gomock.Any()).Return(entities.Position{OpenVolume: 500}, nil)
mds.marketData.EXPECT().GetMarketDataByID(gomock.Any(), gomock.Any()).Times(1).Return(entities.MarketData{MidPrice: num.DecimalFromInt64(1800)}, nil)
mds.service.AddOrder(
&types.Order{
ID: vgcrypto.RandomHash(),
Party: pool.AmmPartyID.String(),
MarketID: marketID,
Side: types.SideBuy,
Status: entities.OrderStatusFilled,
},
time.Date(2022, 3, 8, 16, 15, 39, 901022000, time.UTC),
37,
)

assert.Equal(t, "1828", mds.service.GetBestBidPrice(marketID).String())
assert.Equal(t, "3000", mds.service.GetBestAskPrice(marketID).String()) // this is an actual order volume not AMM volume
}

func ensureLiveOrders(t *testing.T, mds *MDS, marketID string) {
t.Helper()
mds.orders.EXPECT().GetLiveOrders(gomock.Any()).Return([]entities.Order{
Expand All @@ -285,7 +361,7 @@ func ensureLiveOrders(t *testing.T, mds *MDS, marketID string) {
MarketID: entities.MarketID(marketID),
PartyID: entities.PartyID(vgcrypto.RandomHash()),
Side: types.SideBuy,
Price: decimal.NewFromInt(1800),
Price: decimal.NewFromInt(1000),
Size: 10,
Remaining: 10,
Type: entities.OrderTypeLimit,
Expand All @@ -300,7 +376,7 @@ func ensureLiveOrders(t *testing.T, mds *MDS, marketID string) {
Side: types.SideSell,
Type: entities.OrderTypeLimit,
Status: entities.OrderStatusActive,
Price: decimal.NewFromInt(2200),
Price: decimal.NewFromInt(3000),
Size: 10,
Remaining: 10,
VegaTime: time.Date(2022, 3, 8, 14, 15, 39, 901022000, time.UTC),
Expand Down

0 comments on commit 9278fc8

Please sign in to comment.