Skip to content

Commit

Permalink
Session Keys (#2131)
Browse files Browse the repository at this point in the history
* implement session keys

* implement cosmos for sk

* fix merge

* fix merge

* refactor

* addres pr comments

* addres pr comments
  • Loading branch information
tudor-malene authored Nov 8, 2024
1 parent cb0cc3c commit 65e25e0
Show file tree
Hide file tree
Showing 18 changed files with 635 additions and 205 deletions.
4 changes: 4 additions & 0 deletions go/common/custom_query_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import "github.com/ethereum/go-ethereum/common"
const (
UserIDRequestCQMethod = "0x0000000000000000000000000000000000000001"
ListPrivateTransactionsCQMethod = "0x0000000000000000000000000000000000000002"
CreateSessionKeyCQMethod = "0x0000000000000000000000000000000000000003"
ActivateSessionKeyCQMethod = "0x0000000000000000000000000000000000000004"
DeactivateSessionKeyCQMethod = "0x0000000000000000000000000000000000000005"
DeleteSessionKeyCQMethod = "0x0000000000000000000000000000000000000006"
)

type ListPrivateTransactionsQueryParams struct {
Expand Down
2 changes: 1 addition & 1 deletion integration/tengateway/tengateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ func testGetStorageAtForReturningUserID(t *testing.T, _ int, httpURL, wsURL stri
t.Error("Unable to unmarshal response")
}
if !bytes.Equal(gethcommon.FromHex(response.Result), user.tgClient.UserIDBytes()) {
t.Errorf("Wrong UserID returned. Expected: %s, received: %s", user.tgClient.UserID(), response.Result)
t.Errorf("Wrong ID returned. Expected: %s, received: %s", user.tgClient.UserID(), response.Result)
}

// make a request to GetStorageAt with correct parameters to get userID, but with wrong userID
Expand Down
1 change: 1 addition & 0 deletions tools/walletextension/common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const (
PathQuery = "/query/"
PathRevoke = "/revoke/"
PathHealth = "/health/"
PathSessionKeys = "/session-key/"
PathNetworkHealth = "/network-health/"
PathNetworkConfig = "/network-config/"
WSProtocol = "ws://"
Expand Down
27 changes: 22 additions & 5 deletions tools/walletextension/common/types.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,42 @@
package common

import (
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ten-protocol/go-ten/go/common/viewingkey"
"golang.org/x/exp/maps"

"github.com/ethereum/go-ethereum/common"
)

// GWSessionKey - an account key-pair registered for a user
type GWSessionKey struct {
Account *GWAccount
PrivateKey *ecies.PrivateKey // the private key corresponding to the account
}

type GWAccount struct {
User *GWUser
Address *common.Address
Signature []byte
Signature []byte // the signature by the account over the userId - which is derived from the VK
SignatureType viewingkey.SignatureType
}

type GWUser struct {
UserID []byte
Accounts map[common.Address]*GWAccount
UserKey []byte
ID []byte
Accounts map[common.Address]*GWAccount
UserKey []byte
SessionKey *GWSessionKey
ActiveSK bool // the session key is active, and it must be used to sign all incoming transactions, and used as the preferred account
}

func (u GWUser) AllAccounts() map[common.Address]*GWAccount {
res := maps.Clone(u.Accounts)
if u.SessionKey != nil {
res[*u.SessionKey.Account.Address] = u.SessionKey.Account
}
return res
}

func (u GWUser) GetAllAddresses() []common.Address {
return maps.Keys(u.Accounts)
return maps.Keys(u.AllAccounts())
}
98 changes: 98 additions & 0 deletions tools/walletextension/httpapi/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ func NewHTTPRoutes(walletExt *services.Services) []node.Route {
Name: common.APIVersion1 + common.PathNetworkConfig,
Func: httpHandler(walletExt, networkConfigRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "create",
Func: httpHandler(walletExt, createSKRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "activate",
Func: httpHandler(walletExt, activateSKRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "deactivate",
Func: httpHandler(walletExt, deactivateSKRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "delete",
Func: httpHandler(walletExt, deleteSKRequestHandler),
},
{
Name: common.APIVersion1 + common.PathSessionKeys + "list",
Func: httpHandler(walletExt, listSKRequestHandler),
},
}
}

Expand Down Expand Up @@ -178,6 +198,7 @@ func authenticateRequestHandler(walletExt *services.Services, conn UserConn) {
}
}

// todo - is this needed?
// This function handles request to /query endpoint.
// In the query parameters address and userID are required. We check if provided address is registered for given userID
// and return true/false in json response
Expand Down Expand Up @@ -517,3 +538,80 @@ func getMessageRequestHandler(walletExt *services.Services, conn UserConn) {
walletExt.Logger().Error("error writing success response", log.ErrKey, err)
}
}

func listSKRequestHandler(walletExt *services.Services, conn UserConn) {
}

func createSKRequestHandler(walletExt *services.Services, conn UserConn) {
withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) {
sk, err := walletExt.SKManager.CreateSessionKey(user)
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("could not create session key: %w", err))
return nil, err
}
return []byte(hexutils.BytesToHex(sk.Account.Address.Bytes())), nil
})
}

func deleteSKRequestHandler(walletExt *services.Services, conn UserConn) {
withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) {
err := walletExt.Storage.RemoveSessionKey(user.ID)
if err != nil {
return nil, err
}
return nil, nil
})
}

func activateSKRequestHandler(walletExt *services.Services, conn UserConn) {
withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) {
err := walletExt.Storage.ActivateSessionKey(user.ID, true)
if err != nil {
return nil, err
}
return []byte{1}, nil
})
}

func deactivateSKRequestHandler(walletExt *services.Services, conn UserConn) {
withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) {
err := walletExt.Storage.ActivateSessionKey(user.ID, false)
if err != nil {
return nil, err
}
return nil, nil
})
}

// extracts the user from the request, and writes the response to the connection
func withUser(walletExt *services.Services, conn UserConn, withUser func(user *common.GWUser) ([]byte, error)) {
_, err := conn.ReadRequest()
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("error reading request: %w", err))
return
}

userID, err := getUserID(conn)
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("user ('u') not found in query parameters"))
walletExt.Logger().Info("user not found in the query params", log.ErrKey, err)
return
}

user, err := walletExt.Storage.GetUser(userID)
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("could not get user: %w", err))
return
}

resp, err := withUser(user)
if err != nil {
handleError(conn, walletExt.Logger(), fmt.Errorf("could not process request: %w", err))
return
}

err = conn.WriteResponse(resp)
if err != nil {
walletExt.Logger().Error("error writing success response", log.ErrKey, err)
}
}
53 changes: 37 additions & 16 deletions tools/walletextension/rpcapi/blockchain_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (api *BlockChainAPI) GetBalance(ctx context.Context, address gethcommon.Add
return ExecAuthRPC[hexutil.Big](
ctx,
api.we,
&ExecCfg{
&AuthExecCfg{
cacheCfg: &cache.Cfg{
DynamicType: func() cache.Strategy {
return cacheBlockNumberOrHash(blockNrOrHash)
Expand Down Expand Up @@ -180,25 +180,21 @@ func (api *BlockChainAPI) GetCode(ctx context.Context, address gethcommon.Addres
//
// In future, we can support both CustomQueries and some debug version of eth_getStorageAt if needed.
func (api *BlockChainAPI) GetStorageAt(ctx context.Context, address gethcommon.Address, params string, _ rpc.BlockNumberOrHash) (hexutil.Bytes, error) {
switch address.Hex() {
case common.UserIDRequestCQMethod:
userID, err := extractUserID(ctx, api.we)
if err != nil {
return nil, err
}
user, err := extractUserForRequest(ctx, api.we)
if err != nil {
return nil, err
}

_, err = api.we.Storage.GetUser(userID)
if err != nil {
return nil, err
}
return userID, nil
switch address.Hex() {
case common.UserIDRequestCQMethod: // todo - review whether we need this endpoint
return user.ID, nil
case common.ListPrivateTransactionsCQMethod:
// sensitive CustomQuery methods use the convention of having "address" at the top level of the params json
userAddr, err := extractCustomQueryAddress(params)
if err != nil {
return nil, fmt.Errorf("unable to extract address from custom query params: %w", err)
}
resp, err := ExecAuthRPC[any](ctx, api.we, &ExecCfg{account: userAddr}, "scan_getPersonalTransactions", params)
resp, err := ExecAuthRPC[any](ctx, api.we, &AuthExecCfg{account: userAddr}, "scan_getPersonalTransactions", params)
if err != nil {
return nil, fmt.Errorf("unable to execute custom query: %w", err)
}
Expand All @@ -208,8 +204,33 @@ func (api *BlockChainAPI) GetStorageAt(ctx context.Context, address gethcommon.A
return nil, fmt.Errorf("unable to marshal response object: %w", err)
}
return serialised, nil
case common.CreateSessionKeyCQMethod:
sk, err := api.we.SKManager.CreateSessionKey(user)
if err != nil {
return nil, fmt.Errorf("unable to create session key: %w", err)
}
return sk.Account.Address.Bytes(), nil
case common.ActivateSessionKeyCQMethod:
err := api.we.Storage.ActivateSessionKey(user.ID, true)
if err != nil {
return nil, err
}
return []byte{1}, nil

case common.DeactivateSessionKeyCQMethod:
err := api.we.Storage.ActivateSessionKey(user.ID, false)
if err != nil {
return nil, err
}
return []byte{1}, nil
case common.DeleteSessionKeyCQMethod:
err := api.we.Storage.RemoveSessionKey(user.ID)
if err != nil {
return nil, err
}
return nil, nil
default: // address was not a recognised custom query method address
resp, err := ExecAuthRPC[any](ctx, api.we, &ExecCfg{tryUntilAuthorised: true}, "eth_getStorageAt", address, params, nil)
resp, err := ExecAuthRPC[any](ctx, api.we, &AuthExecCfg{tryUntilAuthorised: true}, "eth_getStorageAt", address, params, nil)
if err != nil {
return nil, fmt.Errorf("unable to execute eth_getStorageAt: %w", err)
}
Expand Down Expand Up @@ -251,7 +272,7 @@ type (
)

func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, blockOverrides *BlockOverrides) (hexutil.Bytes, error) {
resp, err := ExecAuthRPC[hexutil.Bytes](ctx, api.we, &ExecCfg{
resp, err := ExecAuthRPC[hexutil.Bytes](ctx, api.we, &AuthExecCfg{
cacheCfg: &cache.Cfg{
DynamicType: func() cache.Strategy {
return cacheBlockNumberOrHash(blockNrOrHash)
Expand All @@ -273,7 +294,7 @@ func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs
}

func (api *BlockChainAPI) EstimateGas(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Uint64, error) {
resp, err := ExecAuthRPC[hexutil.Uint64](ctx, api.we, &ExecCfg{
resp, err := ExecAuthRPC[hexutil.Uint64](ctx, api.we, &AuthExecCfg{
cacheCfg: &cache.Cfg{
DynamicType: func() cache.Strategy {
if blockNrOrHash != nil {
Expand Down
2 changes: 1 addition & 1 deletion tools/walletextension/rpcapi/debug_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (api *DebugAPI) EventLogRelevancy(ctx context.Context, crit common.FilterCr
l, err := ExecAuthRPC[[]*common.DebugLogVisibility](
ctx,
api.we,
&ExecCfg{
&AuthExecCfg{
cacheCfg: &cache.Cfg{
Type: cache.NoCache,
},
Expand Down
21 changes: 8 additions & 13 deletions tools/walletextension/rpcapi/filter_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rp
errorChannels := make([]<-chan error, 0)
backendSubscriptions := make([]*rpc.ClientSubscription, 0)
for _, address := range candidateAddresses {
rpcWSClient, err := api.we.BackendRPC.ConnectWS(ctx, user.Accounts[address])
rpcWSClient, err := api.we.BackendRPC.ConnectWS(ctx, user.AllAccounts()[address])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -196,13 +196,13 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) (
method := "eth_getLogs"
audit(api.we, "RPC start method=%s args=%v", method, ctx)
requestStartTime := time.Now()
userID, err := extractUserID(ctx, api.we)
user, err := extractUserForRequest(ctx, api.we)
if err != nil {
return nil, err
}

rateLimitAllowed, requestUUID := api.we.RateLimiter.Allow(gethcommon.Address(userID))
defer api.we.RateLimiter.SetRequestEnd(gethcommon.Address(userID), requestUUID)
rateLimitAllowed, requestUUID := api.we.RateLimiter.Allow(gethcommon.Address(user.ID))
defer api.we.RateLimiter.SetRequestEnd(gethcommon.Address(user.ID), requestUUID)
if !rateLimitAllowed {
return nil, fmt.Errorf("rate limit exceeded")
}
Expand All @@ -221,18 +221,13 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) (
return cache.LatestBatch
},
},
generateCacheKey([]any{userID, method, common.SerializableFilterCriteria(crit)}),
generateCacheKey([]any{user.ID, method, common.SerializableFilterCriteria(crit)}),
func() (*[]*types.Log, error) { // called when there is no entry in the cache
user, err := api.we.Storage.GetUser(userID)
if err != nil {
return nil, err
}

allEventLogsMap := make(map[LogKey]*types.Log)
// for each account registered for the current user
// execute the get_Logs function
// dedupe and concatenate the results
for _, acct := range user.Accounts {
for _, acct := range user.AllAccounts() {
eventLogs, err := services.WithEncRPCConnection(ctx, api.we.BackendRPC, acct, func(rpcClient *tenrpc.EncRPCClient) (*[]*types.Log, error) {
var result []*types.Log

Expand Down Expand Up @@ -271,7 +266,7 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) (
if err != nil {
return nil, err
}
audit(api.we, "RPC call. uid=%s, method=%s args=%v result=%v error=%v time=%d", hexutils.BytesToHex(userID), method, crit, res, err, time.Since(requestStartTime).Milliseconds())
audit(api.we, "RPC call. uid=%s, method=%s args=%v result=%v error=%v time=%d", hexutils.BytesToHex(user.ID), method, crit, res, err, time.Since(requestStartTime).Milliseconds())
return *res, err
}

Expand All @@ -281,7 +276,7 @@ func (api *FilterAPI) UninstallFilter(id rpc.ID) bool {
}

func (api *FilterAPI) GetFilterLogs(ctx context.Context, id rpc.ID) ([]*types.Log, error) {
//txRec, err := ExecAuthRPC[[]*types.Log](ctx, api.we, "GetFilterLogs", ExecCfg{account: args.From}, id)
//txRec, err := ExecAuthRPC[[]*types.Log](ctx, api.we, "GetFilterLogs", AuthExecCfg{account: args.From}, id)
//if txRec != nil {
// return *txRec, err
//}
Expand Down
Loading

0 comments on commit 65e25e0

Please sign in to comment.