diff --git a/tools/walletextension/accountmanager/account_manager.go b/tools/walletextension/accountmanager/account_manager.go index 5cce5130fe..af0149a80d 100644 --- a/tools/walletextension/accountmanager/account_manager.go +++ b/tools/walletextension/accountmanager/account_manager.go @@ -51,6 +51,18 @@ func NewAccountManager(unauthedClient rpc.Client, logger gethlog.Logger) *Accoun } } +// GetAllAddressesWithClients returns a list of addresses which already have clients (are in accountClients map) +func (m *AccountManager) GetAllAddressesWithClients() []string { + m.accountsMutex.RLock() + defer m.accountsMutex.RUnlock() + + addresses := make([]string, 0, len(m.accountClients)) + for address := range m.accountClients { + addresses = append(addresses, address.Hex()) + } + return addresses +} + // AddClient adds a client to the list of clients, keyed by account address. func (m *AccountManager) AddClient(address gethcommon.Address, client *rpc.EncRPCClient) { m.accountsMutex.Lock() diff --git a/tools/walletextension/container/walletextension_container.go b/tools/walletextension/container/walletextension_container.go index 8551cf9116..3e2ebc94c4 100644 --- a/tools/walletextension/container/walletextension_container.go +++ b/tools/walletextension/container/walletextension_container.go @@ -7,7 +7,6 @@ import ( "net/http" "os" - "github.com/ethereum/go-ethereum/common" "github.com/ten-protocol/go-ten/go/common/log" "github.com/ten-protocol/go-ten/go/common/stopcontrol" "github.com/ten-protocol/go-ten/go/rpc" @@ -38,18 +37,16 @@ func NewWalletExtensionContainerFromConfig(config config.Config, logger gethlog. unAuthedClient, err := rpc.NewNetworkClient(hostRPCBindAddr) if err != nil { logger.Crit("unable to create temporary client for request ", log.ErrKey, err) + os.Exit(1) } - userAccountManager := useraccountmanager.NewUserAccountManager(unAuthedClient, logger) - // start the database databaseStorage, err := storage.New(config.DBType, config.DBConnectionURL, config.DBPathOverride) if err != nil { logger.Crit("unable to create database to store viewing keys ", log.ErrKey, err) + os.Exit(1) } - - // Get all the data from the database and add all the clients for all users - // todo (@ziga) - implement lazy loading for clients to reduce number of connections and speed up loading + userAccountManager := useraccountmanager.NewUserAccountManager(unAuthedClient, logger, databaseStorage, hostRPCBindAddr) // add default user (when no UserID is provided in the query parameter - for WE endpoints) userAccountManager.AddAndReturnAccountManager(hex.EncodeToString([]byte(wecommon.DefaultUser))) @@ -62,21 +59,8 @@ func NewWalletExtensionContainerFromConfig(config config.Config, logger gethlog. // iterate over users create accountManagers and add all accounts to them per user for _, user := range allUsers { - currentUserAccountManager := userAccountManager.AddAndReturnAccountManager(hex.EncodeToString(user.UserID)) - - accounts, err := databaseStorage.GetAccounts(user.UserID) - if err != nil { - logger.Error(fmt.Errorf("error getting accounts for user: %s, %w", hex.EncodeToString(user.UserID), err).Error()) - } - for _, account := range accounts { - encClient, err := wecommon.CreateEncClient(hostRPCBindAddr, account.AccountAddress, user.PrivateKey, account.Signature, logger) - if err != nil { - logger.Error(fmt.Errorf("error creating new client, %w", err).Error()) - } - - // add client to current userAccountManager - currentUserAccountManager.AddClient(common.BytesToAddress(account.AccountAddress), encClient) - } + userAccountManager.AddAndReturnAccountManager(hex.EncodeToString(user.UserID)) + logger.Info(fmt.Sprintf("account manager added for user: %s", hex.EncodeToString(user.UserID))) } // captures version in the env vars diff --git a/tools/walletextension/useraccountmanager/user_account_manager.go b/tools/walletextension/useraccountmanager/user_account_manager.go index ca26a21705..0c787a8b07 100644 --- a/tools/walletextension/useraccountmanager/user_account_manager.go +++ b/tools/walletextension/useraccountmanager/user_account_manager.go @@ -1,23 +1,31 @@ package useraccountmanager import ( + "encoding/hex" "fmt" + "github.com/ethereum/go-ethereum/common" gethlog "github.com/ethereum/go-ethereum/log" "github.com/ten-protocol/go-ten/go/rpc" "github.com/ten-protocol/go-ten/tools/walletextension/accountmanager" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/storage" ) type UserAccountManager struct { userAccountManager map[string]*accountmanager.AccountManager unauthenticatedClient rpc.Client + storage storage.Storage + hostRPCBinAddr string logger gethlog.Logger } -func NewUserAccountManager(unauthenticatedClient rpc.Client, logger gethlog.Logger) UserAccountManager { +func NewUserAccountManager(unauthenticatedClient rpc.Client, logger gethlog.Logger, storage storage.Storage, hostRPCBindAddr string) UserAccountManager { return UserAccountManager{ userAccountManager: make(map[string]*accountmanager.AccountManager), unauthenticatedClient: unauthenticatedClient, + storage: storage, + hostRPCBinAddr: hostRPCBindAddr, logger: logger, } } @@ -34,14 +42,62 @@ func (m *UserAccountManager) AddAndReturnAccountManager(userID string) *accountm } // GetUserAccountManager retrieves the UserAccountManager associated with the given userID. -// It returns the UserAccountManager and nil error if one exists. +// it returns the UserAccountManager and nil error if one exists. +// before returning it checks the database and creates all missing clients for that userID +// (we are not loading all of them at startup to limit the number of established connections) // If a UserAccountManager does not exist for the userID, it returns nil and an error. func (m *UserAccountManager) GetUserAccountManager(userID string) (*accountmanager.AccountManager, error) { userAccManager, exists := m.userAccountManager[userID] - if exists { + if !exists { + return nil, fmt.Errorf("UserAccountManager doesn't exist for user: %s", userID) + } + + // we have userAccountManager as expected. + // now we need to create all clients that don't exist there yet + addressesWithClients := userAccManager.GetAllAddressesWithClients() + + // get all addresses for current userID + userIDbytes, err := hex.DecodeString(userID) + if err != nil { + return nil, err + } + + // log that we don't have a storage, but still return existing userAccountManager + // this should never happen, but is useful for tests + if m.storage == nil { + m.logger.Error("storage is nil in UserAccountManager") return userAccManager, nil } - return nil, fmt.Errorf("UserAccountManager doesn't exist for user: %s", userID) + + databaseAccounts, err := m.storage.GetAccounts(userIDbytes) + if err != nil { + return nil, err + } + + userPrivateKey, err := m.storage.GetUserPrivateKey(userIDbytes) + if err != nil { + return nil, err + } + + for _, account := range databaseAccounts { + addressHexString := common.BytesToAddress(account.AccountAddress).Hex() + // check if a client for the current address already exists (and skip it if it does) + if addressAlreadyExists(addressHexString, addressesWithClients) { + continue + } + + // create a new client + encClient, err := wecommon.CreateEncClient(m.hostRPCBinAddr, account.AccountAddress, userPrivateKey, account.Signature, m.logger) + if err != nil { + m.logger.Error(fmt.Errorf("error creating new client, %w", err).Error()) + } + + // add a client to requested userAccountManager + userAccManager.AddClient(common.BytesToAddress(account.AccountAddress), encClient) + addressesWithClients = append(addressesWithClients, addressHexString) + } + + return userAccManager, nil } // DeleteUserAccountManager removes the UserAccountManager associated with the given userID. @@ -54,3 +110,13 @@ func (m *UserAccountManager) DeleteUserAccountManager(userID string) error { delete(m.userAccountManager, userID) return nil } + +// addressAlreadyExists is a helper function to check if an address is already present in a list of existing addresses +func addressAlreadyExists(str string, list []string) bool { + for _, v := range list { + if v == str { + return true + } + } + return false +} diff --git a/tools/walletextension/useraccountmanager/user_account_manager_test.go b/tools/walletextension/useraccountmanager/user_account_manager_test.go index d6fd4aee48..6451600790 100644 --- a/tools/walletextension/useraccountmanager/user_account_manager_test.go +++ b/tools/walletextension/useraccountmanager/user_account_manager_test.go @@ -9,9 +9,9 @@ import ( func TestAddingAndGettingUserAccountManagers(t *testing.T) { unauthedClient, _ := rpc.NewNetworkClient("ws://test") - userAccountManager := NewUserAccountManager(unauthedClient, log.New()) - userID1 := "user1" - userID2 := "user2" + userAccountManager := NewUserAccountManager(unauthedClient, log.New(), nil, "ws://test") + userID1 := "4A6F686E20446F65" + userID2 := "7A65746F65A2676F" // Test adding and getting account manager for userID1 userAccountManager.AddAndReturnAccountManager(userID1) @@ -21,7 +21,6 @@ func TestAddingAndGettingUserAccountManagers(t *testing.T) { } // We should get error if we try to get Account manager for User2 _, err = userAccountManager.GetUserAccountManager(userID2) - if err == nil { t.Fatal("expecting error when trying to get AccountManager for user that doesn't exist.") } @@ -51,7 +50,7 @@ func TestAddingAndGettingUserAccountManagers(t *testing.T) { func TestDeletingUserAccountManagers(t *testing.T) { unauthedClient, _ := rpc.NewNetworkClient("ws://test") - userAccountManager := NewUserAccountManager(unauthedClient, log.New()) + userAccountManager := NewUserAccountManager(unauthedClient, log.New(), nil, "") userID := "user1" // Add an account manager for the user diff --git a/tools/walletextension/wallet_extension.go b/tools/walletextension/wallet_extension.go index f7aec8aaa1..eb3b9aaf26 100644 --- a/tools/walletextension/wallet_extension.go +++ b/tools/walletextension/wallet_extension.go @@ -144,6 +144,11 @@ func (w *WalletExtension) SubmitViewingKey(address gethcommon.Address, signature signature[64] -= 27 vk.Signature = signature + + err := w.storage.AddUser([]byte(common.DefaultUser), crypto.FromECDSA(vk.PrivateKey.ExportECDSA())) + if err != nil { + return fmt.Errorf("error saving user: %s", common.DefaultUser) + } // create an encrypted RPC client with the signed VK and register it with the enclave // todo (@ziga) - Create the clients lazily, to reduce connections to the host. client, err := rpc.NewEncNetworkClient(w.hostAddr, vk, w.logger) @@ -157,11 +162,6 @@ func (w *WalletExtension) SubmitViewingKey(address gethcommon.Address, signature defaultAccountManager.AddClient(address, client) - err = w.storage.AddUser([]byte(common.DefaultUser), crypto.FromECDSA(vk.PrivateKey.ExportECDSA())) - if err != nil { - return fmt.Errorf("error saving user: %s", common.DefaultUser) - } - err = w.storage.AddAccount([]byte(common.DefaultUser), vk.Account.Bytes(), vk.Signature) if err != nil { return fmt.Errorf("error saving account %s for user %s", vk.Account.Hex(), common.DefaultUser)