Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: amm market depth perf - only calculate levels if they change an… #11494

Merged
merged 4 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions datanode/service/market_depth.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,62 @@ type ammCache struct {
ammOrders map[string][]*types.Order // map amm id -> expanded orders, so we can remove them if amended
activeAMMs map[string]entities.AMMPool // map amm id -> amm definition, so we can refresh its expansion
estimatedOrder map[string]struct{} // order-id -> whether it was an estimated order

// the lowest/highest bounds of all AMMs
lowestBound num.Decimal
highestBound num.Decimal

// reference -> calculation levels, if the reference point hasn't changed we can avoid the busy task
// of recalculating them
levels map[string][]*level
}

func (c *ammCache) addAMM(a entities.AMMPool) {
c.activeAMMs[a.AmmPartyID.String()] = a

low := a.ParametersBase
if a.ParametersLowerBound != nil {
low = *a.ParametersLowerBound
}

if c.lowestBound.IsZero() {
c.lowestBound = low
} else {
c.lowestBound = num.MinD(c.lowestBound, low)
}

high := a.ParametersBase
if a.ParametersUpperBound != nil {
high = *a.ParametersUpperBound
}
c.highestBound = num.MaxD(c.highestBound, high)
}

func (c *ammCache) removeAMM(ammParty string) {
delete(c.activeAMMs, ammParty)
delete(c.ammOrders, ammParty)

// now we need to recalculate the lowest/highest

c.lowestBound = num.DecimalZero()
c.highestBound = num.DecimalZero()
for _, a := range c.activeAMMs {
low := a.ParametersBase
if a.ParametersLowerBound != nil {
low = *a.ParametersLowerBound
}
if c.lowestBound.IsZero() {
c.lowestBound = low
} else {
c.lowestBound = num.MinD(c.lowestBound, low)
}

high := a.ParametersBase
if a.ParametersUpperBound != nil {
high = *a.ParametersUpperBound
}
c.lowestBound = num.MaxD(c.highestBound, high)
}
}

type MarketDepth struct {
Expand Down
61 changes: 52 additions & 9 deletions datanode/service/market_depth_amm.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,14 @@ func (m *MarketDepth) getActiveAMMs(ctx context.Context) map[string][]entities.A
return ammByMarket
}

func (m *MarketDepth) getCalculationBounds(reference num.Decimal, priceFactor num.Decimal) []*level {
func (m *MarketDepth) getCalculationBounds(cache *ammCache, reference num.Decimal, priceFactor num.Decimal) []*level {
if levels, ok := cache.levels[reference.String()]; ok {
return levels
}

lowestBound := cache.lowestBound
highestBound := cache.highestBound

// first lets calculate the region we will expand accurately, this will be some percentage either side of the reference price
factor := num.DecimalFromFloat(m.cfg.AmmFullExpansionPercentage).Div(hundred)

Expand Down Expand Up @@ -138,6 +145,34 @@ func (m *MarketDepth) getCalculationBounds(reference num.Decimal, priceFactor nu
estLow := num.UintZero().Sub(accLow, num.Min(accLow, eRange))
estHigh := num.UintZero().Add(accHigh, eRange)

// cap steps to the lowest/highest boundaries of all AMMs
lowD, _ := num.UintFromDecimal(lowestBound)
if accLow.LTE(lowD) {
accLow = lowD.Clone()
estLow = lowD.Clone()
}

highD, _ := num.UintFromDecimal(highestBound)
if accHigh.GTE(highD) {
accHigh = highD.Clone()
estHigh = highD.Clone()
}

// need to find the first n such that
// accLow - (n * eStep) < lowD
// accLow
if estLow.LT(lowD) {
delta, _ := num.UintZero().Delta(accLow, lowD)
delta.Div(delta, eStep)
estLow = num.UintZero().Sub(accLow, delta.Mul(delta, eStep))
}

if estHigh.GT(highD) {
delta, _ := num.UintZero().Delta(accHigh, highD)
delta.Div(delta, eStep)
estHigh = num.UintZero().Add(accHigh, delta.Mul(delta, eStep))
}

levels := []*level{}

// we now have our four prices [estLow, accLow, accHigh, estHigh] where from
Expand All @@ -164,6 +199,10 @@ func (m *MarketDepth) getCalculationBounds(reference num.Decimal, priceFactor nu
price = num.UintZero().Add(price, eStep)
}

cache.levels = map[string][]*level{
reference.String(): levels,
}

return levels
}

Expand Down Expand Up @@ -268,7 +307,6 @@ func (m *MarketDepth) expandByLevels(pool entities.AMMPool, levels []*level, pri
v1 = ammDefn.position
}
}

// calculate the volume
volume := v1.Sub(v2).Abs().IntPart()

Expand Down Expand Up @@ -303,15 +341,15 @@ func (m *MarketDepth) InitialiseAMMs(ctx context.Context) {

// add it to our active list, we want to do this even if we fail to get a reference
for _, a := range amms {
cache.activeAMMs[a.AmmPartyID.String()] = a
cache.addAMM(a)
}

reference, err := m.getReference(ctx, marketID)
if err != nil {
continue
}

levels := m.getCalculationBounds(reference, priceFactor)
levels := m.getCalculationBounds(cache, reference, priceFactor)

for _, amm := range amms {
orders, estimated, err := m.expandByLevels(amm, levels, priceFactor)
Expand Down Expand Up @@ -345,7 +383,12 @@ func (m *MarketDepth) ExpandAMM(ctx context.Context, pool entities.AMMPool, pric
return nil, nil, err
}

levels := m.getCalculationBounds(reference, priceFactor)
cache, err := m.getAMMCache(string(pool.MarketID))
if err != nil {
return nil, nil, err
}

levels := m.getCalculationBounds(cache, reference, priceFactor)

return m.expandByLevels(pool, levels, priceFactor)
}
Expand Down Expand Up @@ -391,11 +434,12 @@ func (m *MarketDepth) refreshAMM(pool entities.AMMPool, depth *entities.MarketDe
}

if pool.Status == entities.AMMStatusCancelled || pool.Status == entities.AMMStatusStopped {
delete(cache.activeAMMs, ammParty)
delete(cache.ammOrders, ammParty)
cache.removeAMM(ammParty)
return
}

cache.addAMM(pool)

// expand it again into new orders and push them into the market depth
orders, estimated, _ := m.ExpandAMM(context.Background(), pool, cache.priceFactor)
for i := range orders {
Expand All @@ -404,9 +448,7 @@ func (m *MarketDepth) refreshAMM(pool entities.AMMPool, depth *entities.MarketDe
cache.estimatedOrder[orders[i].ID] = struct{}{}
}
}

cache.ammOrders[ammParty] = orders
cache.activeAMMs[ammParty] = pool
}

// refreshAMM is used when an AMM has either traded or its definition has changed.
Expand Down Expand Up @@ -484,6 +526,7 @@ func (m *MarketDepth) getAMMCache(marketID string) (*ammCache, error) {
ammOrders: map[string][]*types.Order{},
activeAMMs: map[string]entities.AMMPool{},
estimatedOrder: map[string]struct{}{},
levels: map[string][]*level{},
}
m.ammCache[marketID] = cache

Expand Down
80 changes: 78 additions & 2 deletions datanode/service/market_depth_amm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,82 @@ func TestEstimatedStepOverAMMBound(t *testing.T) {
assert.Equal(t, 3, int(mds.service.GetVolumeAtPrice(marketID, types.SideSell, 2001)))
}

func TestExpansionMuchBiggerThanAMMs(t *testing.T) {
ctx := context.Background()

cfg := service.MarketDepthConfig{
AmmFullExpansionPercentage: 1,
AmmMaxEstimatedSteps: 10,
AmmEstimatedStepPercentage: 5,
}

mds := getServiceWithConfig(t, cfg)
defer mds.ctrl.Finish()

marketID := vgcrypto.RandomHash()

ensureLiveOrders(t, mds, marketID)
ensureDecimalPlaces(t, mds)
mds.pos.EXPECT().GetByMarketAndParty(gomock.Any(), gomock.Any(), gomock.Any()).Return(entities.Position{OpenVolume: 0}, nil)
mds.marketData.EXPECT().GetMarketDataByID(gomock.Any(), gomock.Any()).Times(1).Return(entities.MarketData{MidPrice: num.DecimalFromInt64(2000)}, nil)

// data node is starting from network history, initialise market-depth based on whats aleady there
ensureAMMs(t, mds, marketID)
mds.service.Initialise(ctx)

assert.Equal(t, 465, int(mds.service.GetTotalAMMVolume(marketID)))
assert.Equal(t, 345, int(mds.service.GetAMMVolume(marketID, true)))
assert.Equal(t, 120, int(mds.service.GetAMMVolume(marketID, false)))
assert.Equal(t, 485, int(mds.service.GetTotalVolume(marketID)))

assert.Equal(t, "1999", mds.service.GetBestBidPrice(marketID).String())
assert.Equal(t, "2001", mds.service.GetBestAskPrice(marketID).String())
}

func TestMidPriceMove(t *testing.T) {
ctx := context.Background()

mds := getService(t)
defer mds.ctrl.Finish()

marketID := vgcrypto.RandomHash()

ensureLiveOrders(t, mds, marketID)
ensureDecimalPlaces(t, mds)
mds.pos.EXPECT().GetByMarketAndParty(gomock.Any(), gomock.Any(), gomock.Any()).Return(entities.Position{OpenVolume: 0}, nil)
mds.marketData.EXPECT().GetMarketDataByID(gomock.Any(), gomock.Any()).Times(1).Return(entities.MarketData{MidPrice: num.DecimalFromInt64(2000)}, nil)

// data node is starting from network history, initialise market-depth based on whats aleady there
pool := ensureAMMs(t, mds, marketID)
mds.service.Initialise(ctx)

assert.Equal(t, 240, int(mds.service.GetTotalAMMVolume(marketID)))
assert.Equal(t, 120, int(mds.service.GetAMMVolume(marketID, true)))
assert.Equal(t, 120, int(mds.service.GetAMMVolume(marketID, false)))
assert.Equal(t, 260, int(mds.service.GetTotalVolume(marketID)))

assert.Equal(t, "1999", mds.service.GetBestBidPrice(marketID).String())
assert.Equal(t, "2001", mds.service.GetBestAskPrice(marketID).String())

// now say the mid-price moves a little, we want to check we recalculate the levels properly
mds.pos.EXPECT().GetByMarketAndParty(gomock.Any(), gomock.Any(), gomock.Any()).Return(entities.Position{OpenVolume: 500}, nil)
mds.marketData.EXPECT().GetMarketDataByID(gomock.Any(), gomock.Any()).Times(1).Return(entities.MarketData{MidPrice: num.DecimalFromInt64(1800)}, nil)
mds.service.AddOrder(
&types.Order{
ID: vgcrypto.RandomHash(),
Party: pool.AmmPartyID.String(),
MarketID: marketID,
Side: types.SideBuy,
Status: entities.OrderStatusFilled,
},
time.Date(2022, 3, 8, 16, 15, 39, 901022000, time.UTC),
37,
)

assert.Equal(t, "1828", mds.service.GetBestBidPrice(marketID).String())
assert.Equal(t, "3000", mds.service.GetBestAskPrice(marketID).String()) // this is an actual order volume not AMM volume
}

func ensureLiveOrders(t *testing.T, mds *MDS, marketID string) {
t.Helper()
mds.orders.EXPECT().GetLiveOrders(gomock.Any()).Return([]entities.Order{
Expand All @@ -285,7 +361,7 @@ func ensureLiveOrders(t *testing.T, mds *MDS, marketID string) {
MarketID: entities.MarketID(marketID),
PartyID: entities.PartyID(vgcrypto.RandomHash()),
Side: types.SideBuy,
Price: decimal.NewFromInt(1800),
Price: decimal.NewFromInt(1000),
Size: 10,
Remaining: 10,
Type: entities.OrderTypeLimit,
Expand All @@ -300,7 +376,7 @@ func ensureLiveOrders(t *testing.T, mds *MDS, marketID string) {
Side: types.SideSell,
Type: entities.OrderTypeLimit,
Status: entities.OrderStatusActive,
Price: decimal.NewFromInt(2200),
Price: decimal.NewFromInt(3000),
Size: 10,
Remaining: 10,
VegaTime: time.Date(2022, 3, 8, 14, 15, 39, 901022000, time.UTC),
Expand Down
Loading