Skip to content

Commit

Permalink
Merge pull request #11525 from vegaprotocol/amm-perf
Browse files Browse the repository at this point in the history
feat: improve snapshot performance by using more targetted caching
  • Loading branch information
wwestgarth authored Aug 6, 2024
2 parents 4ca9285 + 307227b commit d378062
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

- [11428](https://github.com/vegaprotocol/vega/issues/11428) - Add buy back and treasury fee and separate discount/reward factors.
- [11468](https://github.com/vegaprotocol/vega/issues/11468) - Added support for volume rebate program.
- [11523](https://github.com/vegaprotocol/vega/issues/11523) - Change method of caching to improve `AMM` snapshot performance.
- [11459](https://github.com/vegaprotocol/vega/issues/11459) - Deprecate time weight position reward metric and replace it with time weighted notional.

### 🐛 Fixes
Expand Down
29 changes: 7 additions & 22 deletions core/execution/amm/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,16 @@ func (s *Sqrter) sqrt(u *num.Uint) num.Decimal {
return num.DecimalZero()
}

if r, ok := s.cache[u.String()]; ok {
return r
}
// caching was disabled here since it caused problems with snapshots (https://github.com/vegaprotocol/vega/issues/11523)
// and we changed tact to instead cache constant terms in calculations that *involve* sqrt's instead of the sqrt result
// directly. I'm leaving the ghost of this cache here incase we need to introduce it again, maybe as a LRU cache instead.
// if r, ok := s.cache[u.String()]; ok {
// return r
// }

// TODO that we may need to re-visit this depending on the performance impact
// but for now lets do it "properly" in full decimals and work out how we can
// improve it once we have reg-tests and performance data.
r := num.UintOne().Sqrt(u)

// and cache it -- we can also maybe be more clever here and use a LRU but thats for later
s.cache[u.String()] = r
// s.cache[u.String()] = r
return r
}

Expand Down Expand Up @@ -198,11 +197,6 @@ func NewFromProto(
e.ammParties[v.Key] = v.Value
}

// TODO consider whether we want the cache in the snapshot, it might be pretty large/slow and I'm not sure what we gain
for _, v := range state.Sqrter {
e.rooter.cache[v.Key] = num.MustDecimalFromString(v.Value)
}

for _, v := range state.Pools {
p, err := NewPoolFromProto(log, e.rooter.sqrt, e.collateral, e.position, v.Pool, v.Party, priceFactor, positionFactor)
if err != nil {
Expand All @@ -216,19 +210,10 @@ func NewFromProto(

func (e *Engine) IntoProto() *v1.AmmState {
state := &v1.AmmState{
Sqrter: make([]*v1.StringMapEntry, 0, len(e.rooter.cache)),
AmmPartyIds: make([]*v1.StringMapEntry, 0, len(e.ammParties)),
Pools: make([]*v1.PoolMapEntry, 0, len(e.pools)),
}

for k, v := range e.rooter.cache {
state.Sqrter = append(state.Sqrter, &v1.StringMapEntry{
Key: k,
Value: v.String(),
})
}
sort.Slice(state.Sqrter, func(i, j int) bool { return state.Sqrter[i].Key < state.Sqrter[j].Key })

for k, v := range e.ammParties {
state.AmmPartyIds = append(state.AmmPartyIds, &v1.StringMapEntry{
Key: k,
Expand Down
67 changes: 46 additions & 21 deletions core/execution/amm/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ type curve struct {
// note that this equals Vega's position at the boundary only in the lower curve, since Vega position == curve-position
// in the upper curve Vega's position == 0 => position of `pv`` in curve-position, Vega's position pv => 0 in curve-position
pv num.Decimal

lDivSqrtPu num.Decimal
sqrtHigh num.Decimal
}

func (c *curve) volumeBetweenPrices(sqrt sqrtFn, st, nd *num.Uint) uint64 {
Expand All @@ -56,8 +59,8 @@ func (c *curve) volumeBetweenPrices(sqrt sqrtFn, st, nd *num.Uint) uint64 {
return 0
}

stP := impliedPosition(sqrt(st), sqrt(c.high), c.l)
ndP := impliedPosition(sqrt(nd), sqrt(c.high), c.l)
stP := impliedPosition(sqrt(st), c.sqrtHigh, c.l)
ndP := impliedPosition(sqrt(nd), c.sqrtHigh, c.l)

// abs(P(st) - P(nd))
volume := stP.Sub(ndP).Abs()
Expand All @@ -67,7 +70,7 @@ func (c *curve) volumeBetweenPrices(sqrt sqrtFn, st, nd *num.Uint) uint64 {
// positionAtPrice returns the position of the AMM if its fair-price were the given price. This
// will be signed for long/short as usual.
func (c *curve) positionAtPrice(sqrt sqrtFn, price *num.Uint) int64 {
pos := impliedPosition(sqrt(price), sqrt(c.high), c.l)
pos := impliedPosition(sqrt(price), c.sqrtHigh, c.l)
if c.isLower {
return pos.IntPart()
}
Expand Down Expand Up @@ -111,6 +114,8 @@ type Pool struct {

maxCalculationLevels *num.Uint // maximum number of price levels the AMM will be expanded into
oneTick *num.Uint // one price tick

fpCache map[int64]*num.Uint
}

func NewPool(
Expand Down Expand Up @@ -148,6 +153,7 @@ func NewPool(
oneTick: oneTick,
status: types.AMMPoolStatusActive,
maxCalculationLevels: maxCalculationLevels,
fpCache: map[int64]*num.Uint{},
}
err := pool.setCurves(rf, sf, linearSlippage)
if err != nil {
Expand Down Expand Up @@ -268,12 +274,18 @@ func NewCurveFromProto(c *snapshotpb.PoolMapEntry_Curve) (*curve, error) {
if overflow {
return nil, fmt.Errorf("failed to convert string to Uint: %s", c.Low)
}

sqrtHigh := num.UintOne().Sqrt(high)
lDivSqrtPu := l.Div(sqrtHigh)

return &curve{
l: l,
high: high,
low: low,
empty: c.Empty,
pv: pv,
l: l,
high: high,
low: low,
empty: c.Empty,
pv: pv,
sqrtHigh: sqrtHigh,
lDivSqrtPu: lDivSqrtPu,
}, nil
}

Expand Down Expand Up @@ -355,6 +367,7 @@ func (p *Pool) Update(
sqrt: p.sqrt,
oneTick: p.oneTick,
maxCalculationLevels: p.maxCalculationLevels,
fpCache: map[int64]*num.Uint{},
}
if err := updated.setCurves(rf, sf, linearSlippage); err != nil {
return nil, err
Expand Down Expand Up @@ -443,14 +456,20 @@ func generateCurve(

// now we scale theoretical position by position factor so that is it feeds through into all subsequent equations
pv = pv.Mul(positionFactor)
l := pv.Mul(lu)

sqrtHigh := sqrt(high)
lDivSqrtPu := l.Div(sqrtHigh)

// and finally calculate L = pv * Lu
return &curve{
l: pv.Mul(lu),
low: low,
high: high,
pv: pv,
isLower: isLower,
l: l,
low: low,
high: high,
pv: pv,
isLower: isLower,
lDivSqrtPu: lDivSqrtPu,
sqrtHigh: sqrtHigh,
}
}

Expand Down Expand Up @@ -696,6 +715,10 @@ func (p *Pool) fairPrice() *num.Uint {
return p.lower.high.Clone()
}

if fp, ok := p.fpCache[pos]; ok {
return fp.Clone()
}

cu := p.lower
pv := num.DecimalFromInt64(pos)
if pos < 0 {
Expand All @@ -708,10 +731,8 @@ func (p *Pool) fairPrice() *num.Uint {
panic("should not be calculating fair-price on empty-curve side")
}

l := cu.l

// pv * sqrt(pu) * (1/L) + 1
denom := pv.Mul(p.sqrt(cu.high)).Div(l).Add(num.DecimalOne())
denom := pv.Mul(cu.sqrtHigh).Div(cu.l).Add(num.DecimalOne())

// sqrt(fp) = sqrt(pu) / denom
sqrtPf := p.sqrt(cu.high).Div(denom)
Expand All @@ -729,12 +750,16 @@ func (p *Pool) fairPrice() *num.Uint {
}

fairPrice, _ := num.UintFromDecimal(fp)

p.fpCache = map[int64]*num.Uint{
pos: fairPrice.Clone(),
}
return fairPrice
}

// virtualBalancesShort returns the pools x, y balances when the pool has a negative position
//
// x = P + Pv + L / sqrt(pl)
// x = P + Pv + L / sqrt(pu)
// y = L * sqrt(fair-price).
func (p *Pool) virtualBalancesShort(pos int64, fp *num.Uint) (num.Decimal, num.Decimal) {
cu := p.upper
Expand All @@ -750,10 +775,10 @@ func (p *Pool) virtualBalancesShort(pos int64, fp *num.Uint) (num.Decimal, num.D
// Pv
term2x := cu.pv

// L / sqrt(pl)
term3x := cu.l.Div(p.sqrt(cu.high))
// L / sqrt(pu)
term3x := cu.lDivSqrtPu

// x = P + (cc * rf / pu) + (L / sqrt(pl))
// x = P + (cc * rf / pu) + (L / sqrt(pu))
x := term2x.Add(term3x).Add(term1x)

// now lets get y
Expand All @@ -779,7 +804,7 @@ func (p *Pool) virtualBalancesLong(pos int64, fp *num.Uint) (num.Decimal, num.De
term1x := num.DecimalFromInt64(pos)

// L / sqrt(pu)
term2x := cu.l.Div(p.sqrt(cu.high))
term2x := cu.lDivSqrtPu

// x = P + (L / sqrt(pu))
x := term1x.Add(term2x)
Expand Down

0 comments on commit d378062

Please sign in to comment.