From 845ed5088ae57c90ef2ab415534c8f9b10c44746 Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Thu, 21 Nov 2024 11:01:21 +0100 Subject: [PATCH] Add a helper compile_query method to the SQLAlchemy session object --- lms/db/__init__.py | 31 +++++++++++++++++++++++++++--- tests/unit/conftest.py | 9 +++++++++ tests/unit/lms/db/__init___test.py | 15 +++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 tests/unit/lms/db/__init___test.py diff --git a/lms/db/__init__.py b/lms/db/__init__.py index 5dc12e65db..cecdfe2f39 100644 --- a/lms/db/__init__.py +++ b/lms/db/__init__.py @@ -4,9 +4,10 @@ import alembic.config import sqlalchemy import zope.sqlalchemy -from sqlalchemy import text +from sqlalchemy import Select, text +from sqlalchemy.dialects import postgresql from sqlalchemy.inspection import inspect -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.orm import Query, Session, declarative_base, sessionmaker from sqlalchemy.orm.properties import ColumnProperty from lms.db._columns import varchar_enum @@ -40,7 +41,31 @@ def create_engine(database_url): return sqlalchemy.create_engine(database_url) -SESSION = sessionmaker() +class CustomSession(Session): + """Our own session object based on the default orm.Session.""" + + def compile_query(self, query: Query | Select, literal_binds: bool = True) -> str: + """ + Return the SQL representation of `query` for postgres. + + :param literal_binds: Whether or not replace the query parameters by their values. + """ + if isinstance(query, Query): + # Support for SQLAlchemy 1.X style queryies, eg: db.query(Model).filter_by() + statement = query.statement + else: + # SQLALchemy 2.X style, eg: select(Model).where() + statement = query + + return str( + statement.compile( + self.get_bind(), + compile_kwargs={"literal_binds": literal_binds}, + ) + ) + + +SESSION = sessionmaker(class_=CustomSession) def _session(request): # pragma: no cover diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index d88d456c46..bfff08ca14 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,6 +7,7 @@ from pyramid import testing from pyramid.request import apply_request_extensions +from lms import db from lms.models import ApplicationSettings, LTIParams from lms.models.lti_role import Role, RoleScope, RoleType from lms.product import Product @@ -18,6 +19,14 @@ TEST_SETTINGS["database_url"] = environ["DATABASE_URL"] +@pytest.fixture(scope="session") +def db_sessionfactory(): + """Overwrite h-testkit default fixture to customizt the session class""" + from sqlalchemy.orm import sessionmaker # noqa: PLC0415 + + return sessionmaker(class_=db.CustomSession) + + @pytest.fixture def lti_v11_params(): return { diff --git a/tests/unit/lms/db/__init___test.py b/tests/unit/lms/db/__init___test.py new file mode 100644 index 0000000000..028ffe854f --- /dev/null +++ b/tests/unit/lms/db/__init___test.py @@ -0,0 +1,15 @@ +from sqlalchemy import select + +from lms.models import Event + + +class TestCustomSession: + def test_compile_query(self, db_session): + old_style_query = db_session.query(Event).filter_by(id=0) + new_style_query = select(Event).where(Event.id == 0) + + old_style_query_compiled = db_session.compile_query(old_style_query) + new_style_query_complied = db_session.compile_query(new_style_query) + + assert old_style_query_compiled == new_style_query_complied + assert old_style_query_compiled.startswith("SELECT event.id")