diff --git a/pyrekordbox/db6/database.py b/pyrekordbox/db6/database.py index 3b5783b..2348439 100644 --- a/pyrekordbox/db6/database.py +++ b/pyrekordbox/db6/database.py @@ -8,7 +8,7 @@ from uuid import uuid4 from pathlib import Path from typing import Optional -from sqlalchemy import create_engine, or_, event +from sqlalchemy import create_engine, or_, event, MetaData from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.exc import NoResultFound from packaging import version @@ -1763,3 +1763,48 @@ def json_serial(obj): data = self.to_dict(verbose=verbose) with open(file, "w") as fp: json.dump(data, fp, indent=indent, sort_keys=sort_keys, default=json_serial) + + def copy_unlocked(self, output_file): + src_metadata = MetaData(bind=self.engine) + exclude_tables = ("sqlite_master", "sqlite_sequence", "sqlite_temp_master") + + dst_engine = create_engine(f"sqlite:///{output_file}") + dst_metadata = MetaData(bind=dst_engine) + + @event.listens_for(src_metadata, "column_reflect") + def genericize_datatypes(inspector, tablename, column_dict): + column_dict["type"] = column_dict["type"].as_generic(allow_nulltype=True) + + dst_engine.connect() + dst_metadata.reflect() + + # drop all tables in target database + for table in reversed(dst_metadata.sorted_tables): + if table.name not in exclude_tables: + print("dropping table =", table.name) + table.drop() + + dst_metadata.clear() + dst_metadata.reflect() + src_metadata.reflect() + + # create all tables in target database + for table in src_metadata.sorted_tables: + if table.name not in exclude_tables: + table.create(bind=dst_engine) + + # refresh metadata before you can copy data + dst_metadata.clear() + dst_metadata.reflect() + + # Copy all data from src to target + print("Copying data...") + string = "\rCopying table {name}: Inserting row {row}" + for table in dst_metadata.sorted_tables: + src_table = src_metadata.tables[table.name] + stmt = table.insert() + index = 0 + for index, row in enumerate(src_table.select().execute()): + print(string.format(name=table.name, row=index), end="", flush=True) + stmt.execute(row._asdict()) # noqa + print(f"\rCopying table {table.name}: Inserted {index} rows", flush=True) diff --git a/tests/test_db6.py b/tests/test_db6.py index 97f526a..5fcfb8a 100644 --- a/tests/test_db6.py +++ b/tests/test_db6.py @@ -13,8 +13,10 @@ from pyrekordbox.db6 import tables TEST_ROOT = Path(__file__).parent.parent / ".testdata" +LOCKED = TEST_ROOT / "rekordbox 6" / "master_locked.db" UNLOCKED = TEST_ROOT / "rekordbox 6" / "master_unlocked.db" UNLOCKED_COPY = TEST_ROOT / "rekordbox 6" / "master_unlocked_copy.db" +UNLOCKED_OUT = TEST_ROOT / "rekordbox 6" / "master_unlocked_out.db" MASTER_PLAYLIST_SRC = TEST_ROOT / "rekordbox 6" / "masterPlaylists6_template.xml" MASTER_PLAYLIST_DST = TEST_ROOT / "rekordbox 6" / "masterPlaylists6.xml" # Create a copy of the masterPlaylists6.xml file @@ -955,3 +957,25 @@ def test_to_json(): finally: tmp.close() os.remove(tmp.name) + + +def test_copy_unlocked(): + db = Rekordbox6Database(UNLOCKED, unlock=False) + db.copy_unlocked(UNLOCKED_OUT) + + db2 = Rekordbox6Database(UNLOCKED_OUT, unlock=False) + # Check everything got copied + for name in tables.__all__: + if name.startswith("Stats") or name == "Base": + continue + table = getattr(tables, name) + for row in db.query(table): + data = row.to_dict() + if name == "AgentRegistry": + query = db2.query(table).filter_by(registry_id=row.registry_id) + elif name == "DjmdProperty": + query = db2.query(table).filter_by(DBID=row.DBID) + else: + query = db2.query(table).filter_by(ID=row.ID) + data2 = query.one().to_dict() + assert data == data2