Skip to content

Commit

Permalink
Merge pull request #1102 from obscuronet/pedro/fix_we_tests
Browse files Browse the repository at this point in the history
Fixes for the we tests
  • Loading branch information
otherview authored Feb 9, 2023
2 parents de1d8bf + d381e98 commit 55f4fdc
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 47 deletions.
35 changes: 29 additions & 6 deletions tools/walletextension/test/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ import (

const jsonID = "1"

var dummyAPI = NewDummyAPI()

func createWalExtCfg(connectPort, wallHTTPPort, wallWSPort int) *walletextension.Config {
testPersistencePath, err := os.CreateTemp("", "")
if err != nil {
Expand All @@ -50,7 +48,6 @@ func createWalExt(t *testing.T, walExtCfg *walletextension.Config) func() {
logger := log.New(log.WalletExtCmp, int(gethlog.LvlInfo), log.SysOut)

walExt := walletextension.NewWalletExtension(*walExtCfg, logger)
t.Cleanup(walExt.Shutdown)
go walExt.Serve(common.Localhost, walExtCfg.WalletExtensionPort, walExtCfg.WalletExtensionPortWS)

err := waitForEndpoint(fmt.Sprintf("http://%s:%d%s", common.Localhost, walExtCfg.WalletExtensionPort, walletextension.PathReady))
Expand All @@ -62,7 +59,8 @@ func createWalExt(t *testing.T, walExtCfg *walletextension.Config) func() {
}

// Creates an RPC layer that the wallet extension can connect to. Returns a handle to shut down the host.
func createDummyHost(t *testing.T, wsRPCPort int) {
func createDummyHost(t *testing.T, wsRPCPort int) (*DummyAPI, func() error) {
dummyAPI := NewDummyAPI()
cfg := gethnode.Config{
WSHost: common.Localhost,
WSPort: wsRPCPort,
Expand Down Expand Up @@ -92,6 +90,7 @@ func createDummyHost(t *testing.T, wsRPCPort int) {
if err != nil {
t.Fatalf(fmt.Sprintf("could not create new client server. Cause: %s", err))
}
return dummyAPI, rpcServerNode.Close
}

// Waits for the endpoint to be available. Times out after three seconds.
Expand Down Expand Up @@ -122,6 +121,25 @@ func makeWSEthJSONReq(port int, method string, params interface{}) ([]byte, *web
return makeRequestWS(fmt.Sprintf("ws://%s:%d", common.Localhost, port), reqBody)
}

func makeWSEthJSONReqWithConn(conn *websocket.Conn, method string, params interface{}) []byte {
reqBody := prepareRequestBody(method, params)
return issueRequestWS(conn, reqBody)
}

func openWSConn(port int) (*websocket.Conn, error) {
conn, dialResp, err := websocket.DefaultDialer.Dial(fmt.Sprintf("ws://%s:%d", common.Localhost, port), nil)
if dialResp != nil && dialResp.Body != nil {
defer dialResp.Body.Close()
}
if err != nil {
if conn != nil {
conn.Close()
}
panic(fmt.Errorf("received error response from wallet extension: %w", err))
}
return conn, err
}

// Formats a method and its parameters as a Ethereum JSON RPC request.
func prepareRequestBody(method string, params interface{}) []byte {
reqBodyBytes, err := json.Marshal(map[string]interface{}{
Expand Down Expand Up @@ -229,7 +247,12 @@ func makeRequestWS(url string, body []byte) ([]byte, *websocket.Conn) {
panic(fmt.Errorf("received error response from wallet extension: %w", err))
}

err = conn.WriteMessage(websocket.TextMessage, body)
return issueRequestWS(conn, body), conn
}

// issues request on an existing ws connection
func issueRequestWS(conn *websocket.Conn, body []byte) []byte {
err := conn.WriteMessage(websocket.TextMessage, body)
if err != nil {
panic(err)
}
Expand All @@ -238,7 +261,7 @@ func makeRequestWS(url string, body []byte) ([]byte, *websocket.Conn) {
if err != nil {
panic(err)
}
return reqResp, conn
return reqResp
}

// Reads messages from the connection for the provided duration, and returns the read messages.
Expand Down
109 changes: 69 additions & 40 deletions tools/walletextension/test/wallet_extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,20 @@ import (
const (
errFailedDecrypt = "could not decrypt bytes with viewing key"
dummyParams = "dummyParams"
magicNumber = 123789
jsonKeyTopics = "topics"
_hostWSPort = integration.StartPortWalletExtensionUnitTest
_testOffset = 100 // offset each test by a multiplier of the offset to avoid port colision. ie: hostPort := _hostWSPort + _testOffset*2
)

var dummyHash = gethcommon.BigToHash(big.NewInt(magicNumber))

type testHelper struct {
hostPort int
walletHTTPPort int
walletWSPort int
hostAPI *DummyAPI
}

func TestWalletExtension(t *testing.T) {
createDummyHost(t, _hostWSPort)
createWalExt(t, createWalExtCfg(_hostWSPort, _hostWSPort+1, _hostWSPort+2))

h := &testHelper{
hostPort: _hostWSPort,
walletHTTPPort: _hostWSPort + 1,
walletWSPort: _hostWSPort + 2,
}

i := 0
for name, test := range map[string]func(t *testing.T, testHelper *testHelper){
"canInvokeNonSensitiveMethodsWithoutViewingKey": canInvokeNonSensitiveMethodsWithoutViewingKey,
"canInvokeSensitiveMethodsWithViewingKey": canInvokeSensitiveMethodsWithViewingKey,
Expand All @@ -56,13 +46,33 @@ func TestWalletExtension(t *testing.T) {
"canRegisterViewingKeyAndMakeRequestsOverWebsockets": canRegisterViewingKeyAndMakeRequestsOverWebsockets,
} {
t.Run(name, func(t *testing.T) {
hostPort := _hostWSPort + i*_testOffset
dummyAPI, shutDownHost := createDummyHost(t, hostPort)
shutdownWallet := createWalExt(t, createWalExtCfg(hostPort, hostPort+1, hostPort+2))

h := &testHelper{
hostPort: hostPort,
walletHTTPPort: hostPort + 1,
walletWSPort: hostPort + 2,
hostAPI: dummyAPI,
}

test(t, h)

shutdownWallet()
err := shutDownHost()
if err != nil {
t.Fatal(err)
}
})
i++
}
}

func canInvokeNonSensitiveMethodsWithoutViewingKey(t *testing.T, testHelper *testHelper) {
respBody, _ := makeWSEthJSONReq(testHelper.hostPort, rpc.ChainID, []interface{}{})
respBody, wsConnWE := makeWSEthJSONReq(testHelper.hostPort, rpc.ChainID, []interface{}{})
defer wsConnWE.Close()

validateJSONResponse(t, respBody)

if !strings.Contains(string(respBody), l2ChainIDHex) {
Expand All @@ -72,7 +82,7 @@ func canInvokeNonSensitiveMethodsWithoutViewingKey(t *testing.T, testHelper *tes

func canInvokeSensitiveMethodsWithViewingKey(t *testing.T, testHelper *testHelper) {
viewingKeyBytes := registerPrivateKey(t, testHelper.walletHTTPPort, testHelper.walletWSPort, false)
dummyAPI.setViewingKey(viewingKeyBytes)
testHelper.hostAPI.setViewingKey(viewingKeyBytes)

for _, method := range rpc.SensitiveMethods {
// Subscriptions have to be tested separately, as they return results differently.
Expand All @@ -98,7 +108,7 @@ func cannotInvokeSensitiveMethodsWithViewingKeyForAnotherAccount(t *testing.T, t
t.Fatalf(fmt.Sprintf("failed to generate private key. Cause: %s", err))
}
arbitraryPublicKeyBytesHex := hex.EncodeToString(crypto.CompressPubkey(&arbitraryPrivateKey.PublicKey))
dummyAPI.setViewingKey([]byte(arbitraryPublicKeyBytesHex))
testHelper.hostAPI.setViewingKey([]byte(arbitraryPublicKeyBytesHex))

for _, method := range rpc.SensitiveMethods {
// Subscriptions have to be tested separately, as they return results differently.
Expand All @@ -123,7 +133,7 @@ func canInvokeSensitiveMethodsAfterSubmittingMultipleViewingKeys(t *testing.T, t

// We set the API to decrypt with an arbitrary key from the list we just generated.
arbitraryViewingKey := viewingKeys[len(viewingKeys)/2]
dummyAPI.setViewingKey(arbitraryViewingKey)
testHelper.hostAPI.setViewingKey(arbitraryViewingKey)

respBody := makeHTTPEthJSONReq(testHelper.walletHTTPPort, rpc.GetBalance, []interface{}{map[string]interface{}{"params": dummyParams}})
validateJSONResponse(t, respBody)
Expand All @@ -143,57 +153,72 @@ func cannotSubscribeOverHTTP(t *testing.T, testHelper *testHelper) {

func canRegisterViewingKeyAndMakeRequestsOverWebsockets(t *testing.T, testHelper *testHelper) {
viewingKeyBytes := registerPrivateKey(t, testHelper.walletHTTPPort, testHelper.walletWSPort, true)
dummyAPI.setViewingKey(viewingKeyBytes)
testHelper.hostAPI.setViewingKey(viewingKeyBytes)

for _, method := range rpc.SensitiveMethods {
// Subscriptions have to be tested separately, as they return results differently.
if method == rpc.Subscribe {
continue
}
conn, err := openWSConn(testHelper.walletWSPort)
if err != nil {
t.Fatal(err)
}

respBody, _ := makeWSEthJSONReq(testHelper.walletWSPort, method, []interface{}{map[string]interface{}{"params": dummyParams}})
validateJSONResponse(t, respBody)
respBody := makeWSEthJSONReqWithConn(conn, rpc.GetTransactionReceipt, []interface{}{map[string]interface{}{"params": dummyParams}})
validateJSONResponse(t, respBody)

if !strings.Contains(string(respBody), dummyParams) {
t.Fatalf("expected response containing '%s', got '%s'", dummyParams, string(respBody))
}
if !strings.Contains(string(respBody), dummyParams) {
t.Fatalf("expected response containing '%s', got '%s'", dummyParams, string(respBody))
}

return // We only need to test a single sensitive method.
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}

func TestCannotInvokeSensitiveMethodsWithoutViewingKey(t *testing.T) {
hostPort := _hostWSPort + _testOffset
hostPort := _hostWSPort + _testOffset*7
walletHTTPPort := hostPort + 1
walletWSPort := hostPort + 2

createDummyHost(t, hostPort)
createWalExt(t, createWalExtCfg(hostPort, walletHTTPPort, walletWSPort))
_, shutdownHost := createDummyHost(t, hostPort)
defer shutdownHost() //nolint: errcheck

shutdownWallet := createWalExt(t, createWalExtCfg(hostPort, walletHTTPPort, walletWSPort))
defer shutdownWallet()

conn, err := openWSConn(walletWSPort)
if err != nil {
t.Fatal(err)
}

for _, method := range rpc.SensitiveMethods {
// We use a websocket request because one of the sensitive methods, eth_subscribe, requires it.
respBody, _ := makeWSEthJSONReq(walletWSPort, method, []interface{}{})
respBody := makeWSEthJSONReqWithConn(conn, method, []interface{}{})
if !strings.Contains(string(respBody), fmt.Sprintf(accountmanager.ErrNoViewingKey, method)) {
t.Fatalf("expected response containing '%s', got '%s'", fmt.Sprintf(accountmanager.ErrNoViewingKey, method), string(respBody))
}
}
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}

func TestKeysAreReloadedWhenWalletExtensionRestarts(t *testing.T) {
hostPort := _hostWSPort + _testOffset*2
hostPort := _hostWSPort + _testOffset*8
walletHTTPPort := hostPort + 1
walletWSPort := hostPort + 2

createDummyHost(t, hostPort)
dummyAPI, shutdownHost := createDummyHost(t, hostPort)
defer shutdownHost() //nolint: errcheck
walExtCfg := createWalExtCfg(hostPort, walletHTTPPort, walletWSPort)
shutdown := createWalExt(t, walExtCfg)
shutdownWallet := createWalExt(t, walExtCfg)

viewingKeyBytes := registerPrivateKey(t, walletHTTPPort, walletWSPort, false)
dummyAPI.setViewingKey(viewingKeyBytes)

// We shut down the wallet extension and restart it with the same config, forcing the viewing keys to be reloaded.
shutdown()
createWalExt(t, walExtCfg)
shutdownWallet()
shutdownWallet = createWalExt(t, walExtCfg)
defer shutdownWallet()

respBody := makeHTTPEthJSONReq(walletHTTPPort, rpc.GetBalance, []interface{}{map[string]interface{}{"params": dummyParams}})
validateJSONResponse(t, respBody)
Expand All @@ -204,12 +229,16 @@ func TestKeysAreReloadedWhenWalletExtensionRestarts(t *testing.T) {
}

func TestCanSubscribeForLogsOverWebsockets(t *testing.T) {
hostPort := _hostWSPort + _testOffset*3
hostPort := _hostWSPort + _testOffset*9
walletHTTPPort := hostPort + 1
walletWSPort := hostPort + 2

createDummyHost(t, hostPort)
createWalExt(t, createWalExtCfg(hostPort, walletHTTPPort, walletWSPort))
dummyHash := gethcommon.BigToHash(big.NewInt(1234))

dummyAPI, shutdownHost := createDummyHost(t, hostPort)
defer shutdownHost() //nolint: errcheck
shutdownWallet := createWalExt(t, createWalExtCfg(hostPort, walletHTTPPort, walletWSPort))
defer shutdownWallet()

viewingKeyBytes := registerPrivateKey(t, walletHTTPPort, walletWSPort, false)
dummyAPI.setViewingKey(viewingKeyBytes)
Expand Down
3 changes: 2 additions & 1 deletion tools/walletextension/userconn/user_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"strings"

gethlog "github.com/ethereum/go-ethereum/log"

Expand Down Expand Up @@ -137,7 +138,7 @@ func (w *userConnWS) HandleError(msg string) {

err = w.conn.WriteMessage(websocket.TextMessage, errMsg)
if err != nil {
if websocket.IsCloseError(err) {
if websocket.IsCloseError(err) || strings.Contains(msg, "EOF") {
w.isClosed = true
}
w.logger.Error("could not write error message to websocket", log.ErrKey, err)
Expand Down
14 changes: 14 additions & 0 deletions tools/walletextension/wallet_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io/fs"
"net/http"
"sync/atomic"
"time"

gethcommon "github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -52,8 +53,14 @@ type WalletExtension struct {
serverWSShutdown func(ctx context.Context) error
persistence *persistence.Persistence
logger gethlog.Logger
isShutDown atomicBool
}

type atomicBool int32

func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }

func NewWalletExtension(config Config, logger gethlog.Logger) *WalletExtension {
unauthedClient, err := rpc.NewNetworkClient(wsProtocol + config.NodeRPCWebsocketAddress)
if err != nil {
Expand Down Expand Up @@ -102,6 +109,7 @@ func (we *WalletExtension) Serve(host string, httpPort int, wsPort int) {
}

func (we *WalletExtension) Shutdown() {
we.isShutDown.setTrue()
if we.serverHTTPShutdown != nil {
err := we.serverHTTPShutdown(context.Background())
if err != nil {
Expand Down Expand Up @@ -185,6 +193,9 @@ func (we *WalletExtension) handleSubmitViewingKeyWS(resp http.ResponseWriter, re

// Creates an HTTP connection to handle the request.
func (we *WalletExtension) handleRequestHTTP(resp http.ResponseWriter, req *http.Request, fun func(conn userconn.UserConn)) {
if we.isShutDown.isSet() {
return
}
if httputil.EnableCORS(resp, req) {
return
}
Expand All @@ -194,6 +205,9 @@ func (we *WalletExtension) handleRequestHTTP(resp http.ResponseWriter, req *http

// Creates a websocket connection to handle the request.
func (we *WalletExtension) handleRequestWS(resp http.ResponseWriter, req *http.Request, fun func(conn userconn.UserConn)) {
if we.isShutDown.isSet() {
return
}
userConn, err := userconn.NewUserConnWS(resp, req, we.logger)
if err != nil {
return
Expand Down

0 comments on commit 55f4fdc

Please sign in to comment.