diff --git a/gphotos/LocalData.py b/gphotos/LocalData.py index d636679b..a64ffecd 100644 --- a/gphotos/LocalData.py +++ b/gphotos/LocalData.py @@ -1,10 +1,10 @@ -#!/usr/bin/env python3 -# coding: utf8 -from pathlib import Path +import logging import platform import sqlite3 as lite -from sqlite3.dbapi2 import Connection, Cursor from datetime import datetime +from os import unlink +from pathlib import Path +from sqlite3.dbapi2 import Connection, Cursor from typing import Iterator, Type # todo this module could be tidied quite a bit @@ -12,13 +12,11 @@ # also the generic functions seem a bit ugly and could do with rework import gphotos.Queries as Queries from gphotos import Utils +from gphotos.DatabaseMedia import DatabaseMedia +from gphotos.DbRow import DbRow from gphotos.GoogleAlbumsRow import GoogleAlbumsRow -from gphotos.LocalFilesRow import LocalFilesRow from gphotos.GooglePhotosRow import GooglePhotosRow -from gphotos.DbRow import DbRow -from gphotos.DatabaseMedia import DatabaseMedia - -import logging +from gphotos.LocalFilesRow import LocalFilesRow log = logging.getLogger(__name__) @@ -44,7 +42,7 @@ def __init__(self, root_folder: Path, flush_index: bool = False): clean_db = True elif flush_index: clean_db = True - self.db_file.rename(self.db_file.parent / (self.db_file.name + ".previous")) + self.backup_sql_file() self.con: Connection = lite.connect(str(self.db_file), check_same_thread=False) self.con.row_factory = lite.Row @@ -59,6 +57,12 @@ def __init__(self, root_folder: Path, flush_index: bool = False): def __enter__(self): return self + def backup_sql_file(self): + backup = self.db_file.parent / (self.db_file.name + ".previous") + if backup.exists(): + backup.unlink() + self.db_file.rename(backup) + def __exit__(self, exc_type, exc_val, exc_tb): """ Always clean up and close the connection when this object is destroyed. """ @@ -84,7 +88,7 @@ def check_schema_version(self): ) self.con.commit() self.con.close() - self.db_file.rename(self.db_file.parent / (self.db_file.name + ".previous")) + self.backup_sql_file() self.con = lite.connect(str(self.db_file)) self.con.row_factory = lite.Row self.cur = self.con.cursor()