diff --git a/go/rpc/client.go b/go/rpc/client.go index d5ba448a12..768ffaa126 100644 --- a/go/rpc/client.go +++ b/go/rpc/client.go @@ -34,6 +34,7 @@ const ( Attestation = "obscuroscan_attestation" StopHost = "test_stopHost" Subscribe = "eth_subscribe" + Unsubscribe = "eth_unsubscribe" SubscribeNamespace = "eth" SubscriptionTypeLogs = "logs" diff --git a/go/rpc/encrypted_client.go b/go/rpc/encrypted_client.go index 7045376d2b..55389b7b7a 100644 --- a/go/rpc/encrypted_client.go +++ b/go/rpc/encrypted_client.go @@ -89,7 +89,7 @@ func (c *EncRPCClient) CallContext(ctx context.Context, result interface{}, meth return c.executeSensitiveCall(ctx, result, method, args...) } -func (c *EncRPCClient) Subscribe(ctx context.Context, result interface{}, namespace string, ch interface{}, args ...interface{}) (*rpc.ClientSubscription, error) { +func (c *EncRPCClient) Subscribe(ctx context.Context, _ interface{}, namespace string, ch interface{}, args ...interface{}) (*rpc.ClientSubscription, error) { if len(args) == 0 { return nil, fmt.Errorf("subscription did not specify its type") } @@ -125,15 +125,6 @@ func (c *EncRPCClient) Subscribe(ctx context.Context, result interface{}, namesp return nil, err } - // We need to return the subscription ID, to allow unsubscribing. However, the client API has already converted - // from a subscription ID to a subscription object under the hood, so we can't retrieve the subscription ID. - // To hack around this, we always return the subscription ID as the first message on the newly-created subscription. - err = c.setResultToSubID(clientChannel, result, subscriptionToObscuro) - if err != nil { - subscriptionToObscuro.Unsubscribe() - return nil, err - } - go c.forwardLogs(clientChannel, logCh, subscriptionToObscuro) return subscriptionToObscuro, nil @@ -214,24 +205,6 @@ func (c *EncRPCClient) createAuthenticatedLogSubscription(args []interface{}) (* return logSubscription, nil } -func (c *EncRPCClient) setResultToSubID(clientChannel chan common.IDAndEncLog, result interface{}, subscription *rpc.ClientSubscription) error { - select { - case idAndEncLog := <-clientChannel: - if idAndEncLog.SubID == "" || idAndEncLog.EncLog != nil { - return fmt.Errorf("expected an initial subscription response with the subscription ID only") - } - if result != nil { - err := c.setResult([]byte(idAndEncLog.SubID), result) - if err != nil { - return fmt.Errorf("failed to extract result from subscription response: %w", err) - } - } - case <-subscription.Err(): - return fmt.Errorf("did not receive the initial subscription response with the subscription ID") - } - return nil -} - func (c *EncRPCClient) executeSensitiveCall(ctx context.Context, result interface{}, method string, args ...interface{}) error { // encode the params into a json blob and encrypt them encryptedParams, err := c.encryptArgs(args...) @@ -361,27 +334,6 @@ func (c *EncRPCClient) decryptResponse(encryptedBytes []byte) ([]byte, error) { return decryptedResult, nil } -// setResult tries to cast/unmarshal data into the result pointer, based on its type -func (c *EncRPCClient) setResult(data []byte, result interface{}) error { - switch result := result.(type) { - case *string: - *result = string(data) - return nil - - case *interface{}: - err := json.Unmarshal(data, result) - if err != nil { - // if unmarshal failed with generic return we can try to send it back as a string - *result = string(data) - } - return nil - - default: - // for any other type we attempt to json unmarshal it - return json.Unmarshal(data, result) - } -} - // IsSensitiveMethod indicates whether the RPC method's requests and responses should be encrypted. func IsSensitiveMethod(method string) bool { for _, m := range SensitiveMethods { diff --git a/integration/obscurogateway/obscurogateway_test.go b/integration/obscurogateway/obscurogateway_test.go index 5b07dd19b6..96773f8aa9 100644 --- a/integration/obscurogateway/obscurogateway_test.go +++ b/integration/obscurogateway/obscurogateway_test.go @@ -90,9 +90,11 @@ func TestObscuroGateway(t *testing.T) { // run the tests against the exis for name, test := range map[string]func(*testing.T, string, string, wallet.Wallet){ //"testAreTxsMinted": testAreTxsMinted, this breaks the other tests bc, enable once concurency issues are fixed - "testErrorHandling": testErrorHandling, - "testMultipleAccountsSubscription": testMultipleAccountsSubscription, - "testErrorsRevertedArePassed": testErrorsRevertedArePassed, + "testErrorHandling": testErrorHandling, + "testMultipleAccountsSubscription": testMultipleAccountsSubscription, + "testErrorsRevertedArePassed": testErrorsRevertedArePassed, + "testUnsubscribe": testUnsubscribe, + "testClosingConnectionWhileSubscribed": testClosingConnectionWhileSubscribed, } { t.Run(name, func(t *testing.T) { test(t, httpURL, wsURL, w) @@ -383,6 +385,114 @@ func testErrorsRevertedArePassed(t *testing.T, httpURL, wsURL string, w wallet.W require.Equal(t, err.Error(), "execution reverted") } +func testUnsubscribe(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { + // create a user with multiple accounts + user, err := NewUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.ObscuroChainID)}, httpURL, wsURL) + require.NoError(t, err) + fmt.Printf("Created user with UserID: %s\n", user.ogClient.UserID()) + + // register all the accounts for the user + err = user.RegisterAccounts() + require.NoError(t, err) + + // deploy events contract + deployTx := &types.LegacyTx{ + Nonce: w.GetNonceAndIncrement(), + Gas: uint64(1_000_000), + GasPrice: gethcommon.Big1, + Data: gethcommon.FromHex(eventsContractBytecode), + } + + signedTx, err := w.SignTransaction(deployTx) + require.NoError(t, err) + + err = user.HTTPClient.SendTransaction(context.Background(), signedTx) + require.NoError(t, err) + + contractReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), user.HTTPClient, signedTx.Hash(), time.Minute) + require.NoError(t, err) + + fmt.Println("Deployed contract address: ", contractReceipt.ContractAddress) + + // subscribe to an event + var userLogs []types.Log + subscription := subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user.WSClient, &userLogs) + + // make an action that will trigger events + _, err = integrationCommon.InteractWithSmartContract(user.HTTPClient, user.Wallets[0], eventsContractABI, "setMessage", "foo", contractReceipt.ContractAddress) + require.NoError(t, err) + + assert.Equal(t, 1, len(userLogs)) + + // Unsubscribe from events + subscription.Unsubscribe() + + // make another action that will trigger events + _, err = integrationCommon.InteractWithSmartContract(user.HTTPClient, user.Wallets[0], eventsContractABI, "setMessage", "bar", contractReceipt.ContractAddress) + require.NoError(t, err) + + // check that we are not receiving events after unsubscribing + assert.Equal(t, 1, len(userLogs)) +} + +func testClosingConnectionWhileSubscribed(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { + // create a user with multiple accounts + user, err := NewUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.ObscuroChainID)}, httpURL, wsURL) + require.NoError(t, err) + fmt.Printf("Created user with UserID: %s\n", user.ogClient.UserID()) + + // register all the accounts for the user + err = user.RegisterAccounts() + require.NoError(t, err) + + // deploy events contract + deployTx := &types.LegacyTx{ + Nonce: w.GetNonceAndIncrement(), + Gas: uint64(1_000_000), + GasPrice: gethcommon.Big1, + Data: gethcommon.FromHex(eventsContractBytecode), + } + + signedTx, err := w.SignTransaction(deployTx) + require.NoError(t, err) + + err = user.HTTPClient.SendTransaction(context.Background(), signedTx) + require.NoError(t, err) + + contractReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), user.HTTPClient, signedTx.Hash(), time.Minute) + require.NoError(t, err) + + fmt.Println("Deployed contract address: ", contractReceipt.ContractAddress) + + // subscribe to an event + var userLogs []types.Log + subscription := subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user.WSClient, &userLogs) + + // Close the websocket connection and make sure nothing breaks, but user does not receive events + user.WSClient.Close() + + // make an action that will emmit events + _, err = integrationCommon.InteractWithSmartContract(user.HTTPClient, user.Wallets[0], eventsContractABI, "setMessage2", "foo", contractReceipt.ContractAddress) + require.NoError(t, err) + // but with closed connection we don't receive any logs + assert.Equal(t, 0, len(userLogs)) + + // re-establish connection + wsClient, err := ethclient.Dial(wsURL + "/v1/" + "?u=" + user.ogClient.UserID()) + require.NoError(t, err) + user.WSClient = wsClient + + // make an action that will emmit events again + _, err = integrationCommon.InteractWithSmartContract(user.HTTPClient, user.Wallets[0], eventsContractABI, "setMessage2", "foo", contractReceipt.ContractAddress) + require.NoError(t, err) + + // closing connection (above) unsubscribes, and we still should see no logs + assert.Equal(t, 0, len(userLogs)) + + // Call unsubscribe (should handle it without issues even if it is already unsubscribed by closing the channel) + subscription.Unsubscribe() +} + func transferRandomAddr(t *testing.T, client *ethclient.Client, w wallet.Wallet) common.TxHash { //nolint: unused ctx := context.Background() toAddr := datagenerator.RandomAddress() @@ -475,7 +585,7 @@ func transferETHToAddress(client *ethclient.Client, wallet wallet.Wallet, toAddr return integrationCommon.AwaitReceiptEth(context.Background(), client, signedTx.Hash(), 2*time.Second) } -func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Hash, client *ethclient.Client, logs *[]types.Log) { +func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Hash, client *ethclient.Client, logs *[]types.Log) ethereum.Subscription { //nolint:unparam // Make a subscription filterQuery := ethereum.FilterQuery{ Addresses: addresses, @@ -489,8 +599,6 @@ func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Has if err != nil { fmt.Printf("Failed to subscribe to filter logs: %v\n", err) } - // todo (@ziga) - unsubscribe when it is fixed... - // defer subscription.Unsubscribe() // cleanup // Listen for logs in a goroutine go func() { @@ -505,4 +613,6 @@ func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Has } } }() + + return subscription } diff --git a/tools/walletextension/accountmanager/account_manager.go b/tools/walletextension/accountmanager/account_manager.go index a92dcf10e9..8b8c8edb1d 100644 --- a/tools/walletextension/accountmanager/account_manager.go +++ b/tools/walletextension/accountmanager/account_manager.go @@ -61,6 +61,8 @@ func (m *AccountManager) AddClient(address gethcommon.Address, client *rpc.EncRP // ProxyRequest tries to identify the correct EncRPCClient to proxy the request to the Obscuro node, or it will attempt // the request with all clients until it succeeds func (m *AccountManager) ProxyRequest(rpcReq *wecommon.RPCRequest, rpcResp *interface{}, userConn userconn.UserConn) error { + // We need to handle a special case for subscribing and unsubscribing from events, + // because we need to handle multiple accounts with a single user request if rpcReq.Method == rpc.Subscribe { clients, err := m.suggestSubscriptionClient(rpcReq) if err != nil { @@ -73,6 +75,17 @@ func (m *AccountManager) ProxyRequest(rpcReq *wecommon.RPCRequest, rpcResp *inte } return nil } + if rpcReq.Method == rpc.Unsubscribe { + if len(rpcReq.Params) != 1 { + return fmt.Errorf("one parameter (subscriptionID) expected, %d parameters received", len(rpcReq.Params)) + } + subscriptionID, ok := rpcReq.Params[0].(string) + if !ok { + return fmt.Errorf("subscriptionID needs to be a string. Got: %v", rpcReq.Params[0]) + } + m.subscriptionsManager.HandleUnsubscribe(subscriptionID, rpcResp) + return nil + } return m.executeCall(rpcReq, rpcResp) } diff --git a/tools/walletextension/subscriptions/subscriptions.go b/tools/walletextension/subscriptions/subscriptions.go index 6e1a29636c..f6e203e0b2 100644 --- a/tools/walletextension/subscriptions/subscriptions.go +++ b/tools/walletextension/subscriptions/subscriptions.go @@ -19,14 +19,14 @@ import ( ) type SubscriptionManager struct { - subscriptionMappings map[string][]string + subscriptionMappings map[string][]*gethrpc.ClientSubscription logger gethlog.Logger mu sync.Mutex } func New(logger gethlog.Logger) *SubscriptionManager { return &SubscriptionManager{ - subscriptionMappings: make(map[string][]string), + subscriptionMappings: make(map[string][]*gethrpc.ClientSubscription), logger: logger, } } @@ -55,17 +55,11 @@ func (sm *SubscriptionManager) HandleNewSubscriptions(clients []rpc.Client, req if err != nil { return fmt.Errorf("could not call %s with params %v. Cause: %w", req.Method, req.Params, err) } + sm.UpdateSubscriptionMapping(string(userSubscriptionID), subscription) // We periodically check if the websocket is closed, and terminate the subscription. - // TODO: test this feature in integration test - go checkIfUserConnIsClosedAndUnsubscribe(userConn, subscription) - - // Make a connection between subscriptionID returned from node for current request and subscriptionID returned to user - if currentNodeSubscriptionID, ok := (*resp).(string); ok { - sm.UpdateSubscriptionMapping(string(userSubscriptionID), currentNodeSubscriptionID) - } else { - sm.logger.Error("Unable to read subscriptionID") - } + // TODO: Check if it will be much more efficient to create just one go routine for all clients together + go sm.checkIfUserConnIsClosedAndUnsubscribe(userConn, subscription, string(userSubscriptionID)) } // We return subscriptionID with resp interface. We want to use userSubscriptionID to allow unsubscribing @@ -106,39 +100,70 @@ func readFromChannelAndWriteToUserConn(channel chan common.IDAndLog, userConn us } } -func checkIfUserConnIsClosedAndUnsubscribe(userConn userconn.UserConn, subscription *gethrpc.ClientSubscription) { - for { - if userConn.IsClosed() { - subscription.Unsubscribe() - return +func (sm *SubscriptionManager) unsubscribeAndRemove(userSubscriptionID string, subscription *gethrpc.ClientSubscription) { + sm.mu.Lock() + defer sm.mu.Unlock() + + subscription.Unsubscribe() + + subscriptions, exists := sm.subscriptionMappings[userSubscriptionID] + if !exists { + sm.logger.Error("subscription that needs to be removed is not present in subscriptionMappings for userSubscriptionID: %s", userSubscriptionID) + return + } + + for i, s := range subscriptions { + if s != subscription { + continue + } + + // Remove the subscription from the slice + lastIndex := len(subscriptions) - 1 + subscriptions[i] = subscriptions[lastIndex] + subscriptions = subscriptions[:lastIndex] + + // If the slice is empty, delete the key from the map + if len(subscriptions) == 0 { + delete(sm.subscriptionMappings, userSubscriptionID) + } else { + sm.subscriptionMappings[userSubscriptionID] = subscriptions } + break + } +} + +func (sm *SubscriptionManager) checkIfUserConnIsClosedAndUnsubscribe(userConn userconn.UserConn, subscription *gethrpc.ClientSubscription, userSubscriptionID string) { + for !userConn.IsClosed() { time.Sleep(100 * time.Millisecond) } + + sm.unsubscribeAndRemove(userSubscriptionID, subscription) } -func (sm *SubscriptionManager) UpdateSubscriptionMapping(userSubscriptionID string, obscuroNodeSubscriptionID string) { +func (sm *SubscriptionManager) UpdateSubscriptionMapping(userSubscriptionID string, subscription *gethrpc.ClientSubscription) { // Ensure there is no concurrent map writes sm.mu.Lock() defer sm.mu.Unlock() - existingUserIDs, exists := sm.subscriptionMappings[userSubscriptionID] + // Check if the userSubscriptionID already exists in the map + subscriptions, exists := sm.subscriptionMappings[userSubscriptionID] + // If it doesn't exist, create a new slice for it if !exists { - sm.subscriptionMappings[userSubscriptionID] = []string{obscuroNodeSubscriptionID} - return + subscriptions = []*gethrpc.ClientSubscription{} } - // Check if obscuroNodeSubscriptionID already exists to avoid duplication - alreadyExists := false - for _, existingID := range existingUserIDs { - if obscuroNodeSubscriptionID == existingID { - alreadyExists = true + // Check if the subscription is already in the slice, if not, add it + subscriptionExists := false + for _, sub := range subscriptions { + if sub == subscription { + subscriptionExists = true break } } - if !alreadyExists { - sm.subscriptionMappings[userSubscriptionID] = append(existingUserIDs, obscuroNodeSubscriptionID) + if !subscriptionExists { + sm.subscriptionMappings[userSubscriptionID] = append(subscriptions, subscription) } } @@ -159,3 +184,19 @@ func prepareLogResponse(idAndLog common.IDAndLog, userSubscriptionID gethrpc.ID) } return jsonResponse, nil } + +func (sm *SubscriptionManager) HandleUnsubscribe(userSubscriptionID string, rpcResp *interface{}) { + subscriptions, exists := sm.subscriptionMappings[userSubscriptionID] + if !exists { + *rpcResp = false + return + } + + sm.mu.Lock() + defer sm.mu.Unlock() + for _, sub := range subscriptions { + sub.Unsubscribe() + } + delete(sm.subscriptionMappings, userSubscriptionID) + *rpcResp = true +}