From 56f62e9510fb11a4611c60ad875ff7fa17382fa4 Mon Sep 17 00:00:00 2001 From: Tudor Malene Date: Tue, 5 Nov 2024 10:08:06 +0000 Subject: [PATCH] clear responsibilities --- tools/walletextension/cache/RistrettoCache.go | 12 +- tools/walletextension/cache/cache.go | 4 +- tools/walletextension/common/db_types.go | 13 -- tools/walletextension/common/types.go | 28 ++++ tools/walletextension/httpapi/routes.go | 2 +- .../walletextension/rpcapi/blockchain_api.go | 14 +- tools/walletextension/rpcapi/filter_api.go | 14 +- tools/walletextension/rpcapi/utils.go | 20 +-- tools/walletextension/services/conn_utils.go | 128 +++++++++++++--- tools/walletextension/services/gw_user.go | 61 -------- .../services/wallet_extension.go | 142 ++++++------------ .../storage/database/common/db_types.go | 40 +++++ .../storage/database/cosmosdb/cosmosdb.go | 24 +-- .../storage/database/sqlite/sqlite.go | 37 ++--- tools/walletextension/storage/storage.go | 20 ++- tools/walletextension/storage/storage_test.go | 10 +- .../storage/storage_with_cache.go | 64 +++----- .../walletextension_container.go | 4 +- 18 files changed, 319 insertions(+), 318 deletions(-) delete mode 100644 tools/walletextension/common/db_types.go create mode 100644 tools/walletextension/common/types.go delete mode 100644 tools/walletextension/services/gw_user.go create mode 100644 tools/walletextension/storage/database/common/db_types.go diff --git a/tools/walletextension/cache/RistrettoCache.go b/tools/walletextension/cache/RistrettoCache.go index 6893aac2f8..979965570b 100644 --- a/tools/walletextension/cache/RistrettoCache.go +++ b/tools/walletextension/cache/RistrettoCache.go @@ -10,10 +10,8 @@ import ( ) const ( - numCounters = 1e7 // number of keys to track frequency of (10M). - maxCost = 1_000_000 // 1 million entries - bufferItems = 64 // number of keys per Get buffer. - defaultCost = 1 // default cost of cache. + bufferItems = 64 // number of keys per Get buffer. + defaultCost = 1 // default cost of cache. ) type ristrettoCache struct { @@ -23,10 +21,10 @@ type ristrettoCache struct { } // NewRistrettoCacheWithEviction returns a new ristrettoCache. -func NewRistrettoCacheWithEviction(logger log.Logger) (Cache, error) { +func NewRistrettoCacheWithEviction(nrElems int, logger log.Logger) (Cache, error) { cache, err := ristretto.NewCache(&ristretto.Config{ - NumCounters: numCounters, - MaxCost: maxCost, + NumCounters: int64(nrElems * 10), + MaxCost: int64(nrElems), BufferItems: bufferItems, Metrics: true, }) diff --git a/tools/walletextension/cache/cache.go b/tools/walletextension/cache/cache.go index 58df4489b8..722a5e60bf 100644 --- a/tools/walletextension/cache/cache.go +++ b/tools/walletextension/cache/cache.go @@ -19,8 +19,8 @@ type Cache interface { Remove(key []byte) } -func NewCache(logger log.Logger) (Cache, error) { - return NewRistrettoCacheWithEviction(logger) +func NewCache(nrElems int, logger log.Logger) (Cache, error) { + return NewRistrettoCacheWithEviction(nrElems, logger) } type Strategy uint8 diff --git a/tools/walletextension/common/db_types.go b/tools/walletextension/common/db_types.go deleted file mode 100644 index 04993bbd75..0000000000 --- a/tools/walletextension/common/db_types.go +++ /dev/null @@ -1,13 +0,0 @@ -package common - -type GWUserDB struct { - UserId []byte `json:"userId"` - PrivateKey []byte `json:"privateKey"` - Accounts []GWAccountDB `json:"accounts"` -} - -type GWAccountDB struct { - AccountAddress []byte `json:"accountAddress"` - Signature []byte `json:"signature"` - SignatureType int `json:"signatureType"` -} diff --git a/tools/walletextension/common/types.go b/tools/walletextension/common/types.go new file mode 100644 index 0000000000..4d793bbf86 --- /dev/null +++ b/tools/walletextension/common/types.go @@ -0,0 +1,28 @@ +package common + +import ( + "github.com/ten-protocol/go-ten/go/common/viewingkey" + + "github.com/ethereum/go-ethereum/common" +) + +type GWAccount struct { + User *GWUser + Address *common.Address + Signature []byte + SignatureType viewingkey.SignatureType +} + +type GWUser struct { + UserID []byte + Accounts map[common.Address]*GWAccount + UserKey []byte +} + +func (u GWUser) GetAllAddresses() []*common.Address { + accts := make([]*common.Address, 0) + for _, acc := range u.Accounts { + accts = append(accts, acc.Address) + } + return accts +} diff --git a/tools/walletextension/httpapi/routes.go b/tools/walletextension/httpapi/routes.go index 1d1e935577..efac92d1d4 100644 --- a/tools/walletextension/httpapi/routes.go +++ b/tools/walletextension/httpapi/routes.go @@ -249,7 +249,7 @@ func revokeRequestHandler(walletExt *services.Services, conn UserConn) { } // delete user and accounts associated with it from the database - err = walletExt.DeleteUser(userID) + err = walletExt.Storage.DeleteUser(userID) if err != nil { handleError(conn, walletExt.Logger(), fmt.Errorf("internal error")) walletExt.Logger().Error("unable to delete user", "userID", userID, log.ErrKey, err) diff --git a/tools/walletextension/rpcapi/blockchain_api.go b/tools/walletextension/rpcapi/blockchain_api.go index 961b6ac614..d94f3b0d87 100644 --- a/tools/walletextension/rpcapi/blockchain_api.go +++ b/tools/walletextension/rpcapi/blockchain_api.go @@ -6,6 +6,8 @@ import ( "encoding/json" "fmt" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/cache" "github.com/ten-protocol/go-ten/tools/walletextension/services" @@ -185,7 +187,7 @@ func (api *BlockChainAPI) GetStorageAt(ctx context.Context, address gethcommon.A return nil, err } - _, err = api.we.GetUser(userID) + _, err = api.we.Storage.GetUser(userID) if err != nil { return nil, err } @@ -255,10 +257,10 @@ func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs return cacheBlockNumberOrHash(blockNrOrHash) }, }, - computeFromCallback: func(user *services.GWUser) *gethcommon.Address { + computeFromCallback: func(user *wecommon.GWUser) *gethcommon.Address { return searchFromAndData(user.GetAllAddresses(), args) }, - adjustArgs: func(acct *services.GWAccount) []any { + adjustArgs: func(acct *wecommon.GWAccount) []any { argsClone := populateFrom(acct, args) return []any{argsClone, blockNrOrHash, overrides, blockOverrides} }, @@ -280,10 +282,10 @@ func (api *BlockChainAPI) EstimateGas(ctx context.Context, args gethapi.Transact return cache.LatestBatch }, }, - computeFromCallback: func(user *services.GWUser) *gethcommon.Address { + computeFromCallback: func(user *wecommon.GWUser) *gethcommon.Address { return searchFromAndData(user.GetAllAddresses(), args) }, - adjustArgs: func(acct *services.GWAccount) []any { + adjustArgs: func(acct *wecommon.GWAccount) []any { argsClone := populateFrom(acct, args) return []any{argsClone, blockNrOrHash, overrides} }, @@ -296,7 +298,7 @@ func (api *BlockChainAPI) EstimateGas(ctx context.Context, args gethapi.Transact return *resp, err } -func populateFrom(acct *services.GWAccount, args gethapi.TransactionArgs) gethapi.TransactionArgs { +func populateFrom(acct *wecommon.GWAccount, args gethapi.TransactionArgs) gethapi.TransactionArgs { // clone the args argsClone := cloneArgs(args) // set the from diff --git a/tools/walletextension/rpcapi/filter_api.go b/tools/walletextension/rpcapi/filter_api.go index de8fb1fa6c..15e6ec63a8 100644 --- a/tools/walletextension/rpcapi/filter_api.go +++ b/tools/walletextension/rpcapi/filter_api.go @@ -85,7 +85,7 @@ func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rp errorChannels := make([]<-chan error, 0) backendSubscriptions := make([]*rpc.ClientSubscription, 0) for _, address := range candidateAddresses { - rpcWSClient, err := services.ConnectWS(ctx, user.Accounts[*address], api.we.Logger()) + rpcWSClient, err := api.we.BackendRPC.ConnectWS(ctx, user.Accounts[*address]) if err != nil { return nil, err } @@ -152,11 +152,11 @@ func (api *FilterAPI) closeConnections(backendSubscriptions []*rpc.ClientSubscri backendSub.Unsubscribe() } for _, connection := range backendWSConnections { - _ = services.ReturnConn(api.we.RpcWSConnPool, connection.BackingClient(), api.logger) + _ = api.we.BackendRPC.ReturnConnWS(connection.BackingClient()) } } -func getUserAndNotifier(ctx context.Context, api *FilterAPI) (*rpc.Notifier, *services.GWUser, error) { +func getUserAndNotifier(ctx context.Context, api *FilterAPI) (*rpc.Notifier, *wecommon.GWUser, error) { subNotifier, supported := rpc.NotifierFromContext(ctx) if !supported { return nil, nil, fmt.Errorf("creation of subscriptions is not supported") @@ -167,7 +167,7 @@ func getUserAndNotifier(ctx context.Context, api *FilterAPI) (*rpc.Notifier, *se return nil, nil, fmt.Errorf("illegal access") } - user, err := api.we.GetUser(subNotifier.UserID) + user, err := api.we.Storage.GetUser(subNotifier.UserID) if err != nil { return nil, nil, fmt.Errorf("illegal access: %s, %w", subNotifier.UserID, err) } @@ -208,7 +208,7 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) ( } res, err := cache.WithCache( - api.we.Cache, + api.we.RPCResponsesCache, &cache.Cfg{ DynamicType: func() cache.Strategy { if crit.ToBlock != nil && crit.ToBlock.Int64() > 0 { @@ -223,7 +223,7 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) ( }, generateCacheKey([]any{userID, method, common.SerializableFilterCriteria(crit)}), func() (*[]*types.Log, error) { // called when there is no entry in the cache - user, err := api.we.GetUser(userID) + user, err := api.we.Storage.GetUser(userID) if err != nil { return nil, err } @@ -233,7 +233,7 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) ( // execute the get_Logs function // dedupe and concatenate the results for _, acct := range user.Accounts { - eventLogs, err := services.WithEncRPCConnection(ctx, api.we, acct, func(rpcClient *tenrpc.EncRPCClient) (*[]*types.Log, error) { + eventLogs, err := services.WithEncRPCConnection(ctx, api.we.BackendRPC, acct, func(rpcClient *tenrpc.EncRPCClient) (*[]*types.Log, error) { var result []*types.Log // wrap the context with a timeout to prevent long executions diff --git a/tools/walletextension/rpcapi/utils.go b/tools/walletextension/rpcapi/utils.go index 82392b8ac2..6ee345b9b3 100644 --- a/tools/walletextension/rpcapi/utils.go +++ b/tools/walletextension/rpcapi/utils.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/cache" "github.com/ten-protocol/go-ten/tools/walletextension/services" @@ -42,11 +44,11 @@ var rpcNotImplemented = fmt.Errorf("rpc endpoint not implemented") type ExecCfg struct { // these 4 fields specify the account(s) that should make the backend call account *gethcommon.Address - computeFromCallback func(user *services.GWUser) *gethcommon.Address + computeFromCallback func(user *common.GWUser) *gethcommon.Address tryAll bool tryUntilAuthorised bool - adjustArgs func(acct *services.GWAccount) []any + adjustArgs func(acct *common.GWAccount) []any cacheCfg *cache.Cfg timeout time.Duration } @@ -60,8 +62,8 @@ func UnauthenticatedTenRPCCall[R any](ctx context.Context, w *services.Services, cacheArgs := []any{method} cacheArgs = append(cacheArgs, args...) - res, err := cache.WithCache(w.Cache, cfg, generateCacheKey(cacheArgs), func() (*R, error) { - return services.WithPlainRPCConnection(ctx, w, func(client *rpc.Client) (*R, error) { + res, err := cache.WithCache(w.RPCResponsesCache, cfg, generateCacheKey(cacheArgs), func() (*R, error) { + return services.WithPlainRPCConnection(ctx, w.BackendRPC, func(client *rpc.Client) (*R, error) { var resp *R var err error @@ -94,8 +96,8 @@ func ExecAuthRPC[R any](ctx context.Context, w *services.Services, cfg *ExecCfg, cacheArgs := []any{userID, method} cacheArgs = append(cacheArgs, args...) - res, err := cache.WithCache(w.Cache, cfg.cacheCfg, generateCacheKey(cacheArgs), func() (*R, error) { - user, err := w.GetUser(userID) + res, err := cache.WithCache(w.RPCResponsesCache, cfg.cacheCfg, generateCacheKey(cacheArgs), func() (*R, error) { + user, err := w.Storage.GetUser(userID) if err != nil { return nil, err } @@ -112,7 +114,7 @@ func ExecAuthRPC[R any](ctx context.Context, w *services.Services, cfg *ExecCfg, var rpcErr error for i := range candidateAccts { acct := candidateAccts[i] - result, err := services.WithEncRPCConnection(ctx, w, acct, func(rpcClient *tenrpc.EncRPCClient) (*R, error) { + result, err := services.WithEncRPCConnection(ctx, w.BackendRPC, acct, func(rpcClient *tenrpc.EncRPCClient) (*R, error) { var result *R adjustedArgs := args if cfg.adjustArgs != nil { @@ -151,8 +153,8 @@ func ExecAuthRPC[R any](ctx context.Context, w *services.Services, cfg *ExecCfg, return res, err } -func getCandidateAccounts(user *services.GWUser, _ *services.Services, cfg *ExecCfg) ([]*services.GWAccount, error) { - candidateAccts := make([]*services.GWAccount, 0) +func getCandidateAccounts(user *common.GWUser, _ *services.Services, cfg *ExecCfg) ([]*common.GWAccount, error) { + candidateAccts := make([]*common.GWAccount, 0) // for users with multiple accounts try to determine a candidate account based on the available information switch { case cfg.account != nil: diff --git a/tools/walletextension/services/conn_utils.go b/tools/walletextension/services/conn_utils.go index ee3a6ec87f..4f1cde628a 100644 --- a/tools/walletextension/services/conn_utils.go +++ b/tools/walletextension/services/conn_utils.go @@ -11,51 +11,129 @@ import ( "github.com/ten-protocol/go-ten/go/enclave/core" tenrpc "github.com/ten-protocol/go-ten/go/rpc" "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" ) -func ConnectWS(ctx context.Context, account *GWAccount, logger gethlog.Logger) (*tenrpc.EncRPCClient, error) { - return connect(ctx, account.user.services.RpcWSConnPool, account, logger) +type BackendRPC struct { + // the OG maintains a connection pool of rpc connections to underlying nodes + rpcHTTPConnPool *pool.ObjectPool + rpcWSConnPool *pool.ObjectPool + logger gethlog.Logger } -func connect(ctx context.Context, p *pool.ObjectPool, account *GWAccount, logger gethlog.Logger) (*tenrpc.EncRPCClient, error) { +// todo - tweak the number of backend connections +const poolSize = 200 + +func NewBackendRPC(hostAddrHTTP string, hostAddrWS string, logger gethlog.Logger) *BackendRPC { + factoryHTTP := pool.NewPooledObjectFactory( + func(context.Context) (interface{}, error) { + rpcClient, err := gethrpc.Dial(hostAddrHTTP) + if err != nil { + return nil, fmt.Errorf("could not create RPC client on %s. Cause: %w", hostAddrHTTP, err) + } + return rpcClient, nil + }, func(ctx context.Context, object *pool.PooledObject) error { + client := object.Object.(*gethrpc.Client) + client.Close() + return nil + }, nil, nil, nil) + + factoryWS := pool.NewPooledObjectFactory( + func(context.Context) (interface{}, error) { + rpcClient, err := gethrpc.Dial(hostAddrWS) + if err != nil { + return nil, fmt.Errorf("could not create RPC client on %s. Cause: %w", hostAddrWS, err) + } + return rpcClient, nil + }, func(ctx context.Context, object *pool.PooledObject) error { + client := object.Object.(*gethrpc.Client) + client.Close() + return nil + }, nil, nil, nil) + + cfg := pool.NewDefaultPoolConfig() + cfg.MaxTotal = poolSize + + return &BackendRPC{ + rpcHTTPConnPool: pool.NewObjectPool(context.Background(), factoryHTTP, cfg), + rpcWSConnPool: pool.NewObjectPool(context.Background(), factoryWS, cfg), + logger: logger, + } +} + +func (rpc *BackendRPC) ConnectWS(ctx context.Context, account *wecommon.GWAccount) (*tenrpc.EncRPCClient, error) { + return connect(ctx, rpc.rpcWSConnPool, account, rpc.logger) +} + +func (rpc *BackendRPC) ReturnConnWS(conn tenrpc.Client) error { + return returnConn(rpc.rpcWSConnPool, conn, rpc.logger) +} + +func (rpc *BackendRPC) ConnectHttp(ctx context.Context, account *wecommon.GWAccount) (*tenrpc.EncRPCClient, error) { + return connect(ctx, rpc.rpcHTTPConnPool, account, rpc.logger) +} + +func (rpc *BackendRPC) PlainConnectWs(ctx context.Context) (*gethrpc.Client, error) { + return connectPlain(ctx, rpc.rpcWSConnPool, rpc.logger) +} + +func (rpc *BackendRPC) ReturnConn(conn tenrpc.Client) error { + return returnConn(rpc.rpcHTTPConnPool, conn, rpc.logger) +} + +func (rpc *BackendRPC) Stop() { + rpc.rpcHTTPConnPool.Close(context.Background()) + rpc.rpcWSConnPool.Close(context.Background()) +} + +func WithEncRPCConnection[R any](ctx context.Context, rpc *BackendRPC, acct *wecommon.GWAccount, execute func(*tenrpc.EncRPCClient) (*R, error)) (*R, error) { + rpcClient, err := connect(ctx, rpc.rpcHTTPConnPool, acct, rpc.logger) + if err != nil { + return nil, fmt.Errorf("could not connect to backed. Cause: %w", err) + } + defer rpc.ReturnConn(rpcClient.BackingClient()) + return execute(rpcClient) +} + +func WithPlainRPCConnection[R any](ctx context.Context, b *BackendRPC, execute func(client *rpc.Client) (*R, error)) (*R, error) { + connectionObj, err := connectPlain(ctx, b.rpcHTTPConnPool, b.logger) + if err != nil { + return nil, fmt.Errorf("cannot fetch rpc connection to backend node %w", err) + } + defer b.ReturnConn(connectionObj) + return execute(connectionObj) +} + +func connectPlain(ctx context.Context, p *pool.ObjectPool, logger gethlog.Logger) (*rpc.Client, error) { + defer core.LogMethodDuration(logger, measure.NewStopwatch(), "get rpc connection") + connectionObj, err := p.BorrowObject(ctx) + if err != nil { + return nil, fmt.Errorf("cannot fetch rpc connection to backend node %w", err) + } + conn := connectionObj.(*rpc.Client) + return conn, nil +} + +func connect(ctx context.Context, p *pool.ObjectPool, account *wecommon.GWAccount, logger gethlog.Logger) (*tenrpc.EncRPCClient, error) { defer core.LogMethodDuration(logger, measure.NewStopwatch(), "get rpc connection") connectionObj, err := p.BorrowObject(ctx) if err != nil { return nil, fmt.Errorf("cannot fetch rpc connection to backend node %w", err) } conn := connectionObj.(*rpc.Client) - encClient, err := wecommon.CreateEncClient(conn, account.Address.Bytes(), account.user.userKey, account.signature, account.signatureType, logger) + encClient, err := wecommon.CreateEncClient(conn, account.Address.Bytes(), account.User.UserKey, account.Signature, account.SignatureType, logger) if err != nil { - _ = ReturnConn(p, conn, logger) + _ = returnConn(p, conn, logger) return nil, fmt.Errorf("error creating new client, %w", err) } return encClient, nil } -func ReturnConn(p *pool.ObjectPool, conn tenrpc.Client, logger gethlog.Logger) error { +func returnConn(p *pool.ObjectPool, conn tenrpc.Client, logger gethlog.Logger) error { err := p.ReturnObject(context.Background(), conn) if err != nil { logger.Error("Error returning connection to pool", log.ErrKey, err) } return err } - -func WithEncRPCConnection[R any](ctx context.Context, w *Services, acct *GWAccount, execute func(*tenrpc.EncRPCClient) (*R, error)) (*R, error) { - rpcClient, err := connect(ctx, acct.user.services.RpcHTTPConnPool, acct, w.logger) - if err != nil { - return nil, fmt.Errorf("could not connect to backed. Cause: %w", err) - } - defer ReturnConn(w.RpcHTTPConnPool, rpcClient.BackingClient(), w.logger) - return execute(rpcClient) -} - -func WithPlainRPCConnection[R any](ctx context.Context, w *Services, execute func(client *rpc.Client) (*R, error)) (*R, error) { - connectionObj, err := w.RpcHTTPConnPool.BorrowObject(ctx) - if err != nil { - return nil, fmt.Errorf("cannot fetch rpc connection to backend node %w", err) - } - rpcClient := connectionObj.(*rpc.Client) - defer ReturnConn(w.RpcHTTPConnPool, rpcClient, w.logger) - return execute(rpcClient) -} diff --git a/tools/walletextension/services/gw_user.go b/tools/walletextension/services/gw_user.go deleted file mode 100644 index 92a64dd476..0000000000 --- a/tools/walletextension/services/gw_user.go +++ /dev/null @@ -1,61 +0,0 @@ -package services - -import ( - "github.com/ten-protocol/go-ten/go/common/viewingkey" - - "github.com/ethereum/go-ethereum/common" - wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" -) - -var userCacheKeyPrefix = []byte{0x0, 0x1, 0x2, 0x3} - -type GWAccount struct { - user *GWUser - Address *common.Address - signature []byte - signatureType viewingkey.SignatureType -} - -type GWUser struct { - userID []byte - services *Services - Accounts map[common.Address]*GWAccount - userKey []byte -} - -func (u GWUser) GetAllAddresses() []*common.Address { - accts := make([]*common.Address, 0) - for _, acc := range u.Accounts { - accts = append(accts, acc.Address) - } - return accts -} - -func gwUserFromDB(userDB wecommon.GWUserDB, s *Services) (*GWUser, error) { - result := &GWUser{ - userID: userDB.UserId, - services: s, - Accounts: make(map[common.Address]*GWAccount), - userKey: userDB.PrivateKey, - } - - for _, accountDB := range userDB.Accounts { - address := common.BytesToAddress(accountDB.AccountAddress) - gwAccount := &GWAccount{ - user: result, - Address: &address, - signature: accountDB.Signature, - signatureType: viewingkey.SignatureType(accountDB.SignatureType), - } - result.Accounts[address] = gwAccount - } - - return result, nil -} - -func userCacheKey(userID []byte) []byte { - var key []byte - key = append(key, userCacheKeyPrefix...) - key = append(key, userID...) - return key -} diff --git a/tools/walletextension/services/wallet_extension.go b/tools/walletextension/services/wallet_extension.go index aac8986d99..507a1cadf3 100644 --- a/tools/walletextension/services/wallet_extension.go +++ b/tools/walletextension/services/wallet_extension.go @@ -7,6 +7,8 @@ import ( "fmt" "time" + gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + "github.com/ten-protocol/go-ten/go/common/log" "github.com/ten-protocol/go-ten/go/common/retry" @@ -17,9 +19,6 @@ import ( "github.com/ten-protocol/go-ten/go/obsclient" - pool "github.com/jolestar/go-commons-pool/v2" - gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" - "github.com/status-im/keycard-go/hexutils" "github.com/ten-protocol/go-ten/tools/walletextension/cache" @@ -37,82 +36,53 @@ import ( // Services handles the various business logic for the api endpoints type Services struct { - HostAddrHTTP string // The HTTP address on which the TEN host can be reached - HostAddrWS string // The WS address on which the TEN host can be reached - Storage storage.Storage - logger gethlog.Logger - stopControl *stopcontrol.StopControl - version string - Cache cache.Cache - RateLimiter *ratelimiter.RateLimiter - // the OG maintains a connection pool of rpc connections to underlying nodes - RpcHTTPConnPool *pool.ObjectPool - RpcWSConnPool *pool.ObjectPool - Config *common.Config - NewHeadsService *subscriptioncommon.NewHeadsService + HostAddrHTTP string // The HTTP address on which the TEN host can be reached + HostAddrWS string // The WS address on which the TEN host can be reached + Storage storage.UserStorage + logger gethlog.Logger + stopControl *stopcontrol.StopControl + version string + RPCResponsesCache cache.Cache + BackendRPC *BackendRPC + RateLimiter *ratelimiter.RateLimiter + Config *common.Config + NewHeadsService *subscriptioncommon.NewHeadsService } type NewHeadNotifier interface { onNewHead(header *tencommon.BatchHeader) } -func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage, stopControl *stopcontrol.StopControl, version string, logger gethlog.Logger, config *common.Config) *Services { - newGatewayCache, err := cache.NewCache(logger) +// number of rpc responses to cache +const rpcResponseCacheSize = 1_000_000 + +func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.UserStorage, stopControl *stopcontrol.StopControl, version string, logger gethlog.Logger, config *common.Config) *Services { + newGatewayCache, err := cache.NewCache(rpcResponseCacheSize, logger) if err != nil { logger.Error(fmt.Errorf("could not create cache. Cause: %w", err).Error()) panic(err) } - factoryHTTP := pool.NewPooledObjectFactory( - func(context.Context) (interface{}, error) { - rpcClient, err := gethrpc.Dial(hostAddrHTTP) - if err != nil { - return nil, fmt.Errorf("could not create RPC client on %s. Cause: %w", hostAddrHTTP, err) - } - return rpcClient, nil - }, func(ctx context.Context, object *pool.PooledObject) error { - client := object.Object.(*gethrpc.Client) - client.Close() - return nil - }, nil, nil, nil) - - factoryWS := pool.NewPooledObjectFactory( - func(context.Context) (interface{}, error) { - rpcClient, err := gethrpc.Dial(hostAddrWS) - if err != nil { - return nil, fmt.Errorf("could not create RPC client on %s. Cause: %w", hostAddrWS, err) - } - return rpcClient, nil - }, func(ctx context.Context, object *pool.PooledObject) error { - client := object.Object.(*gethrpc.Client) - client.Close() - return nil - }, nil, nil, nil) - - cfg := pool.NewDefaultPoolConfig() - cfg.MaxTotal = 200 // todo - what is the right number - rateLimiter := ratelimiter.NewRateLimiter(config.RateLimitUserComputeTime, config.RateLimitWindow, uint32(config.RateLimitMaxConcurrentRequests), logger) services := Services{ - HostAddrHTTP: hostAddrHTTP, - HostAddrWS: hostAddrWS, - Storage: storage, - logger: logger, - stopControl: stopControl, - version: version, - Cache: newGatewayCache, - RateLimiter: rateLimiter, - RpcHTTPConnPool: pool.NewObjectPool(context.Background(), factoryHTTP, cfg), - RpcWSConnPool: pool.NewObjectPool(context.Background(), factoryWS, cfg), - Config: config, + HostAddrHTTP: hostAddrHTTP, + HostAddrWS: hostAddrWS, + Storage: storage, + logger: logger, + stopControl: stopControl, + version: version, + RPCResponsesCache: newGatewayCache, + BackendRPC: NewBackendRPC(hostAddrHTTP, hostAddrWS, logger), + RateLimiter: rateLimiter, + Config: config, } services.NewHeadsService = subscriptioncommon.NewNewHeadsService( func() (chan *tencommon.BatchHeader, <-chan error, error) { logger.Info("Connecting to new heads service...") // clear the cache to avoid returning stale data during reconnecting. - services.Cache.EvictShortLiving() + services.RPCResponsesCache.EvictShortLiving() ch := make(chan *tencommon.BatchHeader) errCh, err := subscribeToNewHeadsWithRetry(ch, services, logger) logger.Info("Connected to new heads service.", log.ErrKey, err) @@ -121,7 +91,7 @@ func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage true, logger, func(newHead *tencommon.BatchHeader) error { - services.Cache.EvictShortLiving() + services.RPCResponsesCache.EvictShortLiving() return nil }) @@ -132,15 +102,14 @@ func subscribeToNewHeadsWithRetry(ch chan *tencommon.BatchHeader, services Servi var sub *gethrpc.ClientSubscription err := retry.Do( func() error { - connectionObj, err := services.RpcWSConnPool.BorrowObject(context.Background()) + connectionObj, err := services.BackendRPC.PlainConnectWs(context.Background()) if err != nil { return fmt.Errorf("cannot fetch rpc connection to backend node %w", err) } - rpcClient := connectionObj.(rpc.Client) - sub, err = rpcClient.Subscribe(context.Background(), rpc.SubscribeNamespace, ch, rpc.SubscriptionTypeNewHeads) + sub, err = connectionObj.Subscribe(context.Background(), rpc.SubscribeNamespace, ch, rpc.SubscriptionTypeNewHeads) if err != nil { logger.Info("could not subscribe for new head blocks", log.ErrKey, err) - _ = ReturnConn(services.RpcWSConnPool, rpcClient, logger) + _ = services.BackendRPC.ReturnConnWS(connectionObj) } return err }, @@ -212,7 +181,6 @@ func (w *Services) AddAddressToUser(userID []byte, address string, signature []b return err } - w.Cache.Remove(userCacheKey(userID)) audit(w, "Storing new address for user: %s, address: %s, duration: %d ", hexutils.BytesToHex(userID), address, time.Since(requestStartTime).Milliseconds()) return nil } @@ -238,37 +206,24 @@ func (w *Services) UserHasAccount(userID []byte, address string) (bool, error) { // check if any of the account matches given account found := false for _, account := range accounts { - if bytes.Equal(account.AccountAddress, addressBytes) { + if bytes.Equal(account.Address.Bytes(), addressBytes) { found = true } } return found, nil } -// DeleteUser deletes user and accounts associated with user from the database for given userID -func (w *Services) DeleteUser(userID []byte) error { - audit(w, "Deleting user: %s", hexutils.BytesToHex(userID)) - - err := w.Storage.DeleteUser(userID) - if err != nil { - w.Logger().Error(fmt.Errorf("error deleting user (%s), %w", userID, err).Error()) - return err - } - w.Cache.Remove(userCacheKey(userID)) - return nil -} - func (w *Services) UserExists(userID []byte) bool { audit(w, "Checking if user exists: %s", userID) // Check if user exists and don't log error if user doesn't exist, because we expect this to happen in case of // user revoking encryption token or using different testnet. // todo add a counter here in the future - users, err := w.Storage.GetUser(userID) + user, err := w.Storage.GetUser(userID) if err != nil { return false } - return len(users.PrivateKey) > 0 + return len(user.UserKey) > 0 } func (w *Services) Version() string { @@ -277,7 +232,7 @@ func (w *Services) Version() string { func (w *Services) GetTenNodeHealthStatus() (bool, error) { audit(w, "Getting TEN node health status") - res, err := WithPlainRPCConnection[bool](context.Background(), w, func(client *gethrpc.Client) (*bool, error) { + res, err := WithPlainRPCConnection[bool](context.Background(), w.BackendRPC, func(client *gethrpc.Client) (*bool, error) { res, err := obsclient.NewObsClient(client).Health() return &res, err }) @@ -286,10 +241,12 @@ func (w *Services) GetTenNodeHealthStatus() (bool, error) { func (w *Services) GetTenNetworkConfig() (tencommon.TenNetworkInfo, error) { audit(w, "Getting TEN network config") - res, err := WithPlainRPCConnection[tencommon.TenNetworkInfo](context.Background(), w, func(client *gethrpc.Client) (*tencommon.TenNetworkInfo, error) { - res, err := obsclient.NewObsClient(client).GetConfig() - return res, err + res, err := WithPlainRPCConnection[tencommon.TenNetworkInfo](context.Background(), w.BackendRPC, func(client *gethrpc.Client) (*tencommon.TenNetworkInfo, error) { + return obsclient.NewObsClient(client).GetConfig() }) + if err != nil { + return tencommon.TenNetworkInfo{}, err + } return *res, err } @@ -310,19 +267,6 @@ func (w *Services) GenerateUserMessageToSign(encryptionToken []byte, formatsSlic return string(message), nil } -func (w *Services) GetUser(userID []byte) (*GWUser, error) { - return cache.WithCache(w.Cache, &cache.Cfg{Type: cache.LongLiving}, userCacheKey(userID), func() (*GWUser, error) { - // todo - use storage with cache - user, err := w.Storage.GetUser(userID) - if err != nil { - return nil, fmt.Errorf("user %s not found. %w", hexutils.BytesToHex(userID), err) - } - result, err := gwUserFromDB(user, w) - return result, err - }) -} - func (w *Services) Stop() { - w.RpcHTTPConnPool.Close(context.Background()) - w.RpcWSConnPool.Close(context.Background()) + w.BackendRPC.Stop() } diff --git a/tools/walletextension/storage/database/common/db_types.go b/tools/walletextension/storage/database/common/db_types.go new file mode 100644 index 0000000000..06e4433864 --- /dev/null +++ b/tools/walletextension/storage/database/common/db_types.go @@ -0,0 +1,40 @@ +package common + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ten-protocol/go-ten/go/common/viewingkey" + common2 "github.com/ten-protocol/go-ten/tools/walletextension/common" +) + +type GWUserDB struct { + UserId []byte `json:"userId"` + PrivateKey []byte `json:"privateKey"` + Accounts []GWAccountDB `json:"accounts"` +} + +type GWAccountDB struct { + AccountAddress []byte `json:"accountAddress"` + Signature []byte `json:"signature"` + SignatureType int `json:"signatureType"` +} + +func (userDB *GWUserDB) ToGWUser() *common2.GWUser { + result := &common2.GWUser{ + UserID: userDB.UserId, + Accounts: make(map[common.Address]*common2.GWAccount), + UserKey: userDB.PrivateKey, + } + + for _, accountDB := range userDB.Accounts { + address := common.BytesToAddress(accountDB.AccountAddress) + gwAccount := &common2.GWAccount{ + User: result, + Address: &address, + Signature: accountDB.Signature, + SignatureType: viewingkey.SignatureType(accountDB.SignatureType), + } + result.Accounts[address] = gwAccount + } + + return result +} diff --git a/tools/walletextension/storage/database/cosmosdb/cosmosdb.go b/tools/walletextension/storage/database/cosmosdb/cosmosdb.go index e8296f2523..cb8aead713 100644 --- a/tools/walletextension/storage/database/cosmosdb/cosmosdb.go +++ b/tools/walletextension/storage/database/cosmosdb/cosmosdb.go @@ -7,6 +7,8 @@ import ( "fmt" "strings" + dbcommon "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/common" + "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" @@ -85,10 +87,10 @@ func NewCosmosDB(connectionString string, encryptionKey []byte) (*CosmosDB, erro } func (c *CosmosDB) AddUser(userID []byte, privateKey []byte) error { - user := common.GWUserDB{ + user := dbcommon.GWUserDB{ UserId: userID, PrivateKey: privateKey, - Accounts: []common.GWAccountDB{}, + Accounts: []dbcommon.GWAccountDB{}, } userJSON, err := json.Marshal(user) if err != nil { @@ -158,14 +160,14 @@ func (c *CosmosDB) AddAccount(userID []byte, accountAddress []byte, signature [] return fmt.Errorf("failed to decrypt data: %w", err) } - var user common.GWUserDB + var user dbcommon.GWUserDB err = json.Unmarshal(data, &user) if err != nil { return fmt.Errorf("failed to unmarshal user data: %w", err) } // Add the new account - newAccount := common.GWAccountDB{ + newAccount := dbcommon.GWAccountDB{ AccountAddress: accountAddress, Signature: signature, SignatureType: int(signatureType), @@ -198,7 +200,7 @@ func (c *CosmosDB) AddAccount(userID []byte, accountAddress []byte, signature [] return nil } -func (c *CosmosDB) GetUser(userID []byte) (common.GWUserDB, error) { +func (c *CosmosDB) GetUser(userID []byte) (*common.GWUser, error) { key := c.encryptor.HashWithHMAC(userID) keyString := hex.EncodeToString(key) partitionKey := azcosmos.NewPartitionKeyString(keyString) @@ -206,24 +208,24 @@ func (c *CosmosDB) GetUser(userID []byte) (common.GWUserDB, error) { itemResponse, err := c.usersContainer.ReadItem(ctx, partitionKey, keyString, nil) if err != nil { - return common.GWUserDB{}, errutil.ErrNotFound + return nil, errutil.ErrNotFound } var doc EncryptedDocument err = json.Unmarshal(itemResponse.Value, &doc) if err != nil { - return common.GWUserDB{}, fmt.Errorf("failed to unmarshal document: %w", err) + return nil, fmt.Errorf("failed to unmarshal document: %w", err) } data, err := c.encryptor.Decrypt(doc.Data) if err != nil { - return common.GWUserDB{}, fmt.Errorf("failed to decrypt data: %w", err) + return nil, fmt.Errorf("failed to decrypt data: %w", err) } - var user common.GWUserDB + var user dbcommon.GWUserDB err = json.Unmarshal(data, &user) if err != nil { - return common.GWUserDB{}, fmt.Errorf("failed to unmarshal user data: %w", err) + return nil, fmt.Errorf("failed to unmarshal user data: %w", err) } - return user, nil + return user.ToGWUser(), nil } diff --git a/tools/walletextension/storage/database/sqlite/sqlite.go b/tools/walletextension/storage/database/sqlite/sqlite.go index f08086bf5b..e58926a104 100644 --- a/tools/walletextension/storage/database/sqlite/sqlite.go +++ b/tools/walletextension/storage/database/sqlite/sqlite.go @@ -9,10 +9,13 @@ package sqlite import ( "database/sql" "encoding/json" + "errors" "fmt" "os" "path/filepath" + dbcommon "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/common" + "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/ten-protocol/go-ten/tools/walletextension/common" @@ -21,11 +24,11 @@ import ( "github.com/ten-protocol/go-ten/go/common/errutil" ) -type Database struct { +type SqliteDB struct { db *sql.DB } -func NewSqliteDatabase(dbPath string) (*Database, error) { +func NewSqliteDatabase(dbPath string) (*SqliteDB, error) { // load the db file dbFilePath, err := createOrLoad(dbPath) if err != nil { @@ -56,14 +59,14 @@ func NewSqliteDatabase(dbPath string) (*Database, error) { // Remove the accounts table as it will be stored within the user_data JSON - return &Database{db: db}, nil + return &SqliteDB{db: db}, nil } -func (s *Database) AddUser(userID []byte, privateKey []byte) error { - user := common.GWUserDB{ +func (s *SqliteDB) AddUser(userID []byte, privateKey []byte) error { + user := dbcommon.GWUserDB{ UserId: userID, PrivateKey: privateKey, - Accounts: []common.GWAccountDB{}, + Accounts: []dbcommon.GWAccountDB{}, } userJSON, err := json.Marshal(user) if err != nil { @@ -84,7 +87,7 @@ func (s *Database) AddUser(userID []byte, privateKey []byte) error { return nil } -func (s *Database) DeleteUser(userID []byte) error { +func (s *SqliteDB) DeleteUser(userID []byte) error { stmt, err := s.db.Prepare("DELETE FROM users WHERE id = ?") if err != nil { return err @@ -99,20 +102,20 @@ func (s *Database) DeleteUser(userID []byte) error { return nil } -func (s *Database) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { +func (s *SqliteDB) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { var userDataJSON string err := s.db.QueryRow("SELECT user_data FROM users WHERE id = ?", string(userID)).Scan(&userDataJSON) if err != nil { return fmt.Errorf("failed to get user: %w", err) } - var user common.GWUserDB + var user dbcommon.GWUserDB err = json.Unmarshal([]byte(userDataJSON), &user) if err != nil { return fmt.Errorf("failed to unmarshal user data: %w", err) } - newAccount := common.GWAccountDB{ + newAccount := dbcommon.GWAccountDB{ AccountAddress: accountAddress, Signature: signature, SignatureType: int(signatureType), @@ -139,23 +142,23 @@ func (s *Database) AddAccount(userID []byte, accountAddress []byte, signature [] return nil } -func (s *Database) GetUser(userID []byte) (common.GWUserDB, error) { +func (s *SqliteDB) GetUser(userID []byte) (*common.GWUser, error) { var userDataJSON string err := s.db.QueryRow("SELECT user_data FROM users WHERE id = ?", string(userID)).Scan(&userDataJSON) if err != nil { - if err == sql.ErrNoRows { - return common.GWUserDB{}, fmt.Errorf("failed to get user: %w", errutil.ErrNotFound) + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("failed to get user: %w", errutil.ErrNotFound) } - return common.GWUserDB{}, fmt.Errorf("failed to get user: %w", err) + return nil, fmt.Errorf("failed to get user: %w", err) } - var user common.GWUserDB + var user dbcommon.GWUserDB err = json.Unmarshal([]byte(userDataJSON), &user) if err != nil { - return common.GWUserDB{}, fmt.Errorf("failed to unmarshal user data: %w", err) + return nil, fmt.Errorf("failed to unmarshal user data: %w", err) } - return user, nil + return user.ToGWUser(), nil } func createOrLoad(dbPath string) (string, error) { diff --git a/tools/walletextension/storage/storage.go b/tools/walletextension/storage/storage.go index cb1fb6304b..ba8836734d 100644 --- a/tools/walletextension/storage/storage.go +++ b/tools/walletextension/storage/storage.go @@ -3,6 +3,7 @@ package storage import ( "fmt" + gethlog "github.com/ethereum/go-ethereum/log" "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/ten-protocol/go-ten/tools/walletextension/common" @@ -10,20 +11,27 @@ import ( "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/sqlite" ) -type Storage interface { +type UserStorage interface { AddUser(userID []byte, privateKey []byte) error DeleteUser(userID []byte) error AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error - GetUser(userID []byte) (common.GWUserDB, error) + GetUser(userID []byte) (*common.GWUser, error) } -func New(dbType string, dbConnectionURL, dbPath string, randomKey []byte) (Storage, error) { +func New(dbType, dbConnectionURL, dbPath string, randomKey []byte, logger gethlog.Logger) (UserStorage, error) { + var underlyingStorage UserStorage + var err error switch dbType { case "sqlite": - return sqlite.NewSqliteDatabase(dbPath) + underlyingStorage, err = sqlite.NewSqliteDatabase(dbPath) case "cosmosDB": - return cosmosdb.NewCosmosDB(dbConnectionURL, randomKey) + underlyingStorage, err = cosmosdb.NewCosmosDB(dbConnectionURL, randomKey) + default: + panic(fmt.Sprintf("unknown db type: %s", dbType)) + } + if err != nil { + return nil, fmt.Errorf("failed to initialize underlying storage: %w", err) } - return nil, fmt.Errorf("unknown db %s", dbType) + return NewUserStorageWithCache(underlyingStorage, logger) } diff --git a/tools/walletextension/storage/storage_test.go b/tools/walletextension/storage/storage_test.go index 2053fc37c5..1913555007 100644 --- a/tools/walletextension/storage/storage_test.go +++ b/tools/walletextension/storage/storage_test.go @@ -13,7 +13,7 @@ import ( wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" ) -var tests = map[string]func(storage Storage, t *testing.T){ +var tests = map[string]func(storage UserStorage, t *testing.T){ "testAddAndGetUser": testAddAndGetUser, "testAddAccounts": testAddAccounts, "testDeleteUser": testDeleteUser, @@ -35,7 +35,7 @@ func TestGatewayStorage(t *testing.T) { } } -func testAddAndGetUser(storage Storage, t *testing.T) { +func testAddAndGetUser(storage UserStorage, t *testing.T) { // Generate random user ID and private key userID := make([]byte, 20) _, err := rand.Read(userID) @@ -66,7 +66,7 @@ func testAddAndGetUser(storage Storage, t *testing.T) { } } -func testAddAccounts(storage Storage, t *testing.T) { +func testAddAccounts(storage UserStorage, t *testing.T) { // Generate random user ID, private key, and account details userID := make([]byte, 20) rand.Read(userID) @@ -136,7 +136,7 @@ func testAddAccounts(storage Storage, t *testing.T) { } } -func testDeleteUser(storage Storage, t *testing.T) { +func testDeleteUser(storage UserStorage, t *testing.T) { // Generate random user ID and private key userID := make([]byte, 20) rand.Read(userID) @@ -163,7 +163,7 @@ func testDeleteUser(storage Storage, t *testing.T) { } } -func testGetUser(storage Storage, t *testing.T) { +func testGetUser(storage UserStorage, t *testing.T) { // Generate random user ID and private key userID := make([]byte, 20) rand.Read(userID) diff --git a/tools/walletextension/storage/storage_with_cache.go b/tools/walletextension/storage/storage_with_cache.go index 7c83d75423..08fbbc7cc9 100644 --- a/tools/walletextension/storage/storage_with_cache.go +++ b/tools/walletextension/storage/storage_with_cache.go @@ -1,50 +1,39 @@ package storage import ( - "sync" - "time" - "github.com/ethereum/go-ethereum/log" "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/ten-protocol/go-ten/tools/walletextension/cache" wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" ) -// StorageWithCache implements the Storage interface with caching -type StorageWithCache struct { - storage Storage +// UserStorageWithCache implements the UserStorage interface with caching +type UserStorageWithCache struct { + storage UserStorage cache cache.Cache - mu sync.RWMutex } -// NewStorageWithCache creates a new StorageWithCache instance -func NewStorageWithCache(storage Storage, logger log.Logger) (*StorageWithCache, error) { - c, err := cache.NewCache(logger) +const UserCacheSize = 10_000 + +// NewUserStorageWithCache creates a new UserStorageWithCache instance +func NewUserStorageWithCache(storage UserStorage, logger log.Logger) (*UserStorageWithCache, error) { + c, err := cache.NewCache(UserCacheSize, logger) if err != nil { return nil, err } - return &StorageWithCache{ + return &UserStorageWithCache{ storage: storage, cache: c, }, nil } // AddUser adds a new user and invalidates the cache for the userID -func (s *StorageWithCache) AddUser(userID []byte, privateKey []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - err := s.storage.AddUser(userID, privateKey) - if err != nil { - return err - } - s.cache.Remove(userID) - return nil +func (s *UserStorageWithCache) AddUser(userID []byte, privateKey []byte) error { + return s.storage.AddUser(userID, privateKey) } // DeleteUser deletes a user and invalidates the cache for the userID -func (s *StorageWithCache) DeleteUser(userID []byte) error { - s.mu.Lock() - defer s.mu.Unlock() +func (s *UserStorageWithCache) DeleteUser(userID []byte) error { err := s.storage.DeleteUser(userID) if err != nil { return err @@ -54,9 +43,7 @@ func (s *StorageWithCache) DeleteUser(userID []byte) error { } // AddAccount adds an account to a user and invalidates the cache for the userID -func (s *StorageWithCache) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { - s.mu.Lock() - defer s.mu.Unlock() +func (s *UserStorageWithCache) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { err := s.storage.AddAccount(userID, accountAddress, signature, signatureType) if err != nil { return err @@ -66,25 +53,8 @@ func (s *StorageWithCache) AddAccount(userID []byte, accountAddress []byte, sign } // GetUser retrieves a user from the cache or underlying storage -func (s *StorageWithCache) GetUser(userID []byte) (wecommon.GWUserDB, error) { - s.mu.RLock() - // Check if the user is in the cache - if cachedUser, found := s.cache.Get(userID); found { - s.mu.RUnlock() - return cachedUser.(wecommon.GWUserDB), nil - } - s.mu.RUnlock() - - // If not in cache, retrieve from storage - user, err := s.storage.GetUser(userID) - if err != nil { - return wecommon.GWUserDB{}, err - } - - // Store the retrieved user in the cache - s.mu.Lock() - s.cache.Set(userID, user, 5*time.Minute) - s.mu.Unlock() - - return user, nil +func (s *UserStorageWithCache) GetUser(userID []byte) (*wecommon.GWUser, error) { + return cache.WithCache(s.cache, &cache.Cfg{Type: cache.LongLiving}, userID, func() (*wecommon.GWUser, error) { + return s.storage.GetUser(userID) + }) } diff --git a/tools/walletextension/walletextension_container.go b/tools/walletextension/walletextension_container.go index 64f55b7c7d..87915d2a37 100644 --- a/tools/walletextension/walletextension_container.go +++ b/tools/walletextension/walletextension_container.go @@ -47,7 +47,7 @@ func NewContainerFromConfig(config wecommon.Config, logger gethlog.Logger) *Cont } // start the database with the encryption key - databaseStorage, err := storage.New(config.DBType, config.DBConnectionURL, config.DBPathOverride, encryptionKey) + userStorage, err := storage.New(config.DBType, config.DBConnectionURL, config.DBPathOverride, encryptionKey, logger) if err != nil { logger.Crit("unable to create database to store viewing keys ", log.ErrKey, err) os.Exit(1) @@ -60,7 +60,7 @@ func NewContainerFromConfig(config wecommon.Config, logger gethlog.Logger) *Cont } stopControl := stopcontrol.New() - walletExt := services.NewServices(hostRPCBindAddrHTTP, hostRPCBindAddrWS, databaseStorage, stopControl, version, logger, &config) + walletExt := services.NewServices(hostRPCBindAddrHTTP, hostRPCBindAddrWS, userStorage, stopControl, version, logger, &config) cfg := &node.RPCConfig{ EnableHTTP: true, HTTPPort: config.WalletExtensionPortHTTP,