Skip to content

Commit

Permalink
try to get userID from URL if not present in query params
Browse files Browse the repository at this point in the history
  • Loading branch information
zkokelj committed Sep 20, 2023
1 parent 1a149de commit f9ed8fb
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
14 changes: 10 additions & 4 deletions tools/walletextension/api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func ethRequestHandler(walletExt *walletextension.WalletExtension, conn userconn
}

// Get userID
hexUserID, err := getQueryParameter(conn.ReadRequestParams(), common.UserQueryParameter)
hexUserID, err := getUserID(conn, 1)
if err != nil {
walletExt.Logger().Error(fmt.Errorf("user not found in the query params: %w. Using the default user", err).Error())
hexUserID = hex.EncodeToString([]byte(common.DefaultUser)) // todo (@ziga) - this can be removed once old WE endpoints are removed
Expand Down Expand Up @@ -317,7 +317,7 @@ func authenticateRequestHandler(walletExt *walletextension.WalletExtension, user
}

// read userID from query params
hexUserID, err := getQueryParameter(userConn.ReadRequestParams(), common.UserQueryParameter)
hexUserID, err := getUserID(userConn, 2)
if err != nil {
userConn.HandleError("Malformed query: 'u' required - representing userID")
walletExt.Logger().Error(fmt.Errorf("user not found in the query params: %w", err).Error())
Expand Down Expand Up @@ -350,7 +350,7 @@ func queryRequestHandler(walletExt *walletextension.WalletExtension, userConn us
return
}

hexUserID, err := getQueryParameter(userConn.ReadRequestParams(), common.UserQueryParameter)
hexUserID, err := getUserID(userConn, 2)
if err != nil {
userConn.HandleError("user ('u') not found in query parameters")
walletExt.Logger().Error(fmt.Errorf("user not found in the query params: %w", err).Error())
Expand All @@ -362,6 +362,12 @@ func queryRequestHandler(walletExt *walletextension.WalletExtension, userConn us
walletExt.Logger().Error(fmt.Errorf("address not found in the query params: %w", err).Error())
return
}
// check if address length is correct
if len(address) != common.EthereumAddressLen {
userConn.HandleError(fmt.Sprintf("provided address length is %d, expected: %d", len(address), common.EthereumAddressLen))
walletExt.Logger().Error(fmt.Errorf(fmt.Sprintf("provided address length is %d, expected: %d", len(address), common.EthereumAddressLen)).Error())
return
}

// check if this account is registered with given user
found, err := walletExt.UserHasAccount(hexUserID, address)
Expand Down Expand Up @@ -399,7 +405,7 @@ func revokeRequestHandler(walletExt *walletextension.WalletExtension, userConn u
return
}

hexUserID, err := getQueryParameter(userConn.ReadRequestParams(), common.UserQueryParameter)
hexUserID, err := getUserID(userConn, 2)
if err != nil {
userConn.HandleError("user ('u') not found in query parameters")
walletExt.Logger().Error(fmt.Errorf("user not found in the query params: %w", err).Error())
Expand Down
33 changes: 32 additions & 1 deletion tools/walletextension/api/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package api
import (
"encoding/json"
"fmt"

"github.com/obscuronet/go-obscuro/tools/walletextension/accountmanager"
"github.com/obscuronet/go-obscuro/tools/walletextension/common"
"github.com/obscuronet/go-obscuro/tools/walletextension/userconn"
"strings"
)

func parseRequest(body []byte) (*accountmanager.RPCRequest, error) {
Expand Down Expand Up @@ -46,3 +47,33 @@ func getQueryParameter(params map[string]string, selectedParameter string) (stri

return value, nil
}

func getUserID(conn userconn.UserConn, userIDPosition int) (string, error) {
// try getting userID from query parameters and return it if successful
userID, err := getQueryParameter(conn.ReadRequestParams(), common.UserQueryParameter)
if err == nil {
if len(userID) != common.MessageUserIDLen {
return "", fmt.Errorf(fmt.Sprintf("wrong length of userID from URL. Got: %d, Expected: %d", len(userID), common.MessageUserIDLen))
}
return userID, err
}

// Alternatively, try to get it from URL path
path := conn.GetHttpRequest().URL.Path
path = strings.Trim(path, "/")
parts := strings.Split(path, "/")

// our URLs, which require userID, have following pattern: <version>/<endpoint (*optional)>/<userID (*optional)>
// userID can be only on second or third position
if len(parts) != userIDPosition+1 {
return "", fmt.Errorf("URL structure of the request looks wrong")
}
userID = parts[userIDPosition]

// Check if userID has the correct length
if len(userID) != common.MessageUserIDLen {
return "", fmt.Errorf(fmt.Sprintf("wrong length of userID from URL. Got: %d, Expected: %d", len(userID), common.MessageUserIDLen))
}

return userID, nil
}
1 change: 1 addition & 0 deletions tools/walletextension/common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const (
MessageFormatRegex = `^Register\s(\w+)\sfor\s(\w+)$`
MessageUserIDLen = 64
SignatureLen = 65
EthereumAddressLen = 42
PersonalSignMessagePrefix = "\x19Ethereum Signed Message:\n%d%s"
GetStorageAtUserIDRequestMethodName = "getUserID"
SuccessMsg = "success"
Expand Down
9 changes: 9 additions & 0 deletions tools/walletextension/userconn/user_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type UserConn interface {
HandleError(msg string)
SupportsSubscriptions() bool
IsClosed() bool
GetHttpRequest() *http.Request
}

// Represents a user's connection over HTTP.
Expand Down Expand Up @@ -106,6 +107,10 @@ func (h *userConnHTTP) ReadRequestParams() map[string]string {
return getQueryParams(h.req.URL.Query())
}

func (h *userConnHTTP) GetHttpRequest() *http.Request {
return h.req
}

func (w *userConnWS) ReadRequest() ([]byte, error) {
_, msg, err := w.conn.ReadMessage()
if err != nil {
Expand Down Expand Up @@ -166,6 +171,10 @@ func (w *userConnWS) ReadRequestParams() map[string]string {
return getQueryParams(w.req.URL.Query())
}

func (w *userConnWS) GetHttpRequest() *http.Request {
return w.req
}

// Logs the error, prints it to the console, and returns the error over HTTP.
func httpLogAndSendErr(resp http.ResponseWriter, msg string) {
http.Error(resp, msg, httpCodeErr)
Expand Down

0 comments on commit f9ed8fb

Please sign in to comment.