From 0a7936f0456989c184dbc2c760034ec4ab047bbb Mon Sep 17 00:00:00 2001 From: edisj Date: Wed, 3 Jul 2024 17:39:26 -0700 Subject: [PATCH] capitalize sql keywords --- mdaadb/database.py | 232 +++++++++++++++++++++++++++++++++++---------- mdaadb/query.py | 119 +++++++++++++---------- 2 files changed, 253 insertions(+), 98 deletions(-) diff --git a/mdaadb/database.py b/mdaadb/database.py index 6229144..a5fb81d 100644 --- a/mdaadb/database.py +++ b/mdaadb/database.py @@ -3,8 +3,9 @@ import sqlite3 import pandas as pd from dataclasses import dataclass -from collections import namedtuple +from collections import namedtuple, UserDict from typing import List, Iterable, Any + from query import Query @@ -14,6 +15,11 @@ def _namedtuple_factory(cursor, row): return Row(*row) +class Tables(UserDict): + def __init__(self, db, *args, **kwargs): + self.db = db + + class Database: def __init__(self, database): @@ -35,9 +41,9 @@ def tables(self): def _get_table_names(self) -> List[str]: """""" table_names = ( - self.select("name") - .from_("sqlite_schema") - .where("type = 'table'") + self.SELECT("name") + .FROM("sqlite_schema") + .WHERE("type = 'table'") .execute() .fetchall() ) @@ -45,39 +51,123 @@ def _get_table_names(self) -> List[str]: @property def schema(self) -> pd.DataFrame: - return Query(db=self).select("*").from_("sqlite_schema").to_df() - - def select(self, *fields): """""" - return Query(db=self).select(*fields) - - def pragma(self, *fields): - """""" - return Query(db=self).pragma(*fields) - - def insert_into(self, *tables): - """""" - return Query(db=self).insert_into(*tables) - - def create_table(self, *fields): - """""" - return Query(db=self).create_table(*fields) + return ( + Query(db=self) + .SELECT("*") + .FROM("sqlite_schema") + .to_df() + ) def insert_row_into_table(self, table, row) -> None: + """""" if isinstance(table, str): table = Table(table, self) table.insert_row(row) def insert_array_into_table(self, table, array) -> None: + """""" if isinstance(table, str): table = Table(table, self) table.insert_array(array) + def CREATE_TABLE(self, *fields) -> Query: + """""" + return ( + Query(db=self) + .CREATE_TABLE(*fields) + ) + + def DELETE(self, *fields) -> Query: + """""" + return ( + Query(db=self) + .DELETE(*fields) + ) + + def FROM(self, *fields) -> Query: + """""" + return ( + Query(db=self) + .FROM(*fields) + ) + + def INNER_JOIN(self, *fields) -> Query: + """""" + return ( + Query(db=self) + .INNER_JOIN(*fields) + ) + + def INSERT(self, row) -> Query: + """""" + return ( + Query(db=self) + .INSERT(row) + ) + + def INTO(self, table) -> Query: + """""" + return ( + Query(db=self) + .INTO(table) + ) + + def LIMIT(self, limit) -> Query: + """""" + return ( + Query(db=self) + .LIMIT(limit) + ) + + def ON(self, *fields) -> Query: + """""" + return ( + Query(db=self) + .ON(*fields) + ) + + def ORDER_BY(self, *fields) -> Query: + """""" + return ( + Query(db=self) + .ORDER_BY(*fields) + ) + + def PRAGMA(self, *fields) -> Query: + """""" + return ( + Query(db=self) + .PRAGMA(*fields) + ) + + def SELECT(self, *fields) -> Query: + """""" + return ( + Query(db=self) + .SELECT(*fields) + ) + + def VALUES(self, *values) -> Query: + """""" + return ( + Query(db=self) + .VALUES(*values) + ) + + def WHERE(self, *fields) -> Query: + """""" + return ( + Query(db=self) + .WHERE(*fields) + ) + def __iter__(self): return iter(self.tables) def __contains__(self, table): - return table.database == self + return table.db == self + @dataclass @@ -88,12 +178,23 @@ class Table: @property def schema(self) -> str: """""" - return self.db.select("sql").from_("sqlite_schema").where(f"name = '{self.name}'").execute().fetchone()[0] + return ( + self.db + .SELECT("sql") + .FROM("sqlite_schema") + .WHERE(f"name = '{self.name}'") + .execute() + .fetchone()[0] + ) @property def info(self) -> pd.DataFrame: """""" - return self.db.pragma(f"table_info('{self.name}')").to_df() + return ( + self.db + .PRAGMA(f"table_info('{self.name}')") + .to_df() + ) def row(self, id:int) -> Row: """""" @@ -130,14 +231,19 @@ def columns(self) -> Iterable[Column]: def _get_row_ids(self) -> List[int]: """Helper function that returns a list of integer indices that correspond to the primary key of this table.""" - row_ids = self.select(self.primary_key).execute().fetchall() + row_ids = ( + self.SELECT(self.primary_key) + .execute() + .fetchall() + ) return [id[0] for id in row_ids] def _get_column_names(self) -> List[str]: """Helper function that returns a list of column names for this table..""" column_names = ( - self.db.select("name") - .from_(f"pragma_table_info('simulations')") + self.db + .SELECT("name") + .FROM(f"pragma_table_info('simulations')") .execute() .fetchall() ) @@ -147,22 +253,26 @@ def _get_column_names(self) -> List[str]: def primary_key(self) -> str: """Returns the primary key of this table as a string.""" result = ( - self.db.select("name") - .from_(f"pragma_table_info('{self.name}') as ti") - .where("ti.pk = 1") + self.db + .SELECT("name") + .FROM(f"pragma_table_info('{self.name}') as tblinfo") + .WHERE("tblinfo.pk = 1") .execute() .fetchall() ) assert len(result[0]) == 1 return result[0][0] - def select(self, *columns) -> Query: - """""" - return self.db.select(*columns).from_(self.name) - def insert_row(self, row) -> None: """""" - return self.db.insert_into(self.name).values(row).execute().commit() + return ( + self.db + .INSERT("") + .INTO(self.name) + .VALUES(row) + .execute() + .commit() + ) def insert_array(self, array) -> None: """""" @@ -170,21 +280,39 @@ def insert_array(self, array) -> None: values = "(" + ", ".join(["?" for _ in range(n_cols)]) + ")" return ( self.db - .insert_into(self.name) - .values(values) + .INSERT("") + .INTO(self.name) + .VALUES(values) .executemany(array) .commit() ) - def delete_row(self, row) -> None: - ... + def DELETE(self, *rows) -> Query: + return ( + self.db + .DELETE(*rows) + .FROM(self.name) + ) - def delete_row_by_id(self, id) -> None: - ... + def INSERT(self, *rows) -> Query: + """""" + return ( + self.db + .INSERT(*rows) + .INTO(self.name) + ) + + def SELECT(self, *columns) -> Query: + """Builds query with SELECT keyword""" + return ( + self.db + .SELECT(*columns) + .FROM(self.name) + ) def to_sql(self) -> str: """Returns the sqlite query that generates this table.""" - return self.select("*").to_sql() + return self.SELECT("*").to_sql() def to_df(self) -> pd.DataFrame: """Returns a pandas DataFrame of this table.""" @@ -212,8 +340,8 @@ def data(self): """""" return ( self.table - .select("*") - .where(f"{self.table.primary_key}={self.id}") + .SELECT("*") + .WHERE(f"{self.table.primary_key}={self.id}") .execute() .fetchall() ) @@ -222,8 +350,8 @@ def to_sql(self) -> str: """Returns the sqlite query that generates this row.""" return ( self.table - .select("*") - .where(f"{self.table.primary_key}={self.id}") + .SELECT("*") + .WHERE(f"{self.table.primary_key}={self.id}") .to_sql() ) @@ -247,9 +375,9 @@ def type_(self) -> str: """""" _type = ( Query(db=self.db) - .select("type") - .from_(f"pragma_table_info('{self.table.name}') as t_info") - .where(f"t_info.name='{self.name}'") + .SELECT("type") + .FROM(f"pragma_table_info('{self.table.name}') as t_info") + .WHERE(f"t_info.name='{self.name}'") .execute() .fetchall() ) @@ -258,12 +386,16 @@ def type_(self) -> str: @property def data(self) -> List[Any]: """""" - _data = self.table.select(self.name).execute().fetchall() + _data = self.table.SELECT(self.name).execute().fetchall() return [data[0] for data in _data] def to_sql(self) -> str: """""" - return self.table.select(self.name).to_sql() + return ( + self.table + .SELECT(self.name) + .to_sql() + ) def to_df(self) -> pd.DataFrame: """""" diff --git a/mdaadb/query.py b/mdaadb/query.py index f0452dc..3520b6b 100644 --- a/mdaadb/query.py +++ b/mdaadb/query.py @@ -1,6 +1,10 @@ from functools import wraps -from typing import Self +from collections import UserDict import pandas as pd +try: + from typing import Self +except ImportError: + from typing_extensions import Self # def validate_fields(expected_type): @@ -18,24 +22,27 @@ # # return decorator -def register_fields(prefix): - print(prefix) - def decorator(func): +def register_fields(func): - @wraps(func) - def wrapper(self, *fields): - print(f"registering {prefix} fields: {fields}") - self._registered_fields[prefix] = [] - for field in fields: - if not isinstance(field, str): - raise ValueError("only strings work atm") - self._registered_fields[prefix].append(field) - return func(self) + @wraps(func) + def wrapper(self, *fields): + prefix = wrapper.__name__ + print(prefix) + print(f"registering {prefix} fields: {fields}") + self._registered_fields[prefix] = [] + for field in fields: + if not isinstance(field, str): + raise ValueError("only strings work atm") + self._registered_fields[prefix].append(field) - return wrapper + return func(self) - return decorator + return wrapper + + +class Fields(UserDict): + ... class Query: @@ -44,52 +51,62 @@ def __init__(self, db): self.db = db self._registered_fields = {} - @register_fields("SELECT") - def select(self, *columns) -> Self: + @register_fields + def CREATE_TABLE(self, *fields) -> Self: + return self + + @register_fields + def DELETE(self, *fields) -> Self: return self - # @register_fields("SELECT") - # def select_table(self, *tables) -> Self: - # return self + @register_fields + def FROM(self, *fields) -> Self: + return self - @register_fields("PRAGMA") - def pragma(self, *fields) -> Self: + @register_fields + def INNER_JOIN(self, *fields) -> Self: return self - @register_fields("FROM") - def from_(self, *tables) -> Self: + @register_fields + def INTO(self, *fields) -> Self: return self - @register_fields("WHERE") - def where(self, *conditions) -> Self: + @register_fields + def INSERT(self, *fields) -> Self: return self - @register_fields("INSERT_INTO") - def insert_into(self, *tables) -> Self: + @register_fields + def LIMIT(self, *fields) -> Self: return self - @register_fields("VALUES") - def values(self, *values) -> Self: + @register_fields + def ON(self, *fields) -> Self: return self - @register_fields("INNER_JOIN") - def inner_join(self, *joins) -> Self: + @register_fields + def ORDER_BY(self, *fields) -> Self: return self - @register_fields("DELETE") - def delete(self, *statements) -> Self: + @register_fields + def PRAGMA(self, *fields) -> Self: return self - @register_fields("CREATE_TABLE") - def create_table(self, *fields) -> Self: + @register_fields + def SELECT(self, *fields) -> Self: return self - @register_fields("ORDER_BY") - def order_by(self, *orders) -> Self: + @register_fields + def VALUES(self, *fields) -> Self: + return self + + @register_fields + def WHERE(self, *fields) -> Self: return self def to_sql(self) -> str: - """""" + """Converts the query from it's internal `Query` representation + into a string that is the exact SQL query being executed.""" + statements = [] for prefix, fields in self._registered_fields.items(): statement = prefix + " " + ",".join(fields) @@ -98,16 +115,19 @@ def to_sql(self) -> str: sort_by = ( "SELECT", "PRAGMA", - "FROM", - #"AS", - "WHERE", - "INNER_JOIN", - "INSERT_INTO", - "VALUES", + "INSERT", "DELETE", "CREATE_TABLE", - "ORDER_BY" + "VALUES", + "FROM", + "INTO", + "INNER_JOIN", + "ON", + "WHERE", + "ORDER_BY", + "LIMIT", ) + statements = sorted(statements, key=lambda x: sort_by.index(x.split(" ")[0])) statements = [statement.replace("_", " ") for statement in statements] sql_query = "\n".join(statements) + ";" @@ -115,11 +135,11 @@ def to_sql(self) -> str: return sql_query def to_df(self) -> pd.DataFrame: - """""" + """Executes the current query and outputs result as pandas DataFrame.""" return pd.read_sql(self.to_sql(), self.db.connection) def execute(self): - """""" + """Executes the current query and returns Cursor iterator.""" return self.db.cursor.execute(self.to_sql()) def executemany(self, rows): @@ -129,3 +149,6 @@ def executemany(self, rows): def commit(self): """""" return self.db.connection.commit() + + def __repr__(self): + return self.to_sql()