Skip to content

Commit

Permalink
create database migration for the gateway (#1768)
Browse files Browse the repository at this point in the history
  • Loading branch information
zkokelj authored Feb 28, 2024
1 parent 17b6298 commit b17c332
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 30 deletions.
5 changes: 5 additions & 0 deletions tools/walletextension/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,10 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
# Lightweight final build stage. Includes bare minimum to start wallet extension
FROM alpine:3.18

# copy over the gateway executable
COPY --from=build-wallet /home/obscuro/go-obscuro/tools/walletextension/bin /home/obscuro/go-obscuro/tools/walletextension/bin

# copy over the .sql migration files
COPY --from=build-wallet /home/obscuro/go-obscuro/tools/walletextension/storage/database /home/obscuro/go-obscuro/tools/walletextension/storage/database

WORKDIR /home/obscuro/go-obscuro/tools/walletextension/bin
22 changes: 9 additions & 13 deletions tools/walletextension/storage/database/001_init.sql
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
CREATE DATABASE ogdb;
/*
This file is used to create the database and set the necessary permissions for the user that will be used by the gateway.
*/

USE ogdb;
-- Create the database
CREATE DATABASE IF NOT EXISTS ogdb;

GRANT SELECT, INSERT, UPDATE, DELETE ON ogdb.* TO 'obscurouser';
-- Grant the necessary permissions
GRANT SELECT, INSERT, UPDATE, DELETE, CREATE, DROP ON ogdb.* TO 'obscurouser';

CREATE TABLE IF NOT EXISTS ogdb.users (
user_id varbinary(20) PRIMARY KEY,
private_key varbinary(32)
);
CREATE TABLE IF NOT EXISTS ogdb.accounts (
user_id varbinary(20),
account_address varbinary(20),
signature varbinary(65),
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
-- Reload the privileges from the grant tables in the mysql database
FLUSH PRIVILEGES;
15 changes: 15 additions & 0 deletions tools/walletextension/storage/database/mariadb/001_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
This is a migration file for MariaDB and is executed when the Gateway is started to make sure the database schema is up to date.
*/

CREATE TABLE IF NOT EXISTS ogdb.users (
user_id varbinary(20) PRIMARY KEY,
private_key varbinary(32)
);

CREATE TABLE IF NOT EXISTS ogdb.accounts (
user_id varbinary(20),
account_address varbinary(20),
signature varbinary(65),
FOREIGN KEY(user_id) REFERENCES ogdb.users(user_id) ON DELETE CASCADE
);
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package database
package mariadb

import (
"database/sql"
"fmt"
"path/filepath"
"runtime"

_ "github.com/go-sql-driver/mysql" // Importing MariaDB driver
"github.com/ten-protocol/go-ten/go/common/errutil"
"github.com/ten-protocol/go-ten/tools/walletextension/common"
"github.com/ten-protocol/go-ten/tools/walletextension/storage/database"
)

type MariaDB struct {
Expand All @@ -15,11 +18,23 @@ type MariaDB struct {

// NewMariaDB creates a new MariaDB connection instance
func NewMariaDB(dbURL string) (*MariaDB, error) {
db, err := sql.Open("mysql", dbURL)
db, err := sql.Open("mysql", dbURL+"?multiStatements=true")
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}

// get the path to the migrations (they are always in the same directory as file containing connection function)
_, filename, _, ok := runtime.Caller(0)
if !ok {
return nil, fmt.Errorf("failed to get current directory")
}
migrationsDir := filepath.Dir(filename)

// apply migrations
if err = database.ApplyMigrations(db, migrationsDir); err != nil {
return nil, err
}

return &MariaDB{db: db}, nil
}

Expand Down
48 changes: 48 additions & 0 deletions tools/walletextension/storage/database/migration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package database

import (
"database/sql"
"fmt"
"os"
"path/filepath"
"sort"
)

func ApplyMigrations(db *sql.DB, migrationsPath string) error {
files, err := os.ReadDir(migrationsPath)
if err != nil {
return err
}

var sqlFiles []string
for _, file := range files {
if filepath.Ext(file.Name()) == ".sql" {
sqlFiles = append(sqlFiles, filepath.Join(migrationsPath, file.Name()))
}
}

sort.Strings(sqlFiles) // Sort files lexicographically to apply migrations in order

for _, file := range sqlFiles {
fmt.Println("Executing db migration file: ", file)
if err = executeSQLFile(db, file); err != nil {
return err
}
}

return nil
}

func executeSQLFile(db *sql.DB, filePath string) error {
content, err := os.ReadFile(filePath)
if err != nil {
return err
}

_, err = db.Exec(string(content))
if err != nil {
return fmt.Errorf("failed to execute %s: %w", filePath, err)
}

return nil
}
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
package database
package sqlite

import (
"database/sql"
"fmt"
"os"
"path/filepath"

_ "github.com/mattn/go-sqlite3" // sqlite driver for sql.Open()
obscurocommon "github.com/ten-protocol/go-ten/go/common"
"github.com/ten-protocol/go-ten/go/common/errutil"

_ "github.com/mattn/go-sqlite3" // sqlite driver for sql.Open()
common "github.com/ten-protocol/go-ten/tools/walletextension/common"
)

type SqliteDatabase struct {
type Database struct {
db *sql.DB
}

func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) {
func NewSqliteDatabase(dbPath string) (*Database, error) {
// load the db file
dbFilePath, err := createOrLoad(dbPath)
if err != nil {
Expand Down Expand Up @@ -59,10 +58,10 @@ func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) {
return nil, err
}

return &SqliteDatabase{db: db}, nil
return &Database{db: db}, nil
}

func (s *SqliteDatabase) AddUser(userID []byte, privateKey []byte) error {
func (s *Database) AddUser(userID []byte, privateKey []byte) error {
stmt, err := s.db.Prepare("INSERT OR REPLACE INTO users(user_id, private_key) VALUES (?, ?)")
if err != nil {
return err
Expand All @@ -77,7 +76,7 @@ func (s *SqliteDatabase) AddUser(userID []byte, privateKey []byte) error {
return nil
}

func (s *SqliteDatabase) DeleteUser(userID []byte) error {
func (s *Database) DeleteUser(userID []byte) error {
stmt, err := s.db.Prepare("DELETE FROM users WHERE user_id = ?")
if err != nil {
return err
Expand All @@ -92,7 +91,7 @@ func (s *SqliteDatabase) DeleteUser(userID []byte) error {
return nil
}

func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) {
func (s *Database) GetUserPrivateKey(userID []byte) ([]byte, error) {
var privateKey []byte
err := s.db.QueryRow("SELECT private_key FROM users WHERE user_id = ?", userID).Scan(&privateKey)
if err != nil {
Expand All @@ -106,7 +105,7 @@ func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) {
return privateKey, nil
}

func (s *SqliteDatabase) AddAccount(userID []byte, accountAddress []byte, signature []byte) error {
func (s *Database) AddAccount(userID []byte, accountAddress []byte, signature []byte) error {
stmt, err := s.db.Prepare("INSERT INTO accounts(user_id, account_address, signature) VALUES (?, ?, ?)")
if err != nil {
return err
Expand All @@ -121,7 +120,7 @@ func (s *SqliteDatabase) AddAccount(userID []byte, accountAddress []byte, signat
return nil
}

func (s *SqliteDatabase) GetAccounts(userID []byte) ([]common.AccountDB, error) {
func (s *Database) GetAccounts(userID []byte) ([]common.AccountDB, error) {
rows, err := s.db.Query("SELECT account_address, signature FROM accounts WHERE user_id = ?", userID)
if err != nil {
return nil, err
Expand All @@ -143,7 +142,7 @@ func (s *SqliteDatabase) GetAccounts(userID []byte) ([]common.AccountDB, error)
return accounts, nil
}

func (s *SqliteDatabase) GetAllUsers() ([]common.UserDB, error) {
func (s *Database) GetAllUsers() ([]common.UserDB, error) {
rows, err := s.db.Query("SELECT user_id, private_key FROM users")
if err != nil {
return nil, err
Expand Down
8 changes: 5 additions & 3 deletions tools/walletextension/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package storage
import (
"fmt"

"github.com/ten-protocol/go-ten/tools/walletextension/storage/database/mariadb"
"github.com/ten-protocol/go-ten/tools/walletextension/storage/database/sqlite"

"github.com/ten-protocol/go-ten/tools/walletextension/common"
"github.com/ten-protocol/go-ten/tools/walletextension/storage/database"
)

type Storage interface {
Expand All @@ -19,9 +21,9 @@ type Storage interface {
func New(dbType string, dbConnectionURL, dbPath string) (Storage, error) {
switch dbType {
case "mariaDB":
return database.NewMariaDB(dbConnectionURL)
return mariadb.NewMariaDB(dbConnectionURL)
case "sqlite":
return database.NewSqliteDatabase(dbPath)
return sqlite.NewSqliteDatabase(dbPath)
}
return nil, fmt.Errorf("unknown db %s", dbType)
}

0 comments on commit b17c332

Please sign in to comment.