Skip to content

Commit

Permalink
Last refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleg Strokachuk committed Jan 3, 2022
1 parent cdc6724 commit 4272483
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 182 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ You can install from pypi using `pip install clickhouse-migrations`.
### Usage

```python
from migration_lib.migrate import migrate
from clickhouse_migrations.clickhouse_cluster import ClickhouseCluster

migrator = Migrator(db_host, db_user, db_password)
migrator.migrate(db_name, migrations_home, create_db_if_no_exists)
cluster = ClickhouseCluster(db_host, db_user, db_password)
cluster.migrate(db_name, migrations_home, create_db_if_no_exists)
```

Parameter | Description | Default
Expand Down
2 changes: 1 addition & 1 deletion src/clickhouse_migrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
Simple file-based migrations for clickhouse
"""
__version__ = "0.1.4"
__version__ = "0.1.5"
70 changes: 70 additions & 0 deletions src/clickhouse_migrations/clickhouse_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from pathlib import Path
from typing import List

from clickhouse_driver import Client

from clickhouse_migrations.defaults import DB_HOST, DB_PASSWORD, DB_USER
from clickhouse_migrations.migrator import Migrator
from clickhouse_migrations.types import Migration, MigrationStorage


class ClickhouseCluster:
def __init__(
self,
db_host: str = DB_HOST,
db_user: str = DB_USER,
db_password: str = DB_PASSWORD,
):
self.db_host = db_host
self.db_user = db_user
self.db_password = db_password

def connection(self, db_name: str) -> Client:
return Client(
self.db_host, user=self.db_user, password=self.db_password, database=db_name
)

def create_db(self, db_name):
with self.connection("") as conn:
conn.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")

def init_schema(self, db_name):
with self.connection(db_name) as conn:
migrator = Migrator(conn)
migrator.init_schema()

def show_tables(self, db_name):
with self.connection(db_name) as conn:
result = conn.execute("show tables")
return [t[0] for t in result]

def migrate(
self,
db_name: str,
migration_path: Path,
create_db_if_no_exists: bool = True,
):
storage = MigrationStorage(migration_path)
migrations = storage.migrations()

return self.apply_migrations(
db_name, migrations, create_db_if_no_exists=create_db_if_no_exists
)

def apply_migrations(
self,
db_name: str,
migrations: List[Migration],
create_db_if_no_exists: bool = True,
) -> List[Migration]:

if create_db_if_no_exists:
self.create_db(db_name)

logging.info("Total migrations to apply: %d", len(migrations))

with self.connection(db_name) as conn:
migrator = Migrator(conn)
migrator.init_schema()
return migrator.apply_migration(migrations)
6 changes: 3 additions & 3 deletions src/clickhouse_migrations/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from argparse import ArgumentParser
from pathlib import Path

from clickhouse_migrations.clickhouse_cluster import ClickhouseCluster
from clickhouse_migrations.defaults import (
DB_HOST,
DB_NAME,
DB_PASSWORD,
DB_USER,
MIGRATIONS_DIR,
)
from clickhouse_migrations.migrator import Migrator


def get_context(args):
Expand Down Expand Up @@ -46,8 +46,8 @@ def get_context(args):


def migrate(context) -> int:
migrator = Migrator(context.db_host, context.db_user, context.db_password)
migrator.migrate(context.db_name, Path(context.migrations_dir))
cluster = ClickhouseCluster(context.db_host, context.db_user, context.db_password)
cluster.migrate(context.db_name, Path(context.migrations_dir))
return 0


Expand Down
73 changes: 0 additions & 73 deletions src/clickhouse_migrations/migrate.py

This file was deleted.

117 changes: 77 additions & 40 deletions src/clickhouse_migrations/migrator.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,18 @@
import logging
from pathlib import Path
from typing import List

import pandas as pd
import pandas
from clickhouse_driver import Client

from .defaults import DB_HOST, DB_PASSWORD, DB_USER
from .migrate import apply_migration, migrations_to_apply
from .types import MigrationStorage
from clickhouse_migrations.types import Migration


class Migrator:
def __init__(
self,
db_host: str = DB_HOST,
db_user: str = DB_USER,
db_password: str = DB_PASSWORD,
):
self.db_host = db_host
self.db_user = db_user
self.db_password = db_password

def connection(self, db_name: str) -> Client:
return Client(
self.db_host, user=self.db_user, password=self.db_password, database=db_name
)

def create_db(self, db_name):
with self.connection("") as conn:
conn.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
def __init__(self, conn: Client):
self._conn: Client = conn

@classmethod
def init_schema(cls, conn):
conn.execute(
def init_schema(self):
self._conn.execute(
"CREATE TABLE IF NOT EXISTS schema_versions ("
"version UInt32, "
"md5 String, "
Expand All @@ -40,20 +21,76 @@ def init_schema(cls, conn):
") ENGINE = MergeTree ORDER BY tuple(created_at)"
)

def migrate(
self,
db_name: str,
migration_path: Path,
create_db_if_no_exists: bool = True,
):
if create_db_if_no_exists:
self.create_db(db_name)
def execute_and_inflate(self, query) -> pandas.DataFrame:
result = self._conn.execute(query, with_column_types=True)
column_names = [c[0] for c in result[len(result) - 1]]
return pandas.DataFrame([dict(zip(column_names, d)) for d in result[0]])

def migrations_to_apply(self, migrations: List[Migration]) -> List[Migration]:
applied_migrations = self.execute_and_inflate(
"SELECT version AS version, script AS c_script, md5 as c_md5 from schema_versions",
)

if applied_migrations.empty:
return migrations

incoming = pandas.DataFrame(migrations)
if len(incoming) == 0 or len(incoming) < len(applied_migrations):
raise AssertionError(
"Migrations have gone missing, "
"your code base should not truncate migrations, "
"use migrations to correct older migrations"
)

applied_migrations = applied_migrations.astype({"version": "int32"})
incoming = incoming.astype({"version": "int32"})
exec_stat = pandas.merge(
applied_migrations, incoming, on="version", how="outer"
)
committed_and_absconded = exec_stat[
exec_stat.c_md5.notnull() & exec_stat.md5.isnull()
]
if len(committed_and_absconded) > 0:
raise AssertionError(
"Migrations have gone missing, "
"your code base should not truncate migrations, "
"use migrations to correct older migrations"
)

index = (
exec_stat.c_md5.notnull()
& exec_stat.md5.notnull()
& ~(exec_stat.md5 == exec_stat.c_md5)
)
terms_violated = exec_stat[index]
if len(terms_violated) > 0:
raise AssertionError(
"Do not edit migrations once run, "
"use migrations to correct older migrations"
)
versions_to_apply = exec_stat[exec_stat.c_md5.isnull()][["version"]]
return [m for m in migrations if m.version in versions_to_apply.values]

def apply_migration(self, migrations: List[Migration]) -> List[Migration]:
new_migrations = self.migrations_to_apply(migrations)
if not new_migrations:
return []

for migration in new_migrations:
logging.info("Execute migration %s", migration)
self._conn.execute(migration.script)

storage = MigrationStorage(migration_path)
migrations = storage.migrations()
logging.info("Total migrations: %d", len(migrations))
logging.info("Migration applied")

with self.connection(db_name) as conn:
self.init_schema(conn)
self._conn.execute(
"INSERT INTO schema_versions(version, script, md5) VALUES",
[
{
"version": migration.version,
"script": migration.script,
"md5": migration.md5,
}
],
)

apply_migration(conn, migrations_to_apply(conn, pd.DataFrame(migrations)))
return new_migrations
2 changes: 1 addition & 1 deletion src/clickhouse_migrations/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def migrations(self) -> List[Migration]:
for full_path in self.filenames():
migration = Migration(
version=int(full_path.name.split("_")[0].replace("V", "")),
script=str(full_path),
script=str(full_path.read_text(encoding="utf8")),
md5=hashlib.md5(full_path.read_bytes()).hexdigest(),
)

Expand Down
Loading

0 comments on commit 4272483

Please sign in to comment.