Skip to content

Commit

Permalink
create database migration for the gateway
Browse files Browse the repository at this point in the history
  • Loading branch information
zkokelj committed Jan 30, 2024
1 parent 42d9a38 commit 4f151e1
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 37 deletions.
6 changes: 6 additions & 0 deletions tools/walletextension/storage/database/001_init.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
/*
This file is used to initialize the database by the github action manual-deploy-obscuro-gateway-database.yml
todo (@ziga) : separate the database initialization from the database migration and delete this file and use migration files for each database type instead (sqlite, mariadb, edgelessdb)
*/

CREATE DATABASE ogdb;

USE ogdb;
Expand Down
26 changes: 26 additions & 0 deletions tools/walletextension/storage/database/mariadb/001_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
This is a migration file for MariaDB and is executed when the Gateway is started to make sure the database schema is up to date.
*/

CREATE DATABASE ogdb;

USE ogdb;

GRANT SELECT, INSERT, UPDATE, DELETE ON ogdb.* TO 'obscurouser';

-- Create users table
CREATE TABLE IF NOT EXISTS ogdb.users (
user_id varbinary(20) PRIMARY KEY,
private_key varbinary(32)
);

-- Create accounts table
CREATE TABLE IF NOT EXISTS ogdb.accounts (
user_id varbinary(20),
account_address varbinary(20),
signature varbinary(65),
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE
);

-- Create transactions table
-- TODO @ziga: Add more fields
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package database
package mariadb

import (
"database/sql"
"fmt"
"path/filepath"
"runtime"

_ "github.com/go-sql-driver/mysql" // Importing MariaDB driver
"github.com/ten-protocol/go-ten/go/common/errutil"
"github.com/ten-protocol/go-ten/tools/walletextension/common"
"github.com/ten-protocol/go-ten/tools/walletextension/storage/database"
)

type MariaDB struct {
Expand All @@ -20,6 +23,18 @@ func NewMariaDB(dbURL string) (*MariaDB, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}

// get the path to the migrations (they are always in the same directory as file containing connection function)
_, filename, _, ok := runtime.Caller(0)
if !ok {
return nil, fmt.Errorf("failed to get current directory")
}
migrationsDir := filepath.Dir(filename)

// apply migrations
if err = database.ApplyMigrations(db, migrationsDir); err != nil {
return nil, err
}

return &MariaDB{db: db}, nil
}

Expand Down
48 changes: 48 additions & 0 deletions tools/walletextension/storage/database/migration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package database

import (
"database/sql"
"fmt"
"os"
"path/filepath"
"sort"
)

func ApplyMigrations(db *sql.DB, migrationsPath string) error {
files, err := os.ReadDir(migrationsPath)
if err != nil {
return err
}

var sqlFiles []string
for _, file := range files {
if filepath.Ext(file.Name()) == ".sql" {
sqlFiles = append(sqlFiles, filepath.Join(migrationsPath, file.Name()))
}
}

sort.Strings(sqlFiles) // Sort files lexicographically to apply migrations in order

for _, file := range sqlFiles {
fmt.Println("Executing db migration file: ", file)
if err = executeSQLFile(db, file); err != nil {
return err
}
}

return nil
}

func executeSQLFile(db *sql.DB, filePath string) error {
content, err := os.ReadFile(filePath)
if err != nil {
return err
}

_, err = db.Exec(string(content))
if err != nil {
return fmt.Errorf("failed to execute %s: %w", filePath, err)
}

return nil
}
16 changes: 16 additions & 0 deletions tools/walletextension/storage/database/sqlite/001_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- Enable foreign keys in SQLite
PRAGMA foreign_keys = ON;

-- Create users table
CREATE TABLE IF NOT EXISTS users (
user_id binary(20) PRIMARY KEY,
private_key binary(32)
);

-- Create accounts table
CREATE TABLE IF NOT EXISTS accounts (
user_id binary(20),
account_address binary(20),
signature binary(65),
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
package database
package sqlite

import (
"database/sql"
"fmt"
"os"
"path/filepath"
"runtime"

obscurocommon "github.com/ten-protocol/go-ten/go/common"
"github.com/ten-protocol/go-ten/go/common/errutil"
"github.com/ten-protocol/go-ten/tools/walletextension/storage/database"

_ "github.com/mattn/go-sqlite3" // sqlite driver for sql.Open()
common "github.com/ten-protocol/go-ten/tools/walletextension/common"
)

type SqliteDatabase struct {
type Database struct {
db *sql.DB
}

func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) {
func NewSqliteDatabase(dbPath string) (*Database, error) {
// load the db file
dbFilePath, err := createOrLoad(dbPath)
if err != nil {
Expand All @@ -31,38 +33,22 @@ func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) {
return nil, err
}

// enable foreign keys in sqlite
_, err = db.Exec("PRAGMA foreign_keys = ON;")
if err != nil {
return nil, err
// get the path to the migrations (they are always in the same directory as file containing connection function)
_, filename, _, ok := runtime.Caller(0)
if !ok {
return nil, fmt.Errorf("failed to get current directory")
}
migrationsDir := filepath.Dir(filename)

// create users table
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS users (
user_id binary(20) PRIMARY KEY,
private_key binary(32)
);`)

if err != nil {
return nil, err
}

// create accounts table
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS accounts (
user_id binary(20),
account_address binary(20),
signature binary(65),
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE
);`)

if err != nil {
// apply migrations
if err = database.ApplyMigrations(db, migrationsDir); err != nil {
return nil, err
}

return &SqliteDatabase{db: db}, nil
return &Database{db: db}, nil
}

func (s *SqliteDatabase) AddUser(userID []byte, privateKey []byte) error {
func (s *Database) AddUser(userID []byte, privateKey []byte) error {
stmt, err := s.db.Prepare("INSERT OR REPLACE INTO users(user_id, private_key) VALUES (?, ?)")
if err != nil {
return err
Expand All @@ -77,7 +63,7 @@ func (s *SqliteDatabase) AddUser(userID []byte, privateKey []byte) error {
return nil
}

func (s *SqliteDatabase) DeleteUser(userID []byte) error {
func (s *Database) DeleteUser(userID []byte) error {
stmt, err := s.db.Prepare("DELETE FROM users WHERE user_id = ?")
if err != nil {
return err
Expand All @@ -92,7 +78,7 @@ func (s *SqliteDatabase) DeleteUser(userID []byte) error {
return nil
}

func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) {
func (s *Database) GetUserPrivateKey(userID []byte) ([]byte, error) {
var privateKey []byte
err := s.db.QueryRow("SELECT private_key FROM users WHERE user_id = ?", userID).Scan(&privateKey)
if err != nil {
Expand All @@ -106,7 +92,7 @@ func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) {
return privateKey, nil
}

func (s *SqliteDatabase) AddAccount(userID []byte, accountAddress []byte, signature []byte) error {
func (s *Database) AddAccount(userID []byte, accountAddress []byte, signature []byte) error {
stmt, err := s.db.Prepare("INSERT INTO accounts(user_id, account_address, signature) VALUES (?, ?, ?)")
if err != nil {
return err
Expand All @@ -121,7 +107,7 @@ func (s *SqliteDatabase) AddAccount(userID []byte, accountAddress []byte, signat
return nil
}

func (s *SqliteDatabase) GetAccounts(userID []byte) ([]common.AccountDB, error) {
func (s *Database) GetAccounts(userID []byte) ([]common.AccountDB, error) {
rows, err := s.db.Query("SELECT account_address, signature FROM accounts WHERE user_id = ?", userID)
if err != nil {
return nil, err
Expand All @@ -143,7 +129,7 @@ func (s *SqliteDatabase) GetAccounts(userID []byte) ([]common.AccountDB, error)
return accounts, nil
}

func (s *SqliteDatabase) GetAllUsers() ([]common.UserDB, error) {
func (s *Database) GetAllUsers() ([]common.UserDB, error) {
rows, err := s.db.Query("SELECT user_id, private_key FROM users")
if err != nil {
return nil, err
Expand Down
8 changes: 5 additions & 3 deletions tools/walletextension/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package storage
import (
"fmt"

"github.com/ten-protocol/go-ten/tools/walletextension/storage/database/mariadb"
"github.com/ten-protocol/go-ten/tools/walletextension/storage/database/sqlite"

"github.com/ten-protocol/go-ten/tools/walletextension/common"
"github.com/ten-protocol/go-ten/tools/walletextension/storage/database"
)

type Storage interface {
Expand All @@ -19,9 +21,9 @@ type Storage interface {
func New(dbType string, dbConnectionURL, dbPath string) (Storage, error) {
switch dbType {
case "mariaDB":
return database.NewMariaDB(dbConnectionURL)
return mariadb.NewMariaDB(dbConnectionURL)
case "sqlite":
return database.NewSqliteDatabase(dbPath)
return sqlite.NewSqliteDatabase(dbPath)
}
return nil, fmt.Errorf("unknown db %s", dbType)
}

0 comments on commit 4f151e1

Please sign in to comment.