Skip to content

Commit

Permalink
feat: update core AMM equations to not rely on average-entry-price
Browse files Browse the repository at this point in the history
  • Loading branch information
wwestgarth committed Apr 29, 2024
1 parent 0f29607 commit 6de88b3
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 267 deletions.
18 changes: 3 additions & 15 deletions core/execution/amm/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ func (e *Engine) OnMTM(ctx context.Context) {
if !p.closing() {
continue
}
if pos, _ := p.getPosition(); pos != 0 {
if pos := p.getPosition(); pos != 0 {
continue
}

Expand Down Expand Up @@ -420,26 +420,14 @@ func (e *Engine) submit(active []*Pool, agg *types.Order, inner, outer *num.Uint
continue
}

pos, ae := p.getPosition()
x, y := p.virtualBalances(pos, ae, agg.Side)
dx := num.DecimalFromInt64(int64(volume))
// calculate the price the pool wil give for the trading volume
price := p.PriceForVolume(volume, agg.Side)

// dy = x*y / (x - dx) - y
// where y and x are the balances on either side of the pool, and dx is the change in volume
// then the trade price is dy/dx
dy := x.Mul(y).Div(x.Sub(dx)).Sub(y)
price, _ := num.UintFromDecimal(dy.Div(dx))
if e.log.GetLevel() == logging.DebugLevel {
e.log.Debug("generated order at price",
logging.String("price", price.String()),
logging.Uint64("volume", volume),
logging.String("id", p.ID),
logging.Int64("pos", pos),
logging.String("average-entry", ae.String()),
logging.String("y", y.String()),
logging.String("x", x.String()),
logging.String("dy", dy.String()),
logging.String("dx", dx.String()),
)
}

Expand Down
12 changes: 6 additions & 6 deletions core/execution/amm/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,10 @@ func testBasicSubmitOrder(t *testing.T) {
Price: num.NewUint(1900),
}

ensureBalancesN(t, tst.col, 10000000000, 2)
ensureBalancesN(t, tst.col, 10000000000, 3)
orders = tst.engine.SubmitOrder(agg, num.NewUint(2020), num.NewUint(1990))
require.Len(t, orders, 1)
assert.Equal(t, "2035", orders[0].Price.String())
assert.Equal(t, "2036", orders[0].Price.String())
// note that this volume being bigger than 242367 above means we've moved back to position, then flipped
// sign, and took volume from the other curve.
assert.Equal(t, uint64(371231), orders[0].Size)
Expand Down Expand Up @@ -408,7 +408,7 @@ func testSubmitOrderAcrossAMMBoundary(t *testing.T) {
assert.Equal(t, "2049", orders[2].Price.String())

// second round, 2 orders moving all pool's to the upper boundary of the second shortest
assert.Equal(t, "2124", orders[3].Price.String())
assert.Equal(t, "2125", orders[3].Price.String())
assert.Equal(t, "2124", orders[4].Price.String())

// third round, 1 orders moving the last pool to its boundary
Expand Down Expand Up @@ -459,11 +459,11 @@ func testSubmitOrderAcrossAMMBoundarySell(t *testing.T) {
assert.Equal(t, "2053", orders[2].Price.String())

// second round, 2 orders moving all pool's to the upper boundary of the second shortest
assert.Equal(t, "1923", orders[3].Price.String())
assert.Equal(t, "1923", orders[4].Price.String())
assert.Equal(t, "1925", orders[3].Price.String())
assert.Equal(t, "1925", orders[4].Price.String())

// third round, 1 orders moving the last pool to its boundary
assert.Equal(t, "1872", orders[5].Price.String())
assert.Equal(t, "1875", orders[5].Price.String())
}

func testBestPricesAndVolume(t *testing.T) {
Expand Down
165 changes: 99 additions & 66 deletions core/execution/amm/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package amm

import (
"code.vegaprotocol.io/vega/core/positions"
"code.vegaprotocol.io/vega/core/types"
"code.vegaprotocol.io/vega/libs/num"
"code.vegaprotocol.io/vega/libs/ptr"
Expand All @@ -25,8 +24,7 @@ import (

// ephemeralPosition keeps track of the pools position as if its generated orders had traded.
type ephemeralPosition struct {
size int64
averageEntry *num.Uint
size int64
}

type curve struct {
Expand Down Expand Up @@ -383,13 +381,31 @@ func (p *Pool) OrderbookShape(from, to *num.Uint) ([]*types.Order, []*types.Orde
return buys, sells
}

// PriceForVolume returns the price the AMM is willing to trade at to match with the given volume.
func (p *Pool) PriceForVolume(volume uint64, side types.Side) *num.Uint {
x, y := p.virtualBalances(p.getPosition(), p.fairPrice(), side)

// dy = x*y / (x - dx) - y
// where y and x are the balances on either side of the pool, and dx is the change in volume
// then the trade price is dy/dx
dx := num.DecimalFromInt64(int64(volume))
dy := x.Mul(y).Div(x.Sub(dx)).Sub(y)

// dy / dx
price, overflow := num.UintFromDecimal(dy.Div(dx))
if overflow {
panic("calculated negative price")
}
return price
}

// TradableVolumeInRange returns the volume the pool is willing to provide between the two given price levels for side of a given order
// that is trading with the pool. If `nil` is provided for either price then we take the full volume in that direction.
func (p *Pool) TradableVolumeInRange(side types.Side, price1 *num.Uint, price2 *num.Uint) uint64 {
if !p.canTrade(side) {
return 0
}
pos, _ := p.getPosition()
pos := p.getPosition()
st, nd := price1, price2

if price1 == nil {
Expand Down Expand Up @@ -454,25 +470,20 @@ func (p *Pool) setEphemeralPosition() {
return
}
p.eph = &ephemeralPosition{
size: 0,
averageEntry: num.UintZero(),
size: 0,
}

if pos := p.position.GetPositionsByParty(p.SubAccount); len(pos) != 0 {
p.eph.size = pos[0].Size()
p.eph.averageEntry = pos[0].AverageEntryPrice()
}
}

// updateEphemeralPosition sets the pools transient position given a generated order.
func (p *Pool) updateEphemeralPosition(order *types.Order) {
if order.Side == types.SideSell {
p.eph.averageEntry = positions.CalcVWAP(p.eph.averageEntry, -p.eph.size, int64(order.Size), order.Price)
p.eph.size -= int64(order.Size)
return
}

p.eph.averageEntry = positions.CalcVWAP(p.eph.averageEntry, p.eph.size, int64(order.Size), order.Price)
p.eph.size += int64(order.Size)
}

Expand All @@ -483,22 +494,76 @@ func (p *Pool) clearEphemeralPosition() {
}

// getPosition gets the pools current position an average-entry price.
func (p *Pool) getPosition() (int64, *num.Uint) {
func (p *Pool) getPosition() int64 {
if p.eph != nil {
return p.eph.size, p.eph.averageEntry.Clone()
return p.eph.size
}

if pos := p.position.GetPositionsByParty(p.SubAccount); len(pos) != 0 {
return pos[0].Size(), pos[0].AverageEntryPrice()
return pos[0].Size()
}
return 0
}

// fairPrice returns the fair price of the pool given its current position.

// sqrt(pf) = sqrt(pu) / (1 + pv * sqrt(pu) * 1/L )
// where pv is the virtual-position
// pv = pos, when the pool is long
// pv = pos + rf * cu / pu, when pool is short
//
// this transformation is needed since for each curve its virtual position is 0 at the lower bound which maps to the Vega position when the pool is
// long, but when the pool is short Vega position == 0 at the upper bounds and -ve at the lower.
func (p *Pool) fairPrice() *num.Uint {
pos := p.getPosition()
if pos == 0 {
// if no position fair price is base price
return p.lower.high.Clone()
}

cu := p.lower
pv := num.DecimalFromInt64(pos)

if pos < 0 {
cu = p.upper

// c / pu * rf
balance := p.getBalance()
term2 := cu.rf.Mul(num.DecimalFromUint(balance)).Div(num.DecimalFromUint(cu.high))

// pos + c / pu * rf
pv = pv.Add(term2)
}

l := num.DecimalFromUint(cu.l)

// pv * sqrt(pu) * (1/L) + 1
denom := pv.Mul(p.sqrt(cu.high)).Div(l).Add(num.DecimalOne())

// sqrt(fp) = sqrt(pu) / denom
sqrtPf := p.sqrt(cu.high).Div(denom)

// fair-price = sqrt(fp) * sqrt(fp)
fp := sqrtPf.Mul(sqrtPf)

// we want to round such that the price is further away from the base. This is so that once
// a pool's position is at its boundary we do not report volume that doesn't exist. For example
// say a pool's upper boundary is 1000 and for it to be at that boundary its position needs to
// be 10.5. The closest we can get is 10 but then we'd report a fair-price of 999.78. If
// we use 999 we'd be implying volume between 999 and 1000 which we don't want to trade.
if pos < 0 {
fp = fp.Ceil()
}
return 0, num.UintZero()

fairPrice, _ := num.UintFromDecimal(fp)
return fairPrice
}

// virtualBalancesShort returns the pools x, y balances when the pool has a negative position, where
// virtualBalancesShort returns the pools x, y balances when the pool has a negative position
//
// x = P + (cc * rf) / sqrt(pl) + L / sqrt(pl),
// y = abs(P) * average-entry + L * sqrt(pl).
func (p *Pool) virtualBalancesShort(pos int64, ae *num.Uint) (num.Decimal, num.Decimal) {
// x = P + (cc * rf) / pu + L / sqrt(pl)
// y = L * sqrt(fair-price).
func (p *Pool) virtualBalancesShort(pos int64, fp *num.Uint) (num.Decimal, num.Decimal) {
cu := p.upper
if cu.empty {
panic("should not be calculating balances on empty-curve side")
Expand All @@ -509,7 +574,7 @@ func (p *Pool) virtualBalancesShort(pos int64, ae *num.Uint) (num.Decimal, num.D
// lets start with x

// P
term1x := num.DecimalFromInt64(-pos)
term1x := num.DecimalFromInt64(pos)

// cc * rf / pu
term2x := cu.rf.Mul(num.DecimalFromUint(balance)).Div(num.DecimalFromUint(cu.high))
Expand All @@ -518,26 +583,20 @@ func (p *Pool) virtualBalancesShort(pos int64, ae *num.Uint) (num.Decimal, num.D
term3x := cu.l.ToDecimal().Div(p.sqrt(cu.high))

// x = P + (cc * rf / pu) + (L / sqrt(pl))
x := term2x.Add(term3x).Sub(term1x)
x := term2x.Add(term3x).Add(term1x)

// now lets get y

// abs(P) * average-entry
term1y := ae.Mul(ae, num.NewUint(uint64(-pos)))

// L * sqrt(pl)
term2y := cu.l.ToDecimal().Mul(p.sqrt(cu.low))

// y = abs(P) * average-entry + L * pl
y := term1y.ToDecimal().Add(term2y)
// y = L * sqrt(fair-price)
y := cu.l.ToDecimal().Mul(p.sqrt(fp))
return x, y
}

// virtualBalancesLong returns the pools x, y balances when the pool has a positive position, where
// virtualBalancesLong returns the pools x, y balances when the pool has a positive position
//
// x = P + (L / sqrt(pu)),
// y = L * (sqrt(pu) - sqrt(pl)) - P * average-entry + (L * sqrt(pl)).
func (p *Pool) virtualBalancesLong(pos int64, ae *num.Uint) (num.Decimal, num.Decimal) {
// x = P + (L / sqrt(pu))
// y = L * sqrt(fair-price).
func (p *Pool) virtualBalancesLong(pos int64, fp *num.Uint) (num.Decimal, num.Decimal) {
cu := p.lower
if cu.empty {
panic("should not be calculating balances on empty-curve side")
Expand All @@ -550,52 +609,26 @@ func (p *Pool) virtualBalancesLong(pos int64, ae *num.Uint) (num.Decimal, num.De

// L / sqrt(pu)
term2x := cu.l.ToDecimal().Div(p.sqrt(cu.high))

// x = P + (L / sqrt(pu))
x := term1x.Add(term2x)

// now lets move to y

// L * (sqrt(pu) - sqrt(pl)) + (L * sqrt(pl)) => L * sqrt(pu)
term1y := cu.l.ToDecimal().Mul(p.sqrt(cu.high))

// P * average-entry
term2y := ae.Mul(ae, num.NewUint(uint64(pos)))

y := term1y.Sub(term2y.ToDecimal())
// y = L * sqrt(fair-price)
y := cu.l.ToDecimal().Mul(p.sqrt(fp))
return x, y
}

// fairPrice returns the fair price of the pool given its current position.
func (p *Pool) fairPrice() *num.Uint {
pos, ae := p.getPosition()
if pos == 0 {
return p.lower.high.Clone()
}

x, y := p.virtualBalances(pos, ae, types.SideUnspecified)

// we want to round such that the price is further away from the base. This is so that once
// a pool's position is at its boundary we do not report volume that doesn't exist. For example
// say a pool's upper boundary is 1000 and for it to be at that boundary its position needs to
// be 10.5. The closest we can get is 10 but then we'd report a fair-price of 999.78. If
// we use 999 we'd be implying volume between 999 and 1000 which we don't want to trade.
fp := y.Div(x)
if pos < 0 {
fp = fp.Ceil()
}

fairPrice, _ := num.UintFromDecimal(fp)
return fairPrice
}

// virtualBalances returns the pools x, y values where x is the balance in contracts and y is the balance in asset.
func (p *Pool) virtualBalances(pos int64, ae *num.Uint, side types.Side) (num.Decimal, num.Decimal) {
func (p *Pool) virtualBalances(pos int64, fp *num.Uint, side types.Side) (num.Decimal, num.Decimal) {
switch {
case pos < 0, pos == 0 && side == types.SideBuy:
// zero position but incoming is buy which will make pool short
return p.virtualBalancesShort(pos, ae)
return p.virtualBalancesShort(pos, fp)
case pos > 0, pos == 0 && side == types.SideSell:
// zero position but incoming is sell which will make pool long
return p.virtualBalancesLong(pos, ae)
return p.virtualBalancesLong(pos, fp)
default:
panic("should not reach here")
}
Expand Down Expand Up @@ -636,7 +669,7 @@ func (p *Pool) canTrade(side types.Side) bool {
return true
}

pos, _ := p.getPosition()
pos := p.getPosition()
// pool is long incoming order is a buy and will make it shorter, its ok
if pos > 0 && side == types.SideBuy {
return true
Expand Down
Loading

0 comments on commit 6de88b3

Please sign in to comment.