From 5e08698545338d192e88a0eb40644d2bd11f0b9e Mon Sep 17 00:00:00 2001 From: edisj Date: Wed, 3 Jul 2024 10:28:38 -0700 Subject: [PATCH] Add query and database modules --- mdaadb/database.py | 271 +++++++++++++++++++++++++++++++++++++++++++++ mdaadb/query.py | 131 ++++++++++++++++++++++ 2 files changed, 402 insertions(+) create mode 100644 mdaadb/database.py create mode 100644 mdaadb/query.py diff --git a/mdaadb/database.py b/mdaadb/database.py new file mode 100644 index 0000000..6229144 --- /dev/null +++ b/mdaadb/database.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import sqlite3 +import pandas as pd +from dataclasses import dataclass +from collections import namedtuple +from typing import List, Iterable, Any +from query import Query + + +def _namedtuple_factory(cursor, row): + fields = [col[0] for col in cursor.description] + Row = namedtuple("Row", fields) + return Row(*row) + + +class Database: + + def __init__(self, database): + self.connection = sqlite3.connect(database) + self.connection.row_factory = _namedtuple_factory + self.cursor = self.connection.cursor() + + def table(self, name): + """""" + if name not in self._get_table_names(): + raise ValueError("invalid table") + return Table(name, self) + + @property + def tables(self): + """""" + return (Table(name, self) for name in self._get_table_names()) + + def _get_table_names(self) -> List[str]: + """""" + table_names = ( + self.select("name") + .from_("sqlite_schema") + .where("type = 'table'") + .execute() + .fetchall() + ) + return [name[0] for name in table_names] + + @property + def schema(self) -> pd.DataFrame: + return Query(db=self).select("*").from_("sqlite_schema").to_df() + + def select(self, *fields): + """""" + return Query(db=self).select(*fields) + + def pragma(self, *fields): + """""" + return Query(db=self).pragma(*fields) + + def insert_into(self, *tables): + """""" + return Query(db=self).insert_into(*tables) + + def create_table(self, *fields): + """""" + return Query(db=self).create_table(*fields) + + def insert_row_into_table(self, table, row) -> None: + if isinstance(table, str): + table = Table(table, self) + table.insert_row(row) + + def insert_array_into_table(self, table, array) -> None: + if isinstance(table, str): + table = Table(table, self) + table.insert_array(array) + + def __iter__(self): + return iter(self.tables) + + def __contains__(self, table): + return table.database == self + + +@dataclass +class Table: + name: str + db: Database + + @property + def schema(self) -> str: + """""" + return self.db.select("sql").from_("sqlite_schema").where(f"name = '{self.name}'").execute().fetchone()[0] + + @property + def info(self) -> pd.DataFrame: + """""" + return self.db.pragma(f"table_info('{self.name}')").to_df() + + def row(self, id:int) -> Row: + """""" + if id not in self._get_row_ids(): + raise ValueError("id not in valid ids") + return Row(id, self) + + def column(self, name:str) -> Column: + """""" + if name not in self._get_column_names(): + raise ValueError("name not in column names") + return Column(name, self) + + @property + def n_rows(self) -> int: + """Total number of rows in this table.""" + return len(self._get_row_ids()) + + @property + def n_cols(self) -> int: + """Total number of columns in this table.""" + return len(self._get_column_names()) + + @property + def rows(self) -> Iterable[Row]: + """An iterable of Rows contained in this table.""" + return (Row(id, self) for id in self._get_row_ids()) + + @property + def columns(self) -> Iterable[Column]: + """An iterable of Columns contained in this table.""" + return (Column(name, self) for name in self._get_column_names()) + + def _get_row_ids(self) -> List[int]: + """Helper function that returns a list of integer indices + that correspond to the primary key of this table.""" + row_ids = self.select(self.primary_key).execute().fetchall() + return [id[0] for id in row_ids] + + def _get_column_names(self) -> List[str]: + """Helper function that returns a list of column names for this table..""" + column_names = ( + self.db.select("name") + .from_(f"pragma_table_info('simulations')") + .execute() + .fetchall() + ) + return [name[0] for name in column_names] + + @property + def primary_key(self) -> str: + """Returns the primary key of this table as a string.""" + result = ( + self.db.select("name") + .from_(f"pragma_table_info('{self.name}') as ti") + .where("ti.pk = 1") + .execute() + .fetchall() + ) + assert len(result[0]) == 1 + return result[0][0] + + def select(self, *columns) -> Query: + """""" + return self.db.select(*columns).from_(self.name) + + def insert_row(self, row) -> None: + """""" + return self.db.insert_into(self.name).values(row).execute().commit() + + def insert_array(self, array) -> None: + """""" + n_cols = self.n_cols + values = "(" + ", ".join(["?" for _ in range(n_cols)]) + ")" + return ( + self.db + .insert_into(self.name) + .values(values) + .executemany(array) + .commit() + ) + + def delete_row(self, row) -> None: + ... + + def delete_row_by_id(self, id) -> None: + ... + + def to_sql(self) -> str: + """Returns the sqlite query that generates this table.""" + return self.select("*").to_sql() + + def to_df(self) -> pd.DataFrame: + """Returns a pandas DataFrame of this table.""" + return pd.read_sql(self.to_sql(), self.db.connection) + + def __len__(self) -> int: + return self.n_rows + + def __iter__(self) -> Iterable[Row]: + return iter(self.rows) + + +@dataclass +class Row: + id: int + table: Table + + @property + def db(self) -> Database: + """Returns the database that this row belongs to.""" + return self.table.db + + @property + def data(self): + """""" + return ( + self.table + .select("*") + .where(f"{self.table.primary_key}={self.id}") + .execute() + .fetchall() + ) + + def to_sql(self) -> str: + """Returns the sqlite query that generates this row.""" + return ( + self.table + .select("*") + .where(f"{self.table.primary_key}={self.id}") + .to_sql() + ) + + def to_df(self) -> pd.DataFrame: + """Returns a pandas DataFrame of this row.""" + return pd.read_sql(self.to_sql(), self.db.connection) + + +@dataclass +class Column: + name: str + table: Table + + @property + def db(self) -> Database: + """""" + return self.table.db + + @property + def type_(self) -> str: + """""" + _type = ( + Query(db=self.db) + .select("type") + .from_(f"pragma_table_info('{self.table.name}') as t_info") + .where(f"t_info.name='{self.name}'") + .execute() + .fetchall() + ) + return _type[0][0] + + @property + def data(self) -> List[Any]: + """""" + _data = self.table.select(self.name).execute().fetchall() + return [data[0] for data in _data] + + def to_sql(self) -> str: + """""" + return self.table.select(self.name).to_sql() + + def to_df(self) -> pd.DataFrame: + """""" + return pd.read_sql(self.to_sql(), self.db.connection) + diff --git a/mdaadb/query.py b/mdaadb/query.py new file mode 100644 index 0000000..f0452dc --- /dev/null +++ b/mdaadb/query.py @@ -0,0 +1,131 @@ +from functools import wraps +from typing import Self +import pandas as pd + + +# def validate_fields(expected_type): +# +# print(expected_type) +# +# def decorator(func): +# +# @wraps(func) +# def wrapper(self, *fields): +# print(f"validating fields {fields}") +# return func(self) +# +# return wrapper +# +# return decorator + +def register_fields(prefix): + + print(prefix) + def decorator(func): + + @wraps(func) + def wrapper(self, *fields): + print(f"registering {prefix} fields: {fields}") + self._registered_fields[prefix] = [] + for field in fields: + if not isinstance(field, str): + raise ValueError("only strings work atm") + self._registered_fields[prefix].append(field) + return func(self) + + return wrapper + + return decorator + + +class Query: + + def __init__(self, db): + self.db = db + self._registered_fields = {} + + @register_fields("SELECT") + def select(self, *columns) -> Self: + return self + + # @register_fields("SELECT") + # def select_table(self, *tables) -> Self: + # return self + + @register_fields("PRAGMA") + def pragma(self, *fields) -> Self: + return self + + @register_fields("FROM") + def from_(self, *tables) -> Self: + return self + + @register_fields("WHERE") + def where(self, *conditions) -> Self: + return self + + @register_fields("INSERT_INTO") + def insert_into(self, *tables) -> Self: + return self + + @register_fields("VALUES") + def values(self, *values) -> Self: + return self + + @register_fields("INNER_JOIN") + def inner_join(self, *joins) -> Self: + return self + + @register_fields("DELETE") + def delete(self, *statements) -> Self: + return self + + @register_fields("CREATE_TABLE") + def create_table(self, *fields) -> Self: + return self + + @register_fields("ORDER_BY") + def order_by(self, *orders) -> Self: + return self + + def to_sql(self) -> str: + """""" + statements = [] + for prefix, fields in self._registered_fields.items(): + statement = prefix + " " + ",".join(fields) + statements.append(statement) + + sort_by = ( + "SELECT", + "PRAGMA", + "FROM", + #"AS", + "WHERE", + "INNER_JOIN", + "INSERT_INTO", + "VALUES", + "DELETE", + "CREATE_TABLE", + "ORDER_BY" + ) + statements = sorted(statements, key=lambda x: sort_by.index(x.split(" ")[0])) + statements = [statement.replace("_", " ") for statement in statements] + sql_query = "\n".join(statements) + ";" + + return sql_query + + def to_df(self) -> pd.DataFrame: + """""" + return pd.read_sql(self.to_sql(), self.db.connection) + + def execute(self): + """""" + return self.db.cursor.execute(self.to_sql()) + + def executemany(self, rows): + """""" + return self.db.cursor.executemany(self.to_sql(), rows) + + def commit(self): + """""" + return self.db.connection.commit()