Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Session Keys #2131

Merged
merged 8 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions go/common/custom_query_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import "github.com/ethereum/go-ethereum/common"
const (
UserIDRequestCQMethod = "0x0000000000000000000000000000000000000001"
ListPrivateTransactionsCQMethod = "0x0000000000000000000000000000000000000002"
CreateSessionKeyCQMethod = "0x0000000000000000000000000000000000000003"
ActivateSessionKeyCQMethod = "0x0000000000000000000000000000000000000004"
DeactivateSessionKeyCQMethod = "0x0000000000000000000000000000000000000005"
DeleteSessionKeyCQMethod = "0x0000000000000000000000000000000000000006"
)

type ListPrivateTransactionsQueryParams struct {
Expand Down
2 changes: 1 addition & 1 deletion integration/tengateway/tengateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ func testGetStorageAtForReturningUserID(t *testing.T, _ int, httpURL, wsURL stri
t.Error("Unable to unmarshal response")
}
if !bytes.Equal(gethcommon.FromHex(response.Result), user.tgClient.UserIDBytes()) {
t.Errorf("Wrong UserID returned. Expected: %s, received: %s", user.tgClient.UserID(), response.Result)
t.Errorf("Wrong ID returned. Expected: %s, received: %s", user.tgClient.UserID(), response.Result)
}

// make a request to GetStorageAt with correct parameters to get userID, but with wrong userID
Expand Down
1 change: 1 addition & 0 deletions tools/walletextension/common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const (
PathQuery = "/query/"
PathRevoke = "/revoke/"
PathHealth = "/health/"
PathSessionKeys = "/session-key/"
PathNetworkHealth = "/network-health/"
PathNetworkConfig = "/network-config/"
WSProtocol = "ws://"
Expand Down
27 changes: 22 additions & 5 deletions tools/walletextension/common/types.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,42 @@
package common

import (
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ten-protocol/go-ten/go/common/viewingkey"
"golang.org/x/exp/maps"

"github.com/ethereum/go-ethereum/common"
)

// GWSessionKey - an account key-pair registered for a user
type GWSessionKey struct {
Account *GWAccount
PrivateKey *ecies.PrivateKey // the private key corresponding to the account
}

type GWAccount struct {
User *GWUser
Address *common.Address
Signature []byte
Signature []byte // the signature by the account over the userId - which is derived from the VK
SignatureType viewingkey.SignatureType
}

type GWUser struct {
UserID []byte
Accounts map[common.Address]*GWAccount
UserKey []byte
ID []byte
Accounts map[common.Address]*GWAccount
UserKey []byte
SessionKey *GWSessionKey
ActiveSK bool // the session key is active, and it must be used to sign all incoming transactions, and used as the preferred account
}

func (u GWUser) AllAccounts() map[common.Address]*GWAccount {
res := maps.Clone(u.Accounts)
if u.SessionKey != nil {
res[*u.SessionKey.Account.Address] = u.SessionKey.Account
}
return res
}

func (u GWUser) GetAllAddresses() []common.Address {
return maps.Keys(u.Accounts)
return maps.Keys(u.AllAccounts())
}
98 changes: 98 additions & 0 deletions tools/walletextension/httpapi/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ func NewHTTPRoutes(walletExt *services.Services) []node.Route {
Name: common.APIVersion1 + common.PathNetworkConfig,
Func: httpHandler(walletExt, networkConfigRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "create",
Func: httpHandler(walletExt, createSKRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "activate",
Func: httpHandler(walletExt, activateSKRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "deactivate",
Func: httpHandler(walletExt, deactivateSKRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "delete",
Func: httpHandler(walletExt, deleteSKRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "list",
Func: httpHandler(walletExt, listSKRequestHandler),
},
}
}

Expand Down Expand Up @@ -178,6 +198,7 @@ func authenticateRequestHandler(walletExt *services.Services, conn UserConn) {
}
}

// todo - is this needed?
// This function handles request to /query endpoint.
// In the query parameters address and userID are required. We check if provided address is registered for given userID
// and return true/false in json response
Expand Down Expand Up @@ -517,3 +538,80 @@ func getMessageRequestHandler(walletExt *services.Services, conn UserConn) {
walletExt.Logger().Error("error writing success response", log.ErrKey, err)
}
}

func listSKRequestHandler(walletExt *services.Services, conn UserConn) {
}

func createSKRequestHandler(walletExt *services.Services, conn UserConn) {
withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) {
sk, err := walletExt.SKManager.CreateSessionKey(user)
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("could not create session key: %w", err))
return nil, err
}
return []byte(hexutils.BytesToHex(sk.Account.Address.Bytes())), nil
})
}

func deleteSKRequestHandler(walletExt *services.Services, conn UserConn) {
withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) {
err := walletExt.Storage.RemoveSessionKey(user.ID)
if err != nil {
return nil, err
}
return nil, nil
})
}

func activateSKRequestHandler(walletExt *services.Services, conn UserConn) {
withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) {
err := walletExt.Storage.ActivateSessionKey(user.ID, true)
if err != nil {
return nil, err
}
return []byte{1}, nil
})
}

func deactivateSKRequestHandler(walletExt *services.Services, conn UserConn) {
withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) {
err := walletExt.Storage.ActivateSessionKey(user.ID, false)
if err != nil {
return nil, err
}
return nil, nil
})
}

// extracts the user from the request, and writes the response to the connection
func withUser(walletExt *services.Services, conn UserConn, withUser func(user *common.GWUser) ([]byte, error)) {
_, err := conn.ReadRequest()
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("error reading request: %w", err))
return
}

userID, err := getUserID(conn)
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("user ('u') not found in query parameters"))
walletExt.Logger().Info("user not found in the query params", log.ErrKey, err)
return
}

user, err := walletExt.Storage.GetUser(userID)
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("could not get user: %w", err))
return
}

resp, err := withUser(user)
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("could not process request: %w", err))
return
}

err = conn.WriteResponse(resp)
if err != nil {
walletExt.Logger().Error("error writing success response", log.ErrKey, err)
}
}
53 changes: 37 additions & 16 deletions tools/walletextension/rpcapi/blockchain_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (api *BlockChainAPI) GetBalance(ctx context.Context, address gethcommon.Add
return ExecAuthRPC[hexutil.Big](
ctx,
api.we,
&ExecCfg{
&AuthExecCfg{
cacheCfg: &cache.Cfg{
DynamicType: func() cache.Strategy {
return cacheBlockNumberOrHash(blockNrOrHash)
Expand Down Expand Up @@ -180,25 +180,21 @@ func (api *BlockChainAPI) GetCode(ctx context.Context, address gethcommon.Addres
//
// In future, we can support both CustomQueries and some debug version of eth_getStorageAt if needed.
func (api *BlockChainAPI) GetStorageAt(ctx context.Context, address gethcommon.Address, params string, _ rpc.BlockNumberOrHash) (hexutil.Bytes, error) {
switch address.Hex() {
case common.UserIDRequestCQMethod:
userID, err := extractUserID(ctx, api.we)
if err != nil {
return nil, err
}
user, err := extractUserForRequest(ctx, api.we)
if err != nil {
return nil, err
}

_, err = api.we.Storage.GetUser(userID)
if err != nil {
return nil, err
}
return userID, nil
switch address.Hex() {
case common.UserIDRequestCQMethod: // todo - review whether we need this endpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This endpoint is needed by frontend and is used in wallet-provider.tsx : const fetchedToken = await getToken(providerInstance);

But on the other hand we might need to think about this one again as it can be easy to exploit in my opinion in the following scenario:

User connects to a website with TEN (like Battleships ) and if this website is malicious it can store the data (userID) it gets from eth_getStorageAt call and then it can use the gateway to decrypt your transactions right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly. It's a vulnerability. What does the wallet-provider do with it?

return user.ID, nil
case common.ListPrivateTransactionsCQMethod:
// sensitive CustomQuery methods use the convention of having "address" at the top level of the params json
userAddr, err := extractCustomQueryAddress(params)
if err != nil {
return nil, fmt.Errorf("unable to extract address from custom query params: %w", err)
}
resp, err := ExecAuthRPC[any](ctx, api.we, &ExecCfg{account: userAddr}, "scan_getPersonalTransactions", params)
resp, err := ExecAuthRPC[any](ctx, api.we, &AuthExecCfg{account: userAddr}, "scan_getPersonalTransactions", params)
if err != nil {
return nil, fmt.Errorf("unable to execute custom query: %w", err)
}
Expand All @@ -208,8 +204,33 @@ func (api *BlockChainAPI) GetStorageAt(ctx context.Context, address gethcommon.A
return nil, fmt.Errorf("unable to marshal response object: %w", err)
}
return serialised, nil
case common.CreateSessionKeyCQMethod:
sk, err := api.we.SKManager.CreateSessionKey(user)
if err != nil {
return nil, fmt.Errorf("unable to create session key: %w", err)
}
return sk.Account.Address.Bytes(), nil
case common.ActivateSessionKeyCQMethod:
err := api.we.Storage.ActivateSessionKey(user.ID, true)
if err != nil {
return nil, err
}
return []byte{1}, nil

case common.DeactivateSessionKeyCQMethod:
err := api.we.Storage.ActivateSessionKey(user.ID, false)
if err != nil {
return nil, err
}
return []byte{1}, nil
case common.DeleteSessionKeyCQMethod:
err := api.we.Storage.RemoveSessionKey(user.ID)
if err != nil {
return nil, err
}
return nil, nil
default: // address was not a recognised custom query method address
resp, err := ExecAuthRPC[any](ctx, api.we, &ExecCfg{tryUntilAuthorised: true}, "eth_getStorageAt", address, params, nil)
resp, err := ExecAuthRPC[any](ctx, api.we, &AuthExecCfg{tryUntilAuthorised: true}, "eth_getStorageAt", address, params, nil)
if err != nil {
return nil, fmt.Errorf("unable to execute eth_getStorageAt: %w", err)
}
Expand Down Expand Up @@ -251,7 +272,7 @@ type (
)

func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, blockOverrides *BlockOverrides) (hexutil.Bytes, error) {
resp, err := ExecAuthRPC[hexutil.Bytes](ctx, api.we, &ExecCfg{
resp, err := ExecAuthRPC[hexutil.Bytes](ctx, api.we, &AuthExecCfg{
cacheCfg: &cache.Cfg{
DynamicType: func() cache.Strategy {
return cacheBlockNumberOrHash(blockNrOrHash)
Expand All @@ -273,7 +294,7 @@ func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs
}

func (api *BlockChainAPI) EstimateGas(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Uint64, error) {
resp, err := ExecAuthRPC[hexutil.Uint64](ctx, api.we, &ExecCfg{
resp, err := ExecAuthRPC[hexutil.Uint64](ctx, api.we, &AuthExecCfg{
cacheCfg: &cache.Cfg{
DynamicType: func() cache.Strategy {
if blockNrOrHash != nil {
Expand Down
2 changes: 1 addition & 1 deletion tools/walletextension/rpcapi/debug_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (api *DebugAPI) EventLogRelevancy(ctx context.Context, crit common.FilterCr
l, err := ExecAuthRPC[[]*common.DebugLogVisibility](
ctx,
api.we,
&ExecCfg{
&AuthExecCfg{
cacheCfg: &cache.Cfg{
Type: cache.NoCache,
},
Expand Down
21 changes: 8 additions & 13 deletions tools/walletextension/rpcapi/filter_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rp
errorChannels := make([]<-chan error, 0)
backendSubscriptions := make([]*rpc.ClientSubscription, 0)
for _, address := range candidateAddresses {
rpcWSClient, err := api.we.BackendRPC.ConnectWS(ctx, user.Accounts[address])
rpcWSClient, err := api.we.BackendRPC.ConnectWS(ctx, user.AllAccounts()[address])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -196,13 +196,13 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) (
method := "eth_getLogs"
audit(api.we, "RPC start method=%s args=%v", method, ctx)
requestStartTime := time.Now()
userID, err := extractUserID(ctx, api.we)
user, err := extractUserForRequest(ctx, api.we)
if err != nil {
return nil, err
}

rateLimitAllowed, requestUUID := api.we.RateLimiter.Allow(gethcommon.Address(userID))
defer api.we.RateLimiter.SetRequestEnd(gethcommon.Address(userID), requestUUID)
rateLimitAllowed, requestUUID := api.we.RateLimiter.Allow(gethcommon.Address(user.ID))
defer api.we.RateLimiter.SetRequestEnd(gethcommon.Address(user.ID), requestUUID)
if !rateLimitAllowed {
return nil, fmt.Errorf("rate limit exceeded")
}
Expand All @@ -221,18 +221,13 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) (
return cache.LatestBatch
},
},
generateCacheKey([]any{userID, method, common.SerializableFilterCriteria(crit)}),
generateCacheKey([]any{user.ID, method, common.SerializableFilterCriteria(crit)}),
func() (*[]*types.Log, error) { // called when there is no entry in the cache
user, err := api.we.Storage.GetUser(userID)
if err != nil {
return nil, err
}

allEventLogsMap := make(map[LogKey]*types.Log)
// for each account registered for the current user
// execute the get_Logs function
// dedupe and concatenate the results
for _, acct := range user.Accounts {
for _, acct := range user.AllAccounts() {
eventLogs, err := services.WithEncRPCConnection(ctx, api.we.BackendRPC, acct, func(rpcClient *tenrpc.EncRPCClient) (*[]*types.Log, error) {
var result []*types.Log

Expand Down Expand Up @@ -271,7 +266,7 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) (
if err != nil {
return nil, err
}
audit(api.we, "RPC call. uid=%s, method=%s args=%v result=%v error=%v time=%d", hexutils.BytesToHex(userID), method, crit, res, err, time.Since(requestStartTime).Milliseconds())
audit(api.we, "RPC call. uid=%s, method=%s args=%v result=%v error=%v time=%d", hexutils.BytesToHex(user.ID), method, crit, res, err, time.Since(requestStartTime).Milliseconds())
return *res, err
}

Expand All @@ -281,7 +276,7 @@ func (api *FilterAPI) UninstallFilter(id rpc.ID) bool {
}

func (api *FilterAPI) GetFilterLogs(ctx context.Context, id rpc.ID) ([]*types.Log, error) {
//txRec, err := ExecAuthRPC[[]*types.Log](ctx, api.we, "GetFilterLogs", ExecCfg{account: args.From}, id)
//txRec, err := ExecAuthRPC[[]*types.Log](ctx, api.we, "GetFilterLogs", AuthExecCfg{account: args.From}, id)
//if txRec != nil {
// return *txRec, err
//}
Expand Down
Loading
Loading