diff --git a/.github/workflows/manual-deploy-obscuro-gateway.yml b/.github/workflows/manual-deploy-obscuro-gateway.yml index dd0bf607f5..1257cd279d 100644 --- a/.github/workflows/manual-deploy-obscuro-gateway.yml +++ b/.github/workflows/manual-deploy-obscuro-gateway.yml @@ -63,7 +63,7 @@ jobs: location: 'uksouth' restart-policy: 'Never' environment-variables: PORT=80 - command-line: ./wallet_extension_linux -host=0.0.0.0 -port=80 -portWS=81 -nodeHost=${{ env.OBSCURO_GATEWAY_NODE_HOST }} - ports: 81 80 + command-line: ./wallet_extension_linux -host=0.0.0.0 -port=80 -portWS=81 -nodeHost=${{ env.OBSCURO_GATEWAY_NODE_HOST }} -dbType=mariaDB -dbConnectionURL=obscurouser:${{ secrets.OBSCURO_GATEWAY_MARIADB_USER_PWD }}@tcp(obscurogateway-mariadb-${{ github.event.inputs.testnet_type }}.uksouth.cloudapp.azure.com:3306)/ogdb + ports: 80 81 cpu: 2 memory: 2 diff --git a/tools/walletextension/config/config.go b/tools/walletextension/config/config.go index 60d60fc745..2ee16d3b65 100644 --- a/tools/walletextension/config/config.go +++ b/tools/walletextension/config/config.go @@ -10,4 +10,6 @@ type Config struct { LogPath string DBPathOverride string // Overrides the database file location. Used in tests. VerboseFlag bool + DBType string + DBConnectionURL string } diff --git a/tools/walletextension/container/walletextension_container.go b/tools/walletextension/container/walletextension_container.go index cc5578560a..4c4cdd80b9 100644 --- a/tools/walletextension/container/walletextension_container.go +++ b/tools/walletextension/container/walletextension_container.go @@ -42,7 +42,7 @@ func NewWalletExtensionContainerFromConfig(config config.Config, logger gethlog. userAccountManager := useraccountmanager.NewUserAccountManager(unAuthedClient, logger) // start the database - databaseStorage, err := storage.New(config.DBPathOverride) + databaseStorage, err := storage.New(config.DBType, config.DBConnectionURL, config.DBPathOverride) if err != nil { logger.Crit("unable to create database to store viewing keys ", log.ErrKey, err) } diff --git a/tools/walletextension/main/cli.go b/tools/walletextension/main/cli.go index b7d2734b80..5e932251a4 100644 --- a/tools/walletextension/main/cli.go +++ b/tools/walletextension/main/cli.go @@ -43,6 +43,14 @@ const ( verboseFlagName = "verbose" verboseFlagDefault = false verboseFlagUsage = "Flag to enable verbose logging of wallet extension traffic" + + dbTypeFlagName = "dbType" + dbTypeFlagDefault = "sqlite" + dbTypeFlagUsage = "Defined the db type (sqlite or mariaDB)" + + dbConnectionURLFlagName = "dbConnectionURL" + dbConnectionURLFlagDefault = "" + dbConnectionURLFlagUsage = "If dbType is set to mariaDB, this must be set. ex: obscurouser:password@tcp(127.0.0.1:3306)/ogdb" ) func parseCLIArgs() config.Config { @@ -55,6 +63,8 @@ func parseCLIArgs() config.Config { logPath := flag.String(logPathName, logPathDefault, logPathUsage) databasePath := flag.String(databasePathName, databasePathDefault, databasePathUsage) verboseFlag := flag.Bool(verboseFlagName, verboseFlagDefault, verboseFlagUsage) + dbType := flag.String(dbTypeFlagName, dbTypeFlagDefault, dbTypeFlagUsage) + dbConnectionURL := flag.String(dbConnectionURLFlagName, dbConnectionURLFlagDefault, dbConnectionURLFlagUsage) flag.Parse() return config.Config{ @@ -66,5 +76,7 @@ func parseCLIArgs() config.Config { LogPath: *logPath, DBPathOverride: *databasePath, VerboseFlag: *verboseFlag, + DBType: *dbType, + DBConnectionURL: *dbConnectionURL, } } diff --git a/tools/walletextension/storage/database/001_init.sql b/tools/walletextension/storage/database/001_init.sql index 59c2851be0..20087a4c5a 100644 --- a/tools/walletextension/storage/database/001_init.sql +++ b/tools/walletextension/storage/database/001_init.sql @@ -5,12 +5,12 @@ USE ogdb; GRANT SELECT, INSERT, UPDATE, DELETE ON ogdb.* TO 'obscurouser'; CREATE TABLE IF NOT EXISTS ogdb.users ( - user_id binary(32) PRIMARY KEY, - private_key binary(32) + user_id varbinary(32) PRIMARY KEY, + private_key varbinary(32) ); CREATE TABLE IF NOT EXISTS ogdb.accounts ( - user_id binary(32), - account_address binary(20), - signature binary(65), + user_id varbinary(32), + account_address varbinary(20), + signature varbinary(65), FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE ); \ No newline at end of file diff --git a/tools/walletextension/storage/database/mariadb.go b/tools/walletextension/storage/database/mariadb.go new file mode 100644 index 0000000000..5d33b9a4ee --- /dev/null +++ b/tools/walletextension/storage/database/mariadb.go @@ -0,0 +1,130 @@ +package database + +import ( + "database/sql" + "fmt" + + _ "github.com/go-sql-driver/mysql" // Importing MariaDB driver + "github.com/obscuronet/go-obscuro/go/common/errutil" + "github.com/obscuronet/go-obscuro/tools/walletextension/common" +) + +type MariaDB struct { + db *sql.DB +} + +// NewMariaDB creates a new MariaDB connection instance +func NewMariaDB(dbURL string) (*MariaDB, error) { + db, err := sql.Open("mysql", dbURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + return &MariaDB{db: db}, nil +} + +func (m *MariaDB) AddUser(userID []byte, privateKey []byte) error { + stmt, err := m.db.Prepare("REPLACE INTO users(user_id, private_key) VALUES (?, ?)") + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(userID, privateKey) + if err != nil { + return err + } + + return nil +} + +func (m *MariaDB) DeleteUser(userID []byte) error { + stmt, err := m.db.Prepare("DELETE FROM users WHERE user_id = ?") + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(userID) + if err != nil { + return err + } + + return nil +} + +func (m *MariaDB) GetUserPrivateKey(userID []byte) ([]byte, error) { + var privateKey []byte + err := m.db.QueryRow("SELECT private_key FROM users WHERE user_id = ?", userID).Scan(&privateKey) + if err != nil { + if err == sql.ErrNoRows { + // No rows found for the given userID + return nil, errutil.ErrNotFound + } + return nil, err + } + + return privateKey, nil +} + +func (m *MariaDB) AddAccount(userID []byte, accountAddress []byte, signature []byte) error { + stmt, err := m.db.Prepare("INSERT INTO accounts(user_id, account_address, signature) VALUES (?, ?, ?)") + if err != nil { + return err + } + defer stmt.Close() + + res, err := stmt.Exec(userID, accountAddress, signature) + if err != nil { + return err + } + fmt.Println(res) + + return nil +} + +func (m *MariaDB) GetAccounts(userID []byte) ([]common.AccountDB, error) { + rows, err := m.db.Query("SELECT account_address, signature FROM accounts WHERE user_id = ?", userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var accounts []common.AccountDB + for rows.Next() { + var account common.AccountDB + if err := rows.Scan(&account.AccountAddress, &account.Signature); err != nil { + return nil, err + } + accounts = append(accounts, account) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return accounts, nil +} + +func (m *MariaDB) GetAllUsers() ([]common.UserDB, error) { + rows, err := m.db.Query("SELECT user_id, private_key FROM users") + if err != nil { + return nil, err + } + defer rows.Close() + + var users []common.UserDB + for rows.Next() { + var user common.UserDB + err = rows.Scan(&user.UserID, &user.PrivateKey) + if err != nil { + return nil, err + } + users = append(users, user) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return users, nil +} diff --git a/tools/walletextension/storage/storage.go b/tools/walletextension/storage/storage.go index fa261f0bc8..5b58d29fb7 100644 --- a/tools/walletextension/storage/storage.go +++ b/tools/walletextension/storage/storage.go @@ -1,6 +1,8 @@ package storage import ( + "fmt" + "github.com/obscuronet/go-obscuro/tools/walletextension/common" "github.com/obscuronet/go-obscuro/tools/walletextension/storage/database" ) @@ -14,6 +16,12 @@ type Storage interface { GetAllUsers() ([]common.UserDB, error) } -func New(dbPath string) (Storage, error) { - return database.NewSqliteDatabase(dbPath) +func New(dbType string, dbConnectionURL, dbPath string) (Storage, error) { + switch dbType { + case "mariaDB": + return database.NewMariaDB(dbConnectionURL) + case "sqlite": + return database.NewSqliteDatabase(dbPath) + } + return nil, fmt.Errorf("unknown db %s", dbType) } diff --git a/tools/walletextension/storage/storage_test.go b/tools/walletextension/storage/storage_test.go index fcab6985e4..c08dccf0ef 100644 --- a/tools/walletextension/storage/storage_test.go +++ b/tools/walletextension/storage/storage_test.go @@ -19,7 +19,8 @@ var tests = map[string]func(storage Storage, t *testing.T){ func TestSQLiteGatewayDB(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - storage, err := New("") + // storage, err := New("mariaDB", "obscurouser:password@tcp(127.0.0.1:3306)/ogdb", "") allows to run tests against a local instance of MariaDB + storage, err := New("sqlite", "", "") require.NoError(t, err) test(storage, t) diff --git a/tools/walletextension/test/utils.go b/tools/walletextension/test/utils.go index bfaa37759d..c6926ebe6f 100644 --- a/tools/walletextension/test/utils.go +++ b/tools/walletextension/test/utils.go @@ -42,6 +42,7 @@ func createWalExtCfg(connectPort, wallHTTPPort, wallWSPort int) *config.Config { DBPathOverride: testDBPath.Name(), WalletExtensionPortHTTP: wallHTTPPort, WalletExtensionPortWS: wallWSPort, + DBType: "sqlite", } }