Skip to content

Commit

Permalink
Gateway unsubscribe for multiple subscribed clients (#1637)
Browse files Browse the repository at this point in the history
  • Loading branch information
zkokelj authored Nov 9, 2023
1 parent e7ab386 commit 9e3cf21
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 82 deletions.
1 change: 1 addition & 0 deletions go/rpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const (
Attestation = "obscuroscan_attestation"
StopHost = "test_stopHost"
Subscribe = "eth_subscribe"
Unsubscribe = "eth_unsubscribe"
SubscribeNamespace = "eth"
SubscriptionTypeLogs = "logs"

Expand Down
50 changes: 1 addition & 49 deletions go/rpc/encrypted_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (c *EncRPCClient) CallContext(ctx context.Context, result interface{}, meth
return c.executeSensitiveCall(ctx, result, method, args...)
}

func (c *EncRPCClient) Subscribe(ctx context.Context, result interface{}, namespace string, ch interface{}, args ...interface{}) (*rpc.ClientSubscription, error) {
func (c *EncRPCClient) Subscribe(ctx context.Context, _ interface{}, namespace string, ch interface{}, args ...interface{}) (*rpc.ClientSubscription, error) {
if len(args) == 0 {
return nil, fmt.Errorf("subscription did not specify its type")
}
Expand Down Expand Up @@ -125,15 +125,6 @@ func (c *EncRPCClient) Subscribe(ctx context.Context, result interface{}, namesp
return nil, err
}

// We need to return the subscription ID, to allow unsubscribing. However, the client API has already converted
// from a subscription ID to a subscription object under the hood, so we can't retrieve the subscription ID.
// To hack around this, we always return the subscription ID as the first message on the newly-created subscription.
err = c.setResultToSubID(clientChannel, result, subscriptionToObscuro)
if err != nil {
subscriptionToObscuro.Unsubscribe()
return nil, err
}

go c.forwardLogs(clientChannel, logCh, subscriptionToObscuro)

return subscriptionToObscuro, nil
Expand Down Expand Up @@ -214,24 +205,6 @@ func (c *EncRPCClient) createAuthenticatedLogSubscription(args []interface{}) (*
return logSubscription, nil
}

func (c *EncRPCClient) setResultToSubID(clientChannel chan common.IDAndEncLog, result interface{}, subscription *rpc.ClientSubscription) error {
select {
case idAndEncLog := <-clientChannel:
if idAndEncLog.SubID == "" || idAndEncLog.EncLog != nil {
return fmt.Errorf("expected an initial subscription response with the subscription ID only")
}
if result != nil {
err := c.setResult([]byte(idAndEncLog.SubID), result)
if err != nil {
return fmt.Errorf("failed to extract result from subscription response: %w", err)
}
}
case <-subscription.Err():
return fmt.Errorf("did not receive the initial subscription response with the subscription ID")
}
return nil
}

func (c *EncRPCClient) executeSensitiveCall(ctx context.Context, result interface{}, method string, args ...interface{}) error {
// encode the params into a json blob and encrypt them
encryptedParams, err := c.encryptArgs(args...)
Expand Down Expand Up @@ -361,27 +334,6 @@ func (c *EncRPCClient) decryptResponse(encryptedBytes []byte) ([]byte, error) {
return decryptedResult, nil
}

// setResult tries to cast/unmarshal data into the result pointer, based on its type
func (c *EncRPCClient) setResult(data []byte, result interface{}) error {
switch result := result.(type) {
case *string:
*result = string(data)
return nil

case *interface{}:
err := json.Unmarshal(data, result)
if err != nil {
// if unmarshal failed with generic return we can try to send it back as a string
*result = string(data)
}
return nil

default:
// for any other type we attempt to json unmarshal it
return json.Unmarshal(data, result)
}
}

// IsSensitiveMethod indicates whether the RPC method's requests and responses should be encrypted.
func IsSensitiveMethod(method string) bool {
for _, m := range SensitiveMethods {
Expand Down
122 changes: 116 additions & 6 deletions integration/obscurogateway/obscurogateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ func TestObscuroGateway(t *testing.T) {
// run the tests against the exis
for name, test := range map[string]func(*testing.T, string, string, wallet.Wallet){
//"testAreTxsMinted": testAreTxsMinted, this breaks the other tests bc, enable once concurency issues are fixed
"testErrorHandling": testErrorHandling,
"testMultipleAccountsSubscription": testMultipleAccountsSubscription,
"testErrorsRevertedArePassed": testErrorsRevertedArePassed,
"testErrorHandling": testErrorHandling,
"testMultipleAccountsSubscription": testMultipleAccountsSubscription,
"testErrorsRevertedArePassed": testErrorsRevertedArePassed,
"testUnsubscribe": testUnsubscribe,
"testClosingConnectionWhileSubscribed": testClosingConnectionWhileSubscribed,
} {
t.Run(name, func(t *testing.T) {
test(t, httpURL, wsURL, w)
Expand Down Expand Up @@ -383,6 +385,114 @@ func testErrorsRevertedArePassed(t *testing.T, httpURL, wsURL string, w wallet.W
require.Equal(t, err.Error(), "execution reverted")
}

func testUnsubscribe(t *testing.T, httpURL, wsURL string, w wallet.Wallet) {
// create a user with multiple accounts
user, err := NewUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.ObscuroChainID)}, httpURL, wsURL)
require.NoError(t, err)
fmt.Printf("Created user with UserID: %s\n", user.ogClient.UserID())

// register all the accounts for the user
err = user.RegisterAccounts()
require.NoError(t, err)

// deploy events contract
deployTx := &types.LegacyTx{
Nonce: w.GetNonceAndIncrement(),
Gas: uint64(1_000_000),
GasPrice: gethcommon.Big1,
Data: gethcommon.FromHex(eventsContractBytecode),
}

signedTx, err := w.SignTransaction(deployTx)
require.NoError(t, err)

err = user.HTTPClient.SendTransaction(context.Background(), signedTx)
require.NoError(t, err)

contractReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), user.HTTPClient, signedTx.Hash(), time.Minute)
require.NoError(t, err)

fmt.Println("Deployed contract address: ", contractReceipt.ContractAddress)

// subscribe to an event
var userLogs []types.Log
subscription := subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user.WSClient, &userLogs)

// make an action that will trigger events
_, err = integrationCommon.InteractWithSmartContract(user.HTTPClient, user.Wallets[0], eventsContractABI, "setMessage", "foo", contractReceipt.ContractAddress)
require.NoError(t, err)

assert.Equal(t, 1, len(userLogs))

// Unsubscribe from events
subscription.Unsubscribe()

// make another action that will trigger events
_, err = integrationCommon.InteractWithSmartContract(user.HTTPClient, user.Wallets[0], eventsContractABI, "setMessage", "bar", contractReceipt.ContractAddress)
require.NoError(t, err)

// check that we are not receiving events after unsubscribing
assert.Equal(t, 1, len(userLogs))
}

func testClosingConnectionWhileSubscribed(t *testing.T, httpURL, wsURL string, w wallet.Wallet) {
// create a user with multiple accounts
user, err := NewUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.ObscuroChainID)}, httpURL, wsURL)
require.NoError(t, err)
fmt.Printf("Created user with UserID: %s\n", user.ogClient.UserID())

// register all the accounts for the user
err = user.RegisterAccounts()
require.NoError(t, err)

// deploy events contract
deployTx := &types.LegacyTx{
Nonce: w.GetNonceAndIncrement(),
Gas: uint64(1_000_000),
GasPrice: gethcommon.Big1,
Data: gethcommon.FromHex(eventsContractBytecode),
}

signedTx, err := w.SignTransaction(deployTx)
require.NoError(t, err)

err = user.HTTPClient.SendTransaction(context.Background(), signedTx)
require.NoError(t, err)

contractReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), user.HTTPClient, signedTx.Hash(), time.Minute)
require.NoError(t, err)

fmt.Println("Deployed contract address: ", contractReceipt.ContractAddress)

// subscribe to an event
var userLogs []types.Log
subscription := subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user.WSClient, &userLogs)

// Close the websocket connection and make sure nothing breaks, but user does not receive events
user.WSClient.Close()

// make an action that will emmit events
_, err = integrationCommon.InteractWithSmartContract(user.HTTPClient, user.Wallets[0], eventsContractABI, "setMessage2", "foo", contractReceipt.ContractAddress)
require.NoError(t, err)
// but with closed connection we don't receive any logs
assert.Equal(t, 0, len(userLogs))

// re-establish connection
wsClient, err := ethclient.Dial(wsURL + "/v1/" + "?u=" + user.ogClient.UserID())
require.NoError(t, err)
user.WSClient = wsClient

// make an action that will emmit events again
_, err = integrationCommon.InteractWithSmartContract(user.HTTPClient, user.Wallets[0], eventsContractABI, "setMessage2", "foo", contractReceipt.ContractAddress)
require.NoError(t, err)

// closing connection (above) unsubscribes, and we still should see no logs
assert.Equal(t, 0, len(userLogs))

// Call unsubscribe (should handle it without issues even if it is already unsubscribed by closing the channel)
subscription.Unsubscribe()
}

func transferRandomAddr(t *testing.T, client *ethclient.Client, w wallet.Wallet) common.TxHash { //nolint: unused
ctx := context.Background()
toAddr := datagenerator.RandomAddress()
Expand Down Expand Up @@ -475,7 +585,7 @@ func transferETHToAddress(client *ethclient.Client, wallet wallet.Wallet, toAddr
return integrationCommon.AwaitReceiptEth(context.Background(), client, signedTx.Hash(), 2*time.Second)
}

func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Hash, client *ethclient.Client, logs *[]types.Log) {
func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Hash, client *ethclient.Client, logs *[]types.Log) ethereum.Subscription { //nolint:unparam
// Make a subscription
filterQuery := ethereum.FilterQuery{
Addresses: addresses,
Expand All @@ -489,8 +599,6 @@ func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Has
if err != nil {
fmt.Printf("Failed to subscribe to filter logs: %v\n", err)
}
// todo (@ziga) - unsubscribe when it is fixed...
// defer subscription.Unsubscribe() // cleanup

// Listen for logs in a goroutine
go func() {
Expand All @@ -505,4 +613,6 @@ func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Has
}
}
}()

return subscription
}
13 changes: 13 additions & 0 deletions tools/walletextension/accountmanager/account_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ func (m *AccountManager) AddClient(address gethcommon.Address, client *rpc.EncRP
// ProxyRequest tries to identify the correct EncRPCClient to proxy the request to the Obscuro node, or it will attempt
// the request with all clients until it succeeds
func (m *AccountManager) ProxyRequest(rpcReq *wecommon.RPCRequest, rpcResp *interface{}, userConn userconn.UserConn) error {
// We need to handle a special case for subscribing and unsubscribing from events,
// because we need to handle multiple accounts with a single user request
if rpcReq.Method == rpc.Subscribe {
clients, err := m.suggestSubscriptionClient(rpcReq)
if err != nil {
Expand All @@ -73,6 +75,17 @@ func (m *AccountManager) ProxyRequest(rpcReq *wecommon.RPCRequest, rpcResp *inte
}
return nil
}
if rpcReq.Method == rpc.Unsubscribe {
if len(rpcReq.Params) != 1 {
return fmt.Errorf("one parameter (subscriptionID) expected, %d parameters received", len(rpcReq.Params))
}
subscriptionID, ok := rpcReq.Params[0].(string)
if !ok {
return fmt.Errorf("subscriptionID needs to be a string. Got: %v", rpcReq.Params[0])
}
m.subscriptionsManager.HandleUnsubscribe(subscriptionID, rpcResp)
return nil
}
return m.executeCall(rpcReq, rpcResp)
}

Expand Down
Loading

0 comments on commit 9e3cf21

Please sign in to comment.