From 3154afade940aa8e4fa51c61baa21821146aea17 Mon Sep 17 00:00:00 2001 From: Aleh Strakachuk Date: Mon, 1 Jul 2024 22:59:22 +0300 Subject: [PATCH] Allow default db name #24 (#26) * Allow default db name (from db URL) --------- Co-authored-by: Aleh Strakachuk --- src/clickhouse_migrations/__init__.py | 2 +- .../clickhouse_cluster.py | 26 ++++++++++++++++--- src/clickhouse_migrations/command_line.py | 3 +-- src/tests/test_clickhouse_migration.py | 13 ++++++++++ src/tests/test_init_clickhouse_cluster.py | 7 +++-- 5 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/clickhouse_migrations/__init__.py b/src/clickhouse_migrations/__init__.py index 8276820..aebd650 100644 --- a/src/clickhouse_migrations/__init__.py +++ b/src/clickhouse_migrations/__init__.py @@ -1,4 +1,4 @@ """ Simple file-based migrations for clickhouse """ -__version__ = "0.7.0" +__version__ = "0.7.1" diff --git a/src/clickhouse_migrations/clickhouse_cluster.py b/src/clickhouse_migrations/clickhouse_cluster.py index 25414c4..bd27b0b 100644 --- a/src/clickhouse_migrations/clickhouse_cluster.py +++ b/src/clickhouse_migrations/clickhouse_cluster.py @@ -16,12 +16,16 @@ def __init__( db_password: str = DB_PASSWORD, db_port: str = DB_PORT, db_url: Optional[str] = None, + db_name: Optional[str] = None, **kwargs, ): self.db_url: Optional[str] = db_url + self.default_db_name: Optional[str] = db_name + if db_url: parts = self.db_url.split("/") if len(parts) == 4: + self.default_db_name = parts[-1] parts = parts[0:-1] self.db_url = "/".join(parts) @@ -32,7 +36,9 @@ def __init__( self.db_password = db_password self.connection_kwargs = kwargs - def connection(self, db_name: str) -> Client: + def connection(self, db_name: Optional[str] = None) -> Client: + db_name = db_name if db_name is not None else self.default_db_name + if self.db_url: db_url = self.db_url if db_name: @@ -49,7 +55,11 @@ def connection(self, db_name: str) -> Client: ) return ch_client - def create_db(self, db_name, cluster_name=None): + def create_db( + self, db_name: Optional[str] = None, cluster_name: Optional[str] = None + ): + db_name = db_name if db_name is not None else self.default_db_name + with self.connection("") as conn: if cluster_name is None: conn.execute(f'CREATE DATABASE IF NOT EXISTS "{db_name}"') @@ -58,25 +68,33 @@ def create_db(self, db_name, cluster_name=None): f'CREATE DATABASE IF NOT EXISTS "{db_name}" ON CLUSTER "{cluster_name}"' ) - def init_schema(self, db_name, cluster_name=None): + def init_schema( + self, db_name: Optional[str] = None, cluster_name: Optional[str] = None + ): + db_name = db_name if db_name is not None else self.default_db_name + with self.connection(db_name) as conn: migrator = Migrator(conn) migrator.init_schema(cluster_name) def show_tables(self, db_name): + db_name = db_name if db_name is not None else self.default_db_name + with self.connection(db_name) as conn: result = conn.execute("show tables") return [t[0] for t in result] def migrate( self, - db_name: str, + db_name: Optional[str], migration_path: Path, cluster_name: Optional[str] = None, create_db_if_no_exists: bool = True, multi_statement: bool = True, dryrun: bool = False, ): + db_name = db_name if db_name is not None else self.default_db_name + storage = MigrationStorage(migration_path) migrations = storage.migrations() diff --git a/src/clickhouse_migrations/command_line.py b/src/clickhouse_migrations/command_line.py index 8664a6a..78c5228 100644 --- a/src/clickhouse_migrations/command_line.py +++ b/src/clickhouse_migrations/command_line.py @@ -7,7 +7,6 @@ from clickhouse_migrations.clickhouse_cluster import ClickhouseCluster from clickhouse_migrations.defaults import ( DB_HOST, - DB_NAME, DB_PASSWORD, DB_PORT, DB_USER, @@ -64,7 +63,7 @@ def get_context(args): ) parser.add_argument( "--db-name", - default=os.environ.get("DB_NAME", DB_NAME), + default=os.environ.get("DB_NAME", None), help="Clickhouse database name", ) parser.add_argument( diff --git a/src/tests/test_clickhouse_migration.py b/src/tests/test_clickhouse_migration.py index 8207c58..cc56a3e 100644 --- a/src/tests/test_clickhouse_migration.py +++ b/src/tests/test_clickhouse_migration.py @@ -159,6 +159,19 @@ def test_main_pass_db_name_ok(): ) +def test_main_pass_db_url_ok(): + migrate( + get_context( + [ + "--db-url", + "clickhouse://default:@localhost:9000/pytest", + "--migrations-dir", + str(TESTS_DIR / "migrations"), + ] + ) + ) + + def test_check_multistatement_arg(): context = get_context(["--multi-statement", "false"]) assert context.multi_statement is False diff --git a/src/tests/test_init_clickhouse_cluster.py b/src/tests/test_init_clickhouse_cluster.py index c8c5414..7bfa1d9 100644 --- a/src/tests/test_init_clickhouse_cluster.py +++ b/src/tests/test_init_clickhouse_cluster.py @@ -3,7 +3,6 @@ import pytest from clickhouse_migrations.clickhouse_cluster import ClickhouseCluster -from clickhouse_migrations.defaults import DB_URL from clickhouse_migrations.migration import Migration, MigrationStorage TESTS_DIR = Path(__file__).parent @@ -12,13 +11,13 @@ @pytest.fixture def cluster(): - return ClickhouseCluster(db_url=DB_URL) + return ClickhouseCluster(db_url="clickhouse://default:@localhost:9000/pytest") def test_apply_new_migration_ok(cluster): - cluster.init_schema("pytest") + cluster.init_schema() - with cluster.connection("pytest") as conn: + with cluster.connection() as conn: conn.execute( "INSERT INTO schema_versions(version, script, md5) VALUES", [{"version": 1, "script": "SHOW TABLES", "md5": "12345"}],