diff --git a/go/common/custom_query_types.go b/go/common/custom_query_types.go index e3e483e63b..662c71d6c9 100644 --- a/go/common/custom_query_types.go +++ b/go/common/custom_query_types.go @@ -21,6 +21,10 @@ import "github.com/ethereum/go-ethereum/common" const ( UserIDRequestCQMethod = "0x0000000000000000000000000000000000000001" ListPrivateTransactionsCQMethod = "0x0000000000000000000000000000000000000002" + CreateSessionKeyCQMethod = "0x0000000000000000000000000000000000000003" + ActivateSessionKeyCQMethod = "0x0000000000000000000000000000000000000004" + DeactivateSessionKeyCQMethod = "0x0000000000000000000000000000000000000005" + DeleteSessionKeyCQMethod = "0x0000000000000000000000000000000000000006" ) type ListPrivateTransactionsQueryParams struct { diff --git a/integration/tengateway/tengateway_test.go b/integration/tengateway/tengateway_test.go index a7908f459b..a63b8fde86 100644 --- a/integration/tengateway/tengateway_test.go +++ b/integration/tengateway/tengateway_test.go @@ -778,7 +778,7 @@ func testGetStorageAtForReturningUserID(t *testing.T, _ int, httpURL, wsURL stri t.Error("Unable to unmarshal response") } if !bytes.Equal(gethcommon.FromHex(response.Result), user.tgClient.UserIDBytes()) { - t.Errorf("Wrong UserID returned. Expected: %s, received: %s", user.tgClient.UserID(), response.Result) + t.Errorf("Wrong ID returned. Expected: %s, received: %s", user.tgClient.UserID(), response.Result) } // make a request to GetStorageAt with correct parameters to get userID, but with wrong userID diff --git a/tools/walletextension/common/constants.go b/tools/walletextension/common/constants.go index f0d3b83192..e39f2938e7 100644 --- a/tools/walletextension/common/constants.go +++ b/tools/walletextension/common/constants.go @@ -22,6 +22,7 @@ const ( PathQuery = "/query/" PathRevoke = "/revoke/" PathHealth = "/health/" + PathSessionKeys = "/session-key/" PathNetworkHealth = "/network-health/" PathNetworkConfig = "/network-config/" WSProtocol = "ws://" diff --git a/tools/walletextension/common/types.go b/tools/walletextension/common/types.go index 1326e26203..b54ca9dd11 100644 --- a/tools/walletextension/common/types.go +++ b/tools/walletextension/common/types.go @@ -1,25 +1,42 @@ package common import ( + "github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ten-protocol/go-ten/go/common/viewingkey" "golang.org/x/exp/maps" "github.com/ethereum/go-ethereum/common" ) +// GWSessionKey - an account key-pair registered for a user +type GWSessionKey struct { + Account *GWAccount + PrivateKey *ecies.PrivateKey // the private key corresponding to the account +} + type GWAccount struct { User *GWUser Address *common.Address - Signature []byte + Signature []byte // the signature by the account over the userId - which is derived from the VK SignatureType viewingkey.SignatureType } type GWUser struct { - UserID []byte - Accounts map[common.Address]*GWAccount - UserKey []byte + ID []byte + Accounts map[common.Address]*GWAccount + UserKey []byte + SessionKey *GWSessionKey + ActiveSK bool // the session key is active, and it must be used to sign all incoming transactions, and used as the preferred account +} + +func (u GWUser) AllAccounts() map[common.Address]*GWAccount { + res := maps.Clone(u.Accounts) + if u.SessionKey != nil { + res[*u.SessionKey.Account.Address] = u.SessionKey.Account + } + return res } func (u GWUser) GetAllAddresses() []common.Address { - return maps.Keys(u.Accounts) + return maps.Keys(u.AllAccounts()) } diff --git a/tools/walletextension/httpapi/routes.go b/tools/walletextension/httpapi/routes.go index efac92d1d4..4134c5c53b 100644 --- a/tools/walletextension/httpapi/routes.go +++ b/tools/walletextension/httpapi/routes.go @@ -62,6 +62,26 @@ func NewHTTPRoutes(walletExt *services.Services) []node.Route { Name: common.APIVersion1 + common.PathNetworkConfig, Func: httpHandler(walletExt, networkConfigRequestHandler), }, + { + Name: common.APIVersion1 + common.PathSessionKeys + "create", + Func: httpHandler(walletExt, createSKRequestHandler), + }, + { + Name: common.APIVersion1 + common.PathSessionKeys + "activate", + Func: httpHandler(walletExt, activateSKRequestHandler), + }, + { + Name: common.APIVersion1 + common.PathSessionKeys + "deactivate", + Func: httpHandler(walletExt, deactivateSKRequestHandler), + }, + { + Name: common.APIVersion1 + common.PathSessionKeys + "delete", + Func: httpHandler(walletExt, deleteSKRequestHandler), + }, + { + Name: common.APIVersion1 + common.PathSessionKeys + "list", + Func: httpHandler(walletExt, listSKRequestHandler), + }, } } @@ -178,6 +198,7 @@ func authenticateRequestHandler(walletExt *services.Services, conn UserConn) { } } +// todo - is this needed? // This function handles request to /query endpoint. // In the query parameters address and userID are required. We check if provided address is registered for given userID // and return true/false in json response @@ -517,3 +538,80 @@ func getMessageRequestHandler(walletExt *services.Services, conn UserConn) { walletExt.Logger().Error("error writing success response", log.ErrKey, err) } } + +func listSKRequestHandler(walletExt *services.Services, conn UserConn) { +} + +func createSKRequestHandler(walletExt *services.Services, conn UserConn) { + withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) { + sk, err := walletExt.SKManager.CreateSessionKey(user) + if err != nil { + handleError(conn, walletExt.Logger(), fmt.Errorf("could not create session key: %w", err)) + return nil, err + } + return []byte(hexutils.BytesToHex(sk.Account.Address.Bytes())), nil + }) +} + +func deleteSKRequestHandler(walletExt *services.Services, conn UserConn) { + withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) { + err := walletExt.Storage.RemoveSessionKey(user.ID) + if err != nil { + return nil, err + } + return nil, nil + }) +} + +func activateSKRequestHandler(walletExt *services.Services, conn UserConn) { + withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) { + err := walletExt.Storage.ActivateSessionKey(user.ID, true) + if err != nil { + return nil, err + } + return []byte{1}, nil + }) +} + +func deactivateSKRequestHandler(walletExt *services.Services, conn UserConn) { + withUser(walletExt, conn, func(user *common.GWUser) ([]byte, error) { + err := walletExt.Storage.ActivateSessionKey(user.ID, false) + if err != nil { + return nil, err + } + return nil, nil + }) +} + +// extracts the user from the request, and writes the response to the connection +func withUser(walletExt *services.Services, conn UserConn, withUser func(user *common.GWUser) ([]byte, error)) { + _, err := conn.ReadRequest() + if err != nil { + handleError(conn, walletExt.Logger(), fmt.Errorf("error reading request: %w", err)) + return + } + + userID, err := getUserID(conn) + if err != nil { + handleError(conn, walletExt.Logger(), fmt.Errorf("user ('u') not found in query parameters")) + walletExt.Logger().Info("user not found in the query params", log.ErrKey, err) + return + } + + user, err := walletExt.Storage.GetUser(userID) + if err != nil { + handleError(conn, walletExt.Logger(), fmt.Errorf("could not get user: %w", err)) + return + } + + resp, err := withUser(user) + if err != nil { + handleError(conn, walletExt.Logger(), fmt.Errorf("could not process request: %w", err)) + return + } + + err = conn.WriteResponse(resp) + if err != nil { + walletExt.Logger().Error("error writing success response", log.ErrKey, err) + } +} diff --git a/tools/walletextension/rpcapi/blockchain_api.go b/tools/walletextension/rpcapi/blockchain_api.go index d94f3b0d87..04f467d1a9 100644 --- a/tools/walletextension/rpcapi/blockchain_api.go +++ b/tools/walletextension/rpcapi/blockchain_api.go @@ -52,7 +52,7 @@ func (api *BlockChainAPI) GetBalance(ctx context.Context, address gethcommon.Add return ExecAuthRPC[hexutil.Big]( ctx, api.we, - &ExecCfg{ + &AuthExecCfg{ cacheCfg: &cache.Cfg{ DynamicType: func() cache.Strategy { return cacheBlockNumberOrHash(blockNrOrHash) @@ -180,25 +180,21 @@ func (api *BlockChainAPI) GetCode(ctx context.Context, address gethcommon.Addres // // In future, we can support both CustomQueries and some debug version of eth_getStorageAt if needed. func (api *BlockChainAPI) GetStorageAt(ctx context.Context, address gethcommon.Address, params string, _ rpc.BlockNumberOrHash) (hexutil.Bytes, error) { - switch address.Hex() { - case common.UserIDRequestCQMethod: - userID, err := extractUserID(ctx, api.we) - if err != nil { - return nil, err - } + user, err := extractUserForRequest(ctx, api.we) + if err != nil { + return nil, err + } - _, err = api.we.Storage.GetUser(userID) - if err != nil { - return nil, err - } - return userID, nil + switch address.Hex() { + case common.UserIDRequestCQMethod: // todo - review whether we need this endpoint + return user.ID, nil case common.ListPrivateTransactionsCQMethod: // sensitive CustomQuery methods use the convention of having "address" at the top level of the params json userAddr, err := extractCustomQueryAddress(params) if err != nil { return nil, fmt.Errorf("unable to extract address from custom query params: %w", err) } - resp, err := ExecAuthRPC[any](ctx, api.we, &ExecCfg{account: userAddr}, "scan_getPersonalTransactions", params) + resp, err := ExecAuthRPC[any](ctx, api.we, &AuthExecCfg{account: userAddr}, "scan_getPersonalTransactions", params) if err != nil { return nil, fmt.Errorf("unable to execute custom query: %w", err) } @@ -208,8 +204,33 @@ func (api *BlockChainAPI) GetStorageAt(ctx context.Context, address gethcommon.A return nil, fmt.Errorf("unable to marshal response object: %w", err) } return serialised, nil + case common.CreateSessionKeyCQMethod: + sk, err := api.we.SKManager.CreateSessionKey(user) + if err != nil { + return nil, fmt.Errorf("unable to create session key: %w", err) + } + return sk.Account.Address.Bytes(), nil + case common.ActivateSessionKeyCQMethod: + err := api.we.Storage.ActivateSessionKey(user.ID, true) + if err != nil { + return nil, err + } + return []byte{1}, nil + + case common.DeactivateSessionKeyCQMethod: + err := api.we.Storage.ActivateSessionKey(user.ID, false) + if err != nil { + return nil, err + } + return []byte{1}, nil + case common.DeleteSessionKeyCQMethod: + err := api.we.Storage.RemoveSessionKey(user.ID) + if err != nil { + return nil, err + } + return nil, nil default: // address was not a recognised custom query method address - resp, err := ExecAuthRPC[any](ctx, api.we, &ExecCfg{tryUntilAuthorised: true}, "eth_getStorageAt", address, params, nil) + resp, err := ExecAuthRPC[any](ctx, api.we, &AuthExecCfg{tryUntilAuthorised: true}, "eth_getStorageAt", address, params, nil) if err != nil { return nil, fmt.Errorf("unable to execute eth_getStorageAt: %w", err) } @@ -251,7 +272,7 @@ type ( ) func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, blockOverrides *BlockOverrides) (hexutil.Bytes, error) { - resp, err := ExecAuthRPC[hexutil.Bytes](ctx, api.we, &ExecCfg{ + resp, err := ExecAuthRPC[hexutil.Bytes](ctx, api.we, &AuthExecCfg{ cacheCfg: &cache.Cfg{ DynamicType: func() cache.Strategy { return cacheBlockNumberOrHash(blockNrOrHash) @@ -273,7 +294,7 @@ func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs } func (api *BlockChainAPI) EstimateGas(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Uint64, error) { - resp, err := ExecAuthRPC[hexutil.Uint64](ctx, api.we, &ExecCfg{ + resp, err := ExecAuthRPC[hexutil.Uint64](ctx, api.we, &AuthExecCfg{ cacheCfg: &cache.Cfg{ DynamicType: func() cache.Strategy { if blockNrOrHash != nil { diff --git a/tools/walletextension/rpcapi/debug_api.go b/tools/walletextension/rpcapi/debug_api.go index c3b1891d32..1ab39ebfa2 100644 --- a/tools/walletextension/rpcapi/debug_api.go +++ b/tools/walletextension/rpcapi/debug_api.go @@ -61,7 +61,7 @@ func (api *DebugAPI) EventLogRelevancy(ctx context.Context, crit common.FilterCr l, err := ExecAuthRPC[[]*common.DebugLogVisibility]( ctx, api.we, - &ExecCfg{ + &AuthExecCfg{ cacheCfg: &cache.Cfg{ Type: cache.NoCache, }, diff --git a/tools/walletextension/rpcapi/filter_api.go b/tools/walletextension/rpcapi/filter_api.go index 202e864cd2..612bcbd9b3 100644 --- a/tools/walletextension/rpcapi/filter_api.go +++ b/tools/walletextension/rpcapi/filter_api.go @@ -85,7 +85,7 @@ func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rp errorChannels := make([]<-chan error, 0) backendSubscriptions := make([]*rpc.ClientSubscription, 0) for _, address := range candidateAddresses { - rpcWSClient, err := api.we.BackendRPC.ConnectWS(ctx, user.Accounts[address]) + rpcWSClient, err := api.we.BackendRPC.ConnectWS(ctx, user.AllAccounts()[address]) if err != nil { return nil, err } @@ -196,13 +196,13 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) ( method := "eth_getLogs" audit(api.we, "RPC start method=%s args=%v", method, ctx) requestStartTime := time.Now() - userID, err := extractUserID(ctx, api.we) + user, err := extractUserForRequest(ctx, api.we) if err != nil { return nil, err } - rateLimitAllowed, requestUUID := api.we.RateLimiter.Allow(gethcommon.Address(userID)) - defer api.we.RateLimiter.SetRequestEnd(gethcommon.Address(userID), requestUUID) + rateLimitAllowed, requestUUID := api.we.RateLimiter.Allow(gethcommon.Address(user.ID)) + defer api.we.RateLimiter.SetRequestEnd(gethcommon.Address(user.ID), requestUUID) if !rateLimitAllowed { return nil, fmt.Errorf("rate limit exceeded") } @@ -221,18 +221,13 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) ( return cache.LatestBatch }, }, - generateCacheKey([]any{userID, method, common.SerializableFilterCriteria(crit)}), + generateCacheKey([]any{user.ID, method, common.SerializableFilterCriteria(crit)}), func() (*[]*types.Log, error) { // called when there is no entry in the cache - user, err := api.we.Storage.GetUser(userID) - if err != nil { - return nil, err - } - allEventLogsMap := make(map[LogKey]*types.Log) // for each account registered for the current user // execute the get_Logs function // dedupe and concatenate the results - for _, acct := range user.Accounts { + for _, acct := range user.AllAccounts() { eventLogs, err := services.WithEncRPCConnection(ctx, api.we.BackendRPC, acct, func(rpcClient *tenrpc.EncRPCClient) (*[]*types.Log, error) { var result []*types.Log @@ -271,7 +266,7 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) ( if err != nil { return nil, err } - audit(api.we, "RPC call. uid=%s, method=%s args=%v result=%v error=%v time=%d", hexutils.BytesToHex(userID), method, crit, res, err, time.Since(requestStartTime).Milliseconds()) + audit(api.we, "RPC call. uid=%s, method=%s args=%v result=%v error=%v time=%d", hexutils.BytesToHex(user.ID), method, crit, res, err, time.Since(requestStartTime).Milliseconds()) return *res, err } @@ -281,7 +276,7 @@ func (api *FilterAPI) UninstallFilter(id rpc.ID) bool { } func (api *FilterAPI) GetFilterLogs(ctx context.Context, id rpc.ID) ([]*types.Log, error) { - //txRec, err := ExecAuthRPC[[]*types.Log](ctx, api.we, "GetFilterLogs", ExecCfg{account: args.From}, id) + //txRec, err := ExecAuthRPC[[]*types.Log](ctx, api.we, "GetFilterLogs", AuthExecCfg{account: args.From}, id) //if txRec != nil { // return *txRec, err //} diff --git a/tools/walletextension/rpcapi/transaction_api.go b/tools/walletextension/rpcapi/transaction_api.go index 73bd8288e3..f7f89aad73 100644 --- a/tools/walletextension/rpcapi/transaction_api.go +++ b/tools/walletextension/rpcapi/transaction_api.go @@ -65,7 +65,7 @@ func (s *TransactionAPI) GetTransactionCount(ctx context.Context, address common return ExecAuthRPC[hexutil.Uint64]( ctx, s.we, - &ExecCfg{ + &AuthExecCfg{ account: &address, cacheCfg: &cache.Cfg{ DynamicType: func() cache.Strategy { @@ -80,11 +80,11 @@ func (s *TransactionAPI) GetTransactionCount(ctx context.Context, address common } func (s *TransactionAPI) GetTransactionByHash(ctx context.Context, hash common.Hash) (*rpc.RpcTransaction, error) { - return ExecAuthRPC[rpc.RpcTransaction](ctx, s.we, &ExecCfg{tryAll: true, cacheCfg: &cache.Cfg{Type: cache.LongLiving}}, "eth_getTransactionByHash", hash) + return ExecAuthRPC[rpc.RpcTransaction](ctx, s.we, &AuthExecCfg{tryAll: true, cacheCfg: &cache.Cfg{Type: cache.LongLiving}}, "eth_getTransactionByHash", hash) } func (s *TransactionAPI) GetRawTransactionByHash(ctx context.Context, hash common.Hash) (hexutil.Bytes, error) { - tx, err := ExecAuthRPC[hexutil.Bytes](ctx, s.we, &ExecCfg{tryAll: true, cacheCfg: &cache.Cfg{Type: cache.LongLiving}}, "eth_getRawTransactionByHash", hash) + tx, err := ExecAuthRPC[hexutil.Bytes](ctx, s.we, &AuthExecCfg{tryAll: true, cacheCfg: &cache.Cfg{Type: cache.LongLiving}}, "eth_getRawTransactionByHash", hash) if tx != nil { return *tx, err } @@ -92,7 +92,7 @@ func (s *TransactionAPI) GetRawTransactionByHash(ctx context.Context, hash commo } func (s *TransactionAPI) GetTransactionReceipt(ctx context.Context, hash common.Hash) (map[string]interface{}, error) { - txRec, err := ExecAuthRPC[map[string]interface{}](ctx, s.we, &ExecCfg{tryUntilAuthorised: true, cacheCfg: &cache.Cfg{Type: cache.LongLiving}}, "eth_getTransactionReceipt", hash) + txRec, err := ExecAuthRPC[map[string]interface{}](ctx, s.we, &AuthExecCfg{tryUntilAuthorised: true, cacheCfg: &cache.Cfg{Type: cache.LongLiving}}, "eth_getTransactionReceipt", hash) if err != nil { return nil, err } @@ -103,11 +103,13 @@ func (s *TransactionAPI) GetTransactionReceipt(ctx context.Context, hash common. } func (s *TransactionAPI) SendTransaction(ctx context.Context, args gethapi.TransactionArgs) (common.Hash, error) { - txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &ExecCfg{account: args.From, timeout: sendTransactionDuration}, "eth_sendTransaction", args) - if err != nil { - return common.Hash{}, err - } - return *txRec, err + //txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &AuthExecCfg{account: args.From, timeout: sendTransactionDuration}, "eth_sendTransaction", args) + //if err != nil { + // return common.Hash{}, err + //} + //return *txRec, err + // not implemented for now. We might use this for session keys. + return common.Hash{}, rpcNotImplemented } type SignTransactionResult struct { @@ -120,7 +122,21 @@ func (s *TransactionAPI) FillTransaction(ctx context.Context, args gethapi.Trans } func (s *TransactionAPI) SendRawTransaction(ctx context.Context, input hexutil.Bytes) (common.Hash, error) { - txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &ExecCfg{tryAll: true, timeout: sendTransactionDuration}, "eth_sendRawTransaction", input) + user, err := extractUserForRequest(ctx, s.we) + if err != nil { + return common.Hash{}, err + } + + signedTx := input + // when there is an active Session Key, sign all incoming transactions with that SK + if user.ActiveSK && user.SessionKey != nil { + signedTx, err = s.we.SKManager.SignTx(ctx, user, input) + if err != nil { + return common.Hash{}, err + } + } + + txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &AuthExecCfg{tryAll: true, timeout: sendTransactionDuration}, "eth_sendRawTransaction", signedTx) if err != nil { return common.Hash{}, err } @@ -132,7 +148,7 @@ func (s *TransactionAPI) PendingTransactions() ([]*rpc.RpcTransaction, error) { } func (s *TransactionAPI) Resend(ctx context.Context, sendArgs gethapi.TransactionArgs, gasPrice *hexutil.Big, gasLimit *hexutil.Uint64) (common.Hash, error) { - txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &ExecCfg{account: sendArgs.From}, "eth_resend", sendArgs, gasPrice, gasLimit) + txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &AuthExecCfg{account: sendArgs.From}, "eth_resend", sendArgs, gasPrice, gasLimit) if txRec != nil { return *txRec, err } diff --git a/tools/walletextension/rpcapi/utils.go b/tools/walletextension/rpcapi/utils.go index 9012f3217e..5b1da57589 100644 --- a/tools/walletextension/rpcapi/utils.go +++ b/tools/walletextension/rpcapi/utils.go @@ -41,7 +41,7 @@ const ( var rpcNotImplemented = fmt.Errorf("rpc endpoint not implemented") -type ExecCfg struct { +type AuthExecCfg struct { // these 4 fields specify the account(s) that should make the backend call account *gethcommon.Address computeFromCallback func(user *common.GWUser) *gethcommon.Address @@ -79,29 +79,24 @@ func UnauthenticatedTenRPCCall[R any](ctx context.Context, w *services.Services, return res, err } -func ExecAuthRPC[R any](ctx context.Context, w *services.Services, cfg *ExecCfg, method string, args ...any) (*R, error) { +func ExecAuthRPC[R any](ctx context.Context, w *services.Services, cfg *AuthExecCfg, method string, args ...any) (*R, error) { audit(w, "RPC start method=%s args=%v", method, args) requestStartTime := time.Now() - userID, err := extractUserID(ctx, w) + user, err := extractUserForRequest(ctx, w) if err != nil { return nil, err } - rateLimitAllowed, requestUUID := w.RateLimiter.Allow(gethcommon.Address(userID)) - defer w.RateLimiter.SetRequestEnd(gethcommon.Address(userID), requestUUID) + rateLimitAllowed, requestUUID := w.RateLimiter.Allow(gethcommon.Address(user.ID)) + defer w.RateLimiter.SetRequestEnd(gethcommon.Address(user.ID), requestUUID) if !rateLimitAllowed { return nil, fmt.Errorf("rate limit exceeded") } - cacheArgs := []any{userID, method} + cacheArgs := []any{user.ID, method} cacheArgs = append(cacheArgs, args...) res, err := cache.WithCache(w.RPCResponsesCache, cfg.cacheCfg, generateCacheKey(cacheArgs), func() (*R, error) { - user, err := w.Storage.GetUser(userID) - if err != nil { - return nil, err - } - // determine candidate "from" candidateAccts, err := getCandidateAccounts(user, w, cfg) if err != nil { @@ -149,16 +144,16 @@ func ExecAuthRPC[R any](ctx context.Context, w *services.Services, cfg *ExecCfg, } return nil, rpcErr }) - audit(w, "RPC call. uid=%s, method=%s args=%v result=%s error=%s time=%d", hexutils.BytesToHex(userID), method, args, SafeGenericToString(res), err, time.Since(requestStartTime).Milliseconds()) + audit(w, "RPC call. uid=%s, method=%s args=%v result=%s error=%s time=%d", hexutils.BytesToHex(user.ID), method, args, SafeGenericToString(res), err, time.Since(requestStartTime).Milliseconds()) return res, err } -func getCandidateAccounts(user *common.GWUser, we *services.Services, cfg *ExecCfg) ([]*common.GWAccount, error) { +func getCandidateAccounts(user *common.GWUser, we *services.Services, cfg *AuthExecCfg) ([]*common.GWAccount, error) { candidateAccts := make([]*common.GWAccount, 0) // for users with multiple accounts try to determine a candidate account based on the available information switch { case cfg.account != nil: - acc := user.Accounts[*cfg.account] + acc := user.AllAccounts()[*cfg.account] if acc != nil { candidateAccts = append(candidateAccts, acc) return candidateAccts, nil @@ -167,19 +162,19 @@ func getCandidateAccounts(user *common.GWUser, we *services.Services, cfg *ExecC case cfg.computeFromCallback != nil: suggestedAddress := cfg.computeFromCallback(user) if suggestedAddress != nil { - acc := user.Accounts[*suggestedAddress] + acc := user.AllAccounts()[*suggestedAddress] if acc != nil { candidateAccts = append(candidateAccts, acc) return candidateAccts, nil } else { // this should not happen, because the suggestedAddress is one of the addresses - return nil, fmt.Errorf("should not happen. From: %s . UserId: %s", suggestedAddress.Hex(), hexutils.BytesToHex(user.UserID)) + return nil, fmt.Errorf("should not happen. From: %s . UserId: %s", suggestedAddress.Hex(), hexutils.BytesToHex(user.ID)) } } } if cfg.tryAll || cfg.tryUntilAuthorised { - for _, acc := range user.Accounts { + for _, acc := range user.AllAccounts() { candidateAccts = append(candidateAccts, acc) } } @@ -190,15 +185,27 @@ func getCandidateAccounts(user *common.GWUser, we *services.Services, cfg *ExecC func extractUserID(ctx context.Context, _ *services.Services) ([]byte, error) { token, ok := ctx.Value(rpc.GWTokenKey{}).(string) if !ok { - return nil, fmt.Errorf("invalid userid: %s", ctx.Value(rpc.GWTokenKey{})) + return nil, fmt.Errorf("invalid authentication token: %s", ctx.Value(rpc.GWTokenKey{})) } userID := gethcommon.FromHex(token) if len(userID) != viewingkey.UserIDLength { - return nil, fmt.Errorf("invalid userid: %s", token) + return nil, fmt.Errorf("invalid authentication token: %s", token) } return userID, nil } +func extractUserForRequest(ctx context.Context, w *services.Services) (*common.GWUser, error) { + userID, err := extractUserID(ctx, w) + if err != nil { + return nil, err + } + user, err := w.Storage.GetUser(userID) + if err != nil { + return nil, fmt.Errorf("authentication failed: %w", err) + } + return user, nil +} + // generateCacheKey generates a cache key for the given method, encryptionToken and parameters // encryptionToken is used to generate a unique cache key for each user and empty string should be used for public data func generateCacheKey(params []any) []byte { diff --git a/tools/walletextension/services/sk_manager.go b/tools/walletextension/services/sk_manager.go new file mode 100644 index 0000000000..d87a62c8e3 --- /dev/null +++ b/tools/walletextension/services/sk_manager.go @@ -0,0 +1,108 @@ +package services + +import ( + "context" + "fmt" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core/types" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" + gethlog "github.com/ethereum/go-ethereum/log" + "github.com/ten-protocol/go-ten/go/common/viewingkey" + "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/storage" +) + +// SKManager - session keys are Private Keys managed by the Gateway +// At the moment, each user can have a single Session Key. Which is either active or inactive +// when the SK is active, then all transactions submitted by that user will be signed with the session key +// The SK is also considered an "Account" of that user +// when the SK is created, it signs over the VK of the user so that it can interact with a node the standard way +// From the POV of the Ten network - a session key is a normal account key +type SKManager interface { + CreateSessionKey(user *common.GWUser) (*common.GWSessionKey, error) + SignTx(ctx context.Context, user *common.GWUser, input hexutil.Bytes) (hexutil.Bytes, error) +} + +type skManager struct { + storage storage.UserStorage + config *common.Config + logger gethlog.Logger +} + +func NewSKManager(storage storage.UserStorage, config *common.Config, logger gethlog.Logger) SKManager { + return &skManager{ + storage: storage, + config: config, + logger: logger, + } +} + +// CreateSessionKey - generates a fresh key and signs over the VK of the user with it +func (m *skManager) CreateSessionKey(user *common.GWUser) (*common.GWSessionKey, error) { + sk, err := m.createSK(user) + if err != nil { + return nil, err + } + err = m.storage.AddSessionKey(user.ID, *sk) + if err != nil { + return nil, err + } + return sk, nil +} + +func (m *skManager) createSK(user *common.GWUser) (*common.GWSessionKey, error) { + // generate new key-pair + sk, err := crypto.GenerateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate key-pair: %w", err) + } + skEcies := ecies.ImportECDSA(sk) + + // Compute the Ethereum address from the public key + address := crypto.PubkeyToAddress(sk.PublicKey) + + // use the viewing key to sign over the session key + msg, err := viewingkey.GenerateMessage(user.ID, int64(m.config.TenChainID), 1, viewingkey.EIP712Signature) + if err != nil { + return nil, fmt.Errorf("cannot generate message. Cause %w", err) + } + + msgHash, err := viewingkey.GetMessageHash(msg, viewingkey.EIP712Signature) + if err != nil { + return nil, fmt.Errorf("cannot generate message hash. Cause %w", err) + } + + // current signature is valid - return account address + sig, err := crypto.Sign(msgHash, sk) + if err != nil { + return nil, fmt.Errorf("cannot sign message with session key. Cause %w", err) + } + + return &common.GWSessionKey{ + PrivateKey: skEcies, + Account: &common.GWAccount{ + User: user, + Address: &address, + Signature: sig, + SignatureType: viewingkey.EIP712Signature, + }, + }, nil +} + +func (m *skManager) SignTx(ctx context.Context, user *common.GWUser, input hexutil.Bytes) (hexutil.Bytes, error) { + tx := new(types.Transaction) + if err := tx.UnmarshalBinary(input); err != nil { + return hexutil.Bytes{}, err + } + + signer := types.NewLondonSigner(tx.ChainId()) + + tx, err := types.SignTx(tx, signer, user.SessionKey.PrivateKey.ExportECDSA()) + if err != nil { + return hexutil.Bytes{}, err + } + return tx.MarshalBinary() +} diff --git a/tools/walletextension/services/wallet_extension.go b/tools/walletextension/services/wallet_extension.go index 507a1cadf3..4fa34a3fb2 100644 --- a/tools/walletextension/services/wallet_extension.go +++ b/tools/walletextension/services/wallet_extension.go @@ -45,6 +45,7 @@ type Services struct { RPCResponsesCache cache.Cache BackendRPC *BackendRPC RateLimiter *ratelimiter.RateLimiter + SKManager SKManager Config *common.Config NewHeadsService *subscriptioncommon.NewHeadsService } @@ -74,6 +75,7 @@ func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.UserSto version: version, RPCResponsesCache: newGatewayCache, BackendRPC: NewBackendRPC(hostAddrHTTP, hostAddrWS, logger), + SKManager: NewSKManager(storage, config, logger), RateLimiter: rateLimiter, Config: config, } @@ -139,13 +141,13 @@ func (w *Services) GenerateAndStoreNewUser() ([]byte, error) { requestStartTime := time.Now() // generate new key-pair viewingKeyPrivate, err := crypto.GenerateKey() - viewingPrivateKeyEcies := ecies.ImportECDSA(viewingKeyPrivate) if err != nil { w.Logger().Error(fmt.Sprintf("could not generate new keypair: %s", err)) return nil, err } + viewingPrivateKeyEcies := ecies.ImportECDSA(viewingKeyPrivate) - // create UserID and store it in the database with the private key + // create ID and store it in the database with the private key userID := viewingkey.CalculateUserID(common.PrivateKeyToCompressedPubKey(viewingPrivateKeyEcies)) err = w.Storage.AddUser(userID, crypto.FromECDSA(viewingPrivateKeyEcies.ExportECDSA())) if err != nil { diff --git a/tools/walletextension/storage/database/common/db_types.go b/tools/walletextension/storage/database/common/db_types.go index af84deed2d..ea1b836047 100644 --- a/tools/walletextension/storage/database/common/db_types.go +++ b/tools/walletextension/storage/database/common/db_types.go @@ -1,15 +1,21 @@ package common import ( + "crypto/x509" + "fmt" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ten-protocol/go-ten/go/common/viewingkey" wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" ) type GWUserDB struct { - UserId []byte `json:"userId"` - PrivateKey []byte `json:"privateKey"` - Accounts []GWAccountDB `json:"accounts"` + UserId []byte `json:"userId"` + PrivateKey []byte `json:"privateKey"` + Accounts []GWAccountDB `json:"accounts"` + SessionKey *GWSessionKeyDB `json:"sessionKey"` + ActiveSK bool `json:"activeSK"` } type GWAccountDB struct { @@ -18,23 +24,50 @@ type GWAccountDB struct { SignatureType int `json:"signatureType"` } -func (userDB *GWUserDB) ToGWUser() *wecommon.GWUser { - result := &wecommon.GWUser{ - UserID: userDB.UserId, +// GWSessionKeyDB - an account key-pair registered for a user +type GWSessionKeyDB struct { + PrivateKey []byte `json:"privateKey"` + Account GWAccountDB `json:"account"` +} + +func (userDB *GWUserDB) ToGWUser() (*wecommon.GWUser, error) { + user := &wecommon.GWUser{ + ID: userDB.UserId, Accounts: make(map[common.Address]*wecommon.GWAccount), UserKey: userDB.PrivateKey, + ActiveSK: userDB.ActiveSK, } for _, accountDB := range userDB.Accounts { address := common.BytesToAddress(accountDB.AccountAddress) gwAccount := wecommon.GWAccount{ - User: result, + User: user, Address: &address, Signature: accountDB.Signature, SignatureType: viewingkey.SignatureType(accountDB.SignatureType), } - result.Accounts[address] = &gwAccount + user.Accounts[address] = &gwAccount + } + + if userDB.SessionKey != nil { + ecdsaPrivateKey, err := x509.ParseECPrivateKey(userDB.SessionKey.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to parse ECDSA private key: %w", err) + } + + // Convert ECDSA private key to ECIES private key + eciesPrivateKey := ecies.ImportECDSA(ecdsaPrivateKey) + acc := userDB.SessionKey.Account + user.SessionKey = &wecommon.GWSessionKey{ + Account: &wecommon.GWAccount{ + User: user, + Address: (*common.Address)(acc.AccountAddress), + Signature: acc.Signature, + SignatureType: viewingkey.SignatureType(acc.SignatureType), + }, + PrivateKey: eciesPrivateKey, + } } - return result + return user, nil } diff --git a/tools/walletextension/storage/database/cosmosdb/cosmosdb.go b/tools/walletextension/storage/database/cosmosdb/cosmosdb.go index cb8aead713..3672e7cb9a 100644 --- a/tools/walletextension/storage/database/cosmosdb/cosmosdb.go +++ b/tools/walletextension/storage/database/cosmosdb/cosmosdb.go @@ -7,12 +7,13 @@ import ( "fmt" "strings" + "github.com/ethereum/go-ethereum/crypto" + dbcommon "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/common" "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" - "github.com/ten-protocol/go-ten/go/common/errutil" "github.com/ten-protocol/go-ten/tools/walletextension/common" "github.com/ten-protocol/go-ten/tools/walletextension/encryption" ) @@ -87,37 +88,19 @@ func NewCosmosDB(connectionString string, encryptionKey []byte) (*CosmosDB, erro } func (c *CosmosDB) AddUser(userID []byte, privateKey []byte) error { + ctx := context.Background() + keyString, partitionKey := c.dbKey(userID) + user := dbcommon.GWUserDB{ UserId: userID, PrivateKey: privateKey, Accounts: []dbcommon.GWAccountDB{}, } - userJSON, err := json.Marshal(user) - if err != nil { - return fmt.Errorf("failed to marshal user: %w", err) - } - - ciphertext, err := c.encryptor.Encrypt(userJSON) - if err != nil { - return fmt.Errorf("failed to encrypt user data: %w", err) - } - - key := c.encryptor.HashWithHMAC(userID) - keyString := hex.EncodeToString(key) - - // Create an EncryptedDocument struct to store in CosmosDB - doc := EncryptedDocument{ - ID: keyString, - Data: ciphertext, - } - - docJSON, err := json.Marshal(doc) + docJSON, err := c.createEncryptedDoc(user, keyString) if err != nil { - return fmt.Errorf("failed to marshal document: %w", err) + return err } - partitionKey := azcosmos.NewPartitionKeyString(keyString) - ctx := context.Background() _, err = c.usersContainer.CreateItem(ctx, partitionKey, docJSON, nil) if err != nil { return fmt.Errorf("failed to create item: %w", err) @@ -126,11 +109,9 @@ func (c *CosmosDB) AddUser(userID []byte, privateKey []byte) error { } func (c *CosmosDB) DeleteUser(userID []byte) error { - key := c.encryptor.HashWithHMAC(userID) - keyString := hex.EncodeToString(key) - partitionKey := azcosmos.NewPartitionKeyString(keyString) ctx := context.Background() + keyString, partitionKey := c.dbKey(userID) _, err := c.usersContainer.DeleteItem(ctx, partitionKey, keyString, nil) if err != nil { return fmt.Errorf("failed to delete user: %w", err) @@ -138,32 +119,52 @@ func (c *CosmosDB) DeleteUser(userID []byte) error { return nil } -func (c *CosmosDB) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { - key := c.encryptor.HashWithHMAC(userID) - keyString := hex.EncodeToString(key) - partitionKey := azcosmos.NewPartitionKeyString(keyString) +func (c *CosmosDB) AddSessionKey(userID []byte, key common.GWSessionKey) error { ctx := context.Background() - itemResponse, err := c.usersContainer.ReadItem(ctx, partitionKey, keyString, nil) + user, err := c.getUserDB(userID) if err != nil { return fmt.Errorf("failed to get user: %w", err) } + user.SessionKey = &dbcommon.GWSessionKeyDB{ + PrivateKey: crypto.FromECDSA(key.PrivateKey.ExportECDSA()), + Account: dbcommon.GWAccountDB{ + AccountAddress: key.Account.Address.Bytes(), + Signature: key.Account.Signature, + SignatureType: int(key.Account.SignatureType), + }, + } + return c.updateUser(ctx, user) +} - var doc EncryptedDocument - err = json.Unmarshal(itemResponse.Value, &doc) +func (c *CosmosDB) ActivateSessionKey(userID []byte, active bool) error { + ctx := context.Background() + + user, err := c.getUserDB(userID) if err != nil { - return fmt.Errorf("failed to unmarshal document: %w", err) + return fmt.Errorf("failed to get user: %w", err) } + user.ActiveSK = active + return c.updateUser(ctx, user) +} - data, err := c.encryptor.Decrypt(doc.Data) +func (c *CosmosDB) RemoveSessionKey(userID []byte) error { + ctx := context.Background() + + user, err := c.getUserDB(userID) if err != nil { - return fmt.Errorf("failed to decrypt data: %w", err) + return fmt.Errorf("failed to get user: %w", err) } + user.SessionKey = nil + return c.updateUser(ctx, user) +} - var user dbcommon.GWUserDB - err = json.Unmarshal(data, &user) +func (c *CosmosDB) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { + ctx := context.Background() + + user, err := c.getUserDB(userID) if err != nil { - return fmt.Errorf("failed to unmarshal user data: %w", err) + return fmt.Errorf("failed to get user: %w", err) } // Add the new account @@ -174,58 +175,89 @@ func (c *CosmosDB) AddAccount(userID []byte, accountAddress []byte, signature [] } user.Accounts = append(user.Accounts, newAccount) - userJSON, err := json.Marshal(user) + return c.updateUser(ctx, user) +} + +func (c *CosmosDB) GetUser(userID []byte) (*common.GWUser, error) { + user, err := c.getUserDB(userID) if err != nil { - return fmt.Errorf("error marshaling updated user: %w", err) + return nil, err } + return user.ToGWUser() +} - ciphertext, err := c.encryptor.Encrypt(userJSON) +func (c *CosmosDB) getUserDB(userID []byte) (dbcommon.GWUserDB, error) { + keyString, partitionKey := c.dbKey(userID) + + ctx := context.Background() + + itemResponse, err := c.usersContainer.ReadItem(ctx, partitionKey, keyString, nil) if err != nil { - return fmt.Errorf("failed to encrypt updated user data: %w", err) + return dbcommon.GWUserDB{}, err } - // Update the document - doc.Data = ciphertext + var doc EncryptedDocument + err = json.Unmarshal(itemResponse.Value, &doc) + if err != nil { + return dbcommon.GWUserDB{}, fmt.Errorf("failed to unmarshal document: %w", err) + } - docJSON, err := json.Marshal(doc) + data, err := c.encryptor.Decrypt(doc.Data) + if err != nil { + return dbcommon.GWUserDB{}, fmt.Errorf("failed to decrypt data: %w", err) + } + + var user dbcommon.GWUserDB + err = json.Unmarshal(data, &user) + if err != nil { + return dbcommon.GWUserDB{}, fmt.Errorf("failed to unmarshal user data: %w", err) + } + return user, nil +} + +func (c *CosmosDB) updateUser(ctx context.Context, user dbcommon.GWUserDB) error { + keyString, partitionKey := c.dbKey(user.UserId) + + encryptedDoc, err := c.createEncryptedDoc(user, keyString) if err != nil { return fmt.Errorf("failed to marshal updated document: %w", err) } // Replace the item in the container - _, err = c.usersContainer.ReplaceItem(ctx, partitionKey, keyString, docJSON, nil) + _, err = c.usersContainer.ReplaceItem(ctx, partitionKey, keyString, encryptedDoc, nil) if err != nil { return fmt.Errorf("failed to update user with new account: %w", err) } return nil } -func (c *CosmosDB) GetUser(userID []byte) (*common.GWUser, error) { - key := c.encryptor.HashWithHMAC(userID) - keyString := hex.EncodeToString(key) - partitionKey := azcosmos.NewPartitionKeyString(keyString) - ctx := context.Background() - - itemResponse, err := c.usersContainer.ReadItem(ctx, partitionKey, keyString, nil) +func (c *CosmosDB) createEncryptedDoc(user dbcommon.GWUserDB, keyString string) ([]byte, error) { + userJSON, err := json.Marshal(user) if err != nil { - return nil, errutil.ErrNotFound + return nil, fmt.Errorf("failed to marshal user: %w", err) } - var doc EncryptedDocument - err = json.Unmarshal(itemResponse.Value, &doc) + ciphertext, err := c.encryptor.Encrypt(userJSON) if err != nil { - return nil, fmt.Errorf("failed to unmarshal document: %w", err) + return nil, fmt.Errorf("failed to encrypt user data: %w", err) } - data, err := c.encryptor.Decrypt(doc.Data) - if err != nil { - return nil, fmt.Errorf("failed to decrypt data: %w", err) + // Create an EncryptedDocument struct to store in CosmosDB + doc := EncryptedDocument{ + ID: keyString, + Data: ciphertext, } - var user dbcommon.GWUserDB - err = json.Unmarshal(data, &user) + docJSON, err := json.Marshal(doc) if err != nil { - return nil, fmt.Errorf("failed to unmarshal user data: %w", err) + return nil, fmt.Errorf("failed to marshal document: %w", err) } - return user.ToGWUser(), nil + return docJSON, nil +} + +func (c *CosmosDB) dbKey(userID []byte) (string, azcosmos.PartitionKey) { + key := c.encryptor.HashWithHMAC(userID) + keyString := hex.EncodeToString(key) + partitionKey := azcosmos.NewPartitionKeyString(keyString) + return keyString, partitionKey } diff --git a/tools/walletextension/storage/database/sqlite/sqlite.go b/tools/walletextension/storage/database/sqlite/sqlite.go index 57ce320a92..07c22c069b 100644 --- a/tools/walletextension/storage/database/sqlite/sqlite.go +++ b/tools/walletextension/storage/database/sqlite/sqlite.go @@ -14,6 +14,8 @@ import ( "os" "path/filepath" + "github.com/ethereum/go-ethereum/crypto" + dbcommon "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/common" "github.com/ten-protocol/go-ten/go/common/viewingkey" @@ -76,104 +78,152 @@ func (s *SqliteDB) AddUser(userID []byte, privateKey []byte) error { return err } - stmt, err := s.db.Prepare("INSERT OR REPLACE INTO users(id, user_data) VALUES (?, ?)") - if err != nil { - return err - } - defer stmt.Close() + return s.withTx(func(dbTx *sql.Tx) error { + stmt, err := dbTx.Prepare("INSERT OR REPLACE INTO users(id, user_data) VALUES (?, ?)") + if err != nil { + return err + } + defer stmt.Close() - _, err = stmt.Exec(string(user.UserId), string(userJSON)) - if err != nil { - return err - } + _, err = stmt.Exec(string(user.UserId), string(userJSON)) + if err != nil { + return err + } - return nil + return nil + }) } func (s *SqliteDB) DeleteUser(userID []byte) error { - stmt, err := s.db.Prepare("DELETE FROM users WHERE id = ?") - if err != nil { - return err - } - defer stmt.Close() + return s.withTx(func(dbTx *sql.Tx) error { + stmt, err := dbTx.Prepare("DELETE FROM users WHERE id = ?") + if err != nil { + return err + } + defer stmt.Close() - _, err = stmt.Exec(string(userID)) - if err != nil { - return fmt.Errorf("failed to delete user: %w", err) - } + _, err = stmt.Exec(string(userID)) + if err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } - return nil + return nil + }) +} + +func (s *SqliteDB) ActivateSessionKey(userID []byte, active bool) error { + return s.withTx(func(dbTx *sql.Tx) error { + user, err := s.readUser(dbTx, userID) + if err != nil { + return err + } + user.ActiveSK = active + return s.updateUser(dbTx, user) + }) +} + +func (s *SqliteDB) AddSessionKey(userID []byte, key common.GWSessionKey) error { + return s.withTx(func(dbTx *sql.Tx) error { + user, err := s.readUser(dbTx, userID) + if err != nil { + return err + } + user.SessionKey = &dbcommon.GWSessionKeyDB{ + PrivateKey: crypto.FromECDSA(key.PrivateKey.ExportECDSA()), + Account: dbcommon.GWAccountDB{ + AccountAddress: key.Account.Address.Bytes(), + Signature: key.Account.Signature, + SignatureType: int(key.Account.SignatureType), + }, + } + return s.updateUser(dbTx, user) + }) +} + +func (s *SqliteDB) RemoveSessionKey(userID []byte) error { + return s.withTx(func(dbTx *sql.Tx) error { + user, err := s.readUser(dbTx, userID) + if err != nil { + return err + } + user.SessionKey = nil + return s.updateUser(dbTx, user) + }) } func (s *SqliteDB) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { - var userDataJSON string - tx, err := s.db.Begin() + return s.withTx(func(dbTx *sql.Tx) error { + user, err := s.readUser(dbTx, userID) + if err != nil { + return err + } + + newAccount := dbcommon.GWAccountDB{ + AccountAddress: accountAddress, + Signature: signature, + SignatureType: int(signatureType), + } + + user.Accounts = append(user.Accounts, newAccount) + + return s.updateUser(dbTx, user) + }) +} + +func (s *SqliteDB) GetUser(userID []byte) (*common.GWUser, error) { + var user dbcommon.GWUserDB + var err error + err = s.withTx(func(dbTx *sql.Tx) error { + user, err = s.readUser(dbTx, userID) + if err != nil { + return err + } + return nil + }) if err != nil { - return err + return nil, err } - defer tx.Rollback() + return user.ToGWUser() +} - err = tx.QueryRow("SELECT user_data FROM users WHERE id = ?", string(userID)).Scan(&userDataJSON) +func (s *SqliteDB) readUser(dbTx *sql.Tx, userID []byte) (dbcommon.GWUserDB, error) { + var userDataJSON string + err := dbTx.QueryRow("SELECT user_data FROM users WHERE id = ?", string(userID)).Scan(&userDataJSON) if err != nil { - return fmt.Errorf("failed to get user: %w", err) + if errors.Is(err, sql.ErrNoRows) { + return dbcommon.GWUserDB{}, fmt.Errorf("failed to get user: %w", errutil.ErrNotFound) + } + return dbcommon.GWUserDB{}, fmt.Errorf("failed to get user: %w", err) } var user dbcommon.GWUserDB err = json.Unmarshal([]byte(userDataJSON), &user) if err != nil { - return fmt.Errorf("failed to unmarshal user data: %w", err) + return dbcommon.GWUserDB{}, fmt.Errorf("failed to unmarshal user data: %w", err) } + return user, nil +} - newAccount := dbcommon.GWAccountDB{ - AccountAddress: accountAddress, - Signature: signature, - SignatureType: int(signatureType), - } - - user.Accounts = append(user.Accounts, newAccount) - +func (s *SqliteDB) updateUser(dbTx *sql.Tx, user dbcommon.GWUserDB) error { updatedUserJSON, err := json.Marshal(user) if err != nil { return fmt.Errorf("error marshaling updated user: %w", err) } - stmt, err := tx.Prepare("UPDATE users SET user_data = ? WHERE id = ?") + stmt, err := dbTx.Prepare("UPDATE users SET user_data = ? WHERE id = ?") if err != nil { return err } defer stmt.Close() - _, err = stmt.Exec(string(updatedUserJSON), string(userID)) + _, err = stmt.Exec(string(updatedUserJSON), string(user.UserId)) if err != nil { return fmt.Errorf("failed to update user with new account: %w", err) } - err = tx.Commit() - if err != nil { - return err - } return nil } -func (s *SqliteDB) GetUser(userID []byte) (*common.GWUser, error) { - var userDataJSON string - err := s.db.QueryRow("SELECT user_data FROM users WHERE id = ?", string(userID)).Scan(&userDataJSON) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, fmt.Errorf("failed to get user: %w", errutil.ErrNotFound) - } - return nil, fmt.Errorf("failed to get user: %w", err) - } - - var user dbcommon.GWUserDB - err = json.Unmarshal([]byte(userDataJSON), &user) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal user data: %w", err) - } - - return user.ToGWUser(), nil -} - func createOrLoad(dbPath string) (string, error) { // If path is empty we create a random throwaway temp file, otherwise we use the path to the database if dbPath == "" { @@ -195,3 +245,18 @@ func createOrLoad(dbPath string) (string, error) { return dbPath, nil } + +func (s *SqliteDB) withTx(fn func(*sql.Tx) error) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + err = fn(tx) + if err != nil { + return err + } + + return tx.Commit() +} diff --git a/tools/walletextension/storage/storage.go b/tools/walletextension/storage/storage.go index ba8836734d..b15f8b8fc1 100644 --- a/tools/walletextension/storage/storage.go +++ b/tools/walletextension/storage/storage.go @@ -11,10 +11,14 @@ import ( "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/sqlite" ) +// todo - pass the Context type UserStorage interface { AddUser(userID []byte, privateKey []byte) error DeleteUser(userID []byte) error AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error + AddSessionKey(userID []byte, key common.GWSessionKey) error + ActivateSessionKey(userID []byte, active bool) error + RemoveSessionKey(userID []byte) error GetUser(userID []byte) (*common.GWUser, error) } diff --git a/tools/walletextension/storage/storage_test.go b/tools/walletextension/storage/storage_test.go index 9fde210cc6..29721fcb9f 100644 --- a/tools/walletextension/storage/storage_test.go +++ b/tools/walletextension/storage/storage_test.go @@ -185,8 +185,8 @@ func testGetUser(storage UserStorage, t *testing.T) { } // Check if retrieved user matches the added user - if !bytes.Equal(user.UserID, userID) { - t.Errorf("Retrieved user ID does not match. Expected %x, got %x", userID, user.UserID) + if !bytes.Equal(user.ID, userID) { + t.Errorf("Retrieved user ID does not match. Expected %x, got %x", userID, user.ID) } if !bytes.Equal(user.UserKey, privateKey) { diff --git a/tools/walletextension/storage/storage_with_cache.go b/tools/walletextension/storage/storage_with_cache.go index 08fbbc7cc9..9dae72ba85 100644 --- a/tools/walletextension/storage/storage_with_cache.go +++ b/tools/walletextension/storage/storage_with_cache.go @@ -42,6 +42,33 @@ func (s *UserStorageWithCache) DeleteUser(userID []byte) error { return nil } +func (s *UserStorageWithCache) ActivateSessionKey(userID []byte, active bool) error { + err := s.storage.ActivateSessionKey(userID, active) + if err != nil { + return err + } + s.cache.Remove(userID) + return nil +} + +func (s *UserStorageWithCache) AddSessionKey(userID []byte, key wecommon.GWSessionKey) error { + err := s.storage.AddSessionKey(userID, key) + if err != nil { + return err + } + s.cache.Remove(userID) + return nil +} + +func (s *UserStorageWithCache) RemoveSessionKey(userID []byte) error { + err := s.storage.RemoveSessionKey(userID) + if err != nil { + return err + } + s.cache.Remove(userID) + return nil +} + // AddAccount adds an account to a user and invalidates the cache for the userID func (s *UserStorageWithCache) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { err := s.storage.AddAccount(userID, accountAddress, signature, signatureType)