Skip to content

Commit

Permalink
feat: add subaccount support to position stream API.
Browse files Browse the repository at this point in the history
Signed-off-by: Jeremy Letang <[email protected]>
  • Loading branch information
jeremyletang committed Apr 29, 2024
1 parent 269216a commit cf5ae5d
Show file tree
Hide file tree
Showing 7 changed files with 4,368 additions and 4,254 deletions.
78 changes: 75 additions & 3 deletions datanode/api/trading_data_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bufio"
"bytes"
"context"
"encoding/hex"
"fmt"
"math"
"math/rand"
Expand Down Expand Up @@ -1453,19 +1454,72 @@ func (t *TradingDataServiceV2) ListAllPositions(ctx context.Context, req *v2.Lis
}, nil
}

const ammVersion = "AMMv1"

func deriveSubAccount(
party, market, version string,
index uint64,
) string {
hash := crypto.Hash([]byte(fmt.Sprintf("%v%v%v%v", version, market, party, index)))
return hex.EncodeToString(hash)
}

func (t *TradingDataServiceV2) getSubaccounts(ctx context.Context, party string, market *string) ([]string, error) {
if market != nil && len(*market) > 0 {
return []string{deriveSubAccount(party, *market, ammVersion, 0)}, nil
}

// get a list of all markets to generate all potential
// sub accounts for the party
markets, _, err := t.MarketsService.GetAllPaged(ctx, "", entities.DefaultCursorPagination(true), false)
if err != nil {
return nil, err
}

subs := make([]string, 0, len(markets))
for _, v := range markets {
subs = append(subs, deriveSubAccount(party, v.ID.String(), ammVersion, 0))
}

return subs, nil
}

// ObservePositions subscribes to a stream of Positions.
func (t *TradingDataServiceV2) ObservePositions(req *v2.ObservePositionsRequest, srv v2.TradingDataService_ObservePositionsServer) error {
// Wrap context from the request into cancellable. We can close internal chan on error.
ctx, cancel := context.WithCancel(srv.Context())
defer cancel()

if err := t.sendPositionsSnapshot(ctx, req, srv); err != nil {
// handle subaccounts
includeDerivedParties := ptr.UnBox(req.IncludeDerivedParties)
if includeDerivedParties && (req.PartyId == nil || len(*req.PartyId) <= 0) {
return formatE(newInvalidArgumentError("includeSubaccount requires a partyId"))
}

subaccounts := []string{}
if req.PartyId != nil && len(*req.PartyId) > 0 {
if includeDerivedParties {
subs, err := t.getSubaccounts(ctx, *req.PartyId, req.MarketId)
if err != nil {
return formatE(err)
}

subaccounts = append(subaccounts, subs...)
}
}

if err := t.sendPositionsSnapshot(ctx, req, srv, subaccounts); err != nil {
if !errors.Is(err, entities.ErrNotFound) {
return formatE(ErrPositionServiceSendSnapshot, err)
}
}

positionsChan, ref := t.positionService.Observe(ctx, t.config.StreamRetries, ptr.UnBox(req.PartyId), ptr.UnBox(req.MarketId))
// add the party to the subaccounts
if len(subaccounts) > 0 {
subaccounts = append(subaccounts, *req.PartyId)
}

positionsChan, ref := t.positionService.ObserveMany(ctx, t.config.StreamRetries, ptr.UnBox(req.MarketId), subaccounts...)

if t.log.GetLevel() == logging.DebugLevel {
t.log.Debug("Positions subscriber - new rpc stream", logging.Uint64("ref", ref))
Expand All @@ -1490,7 +1544,7 @@ func (t *TradingDataServiceV2) ObservePositions(req *v2.ObservePositionsRequest,
})
}

func (t *TradingDataServiceV2) sendPositionsSnapshot(ctx context.Context, req *v2.ObservePositionsRequest, srv v2.TradingDataService_ObservePositionsServer) error {
func (t *TradingDataServiceV2) sendPositionsSnapshot(ctx context.Context, req *v2.ObservePositionsRequest, srv v2.TradingDataService_ObservePositionsServer, subaccounts []string) error {
var (
positions []entities.Position
err error
Expand Down Expand Up @@ -1529,6 +1583,24 @@ func (t *TradingDataServiceV2) sendPositionsSnapshot(ctx context.Context, req *v
}
}

// finally handle subaccount
for _, v := range subaccounts {
if req.MarketId != nil {
position, err := t.positionService.GetByMarketAndParty(ctx, *req.MarketId, v)
if err != nil {
return errors.Wrap(err, "getting initial positions by market+party")
}
positions = append(positions, position)
continue
}

subaccountPositions, err := t.positionService.GetByParty(ctx, entities.PartyID(v))
if err != nil {
return errors.Wrap(err, "getting initial positions by party")
}
positions = append(positions, subaccountPositions...)
}

protos := make([]*vega.Position, len(positions))
for i := 0; i < len(positions); i++ {
protos[i] = positions[i].ToProto()
Expand Down
17 changes: 13 additions & 4 deletions datanode/gateway/graphql/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions datanode/gateway/graphql/resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3429,10 +3429,11 @@ func (r *mySubscriptionResolver) TradesStream(ctx context.Context, filter Trades
return c, nil
}

func (r *mySubscriptionResolver) Positions(ctx context.Context, party, market *string) (<-chan []*vegapb.Position, error) {
func (r *mySubscriptionResolver) Positions(ctx context.Context, party, market *string, includeDerivedParties *bool) (<-chan []*vegapb.Position, error) {
req := &v2.ObservePositionsRequest{
PartyId: party,
MarketId: market,
PartyId: party,
MarketId: market,
IncludeDerivedParties: includeDerivedParties,
}
stream, err := r.tradingDataClientV2.ObservePositions(ctx, req)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions datanode/gateway/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ type Subscription {
partyId: ID
"ID of the market from which you want position updates"
marketId: ID
"Also returns all derived parties for the given party, if specified and party_id is required as well"
includeDerivedParties: Boolean
): [PositionUpdate!]!

"Subscribe to proposals. Leave out all arguments to receive all proposals"
Expand Down
11 changes: 11 additions & 0 deletions datanode/service/position.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"code.vegaprotocol.io/vega/logging"

lru "github.com/hashicorp/golang-lru"
"golang.org/x/exp/slices"
)

type PositionStore interface {
Expand Down Expand Up @@ -169,3 +170,13 @@ func (p *Position) Observe(ctx context.Context, retries int, partyID, marketID s
})
return ch, ref
}

func (p *Position) ObserveMany(ctx context.Context, retries int, marketID string, parties ...string) (<-chan []entities.Position, uint64) {
ch, ref := p.observer.Observe(ctx,
retries,
func(pos entities.Position) bool {
return (len(marketID) == 0 || marketID == pos.MarketID.String()) &&
(len(parties) == 0 || slices.Contains(parties, pos.PartyID.String()))
})
return ch, ref
}
Loading

0 comments on commit cf5ae5d

Please sign in to comment.