From fea0ac90dcdea38e67478f86a4a39b990eab90e3 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Sun, 10 Sep 2023 17:23:58 +0200 Subject: [PATCH] Add commands to merge DBs - Fix CLI arg ordering bug - Enable getting banlist entry and backend version in Freqlog module --- src/nexus/Freqlog/Definitions.py | 5 + src/nexus/Freqlog/Freqlog.py | 31 +++- src/nexus/Freqlog/backends/Backend.py | 20 +++ .../Freqlog/backends/SQLite/SQLiteBackend.py | 161 ++++++++++++++++-- src/nexus/GUI.py | 4 +- src/nexus/__main__.py | 47 ++++- 6 files changed, 237 insertions(+), 31 deletions(-) diff --git a/src/nexus/Freqlog/Definitions.py b/src/nexus/Freqlog/Definitions.py index e13013b..ce53816 100644 --- a/src/nexus/Freqlog/Definitions.py +++ b/src/nexus/Freqlog/Definitions.py @@ -37,6 +37,11 @@ class Order(Enum): DESCENDING = False +class Age(Enum): + OLDER = True + NEWER = False + + class WordMetadata: """Metadata for a word""" diff --git a/src/nexus/Freqlog/Freqlog.py b/src/nexus/Freqlog/Freqlog.py index d16403d..598a064 100644 --- a/src/nexus/Freqlog/Freqlog.py +++ b/src/nexus/Freqlog/Freqlog.py @@ -147,14 +147,16 @@ def _log_and_reset_word(min_length: int = 1) -> None: logging.warning("Stopped freqlogging") break - def __init__(self, db_path: str = Defaults.DEFAULT_DB_PATH, loggable: bool = True): + def __init__(self, path: str = Defaults.DEFAULT_DB_PATH, loggable: bool = True): """ Initialize Freqlog - :param db_path: Path to database file + :param path: Path to backend (currently == SQLiteBackend) :param loggable: Whether to create listeners :raises ValueError: If the database version is newer than the current version """ - self.backend: Backend = SQLiteBackend(db_path) + if loggable: + logging.info(f"Logging set to freqlog db at {path}") + self.backend: Backend = SQLiteBackend(path) self.q: Queue = Queue() self.listener: kbd.Listener | None = None self.mouse_listener: mouse.Listener | None = None @@ -203,6 +205,11 @@ def stop_logging(self) -> None: # TODO: find out why this runs twice on one Ctr self.is_logging = False logging.info("Stopped listeners") + def get_backend_version(self) -> str: + """Get backend version""" + logging.info("Getting backend version") + return self.backend.get_version() + def get_word_metadata(self, word: str, case: CaseSensitivity) -> WordMetadata: """Get metadata for a word""" logging.info(f"Getting metadata for '{word}', case {case.name}") @@ -216,6 +223,16 @@ def get_chord_metadata(self, chord: str) -> ChordMetadata | None: logging.info(f"Getting metadata for '{chord}'") return self.backend.get_chord_metadata(chord) + def get_banlist_entry(self, word: str, case: CaseSensitivity) -> BanlistEntry | None: + """ + Get a banlist entry + :param word: Word to get entry for + :param case: Case sensitivity + :return: BanlistEntry if word is banned for the specified case, None otherwise + """ + logging.info(f"Getting banlist entry for '{word}', case {case.name}") + return self.backend.get_banlist_entry(word, case) + def check_banned(self, word: str, case: CaseSensitivity) -> bool: """ Check if a word is banned @@ -357,3 +374,11 @@ def list_banned_words(self, limit: int = -1, sort_by: BanlistAttr = BanlistAttr. """ logging.info(f"Listing banned words, limit {limit}, sort_by {sort_by}, reverse {reverse}") return self.backend.list_banned_words(limit, sort_by, reverse) + + def merge_backends(self, *args, **kwargs): + """ + Merge backends + :raises ValueError: If backend-specific requirements are not met + """ + logging.info(f"Merging backends: {args} {kwargs}") + self.backend.merge_backend(*args, **kwargs) diff --git a/src/nexus/Freqlog/backends/Backend.py b/src/nexus/Freqlog/backends/Backend.py index 331c225..788951e 100644 --- a/src/nexus/Freqlog/backends/Backend.py +++ b/src/nexus/Freqlog/backends/Backend.py @@ -8,6 +8,10 @@ class Backend(ABC): """Base class for all backends""" + @abstractmethod + def get_version(self) -> str: + """Get backend version""" + @abstractmethod def get_word_metadata(self, word: str, case: CaseSensitivity) -> WordMetadata | None: """ @@ -22,6 +26,15 @@ def get_chord_metadata(self, chord: str) -> ChordMetadata | None: :returns: ChordMetadata if chord is found, None otherwise """ + @abstractmethod + def get_banlist_entry(self, word: str, case: CaseSensitivity) -> BanlistEntry | None: + """ + Get a banlist entry + :param word: Word to get entry for + :param case: Case sensitivity + :return: BanlistEntry if word is banned for the specified case, None otherwise + """ + @abstractmethod def log_word(self, word: str, start_time: datetime, end_time: datetime) -> bool: """Log a word entry if not banned, creating it if it doesn't exist""" @@ -95,5 +108,12 @@ def list_banned_words(self, limit: int, sort_by: BanlistAttr, :returns: Tuple of (banned words with case, banned words without case) """ + @abstractmethod + def merge_backend(self, *args, **kwargs): + """ + Merge backends + :raises ValueError: If backend-specific requirements are not met + """ + def close(self) -> None: """Close the backend""" diff --git a/src/nexus/Freqlog/backends/SQLite/SQLiteBackend.py b/src/nexus/Freqlog/backends/SQLite/SQLiteBackend.py index 82d4b53..8fd3263 100644 --- a/src/nexus/Freqlog/backends/SQLite/SQLiteBackend.py +++ b/src/nexus/Freqlog/backends/SQLite/SQLiteBackend.py @@ -1,10 +1,13 @@ +import logging +import os import sqlite3 from datetime import datetime, timedelta +from sqlite3 import Cursor from nexus import __version__ from nexus.Freqlog.backends.Backend import Backend -from nexus.Freqlog.Definitions import BanlistAttr, BanlistEntry, CaseSensitivity, ChordMetadata, ChordMetadataAttr, \ - WordMetadata, WordMetadataAttr +from nexus.Freqlog.Definitions import Age, BanlistAttr, BanlistEntry, CaseSensitivity, ChordMetadata, \ + ChordMetadataAttr, WordMetadata, WordMetadataAttr # WARNING: Directly loaded into SQL query, do not use unsanitized user input SQL_SELECT_STAR_FROM_FREQLOG = "SELECT word, frequency, lastused, avgspeed FROM freqlog" @@ -19,6 +22,27 @@ def encode_version(version: str) -> int: return int(version.split(".")[0]) << 16 | int(version.split(".")[1]) << 8 | int(version.split(".")[2]) +def _init_db(cursor: Cursor, sql_version: int): + """ + Initialize the database + """ + # WARNING: Remember to change _upgrade_database and merge_db when changing DDL + cursor.execute(f"PRAGMA user_version = {sql_version}") + # Freqloq table + cursor.execute("CREATE TABLE IF NOT EXISTS freqlog (word TEXT NOT NULL PRIMARY KEY, frequency INTEGER, " + "lastused timestamp NOT NULL, avgspeed REAL NOT NULL) WITHOUT ROWID") + cursor.execute("CREATE INDEX IF NOT EXISTS freqlog_lower ON freqlog(word COLLATE NOCASE)") + cursor.execute("CREATE INDEX IF NOT EXISTS freqlog_frequency ON freqlog(frequency)") + cursor.execute("CREATE UNIQUE INDEX IF NOT EXISTS freqlog_lastused ON freqlog(lastused)") + cursor.execute("CREATE INDEX IF NOT EXISTS freqlog_avgspeed ON freqlog(avgspeed)") + # Banlist table + cursor.execute( + "CREATE TABLE IF NOT EXISTS banlist (word TEXT PRIMARY KEY, dateadded timestamp NOT NULL) WITHOUT ROWID") + cursor.execute("CREATE INDEX IF NOT EXISTS banlist_dateadded ON banlist(dateadded)") + cursor.execute("CREATE TABLE IF NOT EXISTS banlist_lower (word TEXT PRIMARY KEY COLLATE NOCASE," + "dateadded timestamp NOT NULL) WITHOUT ROWID") + + class SQLiteBackend(Backend): def __init__(self, db_path: str) -> None: @@ -42,22 +66,7 @@ def __init__(self, db_path: str) -> None: raise ValueError( f"Database version {decode_version(old_version)} is newer than the current version {__version__}") - self._execute(f"PRAGMA user_version = {sql_version}") - - # Freqloq table - self._execute("CREATE TABLE IF NOT EXISTS freqlog (word TEXT NOT NULL PRIMARY KEY, frequency INTEGER, " - "lastused timestamp NOT NULL, avgspeed REAL NOT NULL) WITHOUT ROWID") - self._execute("CREATE INDEX IF NOT EXISTS freqlog_lower ON freqlog(word COLLATE NOCASE)") - self._execute("CREATE INDEX IF NOT EXISTS freqlog_frequency ON freqlog(frequency)") - self._execute("CREATE UNIQUE INDEX IF NOT EXISTS freqlog_lastused ON freqlog(lastused)") - self._execute("CREATE INDEX IF NOT EXISTS freqlog_avgspeed ON freqlog(avgspeed)") - - # Banlist table - self._execute( - "CREATE TABLE IF NOT EXISTS banlist (word TEXT PRIMARY KEY, dateadded timestamp NOT NULL) WITHOUT ROWID") - self._execute("CREATE INDEX IF NOT EXISTS banlist_dateadded ON banlist(dateadded)") - self._execute("CREATE TABLE IF NOT EXISTS banlist_lower (word TEXT PRIMARY KEY COLLATE NOCASE," - "dateadded timestamp NOT NULL) WITHOUT ROWID") + _init_db(self.cursor, sql_version) def _execute(self, query: str, params=None) -> None: if params: @@ -86,6 +95,10 @@ def _upgrade_database(self, sql_version: int) -> None: # Remember to warn users to back up their database before upgrading pass + def get_version(self) -> str: + """Get the version of the database""" + return decode_version(self._fetchone("PRAGMA user_version")[0]) + def get_word_metadata(self, word: str, case: CaseSensitivity) -> WordMetadata | None: """ Get metadata for a word @@ -128,6 +141,30 @@ def get_chord_metadata(self, chord: str) -> ChordMetadata | None: """ raise NotImplementedError # TODO: implement + def get_banlist_entry(self, word: str, case: CaseSensitivity) -> BanlistEntry | None: + """ + Get a banlist entry + :param word: Word to get entry for + :param case: Case sensitivity + :return: BanlistEntry if word is banned for the specified case, None otherwise + """ + match case: + case CaseSensitivity.INSENSITIVE: + word = word.lower() + res = self._fetchone(f"{SQL_SELECT_STAR_FROM_BANLIST}_lower WHERE word = ? COLLATE NOCASE", (word,)) + return BanlistEntry(res[0], datetime.fromtimestamp(res[1])) if res else None + case CaseSensitivity.FIRST_CHAR: + word_u = word[0].upper() + word[1:] + word_l = word[0].lower() + word[1:] + res_u = self._fetchone(f"{SQL_SELECT_STAR_FROM_BANLIST} WHERE word=?", (word_u,)) + res_l = self._fetchone(f"{SQL_SELECT_STAR_FROM_BANLIST} WHERE word=?", (word_l,)) + if res_u and res_l: + return BanlistEntry(res_l[0], datetime.fromtimestamp(res_l[1])) + return None # if only one or none are banned + case CaseSensitivity.SENSITIVE: + res = self._fetchone(f"{SQL_SELECT_STAR_FROM_BANLIST} WHERE word=?", (word,)) + return BanlistEntry(res[0], datetime.fromtimestamp(res[1])) if res else None + def log_word(self, word: str, start_time: datetime, end_time: datetime) -> bool: """Log a word entry if not banned, creating it if it doesn't exist""" if self.check_banned(word, CaseSensitivity.SENSITIVE): @@ -143,6 +180,11 @@ def log_word(self, word: str, start_time: datetime, end_time: datetime) -> bool: (word, 1, end_time.timestamp(), (end_time - start_time).total_seconds())) return True + def _insert_word(self, word: str, frequency: int, last_used: datetime, average_speed: timedelta) -> None: + """Insert a word entry""" + self._execute("INSERT INTO freqlog VALUES (?, ?, ?, ?)", + (word, frequency, last_used.timestamp(), average_speed.total_seconds())) + def log_chord(self, word: str, start_time: datetime, end_time: datetime) -> None: raise NotImplementedError # TODO: implement @@ -293,6 +335,89 @@ def list_banned_words(self, limit: int, sort_by: BanlistAttr, return {BanlistEntry(row[0], datetime.fromtimestamp(row[1])) for row in res}, \ {BanlistEntry(row[0], datetime.fromtimestamp(row[1])) for row in res_lower} + def merge_backend(self, src_db_path: str, dst_db_path: str, ban_date: Age) -> None: + """ + Merge another database and this one into a new database + :param src_db_path: Path to the source database + :param dst_db_path: Path to the destination database + :param ban_date: Whether to use older or newer date banned for banlist entries of the same word (OLDER or NEWER) + :requires: src_db_path != dst_db_path != self.db_path + :requires: src_db_path must be a valid Freqlog database and readable + :requires: dst_db_path must not be an existing file but must be writable + :raises ValueError: If requirements are not met + """ + # Assert requirements + if src_db_path == dst_db_path: + raise ValueError("src_db_path and dst_db_path must be different") + if src_db_path == self.db_path: + raise ValueError("src_db_path and self.db_path must be different") + if dst_db_path == self.db_path: + raise ValueError("dst_db_path and self.db_path must be different") + if os.path.isfile(dst_db_path): + raise ValueError("dst_db_path must not be an existing file") + try: # ensure that src is writable (WARNING: Must use 'a' instead of 'w' mode to avoid erasing file!!!) + with open(src_db_path, "a"): + pass + except OSError as e: + raise ValueError("src_db_path must be writable") from e + try: + with open(dst_db_path, "w"): + pass + except OSError as e: + raise ValueError("dst_db_path must be writable") from e + + # DB meta + src_db = SQLiteBackend(src_db_path) + dst_db = SQLiteBackend(dst_db_path) + + # Merge databases + # TODO: optimize this/add progress bars (this takes a long time) + try: + # Merge banlist + logging.info("Merging banlist") + src_banned_words_cased, src_banned_words_uncased = src_db.list_banned_words(0, BanlistAttr.word, False) + banned_words_cased, banned_words_uncased = self.list_banned_words(0, BanlistAttr.word, False) + + # Ban words from self banlist in dst db + for entry in banned_words_cased: + dst_db.ban_word(entry.word, CaseSensitivity.SENSITIVE, entry.date_added) + for entry in banned_words_uncased: + dst_db.ban_word(entry.word, CaseSensitivity.INSENSITIVE, entry.date_added) + + # Ban words from src banlist in dst db + for entry in src_banned_words_cased: + dst_entry = dst_db.get_banlist_entry(entry.word, CaseSensitivity.SENSITIVE) + if dst_entry and ((ban_date == Age.OLDER and entry.date_added < dst_entry.date_added) or + (ban_date == Age.NEWER and entry.date_added > dst_entry.date_added)): + dst_db.unban_word(entry.word, CaseSensitivity.SENSITIVE) + dst_db.ban_word(entry.word, CaseSensitivity.SENSITIVE, entry.date_added) + else: + dst_db.ban_word(entry.word, CaseSensitivity.SENSITIVE, entry.date_added) + for entry in src_banned_words_uncased: + dst_entry = dst_db.get_banlist_entry(entry.word, CaseSensitivity.INSENSITIVE) + if dst_entry and ((ban_date == Age.OLDER and entry.date_added < dst_entry.date_added) or + (ban_date == Age.NEWER and entry.date_added > dst_entry.date_added)): + dst_db.unban_word(entry.word, CaseSensitivity.INSENSITIVE) + dst_db.ban_word(entry.word, CaseSensitivity.INSENSITIVE, entry.date_added) + else: + dst_db.ban_word(entry.word, CaseSensitivity.INSENSITIVE, entry.date_added) + + # Merge freqlog + logging.info("Merging freqlog") + src_words = src_db.list_words(0, WordMetadataAttr.word, False, CaseSensitivity.SENSITIVE) + words = [word.word for word in self.list_words(0, WordMetadataAttr.word, False, CaseSensitivity.SENSITIVE)] + entries = self.list_words(0, WordMetadataAttr.word, False, CaseSensitivity.SENSITIVE) + for src_word in src_words: + if src_word.word in words: + entries[words.index(src_word.word)] |= src_word + else: + entries.append(src_word) + for word in entries: + dst_db._insert_word(word.word, word.frequency, word.last_used, word.average_speed) + finally: # Close databases + src_db.close() + dst_db.close() + def close(self) -> None: """Close the database connection""" self.cursor.close() diff --git a/src/nexus/GUI.py b/src/nexus/GUI.py index 91988ca..503adc6 100644 --- a/src/nexus/GUI.py +++ b/src/nexus/GUI.py @@ -156,7 +156,7 @@ def __init__(self, args: argparse.Namespace): self.set_style('Nexus_Dark') self.freqlog: Freqlog | None = None # for logging - self.temp_freqlog: Freqlog = Freqlog(args.freq_log_path, loggable=False) # for other operations + self.temp_freqlog: Freqlog = Freqlog(args.freqlog_db_path, loggable=False) # for other operations self.logging_thread: Thread | None = None self.start_stop_button_started = False self.args = args @@ -177,7 +177,7 @@ def set_style(self, style: Literal['Nexus_Dark', 'Fusion', 'Default']): def start_logging(self): if not self.freqlog: - self.freqlog = Freqlog(self.args.freq_log_path, loggable=True) + self.freqlog = Freqlog(self.args.freqlog_db_path, loggable=True) self.freqlog.start_logging() def stop_logging(self): diff --git a/src/nexus/__main__.py b/src/nexus/__main__.py index 02660ab..c0d7bfe 100644 --- a/src/nexus/__main__.py +++ b/src/nexus/__main__.py @@ -6,8 +6,9 @@ from pynput import keyboard from nexus import __doc__, __version__, Freqlog -from nexus.Freqlog.Definitions import BanlistAttr, CaseSensitivity, ChordMetadata, ChordMetadataAttr, Defaults, Order, \ - WordMetadata, WordMetadataAttr +from nexus.Freqlog.Definitions import Age, BanlistAttr, CaseSensitivity, ChordMetadata, ChordMetadataAttr, Defaults, \ + Order, WordMetadata, WordMetadataAttr +from nexus.Freqlog.backends.SQLite import SQLiteBackend from nexus.GUI import GUI @@ -22,6 +23,7 @@ def main(): 4: Could not access or write to database 5: Requested word or chord not found 6: Tried to ban already banned word or unban already unbanned word + 7: Merge db requirements not met 11: Python version < 3.11 100: Feature not yet implemented """ @@ -33,11 +35,13 @@ def main(): log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "NONE"] # Common arguments + # Log and path must be SUPPRESS for placement before and after command to work + # (see https://stackoverflow.com/a/62906328/9206488) log_arg = argparse.ArgumentParser(add_help=False) - log_arg.add_argument("-l", "--log-level", default="INFO", help=f"One of {log_levels}", + log_arg.add_argument("-l", "--log-level", default=argparse.SUPPRESS, help=f"One of {log_levels}", metavar="level", choices=log_levels) path_arg = argparse.ArgumentParser(add_help=False) - path_arg.add_argument("--freq-log-path", default=Defaults.DEFAULT_DB_PATH, help="Backend to use") + path_arg.add_argument("--freqlog-db-path", default=argparse.SUPPRESS, help="Path to db backend to use") case_arg = argparse.ArgumentParser(add_help=False) case_arg.add_argument("-c", "--case", default=CaseSensitivity.INSENSITIVE.name, help="Case sensitivity", choices={case.name for case in CaseSensitivity}) @@ -49,9 +53,12 @@ def main(): required=False) # Parse command line arguments - parser = argparse.ArgumentParser(description=__doc__, parents=[log_arg, path_arg], + parser = argparse.ArgumentParser(description=__doc__, epilog="Made with love by CharaChorder, source code available at " "https://github.com/CharaChorder/nexus") + parser.add_argument("-l", "--log-level", default="INFO", help=f"One of {log_levels}", + metavar="level", choices=log_levels) + parser.add_argument("--freqlog-db-path", default=Defaults.DEFAULT_DB_PATH, help="Path to db backend to use") subparsers = parser.add_subparsers(dest="command", title="Commands") # Start freqlogging @@ -122,6 +129,17 @@ def main(): # Stop freqlogging # subparsers.add_parser("stoplog", help="Stop logging", parents=[log_arg]) parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") + + # Merge db + parser_merge = subparsers.add_parser("mergedb", help="Merge two Freqlog databases", parents=[log_arg]) + parser_merge.add_argument("--ban-data-keep", default=Age.OLDER.name, + help=f"Which ban data to keep (default: {Age.OLDER.name})", + choices=[age.name for age in Age]) + parser_merge.add_argument("src1", help="Path to first source database") + parser_merge.add_argument("src2", help="Path to second source database") + parser_merge.add_argument("dst", help="Path to destination database") + + # Parse arguments args = parser.parse_args() # Set up console logging @@ -129,6 +147,7 @@ def main(): logging.disable(logging.CRITICAL) else: logging.basicConfig(level=args.log_level, format="%(asctime)s - %(message)s") + logging.debug(f"Args: {args}") exit_code = 0 @@ -140,7 +159,7 @@ def main(): match args.command: case "startlog": try: # ensure that path is writable (WARNING: Must use 'a' instead of 'w' mode to avoid erasing file!!!) - with open(args.freq_log_path, "a"): + with open(args.freqlog_db_path, "a"): pass except OSError as e: logging.error(e) @@ -168,6 +187,18 @@ def main(): exit_code = 3 # Parse commands + if args.command == "mergedb": # merge databases + # Merge databases + logging.warning("This feature has not been thoroughly tested and is not guaranteed to work. Manually verify" + f"(via an export) that the destination DB ({args.dst}) contains all your data after merging.") + try: + src1: SQLiteBackend = Freqlog.Freqlog(args.src1, loggable=False) + src1.merge_backends(args.src2, args.dst, Age[args.ban_data_keep]) + sys.exit(0) + except ValueError as e: + logging.error(e) + exit_code = 7 + if args.command == "stoplog": # stop freqlogging # Kill freqlogging process logging.warning("This feature hasn't been implemented." + @@ -182,7 +213,7 @@ def main(): # TODO: Some features from this point on may not have been implemented try: # All following commands require a freqlog object - freqlog = Freqlog.Freqlog(args.freq_log_path, loggable=False) + freqlog = Freqlog.Freqlog(args.freqlog_db_path, loggable=False) if args.command == "numwords": # get number of words print(f"{freqlog.num_words(CaseSensitivity[args.case])} words in freqlog") sys.exit(0) @@ -194,7 +225,7 @@ def main(): num = Defaults.DEFAULT_NUM_WORDS_CLI match args.command: case "startlog": # start freqlogging - freqlog = Freqlog.Freqlog(args.freq_log_path, loggable=True) + freqlog = Freqlog.Freqlog(args.freqlog_db_path, loggable=True) signal.signal(signal.SIGINT, lambda: freqlog.stop_logging()) freqlog.start_logging(args.new_word_threshold, args.chord_char_threshold, args.allowed_keys_in_chord, Defaults.DEFAULT_MODIFIER_KEYS - set(args.remove_modifier_key) | set(