From 307227b9538fa19618078532fdb552b4ba772e65 Mon Sep 17 00:00:00 2001 From: wwestgarth Date: Tue, 6 Aug 2024 10:28:01 +0100 Subject: [PATCH] feat: improve snapshot performance by using more targetted caching --- CHANGELOG.md | 1 + core/execution/amm/engine.go | 29 ++++------------ core/execution/amm/pool.go | 67 +++++++++++++++++++++++++----------- 3 files changed, 54 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9417250cb0..be7eb45ec2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ - [11428](https://github.com/vegaprotocol/vega/issues/11428) - Add buy back and treasury fee and separate discount/reward factors. - [11468](https://github.com/vegaprotocol/vega/issues/11468) - Added support for volume rebate program. +- [11523](https://github.com/vegaprotocol/vega/issues/11523) - Change method of caching to improve `AMM` snapshot performance. - [11459](https://github.com/vegaprotocol/vega/issues/11459) - Deprecate time weight position reward metric and replace it with time weighted notional. ### 🐛 Fixes diff --git a/core/execution/amm/engine.go b/core/execution/amm/engine.go index c5be1221d6..302d83ca6f 100644 --- a/core/execution/amm/engine.go +++ b/core/execution/amm/engine.go @@ -99,17 +99,16 @@ func (s *Sqrter) sqrt(u *num.Uint) num.Decimal { return num.DecimalZero() } - if r, ok := s.cache[u.String()]; ok { - return r - } + // caching was disabled here since it caused problems with snapshots (https://github.com/vegaprotocol/vega/issues/11523) + // and we changed tact to instead cache constant terms in calculations that *involve* sqrt's instead of the sqrt result + // directly. I'm leaving the ghost of this cache here incase we need to introduce it again, maybe as a LRU cache instead. + // if r, ok := s.cache[u.String()]; ok { + // return r + // } - // TODO that we may need to re-visit this depending on the performance impact - // but for now lets do it "properly" in full decimals and work out how we can - // improve it once we have reg-tests and performance data. r := num.UintOne().Sqrt(u) - // and cache it -- we can also maybe be more clever here and use a LRU but thats for later - s.cache[u.String()] = r + // s.cache[u.String()] = r return r } @@ -198,11 +197,6 @@ func NewFromProto( e.ammParties[v.Key] = v.Value } - // TODO consider whether we want the cache in the snapshot, it might be pretty large/slow and I'm not sure what we gain - for _, v := range state.Sqrter { - e.rooter.cache[v.Key] = num.MustDecimalFromString(v.Value) - } - for _, v := range state.Pools { p, err := NewPoolFromProto(log, e.rooter.sqrt, e.collateral, e.position, v.Pool, v.Party, priceFactor, positionFactor) if err != nil { @@ -216,19 +210,10 @@ func NewFromProto( func (e *Engine) IntoProto() *v1.AmmState { state := &v1.AmmState{ - Sqrter: make([]*v1.StringMapEntry, 0, len(e.rooter.cache)), AmmPartyIds: make([]*v1.StringMapEntry, 0, len(e.ammParties)), Pools: make([]*v1.PoolMapEntry, 0, len(e.pools)), } - for k, v := range e.rooter.cache { - state.Sqrter = append(state.Sqrter, &v1.StringMapEntry{ - Key: k, - Value: v.String(), - }) - } - sort.Slice(state.Sqrter, func(i, j int) bool { return state.Sqrter[i].Key < state.Sqrter[j].Key }) - for k, v := range e.ammParties { state.AmmPartyIds = append(state.AmmPartyIds, &v1.StringMapEntry{ Key: k, diff --git a/core/execution/amm/pool.go b/core/execution/amm/pool.go index de2f4bda0a..20e7b37f15 100644 --- a/core/execution/amm/pool.go +++ b/core/execution/amm/pool.go @@ -42,6 +42,9 @@ type curve struct { // note that this equals Vega's position at the boundary only in the lower curve, since Vega position == curve-position // in the upper curve Vega's position == 0 => position of `pv`` in curve-position, Vega's position pv => 0 in curve-position pv num.Decimal + + lDivSqrtPu num.Decimal + sqrtHigh num.Decimal } func (c *curve) volumeBetweenPrices(sqrt sqrtFn, st, nd *num.Uint) uint64 { @@ -56,8 +59,8 @@ func (c *curve) volumeBetweenPrices(sqrt sqrtFn, st, nd *num.Uint) uint64 { return 0 } - stP := impliedPosition(sqrt(st), sqrt(c.high), c.l) - ndP := impliedPosition(sqrt(nd), sqrt(c.high), c.l) + stP := impliedPosition(sqrt(st), c.sqrtHigh, c.l) + ndP := impliedPosition(sqrt(nd), c.sqrtHigh, c.l) // abs(P(st) - P(nd)) volume := stP.Sub(ndP).Abs() @@ -67,7 +70,7 @@ func (c *curve) volumeBetweenPrices(sqrt sqrtFn, st, nd *num.Uint) uint64 { // positionAtPrice returns the position of the AMM if its fair-price were the given price. This // will be signed for long/short as usual. func (c *curve) positionAtPrice(sqrt sqrtFn, price *num.Uint) int64 { - pos := impliedPosition(sqrt(price), sqrt(c.high), c.l) + pos := impliedPosition(sqrt(price), c.sqrtHigh, c.l) if c.isLower { return pos.IntPart() } @@ -111,6 +114,8 @@ type Pool struct { maxCalculationLevels *num.Uint // maximum number of price levels the AMM will be expanded into oneTick *num.Uint // one price tick + + fpCache map[int64]*num.Uint } func NewPool( @@ -148,6 +153,7 @@ func NewPool( oneTick: oneTick, status: types.AMMPoolStatusActive, maxCalculationLevels: maxCalculationLevels, + fpCache: map[int64]*num.Uint{}, } err := pool.setCurves(rf, sf, linearSlippage) if err != nil { @@ -268,12 +274,18 @@ func NewCurveFromProto(c *snapshotpb.PoolMapEntry_Curve) (*curve, error) { if overflow { return nil, fmt.Errorf("failed to convert string to Uint: %s", c.Low) } + + sqrtHigh := num.UintOne().Sqrt(high) + lDivSqrtPu := l.Div(sqrtHigh) + return &curve{ - l: l, - high: high, - low: low, - empty: c.Empty, - pv: pv, + l: l, + high: high, + low: low, + empty: c.Empty, + pv: pv, + sqrtHigh: sqrtHigh, + lDivSqrtPu: lDivSqrtPu, }, nil } @@ -355,6 +367,7 @@ func (p *Pool) Update( sqrt: p.sqrt, oneTick: p.oneTick, maxCalculationLevels: p.maxCalculationLevels, + fpCache: map[int64]*num.Uint{}, } if err := updated.setCurves(rf, sf, linearSlippage); err != nil { return nil, err @@ -443,14 +456,20 @@ func generateCurve( // now we scale theoretical position by position factor so that is it feeds through into all subsequent equations pv = pv.Mul(positionFactor) + l := pv.Mul(lu) + + sqrtHigh := sqrt(high) + lDivSqrtPu := l.Div(sqrtHigh) // and finally calculate L = pv * Lu return &curve{ - l: pv.Mul(lu), - low: low, - high: high, - pv: pv, - isLower: isLower, + l: l, + low: low, + high: high, + pv: pv, + isLower: isLower, + lDivSqrtPu: lDivSqrtPu, + sqrtHigh: sqrtHigh, } } @@ -696,6 +715,10 @@ func (p *Pool) fairPrice() *num.Uint { return p.lower.high.Clone() } + if fp, ok := p.fpCache[pos]; ok { + return fp.Clone() + } + cu := p.lower pv := num.DecimalFromInt64(pos) if pos < 0 { @@ -708,10 +731,8 @@ func (p *Pool) fairPrice() *num.Uint { panic("should not be calculating fair-price on empty-curve side") } - l := cu.l - // pv * sqrt(pu) * (1/L) + 1 - denom := pv.Mul(p.sqrt(cu.high)).Div(l).Add(num.DecimalOne()) + denom := pv.Mul(cu.sqrtHigh).Div(cu.l).Add(num.DecimalOne()) // sqrt(fp) = sqrt(pu) / denom sqrtPf := p.sqrt(cu.high).Div(denom) @@ -729,12 +750,16 @@ func (p *Pool) fairPrice() *num.Uint { } fairPrice, _ := num.UintFromDecimal(fp) + + p.fpCache = map[int64]*num.Uint{ + pos: fairPrice.Clone(), + } return fairPrice } // virtualBalancesShort returns the pools x, y balances when the pool has a negative position // -// x = P + Pv + L / sqrt(pl) +// x = P + Pv + L / sqrt(pu) // y = L * sqrt(fair-price). func (p *Pool) virtualBalancesShort(pos int64, fp *num.Uint) (num.Decimal, num.Decimal) { cu := p.upper @@ -750,10 +775,10 @@ func (p *Pool) virtualBalancesShort(pos int64, fp *num.Uint) (num.Decimal, num.D // Pv term2x := cu.pv - // L / sqrt(pl) - term3x := cu.l.Div(p.sqrt(cu.high)) + // L / sqrt(pu) + term3x := cu.lDivSqrtPu - // x = P + (cc * rf / pu) + (L / sqrt(pl)) + // x = P + (cc * rf / pu) + (L / sqrt(pu)) x := term2x.Add(term3x).Add(term1x) // now lets get y @@ -779,7 +804,7 @@ func (p *Pool) virtualBalancesLong(pos int64, fp *num.Uint) (num.Decimal, num.De term1x := num.DecimalFromInt64(pos) // L / sqrt(pu) - term2x := cu.l.Div(p.sqrt(cu.high)) + term2x := cu.lDivSqrtPu // x = P + (L / sqrt(pu)) x := term1x.Add(term2x)