Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor storage to allow for different dbs + tests #1517

Merged
merged 3 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tools/walletextension/container/walletextension_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
type WalletExtensionContainer struct {
hostAddr string
userAccountManager *useraccountmanager.UserAccountManager
storage *storage.Storage
storage storage.Storage
stopControl *stopcontrol.StopControl
logger gethlog.Logger
walletExt *walletextension.WalletExtension
Expand Down Expand Up @@ -101,7 +101,7 @@ func NewWalletExtensionContainer(
hostAddr string,
walletExt *walletextension.WalletExtension,
userAccountManager *useraccountmanager.UserAccountManager,
storage *storage.Storage,
storage storage.Storage,
stopControl *stopcontrol.StopControl,
httpServer *api.Server,
wsServer *api.Server,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package storage
package database

import (
"database/sql"
"fmt"
"os"
"path/filepath"

obscurocommon "github.com/obscuronet/go-obscuro/go/common"
"github.com/obscuronet/go-obscuro/go/common/errutil"

_ "github.com/mattn/go-sqlite3" // sqlite driver for sql.Open()
common "github.com/obscuronet/go-obscuro/tools/walletextension/common"
Expand All @@ -12,8 +17,15 @@ type SqliteDatabase struct {
db *sql.DB
}

func NewSqliteDatabase(dbName string) (*SqliteDatabase, error) {
db, err := sql.Open("sqlite3", dbName)
func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) {
// load the db file
dbFilePath, err := createOrLoad(dbPath)
if err != nil {
return nil, err
}

// open the db
db, err := sql.Open("sqlite3", dbFilePath)
if err != nil {
fmt.Println("Error opening database: ", err)
return nil, err
Expand Down Expand Up @@ -86,7 +98,7 @@ func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) {
if err != nil {
if err == sql.ErrNoRows {
// No rows found for the given userID
return nil, nil
return nil, errutil.ErrNotFound
}
return nil, err
}
Expand Down Expand Up @@ -154,3 +166,25 @@ func (s *SqliteDatabase) GetAllUsers() ([]common.UserDB, error) {

return users, nil
}

func createOrLoad(dbPath string) (string, error) {
// If path is empty we create a random throwaway temp file, otherwise we use the path to the database
if dbPath == "" {
tempDir := filepath.Join("/tmp", "obscuro_gateway", obscurocommon.RandomStr(8))
err := os.MkdirAll(tempDir, os.ModePerm)
if err != nil {
fmt.Println("Error creating directory: ", tempDir, err)
return "", err
}
dbPath = filepath.Join(tempDir, "gateway_databse.db")
} else {
dir := filepath.Dir(dbPath)
err := os.MkdirAll(dir, 0o755)
if err != nil {
fmt.Println("Error creating directories:", err)
return "", err
}
}

return dbPath, nil
}
92 changes: 10 additions & 82 deletions tools/walletextension/storage/storage.go
Original file line number Diff line number Diff line change
@@ -1,91 +1,19 @@
package storage

import (
"fmt"
"os"
"path/filepath"

"github.com/obscuronet/go-obscuro/tools/walletextension/common"

obscurocommon "github.com/obscuronet/go-obscuro/go/common"
"github.com/obscuronet/go-obscuro/tools/walletextension/storage/database"
)

type Storage struct {
db *SqliteDatabase
}

func New(dbPath string) (*Storage, error) {
// If path is empty we create a random throwaway temp file, otherwise we use the path to the database
if dbPath == "" {
tempDir := filepath.Join("/tmp", "obscuro_gateway", obscurocommon.RandomStr(8))
err := os.MkdirAll(tempDir, os.ModePerm)
if err != nil {
fmt.Println("Error creating directory: ", tempDir, err)
return nil, err
}
dbPath = filepath.Join(tempDir, "gateway_databse.db")
} else {
dir := filepath.Dir(dbPath)
err := os.MkdirAll(dir, 0o755)
if err != nil {
fmt.Println("Error creating directories:", err)
return nil, err
}
}

newDB, err := NewSqliteDatabase(dbPath)
if err != nil {
fmt.Println("Error creating database:", err)
return nil, err
}

return &Storage{db: newDB}, nil
}

func (s *Storage) AddUser(userID []byte, privateKey []byte) error {
err := s.db.AddUser(userID, privateKey)
if err != nil {
return err
}
return nil
}

func (s *Storage) DeleteUser(userID []byte) error {
err := s.db.DeleteUser(userID)
if err != nil {
return err
}
return nil
}

func (s *Storage) GetUserPrivateKey(userID []byte) ([]byte, error) {
privateKey, err := s.db.GetUserPrivateKey(userID)
if err != nil {
return nil, err
}
return privateKey, nil
}

func (s *Storage) AddAccount(userID []byte, accountAddress []byte, signature []byte) error {
err := s.db.AddAccount(userID, accountAddress, signature)
if err != nil {
return err
}
return nil
}

func (s *Storage) GetAccounts(userID []byte) ([]common.AccountDB, error) {
accounts, err := s.db.GetAccounts(userID)
if err != nil {
return nil, err
}
return accounts, nil
type Storage interface {
AddUser(userID []byte, privateKey []byte) error
DeleteUser(userID []byte) error
GetUserPrivateKey(userID []byte) ([]byte, error)
AddAccount(userID []byte, accountAddress []byte, signature []byte) error
GetAccounts(userID []byte) ([]common.AccountDB, error)
GetAllUsers() ([]common.UserDB, error)
}

func (s *Storage) GetAllUsers() ([]common.UserDB, error) {
users, err := s.db.GetAllUsers()
if err != nil {
return nil, err
}
return users, nil
func New(dbPath string) (Storage, error) {
return database.NewSqliteDatabase(dbPath)
}
80 changes: 68 additions & 12 deletions tools/walletextension/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,36 @@ package storage

import (
"bytes"
"errors"
"testing"

"github.com/obscuronet/go-obscuro/go/common/errutil"
"github.com/stretchr/testify/require"
)

func TestAddAndGetUser(t *testing.T) {
storage, err := New("")
if err != nil {
t.Fatal(err)
var tests = map[string]func(storage Storage, t *testing.T){
"testAddAndGetUser": testAddAndGetUser,
"testAddAndGetAccounts": testAddAndGetAccounts,
"testDeleteUser": testDeleteUser,
"testGetAllUsers": testGetAllUsers,
}

func TestSQLiteGatewayDB(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
storage, err := New("")
require.NoError(t, err)

test(storage, t)
})
}
}

func testAddAndGetUser(storage Storage, t *testing.T) {
userID := []byte("userID")
privateKey := []byte("privateKey")

err = storage.AddUser(userID, privateKey)
err := storage.AddUser(userID, privateKey)
if err != nil {
t.Fatal(err)
}
Expand All @@ -29,18 +46,13 @@ func TestAddAndGetUser(t *testing.T) {
}
}

func TestAddAndGetAccounts(t *testing.T) {
storage, err := New("")
if err != nil {
t.Fatal(err)
}

func testAddAndGetAccounts(storage Storage, t *testing.T) {
userID := []byte("userID")
privateKey := []byte("privateKey")
accountAddress1 := []byte("accountAddress1")
signature1 := []byte("signature1")

err = storage.AddUser(userID, privateKey)
err := storage.AddUser(userID, privateKey)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -87,3 +99,47 @@ func TestAddAndGetAccounts(t *testing.T) {
t.Errorf("Account 2 was not found in the result")
}
}

func testDeleteUser(storage Storage, t *testing.T) {
userID := []byte("testDeleteUserID")
privateKey := []byte("testDeleteUserPrivateKey")

err := storage.AddUser(userID, privateKey)
if err != nil {
t.Fatal(err)
}

err = storage.DeleteUser(userID)
if err != nil {
t.Fatal(err)
}

_, err = storage.GetUserPrivateKey(userID)
if err == nil || !errors.Is(err, errutil.ErrNotFound) {
t.Fatal("Expected error when getting deleted user, but got none")
}
}

func testGetAllUsers(storage Storage, t *testing.T) {
initialUsers, err := storage.GetAllUsers()
if err != nil {
t.Fatal(err)
}

userID := []byte("getAllUsersTestID")
privateKey := []byte("getAllUsersTestPrivateKey")

err = storage.AddUser(userID, privateKey)
if err != nil {
t.Fatal(err)
}

afterInsertUsers, err := storage.GetAllUsers()
if err != nil {
t.Fatal(err)
}

if len(afterInsertUsers) != len(initialUsers)+1 {
t.Errorf("Expected user count to increase by 1. Got %d initially and %d after insert", len(initialUsers), len(afterInsertUsers))
}
}
12 changes: 7 additions & 5 deletions tools/walletextension/wallet_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"fmt"
"math/big"

"github.com/obscuronet/go-obscuro/go/common/log"

"github.com/obscuronet/go-obscuro/tools/walletextension/useraccountmanager"

"github.com/ethereum/go-ethereum/crypto"
Expand All @@ -32,15 +34,15 @@ type WalletExtension struct {
hostAddr string // The address on which the Obscuro host can be reached.
userAccountManager *useraccountmanager.UserAccountManager
unsignedVKs map[gethcommon.Address]*viewingkey.ViewingKey // Map temporarily holding VKs that have been generated but not yet signed
storage *storage.Storage
storage storage.Storage
logger gethlog.Logger
stopControl *stopcontrol.StopControl
}

func New(
hostAddr string,
userAccountManager *useraccountmanager.UserAccountManager,
storage *storage.Storage,
storage storage.Storage,
stopControl *stopcontrol.StopControl,
logger gethlog.Logger,
) *WalletExtension {
Expand Down Expand Up @@ -367,9 +369,9 @@ func (w *WalletExtension) getStorageAtInterceptor(request *accountmanager.RPCReq
return nil
}

key, err := w.storage.GetUserPrivateKey(userID)
if err != nil || len(key) == 0 {
w.logger.Info("Trying to get userID, but it is not present in our database: ")
_, err = w.storage.GetUserPrivateKey(userID)
if err != nil {
otherview marked this conversation as resolved.
Show resolved Hide resolved
w.logger.Info("Trying to get userID, but it is not present in our database: ", log.ErrKey, err)
return nil
}
response := map[string]interface{}{}
Expand Down