Skip to content

Commit

Permalink
fix: migrate to SqlAlchemy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanljones committed Oct 3, 2023
1 parent 2b0a4f3 commit 28d868d
Show file tree
Hide file tree
Showing 3 changed files with 458 additions and 360 deletions.
51 changes: 28 additions & 23 deletions pyrekordbox/db6/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from pathlib import Path
from typing import Optional
from sqlalchemy import create_engine, or_, event, MetaData
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.orm import Session
from sqlalchemy.exc import NoResultFound
from sqlalchemy.sql.sqltypes import DateTime, String
from packaging import version
from ..utils import get_rekordbox_pid
from ..config import get_config
Expand Down Expand Up @@ -87,12 +88,12 @@ def open_rekordbox_database(path=None, key="", unlock=True, sql_driver=None):
To use the ``pysqlcipher3`` package as SQLite driver, either import it as
>>> from pysqlcipher3 import dbapi2 as sqlite3 # noqa
>>> from sqlcipher3 import dbapi2 as sqlite3 # noqa
>>> db = open_rekordbox_database("path/to/master_copy.db")
or supply the package as driver:
>>> from pysqlcipher3 import dbapi2
>>> from sqlcipher3 import dbapi2 # noqa
>>> db = open_rekordbox_database("path/to/master_copy.db", sql_driver=dbapi2)
"""
if not path:
Expand Down Expand Up @@ -231,7 +232,6 @@ def __init__(self, path=None, db_dir="", key="", unlock=True):
raise FileNotFoundError(f"Database directory '{db_dir}' does not exist!")

self.engine = engine
self._Session = sessionmaker(bind=self.engine)
self.session: Optional[Session] = None

self.registry = RekordboxAgentRegistry(self)
Expand Down Expand Up @@ -272,7 +272,7 @@ def open(self):
>>> db.open()
"""
if self.session is None:
self.session = self._Session()
self.session = Session(bind=self.engine)
self.registry.clear_buffer()

def close(self):
Expand Down Expand Up @@ -1761,46 +1761,51 @@ def json_serial(obj):
json.dump(data, fp, indent=indent, sort_keys=sort_keys, default=json_serial)

def copy_unlocked(self, output_file):
src_metadata = MetaData(bind=self.engine)
src_engine = self.engine
src_metadata = MetaData()
exclude_tables = ("sqlite_master", "sqlite_sequence", "sqlite_temp_master")

dst_engine = create_engine(f"sqlite:///{output_file}")
dst_metadata = MetaData(bind=dst_engine)
dst_metadata = MetaData()

@event.listens_for(src_metadata, "column_reflect")
def genericize_datatypes(inspector, tablename, column_dict):
column_dict["type"] = column_dict["type"].as_generic(allow_nulltype=True)

dst_engine.connect()
dst_metadata.reflect()

type_ = column_dict["type"].as_generic(allow_nulltype=True)
if isinstance(type_, DateTime):
type_ = String
column_dict["type"] = type_

src_conn = src_engine.connect()
dst_conn = dst_engine.connect()
dst_metadata.reflect(bind=dst_engine)
# drop all tables in target database
for table in reversed(dst_metadata.sorted_tables):
if table.name not in exclude_tables:
print("dropping table =", table.name)
table.drop()

table.drop(bind=dst_engine)
# # Delete all data in target database
# for table in reversed(dst_metadata.sorted_tables):
# table.delete()
dst_metadata.clear()
dst_metadata.reflect()
src_metadata.reflect()

dst_metadata.reflect(bind=dst_engine)
src_metadata.reflect(bind=src_engine)
# create all tables in target database
for table in src_metadata.sorted_tables:
if table.name not in exclude_tables:
table.create(bind=dst_engine)

# refresh metadata before you can copy data
dst_metadata.clear()
dst_metadata.reflect()

dst_metadata.reflect(bind=dst_engine)
# Copy all data from src to target
print("Copying data...")
string = "\rCopying table {name}: Inserting row {row}"
index = 0
for table in dst_metadata.sorted_tables:
src_table = src_metadata.tables[table.name]
stmt = table.insert()
index = 0
for index, row in enumerate(src_table.select().execute()):
for index, row in enumerate(src_conn.execute(src_table.select())):
print(string.format(name=table.name, row=index), end="", flush=True)
stmt.execute(row._asdict()) # noqa
dst_conn.execute(stmt.values(row))
print(f"\rCopying table {table.name}: Inserted {index} rows", flush=True)

dst_conn.commit()
Loading

0 comments on commit 28d868d

Please sign in to comment.