Skip to content

Commit

Permalink
Add WS Batch support with ctx mocking
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan O'Hara-Reid committed Jan 17, 2025
1 parent 9567ef0 commit c5b9558
Show file tree
Hide file tree
Showing 21 changed files with 291 additions and 15 deletions.
5 changes: 5 additions & 0 deletions exchanges/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/common/crypto"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/exchanges/account"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/log"
)

Expand Down Expand Up @@ -111,6 +112,10 @@ func (b *Base) GetDefaultCredentials() *account.Credentials {
// GetCredentials checks and validates current credentials, context credentials
// override default credentials, if no credentials found, will return an error.
func (b *Base) GetCredentials(ctx context.Context) (*account.Credentials, error) {
if request.IsMockResponse(ctx) {
return &account.Credentials{}, nil
}

value := ctx.Value(account.ContextCredentialsFlag)
if value != nil {
ctxCredStore, ok := value.(*account.ContextCredentialsStore)
Expand Down
7 changes: 7 additions & 0 deletions exchanges/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"errors"
"testing"

"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/exchanges/account"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
)

func TestGetCredentials(t *testing.T) {
Expand Down Expand Up @@ -127,6 +129,11 @@ func TestGetCredentials(t *testing.T) {
notOverrided.SubAccount != "" {
t.Fatal("unexpected values")
}

creds, err = b.GetCredentials(request.WithMockResponse(context.Background(), nil))
require.NoError(t, err)
require.NotNil(t, creds)
require.Empty(t, creds)
}

func TestAreCredentialsValid(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions exchanges/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -1948,3 +1948,8 @@ func (b *Base) GetTradingRequirements() protocol.TradingRequirements {
func (*Base) WebsocketSubmitOrder(context.Context, *order.Submit) (*order.SubmitResponse, error) {
return nil, common.ErrFunctionNotSupported
}

// WebsocketSubmitBatchOrders submits multiple orders in a batch via the websocket connection
func (*Base) WebsocketSubmitBatchOrders(context.Context, []*order.Submit) (responses []*order.SubmitResponse, err error) {
return nil, common.ErrFunctionNotSupported
}
5 changes: 5 additions & 0 deletions exchanges/exchange_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3081,3 +3081,8 @@ func TestWebsocketSubmitOrder(t *testing.T) {
_, err := (&Base{}).WebsocketSubmitOrder(context.Background(), nil)
require.ErrorIs(t, err, common.ErrFunctionNotSupported)
}

func TestWebsocketSubmitBatchOrders(t *testing.T) {
_, err := (&Base{}).WebsocketSubmitBatchOrders(context.Background(), nil)
require.ErrorIs(t, err, common.ErrFunctionNotSupported)
}
39 changes: 39 additions & 0 deletions exchanges/gateio/gateio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3845,3 +3845,42 @@ func TestDeriveFuturesSubmitOrderResponses(t *testing.T) {
})
}
}

func TestWebsocketSubmitBatchOrders(t *testing.T) {
t.Parallel()
_, err := g.WebsocketSubmitBatchOrders(context.Background(), []*order.Submit{{}})
require.ErrorIs(t, err, order.ErrExchangeNameUnset)

dummy := &order.Submit{
Exchange: "test",
Pair: currency.NewPair(currency.BTC, currency.USDT),
AssetType: asset.Spot,
Type: order.Market,
Side: order.Buy,
QuoteAmount: 1,
}

other := *dummy
other.AssetType = asset.Futures

_, err = g.WebsocketSubmitBatchOrders(context.Background(), []*order.Submit{dummy, &other})
require.ErrorIs(t, err, errSingleAssetRequired)

other.AssetType = asset.Futures
_, err = g.WebsocketSubmitBatchOrders(context.Background(), []*order.Submit{&other})
require.ErrorIs(t, err, common.ErrNotYetImplemented)

mockResponse1 := []byte(`{"header":{"response_time":"1736485230579","status":"200","channel":"spot.order_place","event":"api","client_id":"35.72.184.127-0xc05f1d8d00","conn_id":"6f7d1aad7c5d05cd","conn_trace_id":"4ca4bd9b2484834eb365c5e079ba69d8","trace_id":"b9835664bf217746b6a8e9644412d6cd","x_in_time":1736485230578603,"x_out_time":1736485230579021},"data":{"result":{"req_id":"777","api_key":"","timestamp":"1736485230","signature":"","trace_id":"","text":"","req_header":{},"req_param":[{"time_in_force":"ioc","text":"t-774","currency_pair":"RBC_USDT","type":"market","account":"spot","side":"buy","amount":"9.98662"},{"text":"t-775","currency_pair":"RBC_ETH","type":"market","account":"spot","side":"sell","amount":"397","time_in_force":"ioc"},{"time_in_force":"ioc","text":"t-776","currency_pair":"ETH_USDT","type":"market","account":"spot","side":"sell","amount":"0.003"}]}},"request_id":"777","ack":true}`)
mockResponse2 := []byte(`{"header":{"response_time":"1736485230624","status":"200","channel":"spot.order_place","event":"api","client_id":"35.72.184.127-0xc05f1d8d00","conn_trace_id":"4ca4bd9b2484834eb365c5e079ba69d8","trace_id":"b9835664bf217746b6a8e9644412d6cd","x_in_time":1736485230578603,"x_out_time":1736485230624136},"data":{"result":[{"account":"spot","status":"closed","side":"buy","amount":"9.98662","id":"771815277347","create_time":"1736485230","update_time":"1736485230","text":"t-774","left":"0.0002093","currency_pair":"RBC_USDT","type":"market","finish_as":"filled","price":"0","time_in_force":"ioc","iceberg":"0","filled_total":"9.9864107","fill_price":"9.9864107","create_time_ms":1736485230593,"update_time_ms":1736485230593,"succeeded":true},{"account":"spot","status":"closed","side":"sell","amount":"397","id":"771815277391","create_time":"1736485230","update_time":"1736485230","text":"t-775","left":"0","currency_pair":"RBC_ETH","type":"market","finish_as":"filled","price":"0","time_in_force":"ioc","iceberg":"0","filled_total":"0.002976309","fill_price":"0.002976309","create_time_ms":1736485230600,"update_time_ms":1736485230600,"succeeded":true},{"account":"spot","status":"closed","side":"sell","amount":"0.003","id":"771815277451","create_time":"1736485230","update_time":"1736485230","text":"t-776","left":"0","currency_pair":"ETH_USDT","type":"market","finish_as":"filled","price":"0","time_in_force":"ioc","iceberg":"0","filled_total":"9.76572","fill_price":"9.76572","create_time_ms":1736485230608,"update_time_ms":1736485230608,"succeeded":true}]},"request_id":"777"}`)
ctx := context.Background()
got, err := g.WebsocketSubmitBatchOrders(request.WithMockResponse(ctx, mockResponse1, mockResponse2), []*order.Submit{dummy, dummy, dummy})
require.NoError(t, err)
require.Len(t, got, 3)

mockResponse1 = []byte(`{"header":{"response_time":"1736980695937","status":"200","channel":"spot.order_place","event":"api","client_id":"35.72.184.127-0xc13ed551e0","conn_id":"138c696791d9dc0d","conn_trace_id":"dca3f78ba2d34e8c52a4217258783552","trace_id":"d096e1f953d017d054f678980aff4087","x_in_time":1736980695937125,"x_out_time":1736980695937383},"data":{"result":{"req_id":"743","api_key":"","timestamp":"1736980695","signature":"","trace_id":"","text":"","req_header":{},"req_param":[{"side":"buy","amount":"9.98","time_in_force":"fok","text":"t-740","currency_pair":"ETH_USDT","type":"market","account":"spot"},{"text":"t-741","currency_pair":"LIKE_ETH","type":"market","account":"spot","side":"buy","amount":"0.00289718","time_in_force":"fok"},{"type":"market","account":"spot","side":"sell","amount":"297.16","time_in_force":"fok","text":"t-742","currency_pair":"LIKE_USDT"}]}},"request_id":"743","ack":true}`)
mockResponse2 = []byte(`{"header":{"response_time":"1736980695972","status":"200","channel":"spot.order_place","event":"api","client_id":"35.72.184.127-0xc13ed551e0","conn_trace_id":"dca3f78ba2d34e8c52a4217258783552","trace_id":"d096e1f953d017d054f678980aff4087","x_in_time":1736980695937125,"x_out_time":1736980695972307},"data":{"result":[{"account":"spot","status":"closed","side":"buy","amount":"9.98","id":"775453816782","create_time":"1736980695","update_time":"1736980695","text":"t-740","left":"0.047239","currency_pair":"ETH_USDT","type":"market","finish_as":"filled","price":"0","time_in_force":"fok","iceberg":"0","filled_total":"9.932761","fill_price":"9.932761","create_time_ms":1736980695949,"update_time_ms":1736980695949,"succeeded":true},{"account":"spot","status":"closed","side":"buy","amount":"0.00289718","id":"775453816824","create_time":"1736980695","update_time":"1736980695","text":"t-741","left":"0.00000000962","currency_pair":"LIKE_ETH","type":"market","finish_as":"filled","price":"0","time_in_force":"fok","iceberg":"0","filled_total":"0.00289717038","fill_price":"0.00289717038","create_time_ms":1736980695956,"update_time_ms":1736980695956,"succeeded":true},{"text":"t-742","label":"BALANCE_NOT_ENOUGH","message":"Not enough balance"}]},"request_id":"743"}`)
got, err = g.WebsocketSubmitBatchOrders(request.WithMockResponse(ctx, mockResponse1, mockResponse2), []*order.Submit{dummy, dummy, dummy})
require.NoError(t, err)
require.Len(t, got, 3)
require.ErrorIs(t, got[2].Error, order.ErrUnableToPlaceOrder)
}
2 changes: 1 addition & 1 deletion exchanges/gateio/gateio_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ func (g *Gateio) SendWebsocketRequest(ctx context.Context, epl request.EndpointL
return err
}

conn, err := g.Websocket.GetConnection(connSignature)
conn, err := g.Websocket.GetConnection(ctx, connSignature)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion exchanges/gateio/gateio_websocket_request_spot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestWebsocketLogin(t *testing.T) {
testexch.UpdatePairsOnce(t, g)
g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes

demonstrationConn, err := g.Websocket.GetConnection(asset.Spot)
demonstrationConn, err := g.Websocket.GetConnection(context.Background(), asset.Spot)
require.NoError(t, err)

err = g.websocketLogin(context.Background(), demonstrationConn, "spot.login")
Expand Down
2 changes: 2 additions & 0 deletions exchanges/gateio/gateio_websocket_request_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ type WebsocketOrderResponse struct {
STPID int `json:"stp_id"`
STPAct string `json:"stp_act"`
AverageDealPrice types.Number `json:"avg_deal_price"`
Label string `json:"label"`
Message string `json:"message"`
}

// WebsocketFuturesOrderResponse defines a websocket futures order response
Expand Down
63 changes: 62 additions & 1 deletion exchanges/gateio/gateio_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ import (
// this error.
const unfundedFuturesAccount = `please transfer funds first to create futures account`

var errNoResponseReceived = errors.New("no response received")
var (
errNoResponseReceived = errors.New("no response received")
errSingleAssetRequired = errors.New("single asset type required")
)

// SetDefaults sets default values for the exchange
func (g *Gateio) SetDefaults() {
Expand Down Expand Up @@ -2668,6 +2671,55 @@ func (g *Gateio) WebsocketSubmitOrder(ctx context.Context, s *order.Submit) (*or
}
}

// WebsocketSubmitBatchOrders submits multiple orders to the exchange through the websocket
// RE: Spot batch orders; cannot derive purchased amount as the average price is omitted from the response and the fill
// price is not accurate.
func (g *Gateio) WebsocketSubmitBatchOrders(ctx context.Context, orders []*order.Submit) (responses []*order.SubmitResponse, err error) {
var a asset.Item
for x := range orders {
if err = orders[x].Validate(g.GetTradingRequirements()); err != nil {

Check failure on line 2680 in exchanges/gateio/gateio_wrapper.go

View workflow job for this annotation

GitHub Actions / lint

sloppyReassign: re-assignment to `err` can be replaced with `err := orders[x].Validate(g.GetTradingRequirements())` (gocritic)
return nil, err
}

if !a.IsValid() {
a = orders[x].AssetType
continue
}

if a != orders[x].AssetType {
return nil, fmt.Errorf("%w %v", errSingleAssetRequired, a)
}
}

if !g.CurrencyPairs.IsAssetSupported(a) {
return nil, fmt.Errorf("%w %v", asset.ErrNotSupported, a)
}

switch a {
case asset.Spot:
reqs := make([]*CreateOrderRequest, len(orders))
for x := range orders {
reqs[x], err = g.getSpotOrderRequest(orders[x])
if err != nil {
return nil, err
}
}

got, err := g.WebsocketSpotSubmitOrders(ctx, reqs...)
if err != nil {
return nil, err
}

resps, err := g.DeriveSpotSubmitOrderResponses(got)
if err != nil {
return nil, err
}
return resps, nil
default:
return nil, fmt.Errorf("%w for %s", common.ErrNotYetImplemented, a)
}
}

// DeriveSpotSubmitOrderResponses returns the order submission responses for spot
func (g *Gateio) DeriveSpotSubmitOrderResponses(responses []WebsocketOrderResponse) ([]*order.SubmitResponse, error) {
if len(responses) == 0 {
Expand All @@ -2676,6 +2728,15 @@ func (g *Gateio) DeriveSpotSubmitOrderResponses(responses []WebsocketOrderRespon

out := make([]*order.SubmitResponse, 0, len(responses))
for x := range responses {
if responses[x].Label != "" { // Only returned in a batch order response context
out = append(out, &order.SubmitResponse{
Exchange: g.Name,
ClientOrderID: responses[x].Text,
Error: fmt.Errorf("%w reason label:%s message:%s", order.ErrUnableToPlaceOrder, responses[x].Label, responses[x].Message),
})
continue
}

side, err := order.StringToOrderSide(responses[x].Side)
if err != nil {
return nil, err
Expand Down
2 changes: 2 additions & 0 deletions exchanges/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ type OrderManagement interface {

// WebsocketSubmitOrder submits an order via the websocket connection
WebsocketSubmitOrder(ctx context.Context, s *order.Submit) (*order.SubmitResponse, error)
// WebsocketSubmitBatchOrders submits multiple orders in a batch via the websocket connection
WebsocketSubmitBatchOrders(ctx context.Context, orders []*order.Submit) (responses []*order.SubmitResponse, err error)
}

// CurrencyStateManagement defines functionality for currency state management
Expand Down
2 changes: 1 addition & 1 deletion exchanges/order/order_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestSubmit_Validate(t *testing.T) {
Submit: nil,
}, // nil struct
{
ExpectedErr: errExchangeNameUnset,
ExpectedErr: ErrExchangeNameUnset,
Submit: &Submit{},
}, // empty exchange
{
Expand Down
3 changes: 3 additions & 0 deletions exchanges/order/order_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ type SubmitResponse struct {
BorrowSize float64
LoanApplyID string
MarginType margin.Type

// Error is populated if the order was not successful, this is used in batch order submissions
Error error
}

// Modify contains all properties of an order
Expand Down
4 changes: 2 additions & 2 deletions exchanges/order/orders.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ var (
ErrAmountMustBeSet = errors.New("amount must be set")
ErrClientOrderIDMustBeSet = errors.New("client order ID must be set")
ErrUnknownSubmissionAmountType = errors.New("unknown submission amount type")
ErrExchangeNameUnset = errors.New("exchange name unset")
)

var (
errTimeInForceConflict = errors.New("multiple time in force options applied")
errUnrecognisedOrderType = errors.New("unrecognised order type")
errUnrecognisedOrderStatus = errors.New("unrecognised order status")
errExchangeNameUnset = errors.New("exchange name unset")
errOrderSubmitIsNil = errors.New("order submit is nil")
errOrderSubmitResponseIsNil = errors.New("order submit response is nil")
errOrderDetailIsNil = errors.New("order detail is nil")
Expand All @@ -64,7 +64,7 @@ func (s *Submit) Validate(requirements protocol.TradingRequirements, opt ...vali
}

if s.Exchange == "" {
return errExchangeNameUnset
return ErrExchangeNameUnset
}

if s.Pair.IsEmpty() {
Expand Down
38 changes: 38 additions & 0 deletions exchanges/request/mock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package request

import (
"bytes"
"context"
"io"
"net/http"
)

var mockResponseFlag = struct{ name string }{name: "mockResponse"}

// IsMockResponse returns true if the request has a mock response set
func IsMockResponse(ctx context.Context) bool {
return ctx.Value(mockResponseFlag) != nil
}

// WithMockResponse sets the mock response for a request. This is used for testing purposes.
// REST response is single. Websocket response can be multiple. This allows expected responses to be set for a request if required.
func WithMockResponse(ctx context.Context, mockResponse ...[]byte) context.Context {
return context.WithValue(ctx, mockResponseFlag, mockResponse)
}

// GetMockResponse returns the mock response for a request
func GetMockResponse(ctx context.Context) [][]byte {
mockResponse, _ := ctx.Value(mockResponseFlag).([][]byte)
return mockResponse
}

func getRESTResponseFromMock(ctx context.Context) *http.Response {
mockResp := GetMockResponse(ctx)
if len(mockResp) != 1 {
panic("mock REST response invalid, requires exactly one response")
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(io.Reader(io.LimitReader(bytes.NewBuffer(mockResp[0]), drainBodyLimit))),

Check failure on line 36 in exchanges/request/mock.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary conversion (unconvert)
}
}
27 changes: 27 additions & 0 deletions exchanges/request/mock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package request

import (
"context"
"io"
"testing"

"github.com/stretchr/testify/require"
)

func TestMockResponse(t *testing.T) {
t.Parallel()

ctx := context.Background()
require.False(t, IsMockResponse(ctx))
require.Nil(t, GetMockResponse(ctx))
require.Panics(t, func() { getRESTResponseFromMock(ctx) })

Check failure on line 17 in exchanges/request/mock_test.go

View workflow job for this annotation

GitHub Actions / lint

response body must be closed (bodyclose)
mockCtx := WithMockResponse(ctx, []byte("test"))
require.True(t, IsMockResponse(mockCtx))
require.NotNil(t, GetMockResponse(mockCtx))
got := getRESTResponseFromMock(mockCtx)

Check failure on line 21 in exchanges/request/mock_test.go

View workflow job for this annotation

GitHub Actions / lint

response body must be closed (bodyclose)
require.NotNil(t, got)
require.Equal(t, 200, got.StatusCode)
hotBod, err := io.ReadAll(got.Body)
require.NoError(t, err)
require.Equal(t, []byte("test"), hotBod)
}
7 changes: 6 additions & 1 deletion exchanges/request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe

start := time.Now()

resp, err := r._HTTPClient.do(req)
var resp *http.Response
if IsMockResponse(ctx) {
resp = getRESTResponseFromMock(ctx)
} else {
resp, err = r._HTTPClient.do(req)

Check failure on line 201 in exchanges/request/request.go

View workflow job for this annotation

GitHub Actions / lint

response body must be closed (bodyclose)
}

if r.reporter != nil && err == nil {
r.reporter.Latency(r.name, p.Method, p.Path, time.Since(start))
Expand Down
9 changes: 9 additions & 0 deletions exchanges/request/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,15 @@ func TestDoRequest(t *testing.T) {
if failed != 0 {
t.Fatal("request failed")
}

m := struct {
Mock bool `json:"mock"`
}{}

ctx = WithMockResponse(ctx, []byte(`{"mock":true}`))
err = r.SendPayload(ctx, UnAuth, func() (*Item, error) { return &Item{Method: http.MethodGet, Path: testURL, Result: &m}, nil }, UnauthenticatedRequest)
require.NoError(t, err)
require.True(t, m.Mock)
}

func TestDoRequest_Retries(t *testing.T) {
Expand Down
33 changes: 33 additions & 0 deletions exchanges/stream/mock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package stream

import (
"context"

"github.com/thrasher-corp/gocryptotrader/exchanges/request"
)

// MockWebsocketConnection is a mock websocket connection
type MockWebsocketConnection struct {
WebsocketConnection
}

// SendMessageReturnResponse returns a mock response from context
func (m *MockWebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, payload any) ([]byte, error) {
resps, _ := m.SendMessageReturnResponses(ctx, epl, signature, payload, 1)
return resps[0], nil
}

// SendMessageReturnResponses returns a mock response from context
func (m *MockWebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) {
return m.SendMessageReturnResponsesWithInspector(ctx, epl, signature, payload, expected, nil)
}

// SendMessageReturnResponsesWithInspector returns a mock response from context
func (*MockWebsocketConnection) SendMessageReturnResponsesWithInspector(ctx context.Context, _ request.EndpointLimit, _, _ any, _ int, _ Inspector) ([][]byte, error) {
return request.GetMockResponse(ctx), nil
}

// newMockConnection returns a new mock websocket connection, used so that the websocket does not need to be connected
func newMockWebsocketConnection() Connection {
return &MockWebsocketConnection{}
}
Loading

0 comments on commit c5b9558

Please sign in to comment.