-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: BI-5954 BaseEntrySchemaMigration class (#735)
* feat: BI-5954 BaseEntrySchemaMigration class * feat: BI-5954 BaseLinearEntrySchemaMigration class * feat: BI-5954 just one migration class * feat: BI-5954 migrations without ClassVar * feat: BI-5954 return new object as result
- Loading branch information
1 parent
2e891e7
commit 624c150
Showing
3 changed files
with
281 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from copy import deepcopy | ||
from datetime import datetime | ||
import logging | ||
from typing import ( | ||
Any, | ||
Callable, | ||
) | ||
|
||
import attr | ||
from typing_extensions import Self | ||
|
||
from dl_api_commons.base_models import RequestContextInfo | ||
from dl_app_tools.profiling_base import generic_profiler | ||
from dl_core.services_registry import ServicesRegistry | ||
|
||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
def isoformat_validator(instance: Any, attribute: attr.Attribute, value: str) -> None: | ||
datetime.fromisoformat(value) | ||
|
||
|
||
@attr.s | ||
class Migration: | ||
version: str = attr.ib(validator=isoformat_validator) | ||
name: str = attr.ib() | ||
function: Callable[[dict], dict] = attr.ib() | ||
id: int = attr.ib(init=False) | ||
|
||
def __attrs_post_init__(self) -> None: | ||
self.id = int(datetime.fromisoformat(self.version).timestamp()) | ||
|
||
def __lt__(self, other: Self) -> bool: | ||
return self.id < other.id | ||
|
||
def migrate(self, entry: dict) -> dict: | ||
entry = self.function(entry) | ||
entry["data"]["schema_version"] = self.version | ||
return entry | ||
|
||
|
||
@attr.s | ||
class BaseEntrySchemaMigration: | ||
bi_context: RequestContextInfo | None = attr.ib(default=None) | ||
services_registry: ServicesRegistry | None = attr.ib(default=None) | ||
strict_migration: bool = attr.ib(default=False) | ||
|
||
@property | ||
def migrations(self) -> list[Migration]: | ||
return [] | ||
|
||
@property | ||
def sorted_migrations(self) -> list[Migration]: | ||
return sorted(self.migrations) | ||
|
||
@staticmethod | ||
def _get_entry_schema_id(entry: dict) -> int: | ||
entry_data = entry.get("data") | ||
if not isinstance(entry_data, dict): | ||
raise ValueError(f"Invalid entry: 'data' should be a dict, got {type(entry_data).__name__}") | ||
|
||
schema_version = entry_data.get("schema_version", "") | ||
entry_schema_id = 1 | ||
if schema_version != "1": | ||
entry_schema_id = int(datetime.fromisoformat(schema_version).timestamp()) | ||
return entry_schema_id | ||
|
||
def _migrate(self, entry: dict) -> dict: | ||
seen_versions = set() | ||
entry_schema_id = self._get_entry_schema_id(entry) | ||
|
||
for migration in self.sorted_migrations: | ||
if migration.id <= entry_schema_id: | ||
continue | ||
if migration.version in seen_versions: | ||
raise ValueError(f"Double migration detected for migration version: {migration.version}") | ||
seen_versions.add(migration.version) | ||
LOGGER.info(f"Apply migration ver={migration.version}, {migration.name}") | ||
entry = migration.migrate(entry) | ||
return entry | ||
|
||
@generic_profiler("migrate_entry") | ||
def migrate(self, entry: dict) -> dict: | ||
entry_copy = deepcopy(entry) | ||
|
||
try: | ||
return self._migrate(entry_copy) | ||
except Exception as exc: | ||
if self.strict_migration: | ||
raise exc | ||
LOGGER.warning("Entry migration failed", exc_info=True) | ||
return deepcopy(entry) |
188 changes: 188 additions & 0 deletions
188
lib/dl_core/dl_core_tests/unit/us_manager/test_migration.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
from copy import deepcopy | ||
|
||
import attr | ||
import pytest | ||
|
||
from dl_core.us_manager.schema_migration.base import ( | ||
BaseEntrySchemaMigration, | ||
Migration, | ||
) | ||
|
||
|
||
@attr.s | ||
class Level1EntrySchemaMigration(BaseEntrySchemaMigration): | ||
@property | ||
def migrations(self) -> list[Migration]: | ||
migrations = [ | ||
Migration( | ||
"2022-12-04 13:00:00", | ||
"Second level 1 migration", | ||
Level1EntrySchemaMigration._migrate_v2_to_v3, | ||
), | ||
Migration( | ||
"2022-12-01 12:00:00", | ||
"First level 1 migration", | ||
Level1EntrySchemaMigration._migrate_v1_to_v2, | ||
), | ||
] | ||
migrations.extend(super().migrations) | ||
return migrations | ||
|
||
@staticmethod | ||
def _migrate_v1_to_v2(entry: dict) -> dict: | ||
entry["data"]["new_field"] = entry["data"].pop("old_field", "default_value") | ||
return entry | ||
|
||
@staticmethod | ||
def _migrate_v2_to_v3(entry: dict) -> dict: | ||
entry["data"]["l1_field"] = "added_in_l1" | ||
return entry | ||
|
||
|
||
@attr.s | ||
class Level2EntrySchemaMigration(Level1EntrySchemaMigration): | ||
@property | ||
def migrations(self) -> list[Migration]: | ||
migrations = [ | ||
Migration( | ||
"2022-12-03 13:00:00", | ||
"Second level 2 migration", | ||
Level2EntrySchemaMigration._migrate_v2_to_v3, | ||
), | ||
Migration( | ||
"2022-12-02 12:00:00", | ||
"First level 2 migration", | ||
Level2EntrySchemaMigration._migrate_v1_to_v2, | ||
), | ||
] | ||
migrations.extend(super().migrations) | ||
return migrations | ||
|
||
@staticmethod | ||
def _migrate_v1_to_v2(entry: dict) -> dict: | ||
entry["data"]["new_field"] = "default_value" | ||
return entry | ||
|
||
@staticmethod | ||
def _migrate_v2_to_v3(entry: dict) -> dict: | ||
entry["data"]["l2_field"] = "added_in_l2" | ||
return entry | ||
|
||
|
||
@attr.s | ||
class Level3EntrySchemaMigration(Level2EntrySchemaMigration): | ||
@property | ||
def migrations(self) -> list[Migration]: | ||
migrations = [ | ||
Migration("2022-12-03 13:00:00", "Third level 2 migration", Level3EntrySchemaMigration._migrate_v3_to_v1), | ||
] | ||
migrations.extend(super().migrations) | ||
return migrations | ||
|
||
@staticmethod | ||
def _migrate_v3_to_v1(entry: dict) -> dict: | ||
entry["data"]["abs_field"] = "one more new value" | ||
return entry | ||
|
||
|
||
@pytest.fixture | ||
def l2_migrator(): | ||
return Level2EntrySchemaMigration(strict_migration=True) | ||
|
||
|
||
@pytest.fixture | ||
def l3_migrator(): | ||
return Level3EntrySchemaMigration(strict_migration=True) | ||
|
||
|
||
@pytest.fixture | ||
def l3_nonstrict_migrator(): | ||
return Level3EntrySchemaMigration() | ||
|
||
|
||
def test_successful_migration(l2_migrator): | ||
entry = { | ||
"data": { | ||
"old_field": "value1", | ||
"schema_version": "1", | ||
} | ||
} | ||
expected = { | ||
"data": { | ||
"new_field": "default_value", | ||
"l1_field": "added_in_l1", | ||
"l2_field": "added_in_l2", | ||
"schema_version": "2022-12-04 13:00:00", | ||
} | ||
} | ||
result = l2_migrator.migrate(entry) | ||
assert result == expected | ||
|
||
|
||
def test_no_migration_needed(l2_migrator): | ||
entry = { | ||
"data": { | ||
"new_field": "default_value", | ||
"l1_field": "added_in_l1", | ||
"l2_field": "added_in_l2", | ||
"schema_version": "2022-12-04 13:00:00", | ||
} | ||
} | ||
result = l2_migrator.migrate(entry) | ||
assert result == entry | ||
|
||
|
||
def test_invalid_data_format(l2_migrator): | ||
entry = { | ||
"data": "invalid_data_format", | ||
} | ||
with pytest.raises(ValueError, match="Invalid entry: 'data' should be a dict"): | ||
l2_migrator.migrate(entry) | ||
|
||
|
||
def test_missing_data_key(l2_migrator): | ||
entry = {} | ||
with pytest.raises(ValueError, match="Invalid entry: 'data' should be a dict"): | ||
l2_migrator.migrate(entry) | ||
|
||
|
||
def test_cyclic_migration(l3_migrator): | ||
entry = { | ||
"data": { | ||
"old_field": "value1", | ||
"schema_version": "1", | ||
} | ||
} | ||
with pytest.raises(ValueError, match="Double migration detected"): | ||
l3_migrator.migrate(entry) | ||
|
||
|
||
def test_migration_failure(l3_nonstrict_migrator): | ||
entry = { | ||
"data": { | ||
"old_field": "value1", | ||
"schema_version": "1", | ||
} | ||
} | ||
original_entry = deepcopy(entry) | ||
|
||
result = l3_nonstrict_migrator.migrate(entry) | ||
assert result == original_entry | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"migration_version", | ||
( | ||
"some string", | ||
"2022-13-03 23:00:00", | ||
"2022-12-40 23:00:00", | ||
"2022-12-03 33:00:00", | ||
), | ||
) | ||
def test_wrong_migration_version(migration_version): | ||
with pytest.raises(ValueError): | ||
Migration( | ||
migration_version, | ||
"Broken version migration", | ||
lambda x: x, | ||
) |