diff --git a/tools/walletextension/Dockerfile b/tools/walletextension/Dockerfile index ffa9722f3f..3837c81a54 100644 --- a/tools/walletextension/Dockerfile +++ b/tools/walletextension/Dockerfile @@ -55,5 +55,10 @@ RUN --mount=type=cache,target=/root/.cache/go-build \ # Lightweight final build stage. Includes bare minimum to start wallet extension FROM alpine:3.18 +# copy over the gateway executable COPY --from=build-wallet /home/obscuro/go-obscuro/tools/walletextension/bin /home/obscuro/go-obscuro/tools/walletextension/bin + +# copy over the .sql migration files +COPY --from=build-wallet /home/obscuro/go-obscuro/tools/walletextension/storage/database /home/obscuro/go-obscuro/tools/walletextension/storage/database + WORKDIR /home/obscuro/go-obscuro/tools/walletextension/bin diff --git a/tools/walletextension/storage/database/001_init.sql b/tools/walletextension/storage/database/001_init.sql index 9cf37e7a17..1e432d4ac1 100644 --- a/tools/walletextension/storage/database/001_init.sql +++ b/tools/walletextension/storage/database/001_init.sql @@ -1,16 +1,12 @@ -CREATE DATABASE ogdb; +/* + This file is used to create the database and set the necessary permissions for the user that will be used by the gateway. + */ -USE ogdb; +-- Create the database +CREATE DATABASE IF NOT EXISTS ogdb; -GRANT SELECT, INSERT, UPDATE, DELETE ON ogdb.* TO 'obscurouser'; +-- Grant the necessary permissions +GRANT SELECT, INSERT, UPDATE, DELETE, CREATE, DROP ON ogdb.* TO 'obscurouser'; -CREATE TABLE IF NOT EXISTS ogdb.users ( - user_id varbinary(20) PRIMARY KEY, - private_key varbinary(32) - ); -CREATE TABLE IF NOT EXISTS ogdb.accounts ( - user_id varbinary(20), - account_address varbinary(20), - signature varbinary(65), - FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE - ); \ No newline at end of file +-- Reload the privileges from the grant tables in the mysql database +FLUSH PRIVILEGES; \ No newline at end of file diff --git a/tools/walletextension/storage/database/mariadb/001_init.sql b/tools/walletextension/storage/database/mariadb/001_init.sql new file mode 100644 index 0000000000..95a4e61513 --- /dev/null +++ b/tools/walletextension/storage/database/mariadb/001_init.sql @@ -0,0 +1,15 @@ +/* + This is a migration file for MariaDB and is executed when the Gateway is started to make sure the database schema is up to date. +*/ + +CREATE TABLE IF NOT EXISTS ogdb.users ( + user_id varbinary(20) PRIMARY KEY, + private_key varbinary(32) +); + +CREATE TABLE IF NOT EXISTS ogdb.accounts ( + user_id varbinary(20), + account_address varbinary(20), + signature varbinary(65), + FOREIGN KEY(user_id) REFERENCES ogdb.users(user_id) ON DELETE CASCADE +); diff --git a/tools/walletextension/storage/database/mariadb.go b/tools/walletextension/storage/database/mariadb/mariadb.go similarity index 83% rename from tools/walletextension/storage/database/mariadb.go rename to tools/walletextension/storage/database/mariadb/mariadb.go index 2e8ce75b7e..e68cc009a2 100644 --- a/tools/walletextension/storage/database/mariadb.go +++ b/tools/walletextension/storage/database/mariadb/mariadb.go @@ -1,12 +1,15 @@ -package database +package mariadb import ( "database/sql" "fmt" + "path/filepath" + "runtime" _ "github.com/go-sql-driver/mysql" // Importing MariaDB driver "github.com/ten-protocol/go-ten/go/common/errutil" "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/storage/database" ) type MariaDB struct { @@ -15,11 +18,23 @@ type MariaDB struct { // NewMariaDB creates a new MariaDB connection instance func NewMariaDB(dbURL string) (*MariaDB, error) { - db, err := sql.Open("mysql", dbURL) + db, err := sql.Open("mysql", dbURL+"?multiStatements=true") if err != nil { return nil, fmt.Errorf("failed to connect to database: %w", err) } + // get the path to the migrations (they are always in the same directory as file containing connection function) + _, filename, _, ok := runtime.Caller(0) + if !ok { + return nil, fmt.Errorf("failed to get current directory") + } + migrationsDir := filepath.Dir(filename) + + // apply migrations + if err = database.ApplyMigrations(db, migrationsDir); err != nil { + return nil, err + } + return &MariaDB{db: db}, nil } diff --git a/tools/walletextension/storage/database/migration.go b/tools/walletextension/storage/database/migration.go new file mode 100644 index 0000000000..a89bb0a4b9 --- /dev/null +++ b/tools/walletextension/storage/database/migration.go @@ -0,0 +1,48 @@ +package database + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "sort" +) + +func ApplyMigrations(db *sql.DB, migrationsPath string) error { + files, err := os.ReadDir(migrationsPath) + if err != nil { + return err + } + + var sqlFiles []string + for _, file := range files { + if filepath.Ext(file.Name()) == ".sql" { + sqlFiles = append(sqlFiles, filepath.Join(migrationsPath, file.Name())) + } + } + + sort.Strings(sqlFiles) // Sort files lexicographically to apply migrations in order + + for _, file := range sqlFiles { + fmt.Println("Executing db migration file: ", file) + if err = executeSQLFile(db, file); err != nil { + return err + } + } + + return nil +} + +func executeSQLFile(db *sql.DB, filePath string) error { + content, err := os.ReadFile(filePath) + if err != nil { + return err + } + + _, err = db.Exec(string(content)) + if err != nil { + return fmt.Errorf("failed to execute %s: %w", filePath, err) + } + + return nil +} diff --git a/tools/walletextension/storage/database/sqlite.go b/tools/walletextension/storage/database/sqlite/sqlite.go similarity index 85% rename from tools/walletextension/storage/database/sqlite.go rename to tools/walletextension/storage/database/sqlite/sqlite.go index 0140d99470..5ef0d92b5f 100644 --- a/tools/walletextension/storage/database/sqlite.go +++ b/tools/walletextension/storage/database/sqlite/sqlite.go @@ -1,4 +1,4 @@ -package database +package sqlite import ( "database/sql" @@ -6,18 +6,17 @@ import ( "os" "path/filepath" + _ "github.com/mattn/go-sqlite3" // sqlite driver for sql.Open() obscurocommon "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/errutil" - - _ "github.com/mattn/go-sqlite3" // sqlite driver for sql.Open() common "github.com/ten-protocol/go-ten/tools/walletextension/common" ) -type SqliteDatabase struct { +type Database struct { db *sql.DB } -func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) { +func NewSqliteDatabase(dbPath string) (*Database, error) { // load the db file dbFilePath, err := createOrLoad(dbPath) if err != nil { @@ -59,10 +58,10 @@ func NewSqliteDatabase(dbPath string) (*SqliteDatabase, error) { return nil, err } - return &SqliteDatabase{db: db}, nil + return &Database{db: db}, nil } -func (s *SqliteDatabase) AddUser(userID []byte, privateKey []byte) error { +func (s *Database) AddUser(userID []byte, privateKey []byte) error { stmt, err := s.db.Prepare("INSERT OR REPLACE INTO users(user_id, private_key) VALUES (?, ?)") if err != nil { return err @@ -77,7 +76,7 @@ func (s *SqliteDatabase) AddUser(userID []byte, privateKey []byte) error { return nil } -func (s *SqliteDatabase) DeleteUser(userID []byte) error { +func (s *Database) DeleteUser(userID []byte) error { stmt, err := s.db.Prepare("DELETE FROM users WHERE user_id = ?") if err != nil { return err @@ -92,7 +91,7 @@ func (s *SqliteDatabase) DeleteUser(userID []byte) error { return nil } -func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) { +func (s *Database) GetUserPrivateKey(userID []byte) ([]byte, error) { var privateKey []byte err := s.db.QueryRow("SELECT private_key FROM users WHERE user_id = ?", userID).Scan(&privateKey) if err != nil { @@ -106,7 +105,7 @@ func (s *SqliteDatabase) GetUserPrivateKey(userID []byte) ([]byte, error) { return privateKey, nil } -func (s *SqliteDatabase) AddAccount(userID []byte, accountAddress []byte, signature []byte) error { +func (s *Database) AddAccount(userID []byte, accountAddress []byte, signature []byte) error { stmt, err := s.db.Prepare("INSERT INTO accounts(user_id, account_address, signature) VALUES (?, ?, ?)") if err != nil { return err @@ -121,7 +120,7 @@ func (s *SqliteDatabase) AddAccount(userID []byte, accountAddress []byte, signat return nil } -func (s *SqliteDatabase) GetAccounts(userID []byte) ([]common.AccountDB, error) { +func (s *Database) GetAccounts(userID []byte) ([]common.AccountDB, error) { rows, err := s.db.Query("SELECT account_address, signature FROM accounts WHERE user_id = ?", userID) if err != nil { return nil, err @@ -143,7 +142,7 @@ func (s *SqliteDatabase) GetAccounts(userID []byte) ([]common.AccountDB, error) return accounts, nil } -func (s *SqliteDatabase) GetAllUsers() ([]common.UserDB, error) { +func (s *Database) GetAllUsers() ([]common.UserDB, error) { rows, err := s.db.Query("SELECT user_id, private_key FROM users") if err != nil { return nil, err diff --git a/tools/walletextension/storage/storage.go b/tools/walletextension/storage/storage.go index fbb479671f..57e751b02b 100644 --- a/tools/walletextension/storage/storage.go +++ b/tools/walletextension/storage/storage.go @@ -3,8 +3,10 @@ package storage import ( "fmt" + "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/mariadb" + "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/sqlite" + "github.com/ten-protocol/go-ten/tools/walletextension/common" - "github.com/ten-protocol/go-ten/tools/walletextension/storage/database" ) type Storage interface { @@ -19,9 +21,9 @@ type Storage interface { func New(dbType string, dbConnectionURL, dbPath string) (Storage, error) { switch dbType { case "mariaDB": - return database.NewMariaDB(dbConnectionURL) + return mariadb.NewMariaDB(dbConnectionURL) case "sqlite": - return database.NewSqliteDatabase(dbPath) + return sqlite.NewSqliteDatabase(dbPath) } return nil, fmt.Errorf("unknown db %s", dbType) }