diff --git a/tools/walletextension/test/utils.go b/tools/walletextension/test/utils.go index b851f5f231..dbcee37fff 100644 --- a/tools/walletextension/test/utils.go +++ b/tools/walletextension/test/utils.go @@ -30,8 +30,6 @@ import ( const jsonID = "1" -var dummyAPI = NewDummyAPI() - func createWalExtCfg(connectPort, wallHTTPPort, wallWSPort int) *walletextension.Config { testPersistencePath, err := os.CreateTemp("", "") if err != nil { @@ -50,7 +48,6 @@ func createWalExt(t *testing.T, walExtCfg *walletextension.Config) func() { logger := log.New(log.WalletExtCmp, int(gethlog.LvlInfo), log.SysOut) walExt := walletextension.NewWalletExtension(*walExtCfg, logger) - t.Cleanup(walExt.Shutdown) go walExt.Serve(common.Localhost, walExtCfg.WalletExtensionPort, walExtCfg.WalletExtensionPortWS) err := waitForEndpoint(fmt.Sprintf("http://%s:%d%s", common.Localhost, walExtCfg.WalletExtensionPort, walletextension.PathReady)) @@ -62,7 +59,8 @@ func createWalExt(t *testing.T, walExtCfg *walletextension.Config) func() { } // Creates an RPC layer that the wallet extension can connect to. Returns a handle to shut down the host. -func createDummyHost(t *testing.T, wsRPCPort int) { +func createDummyHost(t *testing.T, wsRPCPort int) (*DummyAPI, func() error) { + dummyAPI := NewDummyAPI() cfg := gethnode.Config{ WSHost: common.Localhost, WSPort: wsRPCPort, @@ -92,6 +90,7 @@ func createDummyHost(t *testing.T, wsRPCPort int) { if err != nil { t.Fatalf(fmt.Sprintf("could not create new client server. Cause: %s", err)) } + return dummyAPI, rpcServerNode.Close } // Waits for the endpoint to be available. Times out after three seconds. @@ -122,6 +121,25 @@ func makeWSEthJSONReq(port int, method string, params interface{}) ([]byte, *web return makeRequestWS(fmt.Sprintf("ws://%s:%d", common.Localhost, port), reqBody) } +func makeWSEthJSONReqWithConn(conn *websocket.Conn, method string, params interface{}) []byte { + reqBody := prepareRequestBody(method, params) + return issueRequestWS(conn, reqBody) +} + +func openWSConn(port int) (*websocket.Conn, error) { + conn, dialResp, err := websocket.DefaultDialer.Dial(fmt.Sprintf("ws://%s:%d", common.Localhost, port), nil) + if dialResp != nil && dialResp.Body != nil { + defer dialResp.Body.Close() + } + if err != nil { + if conn != nil { + conn.Close() + } + panic(fmt.Errorf("received error response from wallet extension: %w", err)) + } + return conn, err +} + // Formats a method and its parameters as a Ethereum JSON RPC request. func prepareRequestBody(method string, params interface{}) []byte { reqBodyBytes, err := json.Marshal(map[string]interface{}{ @@ -229,7 +247,12 @@ func makeRequestWS(url string, body []byte) ([]byte, *websocket.Conn) { panic(fmt.Errorf("received error response from wallet extension: %w", err)) } - err = conn.WriteMessage(websocket.TextMessage, body) + return issueRequestWS(conn, body), conn +} + +// issues request on an existing ws connection +func issueRequestWS(conn *websocket.Conn, body []byte) []byte { + err := conn.WriteMessage(websocket.TextMessage, body) if err != nil { panic(err) } @@ -238,7 +261,7 @@ func makeRequestWS(url string, body []byte) ([]byte, *websocket.Conn) { if err != nil { panic(err) } - return reqResp, conn + return reqResp } // Reads messages from the connection for the provided duration, and returns the read messages. diff --git a/tools/walletextension/test/wallet_extension_test.go b/tools/walletextension/test/wallet_extension_test.go index 9915c72e5e..4db2bfd1ff 100644 --- a/tools/walletextension/test/wallet_extension_test.go +++ b/tools/walletextension/test/wallet_extension_test.go @@ -23,30 +23,20 @@ import ( const ( errFailedDecrypt = "could not decrypt bytes with viewing key" dummyParams = "dummyParams" - magicNumber = 123789 jsonKeyTopics = "topics" _hostWSPort = integration.StartPortWalletExtensionUnitTest _testOffset = 100 // offset each test by a multiplier of the offset to avoid port colision. ie: hostPort := _hostWSPort + _testOffset*2 ) -var dummyHash = gethcommon.BigToHash(big.NewInt(magicNumber)) - type testHelper struct { hostPort int walletHTTPPort int walletWSPort int + hostAPI *DummyAPI } func TestWalletExtension(t *testing.T) { - createDummyHost(t, _hostWSPort) - createWalExt(t, createWalExtCfg(_hostWSPort, _hostWSPort+1, _hostWSPort+2)) - - h := &testHelper{ - hostPort: _hostWSPort, - walletHTTPPort: _hostWSPort + 1, - walletWSPort: _hostWSPort + 2, - } - + i := 0 for name, test := range map[string]func(t *testing.T, testHelper *testHelper){ "canInvokeNonSensitiveMethodsWithoutViewingKey": canInvokeNonSensitiveMethodsWithoutViewingKey, "canInvokeSensitiveMethodsWithViewingKey": canInvokeSensitiveMethodsWithViewingKey, @@ -56,13 +46,33 @@ func TestWalletExtension(t *testing.T) { "canRegisterViewingKeyAndMakeRequestsOverWebsockets": canRegisterViewingKeyAndMakeRequestsOverWebsockets, } { t.Run(name, func(t *testing.T) { + hostPort := _hostWSPort + i*_testOffset + dummyAPI, shutDownHost := createDummyHost(t, hostPort) + shutdownWallet := createWalExt(t, createWalExtCfg(hostPort, hostPort+1, hostPort+2)) + + h := &testHelper{ + hostPort: hostPort, + walletHTTPPort: hostPort + 1, + walletWSPort: hostPort + 2, + hostAPI: dummyAPI, + } + test(t, h) + + shutdownWallet() + err := shutDownHost() + if err != nil { + t.Fatal(err) + } }) + i++ } } func canInvokeNonSensitiveMethodsWithoutViewingKey(t *testing.T, testHelper *testHelper) { - respBody, _ := makeWSEthJSONReq(testHelper.hostPort, rpc.ChainID, []interface{}{}) + respBody, wsConnWE := makeWSEthJSONReq(testHelper.hostPort, rpc.ChainID, []interface{}{}) + defer wsConnWE.Close() + validateJSONResponse(t, respBody) if !strings.Contains(string(respBody), l2ChainIDHex) { @@ -72,7 +82,7 @@ func canInvokeNonSensitiveMethodsWithoutViewingKey(t *testing.T, testHelper *tes func canInvokeSensitiveMethodsWithViewingKey(t *testing.T, testHelper *testHelper) { viewingKeyBytes := registerPrivateKey(t, testHelper.walletHTTPPort, testHelper.walletWSPort, false) - dummyAPI.setViewingKey(viewingKeyBytes) + testHelper.hostAPI.setViewingKey(viewingKeyBytes) for _, method := range rpc.SensitiveMethods { // Subscriptions have to be tested separately, as they return results differently. @@ -98,7 +108,7 @@ func cannotInvokeSensitiveMethodsWithViewingKeyForAnotherAccount(t *testing.T, t t.Fatalf(fmt.Sprintf("failed to generate private key. Cause: %s", err)) } arbitraryPublicKeyBytesHex := hex.EncodeToString(crypto.CompressPubkey(&arbitraryPrivateKey.PublicKey)) - dummyAPI.setViewingKey([]byte(arbitraryPublicKeyBytesHex)) + testHelper.hostAPI.setViewingKey([]byte(arbitraryPublicKeyBytesHex)) for _, method := range rpc.SensitiveMethods { // Subscriptions have to be tested separately, as they return results differently. @@ -123,7 +133,7 @@ func canInvokeSensitiveMethodsAfterSubmittingMultipleViewingKeys(t *testing.T, t // We set the API to decrypt with an arbitrary key from the list we just generated. arbitraryViewingKey := viewingKeys[len(viewingKeys)/2] - dummyAPI.setViewingKey(arbitraryViewingKey) + testHelper.hostAPI.setViewingKey(arbitraryViewingKey) respBody := makeHTTPEthJSONReq(testHelper.walletHTTPPort, rpc.GetBalance, []interface{}{map[string]interface{}{"params": dummyParams}}) validateJSONResponse(t, respBody) @@ -143,57 +153,72 @@ func cannotSubscribeOverHTTP(t *testing.T, testHelper *testHelper) { func canRegisterViewingKeyAndMakeRequestsOverWebsockets(t *testing.T, testHelper *testHelper) { viewingKeyBytes := registerPrivateKey(t, testHelper.walletHTTPPort, testHelper.walletWSPort, true) - dummyAPI.setViewingKey(viewingKeyBytes) + testHelper.hostAPI.setViewingKey(viewingKeyBytes) - for _, method := range rpc.SensitiveMethods { - // Subscriptions have to be tested separately, as they return results differently. - if method == rpc.Subscribe { - continue - } + conn, err := openWSConn(testHelper.walletWSPort) + if err != nil { + t.Fatal(err) + } - respBody, _ := makeWSEthJSONReq(testHelper.walletWSPort, method, []interface{}{map[string]interface{}{"params": dummyParams}}) - validateJSONResponse(t, respBody) + respBody := makeWSEthJSONReqWithConn(conn, rpc.GetTransactionReceipt, []interface{}{map[string]interface{}{"params": dummyParams}}) + validateJSONResponse(t, respBody) - if !strings.Contains(string(respBody), dummyParams) { - t.Fatalf("expected response containing '%s', got '%s'", dummyParams, string(respBody)) - } + if !strings.Contains(string(respBody), dummyParams) { + t.Fatalf("expected response containing '%s', got '%s'", dummyParams, string(respBody)) + } - return // We only need to test a single sensitive method. + err = conn.Close() + if err != nil { + t.Fatal(err) } } func TestCannotInvokeSensitiveMethodsWithoutViewingKey(t *testing.T) { - hostPort := _hostWSPort + _testOffset + hostPort := _hostWSPort + _testOffset*7 walletHTTPPort := hostPort + 1 walletWSPort := hostPort + 2 - createDummyHost(t, hostPort) - createWalExt(t, createWalExtCfg(hostPort, walletHTTPPort, walletWSPort)) + _, shutdownHost := createDummyHost(t, hostPort) + defer shutdownHost() //nolint: errcheck + + shutdownWallet := createWalExt(t, createWalExtCfg(hostPort, walletHTTPPort, walletWSPort)) + defer shutdownWallet() + + conn, err := openWSConn(walletWSPort) + if err != nil { + t.Fatal(err) + } for _, method := range rpc.SensitiveMethods { // We use a websocket request because one of the sensitive methods, eth_subscribe, requires it. - respBody, _ := makeWSEthJSONReq(walletWSPort, method, []interface{}{}) + respBody := makeWSEthJSONReqWithConn(conn, method, []interface{}{}) if !strings.Contains(string(respBody), fmt.Sprintf(accountmanager.ErrNoViewingKey, method)) { t.Fatalf("expected response containing '%s', got '%s'", fmt.Sprintf(accountmanager.ErrNoViewingKey, method), string(respBody)) } } + err = conn.Close() + if err != nil { + t.Fatal(err) + } } func TestKeysAreReloadedWhenWalletExtensionRestarts(t *testing.T) { - hostPort := _hostWSPort + _testOffset*2 + hostPort := _hostWSPort + _testOffset*8 walletHTTPPort := hostPort + 1 walletWSPort := hostPort + 2 - createDummyHost(t, hostPort) + dummyAPI, shutdownHost := createDummyHost(t, hostPort) + defer shutdownHost() //nolint: errcheck walExtCfg := createWalExtCfg(hostPort, walletHTTPPort, walletWSPort) - shutdown := createWalExt(t, walExtCfg) + shutdownWallet := createWalExt(t, walExtCfg) viewingKeyBytes := registerPrivateKey(t, walletHTTPPort, walletWSPort, false) dummyAPI.setViewingKey(viewingKeyBytes) // We shut down the wallet extension and restart it with the same config, forcing the viewing keys to be reloaded. - shutdown() - createWalExt(t, walExtCfg) + shutdownWallet() + shutdownWallet = createWalExt(t, walExtCfg) + defer shutdownWallet() respBody := makeHTTPEthJSONReq(walletHTTPPort, rpc.GetBalance, []interface{}{map[string]interface{}{"params": dummyParams}}) validateJSONResponse(t, respBody) @@ -204,12 +229,16 @@ func TestKeysAreReloadedWhenWalletExtensionRestarts(t *testing.T) { } func TestCanSubscribeForLogsOverWebsockets(t *testing.T) { - hostPort := _hostWSPort + _testOffset*3 + hostPort := _hostWSPort + _testOffset*9 walletHTTPPort := hostPort + 1 walletWSPort := hostPort + 2 - createDummyHost(t, hostPort) - createWalExt(t, createWalExtCfg(hostPort, walletHTTPPort, walletWSPort)) + dummyHash := gethcommon.BigToHash(big.NewInt(1234)) + + dummyAPI, shutdownHost := createDummyHost(t, hostPort) + defer shutdownHost() //nolint: errcheck + shutdownWallet := createWalExt(t, createWalExtCfg(hostPort, walletHTTPPort, walletWSPort)) + defer shutdownWallet() viewingKeyBytes := registerPrivateKey(t, walletHTTPPort, walletWSPort, false) dummyAPI.setViewingKey(viewingKeyBytes) diff --git a/tools/walletextension/userconn/user_conn.go b/tools/walletextension/userconn/user_conn.go index 7dc18b55b9..0b9446d8d3 100644 --- a/tools/walletextension/userconn/user_conn.go +++ b/tools/walletextension/userconn/user_conn.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "strings" gethlog "github.com/ethereum/go-ethereum/log" @@ -137,7 +138,7 @@ func (w *userConnWS) HandleError(msg string) { err = w.conn.WriteMessage(websocket.TextMessage, errMsg) if err != nil { - if websocket.IsCloseError(err) { + if websocket.IsCloseError(err) || strings.Contains(msg, "EOF") { w.isClosed = true } w.logger.Error("could not write error message to websocket", log.ErrKey, err) diff --git a/tools/walletextension/wallet_extension.go b/tools/walletextension/wallet_extension.go index 3546ce3388..a24c981823 100644 --- a/tools/walletextension/wallet_extension.go +++ b/tools/walletextension/wallet_extension.go @@ -9,6 +9,7 @@ import ( "fmt" "io/fs" "net/http" + "sync/atomic" "time" gethcommon "github.com/ethereum/go-ethereum/common" @@ -52,8 +53,14 @@ type WalletExtension struct { serverWSShutdown func(ctx context.Context) error persistence *persistence.Persistence logger gethlog.Logger + isShutDown atomicBool } +type atomicBool int32 + +func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } +func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } + func NewWalletExtension(config Config, logger gethlog.Logger) *WalletExtension { unauthedClient, err := rpc.NewNetworkClient(wsProtocol + config.NodeRPCWebsocketAddress) if err != nil { @@ -102,6 +109,7 @@ func (we *WalletExtension) Serve(host string, httpPort int, wsPort int) { } func (we *WalletExtension) Shutdown() { + we.isShutDown.setTrue() if we.serverHTTPShutdown != nil { err := we.serverHTTPShutdown(context.Background()) if err != nil { @@ -185,6 +193,9 @@ func (we *WalletExtension) handleSubmitViewingKeyWS(resp http.ResponseWriter, re // Creates an HTTP connection to handle the request. func (we *WalletExtension) handleRequestHTTP(resp http.ResponseWriter, req *http.Request, fun func(conn userconn.UserConn)) { + if we.isShutDown.isSet() { + return + } if httputil.EnableCORS(resp, req) { return } @@ -194,6 +205,9 @@ func (we *WalletExtension) handleRequestHTTP(resp http.ResponseWriter, req *http // Creates a websocket connection to handle the request. func (we *WalletExtension) handleRequestWS(resp http.ResponseWriter, req *http.Request, fun func(conn userconn.UserConn)) { + if we.isShutDown.isSet() { + return + } userConn, err := userconn.NewUserConnWS(resp, req, we.logger) if err != nil { return