diff --git a/tools/walletextension/container/walletextension_container.go b/tools/walletextension/container/walletextension_container.go index 6931189e4c..cc5578560a 100644 --- a/tools/walletextension/container/walletextension_container.go +++ b/tools/walletextension/container/walletextension_container.go @@ -23,7 +23,7 @@ import ( type WalletExtensionContainer struct { hostAddr string userAccountManager *useraccountmanager.UserAccountManager - storage *storage.Storage + storage storage.Storage stopControl *stopcontrol.StopControl logger gethlog.Logger walletExt *walletextension.WalletExtension @@ -101,7 +101,7 @@ func NewWalletExtensionContainer( hostAddr string, walletExt *walletextension.WalletExtension, userAccountManager *useraccountmanager.UserAccountManager, - storage *storage.Storage, + storage storage.Storage, stopControl *stopcontrol.StopControl, httpServer *api.Server, wsServer *api.Server, diff --git a/tools/walletextension/storage/sqlite.go b/tools/walletextension/storage/database/sqlite.go similarity index 75% rename from tools/walletextension/storage/sqlite.go rename to tools/walletextension/storage/database/sqlite.go index a8a53c5c1b..1c9cddbec8 100644 --- a/tools/walletextension/storage/sqlite.go +++ b/tools/walletextension/storage/database/sqlite.go @@ -1,8 +1,13 @@ -package storage +package database import ( "database/sql" "fmt" + "os" + "path/filepath" + + obscurocommon "github.com/obscuronet/go-obscuro/go/common" + "github.com/obscuronet/go-obscuro/go/common/errutil" _ "github.com/mattn/go-sqlite3" // sqlite driver for sql.Open() common "github.com/obscuronet/go-obscuro/tools/walletextension/common" @@ -12,8 +17,15 @@ type SqliteDatabase struct { db *sql.DB } -func NewSqliteDatabase(dbName string) (*SqliteDatabase, error) { - db, err := sql.Open("sqlite3", dbName) +func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) { + // load the db file + dbFilePath, err := createOrLoad(dbPath) + if err != nil { + return nil, err + } + + // open the db + db, err := sql.Open("sqlite3", dbFilePath) if err != nil { fmt.Println("Error opening database: ", err) return nil, err @@ -86,7 +98,7 @@ func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) { if err != nil { if err == sql.ErrNoRows { // No rows found for the given userID - return nil, nil + return nil, errutil.ErrNotFound } return nil, err } @@ -154,3 +166,25 @@ func (s *SqliteDatabase) GetAllUsers() ([]common.UserDB, error) { return users, nil } + +func createOrLoad(dbPath string) (string, error) { + // If path is empty we create a random throwaway temp file, otherwise we use the path to the database + if dbPath == "" { + tempDir := filepath.Join("/tmp", "obscuro_gateway", obscurocommon.RandomStr(8)) + err := os.MkdirAll(tempDir, os.ModePerm) + if err != nil { + fmt.Println("Error creating directory: ", tempDir, err) + return "", err + } + dbPath = filepath.Join(tempDir, "gateway_databse.db") + } else { + dir := filepath.Dir(dbPath) + err := os.MkdirAll(dir, 0o755) + if err != nil { + fmt.Println("Error creating directories:", err) + return "", err + } + } + + return dbPath, nil +} diff --git a/tools/walletextension/storage/storage.go b/tools/walletextension/storage/storage.go index 1808d2b977..fa261f0bc8 100644 --- a/tools/walletextension/storage/storage.go +++ b/tools/walletextension/storage/storage.go @@ -1,91 +1,19 @@ package storage import ( - "fmt" - "os" - "path/filepath" - "github.com/obscuronet/go-obscuro/tools/walletextension/common" - - obscurocommon "github.com/obscuronet/go-obscuro/go/common" + "github.com/obscuronet/go-obscuro/tools/walletextension/storage/database" ) -type Storage struct { - db *SqliteDatabase -} - -func New(dbPath string) (*Storage, error) { - // If path is empty we create a random throwaway temp file, otherwise we use the path to the database - if dbPath == "" { - tempDir := filepath.Join("/tmp", "obscuro_gateway", obscurocommon.RandomStr(8)) - err := os.MkdirAll(tempDir, os.ModePerm) - if err != nil { - fmt.Println("Error creating directory: ", tempDir, err) - return nil, err - } - dbPath = filepath.Join(tempDir, "gateway_databse.db") - } else { - dir := filepath.Dir(dbPath) - err := os.MkdirAll(dir, 0o755) - if err != nil { - fmt.Println("Error creating directories:", err) - return nil, err - } - } - - newDB, err := NewSqliteDatabase(dbPath) - if err != nil { - fmt.Println("Error creating database:", err) - return nil, err - } - - return &Storage{db: newDB}, nil -} - -func (s *Storage) AddUser(userID []byte, privateKey []byte) error { - err := s.db.AddUser(userID, privateKey) - if err != nil { - return err - } - return nil -} - -func (s *Storage) DeleteUser(userID []byte) error { - err := s.db.DeleteUser(userID) - if err != nil { - return err - } - return nil -} - -func (s *Storage) GetUserPrivateKey(userID []byte) ([]byte, error) { - privateKey, err := s.db.GetUserPrivateKey(userID) - if err != nil { - return nil, err - } - return privateKey, nil -} - -func (s *Storage) AddAccount(userID []byte, accountAddress []byte, signature []byte) error { - err := s.db.AddAccount(userID, accountAddress, signature) - if err != nil { - return err - } - return nil -} - -func (s *Storage) GetAccounts(userID []byte) ([]common.AccountDB, error) { - accounts, err := s.db.GetAccounts(userID) - if err != nil { - return nil, err - } - return accounts, nil +type Storage interface { + AddUser(userID []byte, privateKey []byte) error + DeleteUser(userID []byte) error + GetUserPrivateKey(userID []byte) ([]byte, error) + AddAccount(userID []byte, accountAddress []byte, signature []byte) error + GetAccounts(userID []byte) ([]common.AccountDB, error) + GetAllUsers() ([]common.UserDB, error) } -func (s *Storage) GetAllUsers() ([]common.UserDB, error) { - users, err := s.db.GetAllUsers() - if err != nil { - return nil, err - } - return users, nil +func New(dbPath string) (Storage, error) { + return database.NewSqliteDatabase(dbPath) } diff --git a/tools/walletextension/storage/storage_test.go b/tools/walletextension/storage/storage_test.go index 69b16209c5..fcab6985e4 100644 --- a/tools/walletextension/storage/storage_test.go +++ b/tools/walletextension/storage/storage_test.go @@ -2,19 +2,36 @@ package storage import ( "bytes" + "errors" "testing" + + "github.com/obscuronet/go-obscuro/go/common/errutil" + "github.com/stretchr/testify/require" ) -func TestAddAndGetUser(t *testing.T) { - storage, err := New("") - if err != nil { - t.Fatal(err) +var tests = map[string]func(storage Storage, t *testing.T){ + "testAddAndGetUser": testAddAndGetUser, + "testAddAndGetAccounts": testAddAndGetAccounts, + "testDeleteUser": testDeleteUser, + "testGetAllUsers": testGetAllUsers, +} + +func TestSQLiteGatewayDB(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { + storage, err := New("") + require.NoError(t, err) + + test(storage, t) + }) } +} +func testAddAndGetUser(storage Storage, t *testing.T) { userID := []byte("userID") privateKey := []byte("privateKey") - err = storage.AddUser(userID, privateKey) + err := storage.AddUser(userID, privateKey) if err != nil { t.Fatal(err) } @@ -29,18 +46,13 @@ func TestAddAndGetUser(t *testing.T) { } } -func TestAddAndGetAccounts(t *testing.T) { - storage, err := New("") - if err != nil { - t.Fatal(err) - } - +func testAddAndGetAccounts(storage Storage, t *testing.T) { userID := []byte("userID") privateKey := []byte("privateKey") accountAddress1 := []byte("accountAddress1") signature1 := []byte("signature1") - err = storage.AddUser(userID, privateKey) + err := storage.AddUser(userID, privateKey) if err != nil { t.Fatal(err) } @@ -87,3 +99,47 @@ func TestAddAndGetAccounts(t *testing.T) { t.Errorf("Account 2 was not found in the result") } } + +func testDeleteUser(storage Storage, t *testing.T) { + userID := []byte("testDeleteUserID") + privateKey := []byte("testDeleteUserPrivateKey") + + err := storage.AddUser(userID, privateKey) + if err != nil { + t.Fatal(err) + } + + err = storage.DeleteUser(userID) + if err != nil { + t.Fatal(err) + } + + _, err = storage.GetUserPrivateKey(userID) + if err == nil || !errors.Is(err, errutil.ErrNotFound) { + t.Fatal("Expected error when getting deleted user, but got none") + } +} + +func testGetAllUsers(storage Storage, t *testing.T) { + initialUsers, err := storage.GetAllUsers() + if err != nil { + t.Fatal(err) + } + + userID := []byte("getAllUsersTestID") + privateKey := []byte("getAllUsersTestPrivateKey") + + err = storage.AddUser(userID, privateKey) + if err != nil { + t.Fatal(err) + } + + afterInsertUsers, err := storage.GetAllUsers() + if err != nil { + t.Fatal(err) + } + + if len(afterInsertUsers) != len(initialUsers)+1 { + t.Errorf("Expected user count to increase by 1. Got %d initially and %d after insert", len(initialUsers), len(afterInsertUsers)) + } +} diff --git a/tools/walletextension/wallet_extension.go b/tools/walletextension/wallet_extension.go index 77fd39456e..b5d5061c4e 100644 --- a/tools/walletextension/wallet_extension.go +++ b/tools/walletextension/wallet_extension.go @@ -8,6 +8,8 @@ import ( "fmt" "math/big" + "github.com/obscuronet/go-obscuro/go/common/log" + "github.com/obscuronet/go-obscuro/tools/walletextension/useraccountmanager" "github.com/ethereum/go-ethereum/crypto" @@ -32,7 +34,7 @@ type WalletExtension struct { hostAddr string // The address on which the Obscuro host can be reached. userAccountManager *useraccountmanager.UserAccountManager unsignedVKs map[gethcommon.Address]*viewingkey.ViewingKey // Map temporarily holding VKs that have been generated but not yet signed - storage *storage.Storage + storage storage.Storage logger gethlog.Logger stopControl *stopcontrol.StopControl } @@ -40,7 +42,7 @@ type WalletExtension struct { func New( hostAddr string, userAccountManager *useraccountmanager.UserAccountManager, - storage *storage.Storage, + storage storage.Storage, stopControl *stopcontrol.StopControl, logger gethlog.Logger, ) *WalletExtension { @@ -367,9 +369,9 @@ func (w *WalletExtension) getStorageAtInterceptor(request *accountmanager.RPCReq return nil } - key, err := w.storage.GetUserPrivateKey(userID) - if err != nil || len(key) == 0 { - w.logger.Info("Trying to get userID, but it is not present in our database: ") + _, err = w.storage.GetUserPrivateKey(userID) + if err != nil { + w.logger.Info("Trying to get userID, but it is not present in our database: ", log.ErrKey, err) return nil } response := map[string]interface{}{}