From dab8f3eef8dd4d554aed57f0d0916b2027f194f0 Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Tue, 3 Dec 2024 15:18:18 +0100 Subject: [PATCH] Add compile_query as a helper function in lms.db._util Change the approach on the previous commit, instead of a method on the session, just use a helper method importable from lms.db. --- lms/db/__init__.py | 29 +++-------------------------- lms/db/_util.py | 24 ++++++++++++++++++++++++ tests/unit/conftest.py | 9 --------- tests/unit/lms/db/__init___test.py | 15 --------------- tests/unit/lms/db/_util_test.py | 20 ++++++++++++++++++++ 5 files changed, 47 insertions(+), 50 deletions(-) create mode 100644 lms/db/_util.py delete mode 100644 tests/unit/lms/db/__init___test.py create mode 100644 tests/unit/lms/db/_util_test.py diff --git a/lms/db/__init__.py b/lms/db/__init__.py index cecdfe2f39..506a44c242 100644 --- a/lms/db/__init__.py +++ b/lms/db/__init__.py @@ -13,8 +13,9 @@ from lms.db._columns import varchar_enum from lms.db._locks import CouldNotAcquireLock, LockType, try_advisory_transaction_lock from lms.db._text_search import full_text_match +from lms.db._util import compile_query -__all__ = ("Base", "create_engine", "varchar_enum") +__all__ = ("Base", "compile_query", "create_engine", "varchar_enum") Base = declarative_base( @@ -41,31 +42,7 @@ def create_engine(database_url): return sqlalchemy.create_engine(database_url) -class CustomSession(Session): - """Our own session object based on the default orm.Session.""" - - def compile_query(self, query: Query | Select, literal_binds: bool = True) -> str: - """ - Return the SQL representation of `query` for postgres. - - :param literal_binds: Whether or not replace the query parameters by their values. - """ - if isinstance(query, Query): - # Support for SQLAlchemy 1.X style queryies, eg: db.query(Model).filter_by() - statement = query.statement - else: - # SQLALchemy 2.X style, eg: select(Model).where() - statement = query - - return str( - statement.compile( - self.get_bind(), - compile_kwargs={"literal_binds": literal_binds}, - ) - ) - - -SESSION = sessionmaker(class_=CustomSession) +SESSION = sessionmaker() def _session(request): # pragma: no cover diff --git a/lms/db/_util.py b/lms/db/_util.py new file mode 100644 index 0000000000..964f25c23c --- /dev/null +++ b/lms/db/_util.py @@ -0,0 +1,24 @@ +from sqlalchemy import Select +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import Query + + +def compile_query(query: Query | Select, literal_binds: bool = True) -> str: + """ + Return the SQL representation of `query` for postgres. + + :param literal_binds: Whether or not replace the query parameters by their values. + """ + if isinstance(query, Query): + # Support for SQLAlchemy 1.X style queryies, eg: db.query(Model).filter_by() + statement = query.statement + else: + # SQLALchemy 2.X style, eg: select(Model).where() + statement = query + + return str( + statement.compile( + dialect=postgresql.dialect(), + compile_kwargs={"literal_binds": literal_binds}, + ) + ) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index bfff08ca14..d88d456c46 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,7 +7,6 @@ from pyramid import testing from pyramid.request import apply_request_extensions -from lms import db from lms.models import ApplicationSettings, LTIParams from lms.models.lti_role import Role, RoleScope, RoleType from lms.product import Product @@ -19,14 +18,6 @@ TEST_SETTINGS["database_url"] = environ["DATABASE_URL"] -@pytest.fixture(scope="session") -def db_sessionfactory(): - """Overwrite h-testkit default fixture to customizt the session class""" - from sqlalchemy.orm import sessionmaker # noqa: PLC0415 - - return sessionmaker(class_=db.CustomSession) - - @pytest.fixture def lti_v11_params(): return { diff --git a/tests/unit/lms/db/__init___test.py b/tests/unit/lms/db/__init___test.py deleted file mode 100644 index 028ffe854f..0000000000 --- a/tests/unit/lms/db/__init___test.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqlalchemy import select - -from lms.models import Event - - -class TestCustomSession: - def test_compile_query(self, db_session): - old_style_query = db_session.query(Event).filter_by(id=0) - new_style_query = select(Event).where(Event.id == 0) - - old_style_query_compiled = db_session.compile_query(old_style_query) - new_style_query_complied = db_session.compile_query(new_style_query) - - assert old_style_query_compiled == new_style_query_complied - assert old_style_query_compiled.startswith("SELECT event.id") diff --git a/tests/unit/lms/db/_util_test.py b/tests/unit/lms/db/_util_test.py new file mode 100644 index 0000000000..756a5e5417 --- /dev/null +++ b/tests/unit/lms/db/_util_test.py @@ -0,0 +1,20 @@ +from sqlalchemy import select + +from lms.db import compile_query +from lms.models import Event + + +def test_compile_query(db_session): + old_style_query = db_session.query(Event).filter_by(id=0) + new_style_query = select(Event).where(Event.id == 0) + + old_style_query_compiled = compile_query(old_style_query) + new_style_query_complied = compile_query(new_style_query) + + assert old_style_query_compiled == new_style_query_complied + assert ( + old_style_query_compiled + == """SELECT event.id, event.timestamp, event.type_id, event.application_instance_id, event.course_id, event.assignment_id, event.grouping_id +FROM event +WHERE event.id = 0""" # noqa: W291 + )