diff --git a/tools/walletextension/storage/database/001_init.sql b/tools/walletextension/storage/database/001_init.sql index 9cf37e7a17..126e1a2a55 100644 --- a/tools/walletextension/storage/database/001_init.sql +++ b/tools/walletextension/storage/database/001_init.sql @@ -1,3 +1,9 @@ +/* + This file is used to initialize the database by the github action manual-deploy-obscuro-gateway-database.yml + + todo (@ziga) : separate the database initialization from the database migration and delete this file and use migration files for each database type instead (sqlite, mariadb, edgelessdb) + */ + CREATE DATABASE ogdb; USE ogdb; diff --git a/tools/walletextension/storage/database/mariadb/001_init.sql b/tools/walletextension/storage/database/mariadb/001_init.sql new file mode 100644 index 0000000000..0461865faf --- /dev/null +++ b/tools/walletextension/storage/database/mariadb/001_init.sql @@ -0,0 +1,26 @@ +/* + This is a migration file for MariaDB and is executed when the Gateway is started to make sure the database schema is up to date. + */ + +CREATE DATABASE ogdb; + +USE ogdb; + +GRANT SELECT, INSERT, UPDATE, DELETE ON ogdb.* TO 'obscurouser'; + +-- Create users table +CREATE TABLE IF NOT EXISTS ogdb.users ( + user_id varbinary(20) PRIMARY KEY, + private_key varbinary(32) + ); + +-- Create accounts table +CREATE TABLE IF NOT EXISTS ogdb.accounts ( + user_id varbinary(20), + account_address varbinary(20), + signature varbinary(65), + FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE + ); + +-- Create transactions table +-- TODO @ziga: Add more fields \ No newline at end of file diff --git a/tools/walletextension/storage/database/mariadb.go b/tools/walletextension/storage/database/mariadb/mariadb.go similarity index 84% rename from tools/walletextension/storage/database/mariadb.go rename to tools/walletextension/storage/database/mariadb/mariadb.go index 2e8ce75b7e..1fb97e2af8 100644 --- a/tools/walletextension/storage/database/mariadb.go +++ b/tools/walletextension/storage/database/mariadb/mariadb.go @@ -1,12 +1,15 @@ -package database +package mariadb import ( "database/sql" "fmt" + "path/filepath" + "runtime" _ "github.com/go-sql-driver/mysql" // Importing MariaDB driver "github.com/ten-protocol/go-ten/go/common/errutil" "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/storage/database" ) type MariaDB struct { @@ -20,6 +23,18 @@ func NewMariaDB(dbURL string) (*MariaDB, error) { return nil, fmt.Errorf("failed to connect to database: %w", err) } + // get the path to the migrations (they are always in the same directory as file containing connection function) + _, filename, _, ok := runtime.Caller(0) + if !ok { + return nil, fmt.Errorf("failed to get current directory") + } + migrationsDir := filepath.Dir(filename) + + // apply migrations + if err = database.ApplyMigrations(db, migrationsDir); err != nil { + return nil, err + } + return &MariaDB{db: db}, nil } diff --git a/tools/walletextension/storage/database/migration.go b/tools/walletextension/storage/database/migration.go new file mode 100644 index 0000000000..a89bb0a4b9 --- /dev/null +++ b/tools/walletextension/storage/database/migration.go @@ -0,0 +1,48 @@ +package database + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "sort" +) + +func ApplyMigrations(db *sql.DB, migrationsPath string) error { + files, err := os.ReadDir(migrationsPath) + if err != nil { + return err + } + + var sqlFiles []string + for _, file := range files { + if filepath.Ext(file.Name()) == ".sql" { + sqlFiles = append(sqlFiles, filepath.Join(migrationsPath, file.Name())) + } + } + + sort.Strings(sqlFiles) // Sort files lexicographically to apply migrations in order + + for _, file := range sqlFiles { + fmt.Println("Executing db migration file: ", file) + if err = executeSQLFile(db, file); err != nil { + return err + } + } + + return nil +} + +func executeSQLFile(db *sql.DB, filePath string) error { + content, err := os.ReadFile(filePath) + if err != nil { + return err + } + + _, err = db.Exec(string(content)) + if err != nil { + return fmt.Errorf("failed to execute %s: %w", filePath, err) + } + + return nil +} diff --git a/tools/walletextension/storage/database/sqlite/001_init.sql b/tools/walletextension/storage/database/sqlite/001_init.sql new file mode 100644 index 0000000000..eb2bcbcbc9 --- /dev/null +++ b/tools/walletextension/storage/database/sqlite/001_init.sql @@ -0,0 +1,16 @@ +-- Enable foreign keys in SQLite +PRAGMA foreign_keys = ON; + +-- Create users table +CREATE TABLE IF NOT EXISTS users ( + user_id binary(20) PRIMARY KEY, + private_key binary(32) + ); + +-- Create accounts table +CREATE TABLE IF NOT EXISTS accounts ( + user_id binary(20), + account_address binary(20), + signature binary(65), + FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE + ); \ No newline at end of file diff --git a/tools/walletextension/storage/database/sqlite.go b/tools/walletextension/storage/database/sqlite/sqlite.go similarity index 72% rename from tools/walletextension/storage/database/sqlite.go rename to tools/walletextension/storage/database/sqlite/sqlite.go index 0140d99470..471172020b 100644 --- a/tools/walletextension/storage/database/sqlite.go +++ b/tools/walletextension/storage/database/sqlite/sqlite.go @@ -1,23 +1,25 @@ -package database +package sqlite import ( "database/sql" "fmt" "os" "path/filepath" + "runtime" obscurocommon "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/errutil" + "github.com/ten-protocol/go-ten/tools/walletextension/storage/database" _ "github.com/mattn/go-sqlite3" // sqlite driver for sql.Open() common "github.com/ten-protocol/go-ten/tools/walletextension/common" ) -type SqliteDatabase struct { +type Database struct { db *sql.DB } -func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) { +func NewSqliteDatabase(dbPath string) (*Database, error) { // load the db file dbFilePath, err := createOrLoad(dbPath) if err != nil { @@ -31,38 +33,22 @@ func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) { return nil, err } - // enable foreign keys in sqlite - _, err = db.Exec("PRAGMA foreign_keys = ON;") - if err != nil { - return nil, err + // get the path to the migrations (they are always in the same directory as file containing connection function) + _, filename, _, ok := runtime.Caller(0) + if !ok { + return nil, fmt.Errorf("failed to get current directory") } + migrationsDir := filepath.Dir(filename) - // create users table - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS users ( - user_id binary(20) PRIMARY KEY, - private_key binary(32) - );`) - - if err != nil { - return nil, err - } - - // create accounts table - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS accounts ( - user_id binary(20), - account_address binary(20), - signature binary(65), - FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE - );`) - - if err != nil { + // apply migrations + if err = database.ApplyMigrations(db, migrationsDir); err != nil { return nil, err } - return &SqliteDatabase{db: db}, nil + return &Database{db: db}, nil } -func (s *SqliteDatabase) AddUser(userID []byte, privateKey []byte) error { +func (s *Database) AddUser(userID []byte, privateKey []byte) error { stmt, err := s.db.Prepare("INSERT OR REPLACE INTO users(user_id, private_key) VALUES (?, ?)") if err != nil { return err @@ -77,7 +63,7 @@ func (s *SqliteDatabase) AddUser(userID []byte, privateKey []byte) error { return nil } -func (s *SqliteDatabase) DeleteUser(userID []byte) error { +func (s *Database) DeleteUser(userID []byte) error { stmt, err := s.db.Prepare("DELETE FROM users WHERE user_id = ?") if err != nil { return err @@ -92,7 +78,7 @@ func (s *SqliteDatabase) DeleteUser(userID []byte) error { return nil } -func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) { +func (s *Database) GetUserPrivateKey(userID []byte) ([]byte, error) { var privateKey []byte err := s.db.QueryRow("SELECT private_key FROM users WHERE user_id = ?", userID).Scan(&privateKey) if err != nil { @@ -106,7 +92,7 @@ func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) { return privateKey, nil } -func (s *SqliteDatabase) AddAccount(userID []byte, accountAddress []byte, signature []byte) error { +func (s *Database) AddAccount(userID []byte, accountAddress []byte, signature []byte) error { stmt, err := s.db.Prepare("INSERT INTO accounts(user_id, account_address, signature) VALUES (?, ?, ?)") if err != nil { return err @@ -121,7 +107,7 @@ func (s *SqliteDatabase) AddAccount(userID []byte, accountAddress []byte, signat return nil } -func (s *SqliteDatabase) GetAccounts(userID []byte) ([]common.AccountDB, error) { +func (s *Database) GetAccounts(userID []byte) ([]common.AccountDB, error) { rows, err := s.db.Query("SELECT account_address, signature FROM accounts WHERE user_id = ?", userID) if err != nil { return nil, err @@ -143,7 +129,7 @@ func (s *SqliteDatabase) GetAccounts(userID []byte) ([]common.AccountDB, error) return accounts, nil } -func (s *SqliteDatabase) GetAllUsers() ([]common.UserDB, error) { +func (s *Database) GetAllUsers() ([]common.UserDB, error) { rows, err := s.db.Query("SELECT user_id, private_key FROM users") if err != nil { return nil, err diff --git a/tools/walletextension/storage/storage.go b/tools/walletextension/storage/storage.go index fbb479671f..57e751b02b 100644 --- a/tools/walletextension/storage/storage.go +++ b/tools/walletextension/storage/storage.go @@ -3,8 +3,10 @@ package storage import ( "fmt" + "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/mariadb" + "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/sqlite" + "github.com/ten-protocol/go-ten/tools/walletextension/common" - "github.com/ten-protocol/go-ten/tools/walletextension/storage/database" ) type Storage interface { @@ -19,9 +21,9 @@ type Storage interface { func New(dbType string, dbConnectionURL, dbPath string) (Storage, error) { switch dbType { case "mariaDB": - return database.NewMariaDB(dbConnectionURL) + return mariadb.NewMariaDB(dbConnectionURL) case "sqlite": - return database.NewSqliteDatabase(dbPath) + return sqlite.NewSqliteDatabase(dbPath) } return nil, fmt.Errorf("unknown db %s", dbType) }