-
Notifications
You must be signed in to change notification settings - Fork 1
/
migrator.py
126 lines (107 loc) · 4.17 KB
/
migrator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import importlib.util
from qdrant_client import QdrantClient, models
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logging.getLogger("httpx").setLevel(logging.WARNING)
def initialize_qdrant_client(url, api_key):
client = QdrantClient(url=url, api_key=api_key)
return client
def check_and_create_migrations_collection(client):
collections_response = client.get_collections()
collections = collections_response.collections
collection_names = [collection.name for collection in collections]
if "migrations" not in collection_names:
client.create_collection(
collection_name="migrations",
vectors_config=models.VectorParams(size=2, distance=models.Distance.COSINE),
)
client.upsert(
collection_name="migrations",
points=[
models.PointStruct(
id="5c56c793-69f3-4fbf-87e6-c4bf54c28c26",
payload={
"version": 0,
},
vector=[0.0, 0.1],
),
],
)
def get_current_version(client):
points = client.scroll(
"migrations",
with_payload=True,
)
if points:
version = points[0][0].payload["version"]
return version
return 0
def set_current_version(client, version):
client.upsert(
collection_name="migrations",
points=[
models.PointStruct(
id="5c56c793-69f3-4fbf-87e6-c4bf54c28c26",
payload={
"version": version,
},
vector=[0.0, 0.1],
),
],
)
def get_migration_files(migration_folder):
files = []
for file in os.listdir(migration_folder):
if file.endswith(".py"):
index = int(file.split("_")[0])
files.append((index, file))
files.sort(key=lambda x: x[0])
return files
def run_migrations(client, migration_folder, current_version, target_version=None):
migration_files = get_migration_files(migration_folder)
for index, file in migration_files:
if index > current_version and (
target_version is None or index <= target_version
):
module_name = file.replace(".py", "")
spec = importlib.util.spec_from_file_location(
module_name, os.path.join(migration_folder, file)
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module.forward(client)
set_current_version(client, index)
logger.info(
f"Migration completed successfully for {file}! Enjoy your migration :D"
)
else:
logger.info(f"Skipping migration {file}")
def rollback_migrations(client, migration_folder, current_version, target_version):
migration_files = get_migration_files(migration_folder)
for index, file in reversed(migration_files):
if index <= current_version and index > target_version:
module_name = file.replace(".py", "")
spec = importlib.util.spec_from_file_location(
module_name, os.path.join(migration_folder, file)
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module.backward(client)
set_current_version(client, index - 1)
logger.info(
f"Rollback completed successfully for {file}! Enjoy your migration :D"
)
else:
logger.info(f"Skipping rollback {file}")
def migrate(url, api_key, migration_folder):
client = initialize_qdrant_client(url, api_key)
check_and_create_migrations_collection(client)
current_version = get_current_version(client)
run_migrations(client, migration_folder, current_version)
def rollback(url, api_key, migration_folder, target_version):
client = initialize_qdrant_client(url, api_key)
check_and_create_migrations_collection(client)
current_version = get_current_version(client)
rollback_migrations(client, migration_folder, current_version, target_version)