From 9496e42bf88ebe91d713675cb84914fff3c40a83 Mon Sep 17 00:00:00 2001 From: boolangery Date: Thu, 14 Nov 2024 12:42:17 +0100 Subject: [PATCH] add ctx to Recv (#259) --- README.md | 25 ++++++++++++------- programs/serum/rpc.go | 4 +-- programs/serum/rpc_test.go | 2 +- rpc/ws/accountSubscribe.go | 6 ++++- rpc/ws/blockSubscribe.go | 5 +++- rpc/ws/client_test.go | 8 +++--- .../accountSubscribe/accountSubscribe.go | 4 +-- .../examples/logsSubscribe/logsSubscribe.go | 5 ++-- .../programSubscribe/programSubscribe.go | 3 ++- .../examples/rootSubscribe/rootSubscribe.go | 3 ++- .../signatureSubscribe/signatureSubscribe.go | 3 ++- .../examples/slotSubscribe/slotSubscribe.go | 3 ++- .../examples/voteSubscribe/voteSubscribe.go | 3 ++- rpc/ws/logsSubscribe.go | 6 ++++- rpc/ws/parsedBlockSubscribe.go | 6 ++++- rpc/ws/programSubscribe.go | 6 ++++- rpc/ws/rootSubscribe.go | 6 ++++- rpc/ws/signatureSubscribe.go | 5 +++- rpc/ws/slotSubscribe.go | 6 ++++- rpc/ws/slotsUpdatesSubscribe.go | 10 ++++++-- rpc/ws/subscription.go | 6 ++++- rpc/ws/voteSubscribe.go | 6 ++++- 22 files changed, 94 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 5041c563..629a3f5c 100644 --- a/README.md +++ b/README.md @@ -2954,6 +2954,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.MainNetBeta_WS) if err != nil { panic(err) @@ -2971,7 +2972,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } @@ -2991,7 +2992,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } @@ -3016,6 +3017,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.MainNetBeta_WS) if err != nil { panic(err) @@ -3034,7 +3036,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } @@ -3053,7 +3055,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } @@ -3078,6 +3080,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.MainNetBeta_WS) if err != nil { panic(err) @@ -3096,7 +3099,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } @@ -3130,6 +3133,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.TestNet_WS) if err != nil { panic(err) @@ -3141,7 +3145,7 @@ func main() { } for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } @@ -3165,6 +3169,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.TestNet_WS) if err != nil { panic(err) @@ -3182,7 +3187,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } @@ -3205,6 +3210,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.TestNet_WS) if err != nil { panic(err) @@ -3217,7 +3223,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } @@ -3240,6 +3246,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.MainNetBeta_WS) if err != nil { panic(err) @@ -3254,7 +3261,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } diff --git a/programs/serum/rpc.go b/programs/serum/rpc.go index dbb6d8f1..e680b5d2 100644 --- a/programs/serum/rpc.go +++ b/programs/serum/rpc.go @@ -101,14 +101,14 @@ func FetchMarket(ctx context.Context, rpcCli *rpc.Client, marketAddr solana.Publ return meta, nil } -func StreamOpenOrders(client *ws.Client) error { +func StreamOpenOrders(ctx context.Context, client *ws.Client) error { sub, err := client.ProgramSubscribe(DEXProgramIDV2, rpc.CommitmentSingleGossip) if err != nil { return fmt.Errorf("unable to subscribe to programID %q: %w", DEXProgramIDV2, err) } count := 0 for { - d, err := sub.Recv() + d, err := sub.Recv(ctx) if err != nil { return fmt.Errorf("received error from programID subscription: %w", err) } diff --git a/programs/serum/rpc_test.go b/programs/serum/rpc_test.go index bea8434a..5e3c20b2 100644 --- a/programs/serum/rpc_test.go +++ b/programs/serum/rpc_test.go @@ -66,6 +66,6 @@ func TestStreamOpenOrders(t *testing.T) { client, err := ws.Connect(context.Background(), rpcURL) require.NoError(t, err) - err = StreamOpenOrders(client) + err = StreamOpenOrders(context.Background(), client) require.NoError(t, err) } diff --git a/rpc/ws/accountSubscribe.go b/rpc/ws/accountSubscribe.go index 65bd59e7..5e65d3a2 100644 --- a/rpc/ws/accountSubscribe.go +++ b/rpc/ws/accountSubscribe.go @@ -15,6 +15,8 @@ package ws import ( + "context" + "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" ) @@ -83,8 +85,10 @@ type AccountSubscription struct { sub *Subscription } -func (sw *AccountSubscription) Recv() (*AccountResult, error) { +func (sw *AccountSubscription) Recv(ctx context.Context) (*AccountResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d, ok := <-sw.sub.stream: if !ok { return nil, ErrSubscriptionClosed diff --git a/rpc/ws/blockSubscribe.go b/rpc/ws/blockSubscribe.go index 24e4bb57..404ce072 100644 --- a/rpc/ws/blockSubscribe.go +++ b/rpc/ws/blockSubscribe.go @@ -15,6 +15,7 @@ package ws import ( + "context" "fmt" "github.com/gagliardetto/solana-go" @@ -148,8 +149,10 @@ type BlockSubscription struct { sub *Subscription } -func (sw *BlockSubscription) Recv() (*BlockResult, error) { +func (sw *BlockSubscription) Recv(ctx context.Context) (*BlockResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d, ok := <-sw.sub.stream: if !ok { return nil, ErrSubscriptionClosed diff --git a/rpc/ws/client_test.go b/rpc/ws/client_test.go index 14d7a261..91490441 100644 --- a/rpc/ws/client_test.go +++ b/rpc/ws/client_test.go @@ -45,7 +45,7 @@ func Test_AccountSubscribe(t *testing.T) { sub, err := c.AccountSubscribe(accountID, "") require.NoError(t, err) - data, err := sub.Recv() + data, err := sub.Recv(context.Background()) if err != nil { fmt.Println("receive an error: ", err) return @@ -95,7 +95,7 @@ func Test_AccountSubscribeWithHttpHeader(t *testing.T) { sub.Unsubscribe() }(sub) - data, err := sub.Recv() + data, err := sub.Recv(context.Background()) if err != nil { t.Errorf("Received an error: %v", err) } @@ -127,7 +127,7 @@ func Test_ProgramSubscribe(t *testing.T) { require.NoError(t, err) for { - data, err := sub.Recv() + data, err := sub.Recv(context.Background()) if err != nil { fmt.Println("receive an error: ", err) return @@ -148,7 +148,7 @@ func Test_SlotSubscribe(t *testing.T) { sub, err := c.SlotSubscribe() require.NoError(t, err) - data, err := sub.Recv() + data, err := sub.Recv(context.Background()) if err != nil { fmt.Println("receive an error: ", err) return diff --git a/rpc/ws/examples/accountSubscribe/accountSubscribe.go b/rpc/ws/examples/accountSubscribe/accountSubscribe.go index 61193df3..a82d6b75 100644 --- a/rpc/ws/examples/accountSubscribe/accountSubscribe.go +++ b/rpc/ws/examples/accountSubscribe/accountSubscribe.go @@ -42,7 +42,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(context.Background()) if err != nil { panic(err) } @@ -62,7 +62,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(context.Background()) if err != nil { panic(err) } diff --git a/rpc/ws/examples/logsSubscribe/logsSubscribe.go b/rpc/ws/examples/logsSubscribe/logsSubscribe.go index 3a32e3d8..2d7f5677 100644 --- a/rpc/ws/examples/logsSubscribe/logsSubscribe.go +++ b/rpc/ws/examples/logsSubscribe/logsSubscribe.go @@ -24,6 +24,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.MainNetBeta_WS) if err != nil { panic(err) @@ -43,7 +44,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } @@ -62,7 +63,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } diff --git a/rpc/ws/examples/programSubscribe/programSubscribe.go b/rpc/ws/examples/programSubscribe/programSubscribe.go index 2a54e77f..05e3ab0b 100644 --- a/rpc/ws/examples/programSubscribe/programSubscribe.go +++ b/rpc/ws/examples/programSubscribe/programSubscribe.go @@ -24,6 +24,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.MainNetBeta_WS) if err != nil { panic(err) @@ -43,7 +44,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } diff --git a/rpc/ws/examples/rootSubscribe/rootSubscribe.go b/rpc/ws/examples/rootSubscribe/rootSubscribe.go index 87ccd962..f254ff75 100644 --- a/rpc/ws/examples/rootSubscribe/rootSubscribe.go +++ b/rpc/ws/examples/rootSubscribe/rootSubscribe.go @@ -23,6 +23,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.TestNet_WS) if err != nil { panic(err) @@ -35,7 +36,7 @@ func main() { } for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } diff --git a/rpc/ws/examples/signatureSubscribe/signatureSubscribe.go b/rpc/ws/examples/signatureSubscribe/signatureSubscribe.go index 6a2f40b2..9c9a7dfc 100644 --- a/rpc/ws/examples/signatureSubscribe/signatureSubscribe.go +++ b/rpc/ws/examples/signatureSubscribe/signatureSubscribe.go @@ -24,6 +24,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.TestNet_WS) if err != nil { panic(err) @@ -42,7 +43,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } diff --git a/rpc/ws/examples/slotSubscribe/slotSubscribe.go b/rpc/ws/examples/slotSubscribe/slotSubscribe.go index 1da17f87..8e3dd7e9 100644 --- a/rpc/ws/examples/slotSubscribe/slotSubscribe.go +++ b/rpc/ws/examples/slotSubscribe/slotSubscribe.go @@ -23,6 +23,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.TestNet_WS) if err != nil { panic(err) @@ -36,7 +37,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } diff --git a/rpc/ws/examples/voteSubscribe/voteSubscribe.go b/rpc/ws/examples/voteSubscribe/voteSubscribe.go index e5f8d7cf..6205a4fa 100644 --- a/rpc/ws/examples/voteSubscribe/voteSubscribe.go +++ b/rpc/ws/examples/voteSubscribe/voteSubscribe.go @@ -23,6 +23,7 @@ import ( ) func main() { + ctx := context.Background() client, err := ws.Connect(context.Background(), rpc.MainNetBeta_WS) if err != nil { panic(err) @@ -38,7 +39,7 @@ func main() { defer sub.Unsubscribe() for { - got, err := sub.Recv() + got, err := sub.Recv(ctx) if err != nil { panic(err) } diff --git a/rpc/ws/logsSubscribe.go b/rpc/ws/logsSubscribe.go index 98a4160d..393f32fa 100644 --- a/rpc/ws/logsSubscribe.go +++ b/rpc/ws/logsSubscribe.go @@ -15,6 +15,8 @@ package ws import ( + "context" + "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" ) @@ -107,8 +109,10 @@ type LogSubscription struct { sub *Subscription } -func (sw *LogSubscription) Recv() (*LogResult, error) { +func (sw *LogSubscription) Recv(ctx context.Context) (*LogResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d, ok := <-sw.sub.stream: if !ok { return nil, ErrSubscriptionClosed diff --git a/rpc/ws/parsedBlockSubscribe.go b/rpc/ws/parsedBlockSubscribe.go index c6c7f58b..67215902 100644 --- a/rpc/ws/parsedBlockSubscribe.go +++ b/rpc/ws/parsedBlockSubscribe.go @@ -15,6 +15,8 @@ package ws import ( + "context" + "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" ) @@ -92,8 +94,10 @@ type ParsedBlockSubscription struct { sub *Subscription } -func (sw *ParsedBlockSubscription) Recv() (*ParsedBlockResult, error) { +func (sw *ParsedBlockSubscription) Recv(ctx context.Context) (*ParsedBlockResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d := <-sw.sub.stream: return d.(*ParsedBlockResult), nil case err := <-sw.sub.err: diff --git a/rpc/ws/programSubscribe.go b/rpc/ws/programSubscribe.go index 3a4f127e..cac4fca8 100644 --- a/rpc/ws/programSubscribe.go +++ b/rpc/ws/programSubscribe.go @@ -15,6 +15,8 @@ package ws import ( + "context" + "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" ) @@ -86,8 +88,10 @@ type ProgramSubscription struct { sub *Subscription } -func (sw *ProgramSubscription) Recv() (*ProgramResult, error) { +func (sw *ProgramSubscription) Recv(ctx context.Context) (*ProgramResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d, ok := <-sw.sub.stream: if !ok { return nil, ErrSubscriptionClosed diff --git a/rpc/ws/rootSubscribe.go b/rpc/ws/rootSubscribe.go index 46ce656b..9a022adb 100644 --- a/rpc/ws/rootSubscribe.go +++ b/rpc/ws/rootSubscribe.go @@ -14,6 +14,8 @@ package ws +import "context" + type RootResult uint64 // SignatureSubscribe subscribes to receive notification @@ -42,8 +44,10 @@ type RootSubscription struct { sub *Subscription } -func (sw *RootSubscription) Recv() (*RootResult, error) { +func (sw *RootSubscription) Recv(ctx context.Context) (*RootResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d, ok := <-sw.sub.stream: if !ok { return nil, ErrSubscriptionClosed diff --git a/rpc/ws/signatureSubscribe.go b/rpc/ws/signatureSubscribe.go index d0f08025..2929be91 100644 --- a/rpc/ws/signatureSubscribe.go +++ b/rpc/ws/signatureSubscribe.go @@ -15,6 +15,7 @@ package ws import ( + "context" "fmt" "time" @@ -67,8 +68,10 @@ type SignatureSubscription struct { sub *Subscription } -func (sw *SignatureSubscription) Recv() (*SignatureResult, error) { +func (sw *SignatureSubscription) Recv(ctx context.Context) (*SignatureResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d, ok := <-sw.sub.stream: if !ok { return nil, ErrSubscriptionClosed diff --git a/rpc/ws/slotSubscribe.go b/rpc/ws/slotSubscribe.go index c49ac9dc..cc058ced 100644 --- a/rpc/ws/slotSubscribe.go +++ b/rpc/ws/slotSubscribe.go @@ -14,6 +14,8 @@ package ws +import "context" + type SlotResult struct { Parent uint64 `json:"parent"` Root uint64 `json:"root"` @@ -45,8 +47,10 @@ type SlotSubscription struct { sub *Subscription } -func (sw *SlotSubscription) Recv() (*SlotResult, error) { +func (sw *SlotSubscription) Recv(ctx context.Context) (*SlotResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d, ok := <-sw.sub.stream: if !ok { return nil, ErrSubscriptionClosed diff --git a/rpc/ws/slotsUpdatesSubscribe.go b/rpc/ws/slotsUpdatesSubscribe.go index 36fb564b..d636ace1 100644 --- a/rpc/ws/slotsUpdatesSubscribe.go +++ b/rpc/ws/slotsUpdatesSubscribe.go @@ -14,7 +14,11 @@ package ws -import "github.com/gagliardetto/solana-go" +import ( + "context" + + "github.com/gagliardetto/solana-go" +) type SlotsUpdatesResult struct { // The parent slot. @@ -77,8 +81,10 @@ type SlotsUpdatesSubscription struct { sub *Subscription } -func (sw *SlotsUpdatesSubscription) Recv() (*SlotsUpdatesResult, error) { +func (sw *SlotsUpdatesSubscription) Recv(ctx context.Context) (*SlotsUpdatesResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d, ok := <-sw.sub.stream: if !ok { return nil, ErrSubscriptionClosed diff --git a/rpc/ws/subscription.go b/rpc/ws/subscription.go index d2a16a18..cb8de86e 100644 --- a/rpc/ws/subscription.go +++ b/rpc/ws/subscription.go @@ -17,6 +17,8 @@ package ws +import "context" + type Subscription struct { req *request subID uint64 @@ -47,8 +49,10 @@ func newSubscription( } } -func (s *Subscription) Recv() (interface{}, error) { +func (s *Subscription) Recv(ctx context.Context) (interface{}, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d := <-s.stream: return d, nil case err := <-s.err: diff --git a/rpc/ws/voteSubscribe.go b/rpc/ws/voteSubscribe.go index b0e56694..8fbe8bdf 100644 --- a/rpc/ws/voteSubscribe.go +++ b/rpc/ws/voteSubscribe.go @@ -15,6 +15,8 @@ package ws import ( + "context" + "github.com/gagliardetto/solana-go" ) @@ -59,8 +61,10 @@ type VoteSubscription struct { sub *Subscription } -func (sw *VoteSubscription) Recv() (*VoteResult, error) { +func (sw *VoteSubscription) Recv(ctx context.Context) (*VoteResult, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case d, ok := <-sw.sub.stream: if !ok { return nil, ErrSubscriptionClosed