Skip to content

Commit

Permalink
Add a helper compile_query method to the SQLAlchemy session object
Browse files Browse the repository at this point in the history
  • Loading branch information
marcospri committed Dec 6, 2024
1 parent b864748 commit 845ed50
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
31 changes: 28 additions & 3 deletions lms/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import alembic.config
import sqlalchemy
import zope.sqlalchemy
from sqlalchemy import text
from sqlalchemy import Select, text
from sqlalchemy.dialects import postgresql
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm import Query, Session, declarative_base, sessionmaker
from sqlalchemy.orm.properties import ColumnProperty

from lms.db._columns import varchar_enum
Expand Down Expand Up @@ -40,7 +41,31 @@ def create_engine(database_url):
return sqlalchemy.create_engine(database_url)


SESSION = sessionmaker()
class CustomSession(Session):
"""Our own session object based on the default orm.Session."""

def compile_query(self, query: Query | Select, literal_binds: bool = True) -> str:
"""
Return the SQL representation of `query` for postgres.
:param literal_binds: Whether or not replace the query parameters by their values.
"""
if isinstance(query, Query):
# Support for SQLAlchemy 1.X style queryies, eg: db.query(Model).filter_by()
statement = query.statement
else:
# SQLALchemy 2.X style, eg: select(Model).where()
statement = query

return str(
statement.compile(
self.get_bind(),
compile_kwargs={"literal_binds": literal_binds},
)
)


SESSION = sessionmaker(class_=CustomSession)


def _session(request): # pragma: no cover
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pyramid import testing
from pyramid.request import apply_request_extensions

from lms import db
from lms.models import ApplicationSettings, LTIParams
from lms.models.lti_role import Role, RoleScope, RoleType
from lms.product import Product
Expand All @@ -18,6 +19,14 @@
TEST_SETTINGS["database_url"] = environ["DATABASE_URL"]


@pytest.fixture(scope="session")
def db_sessionfactory():
"""Overwrite h-testkit default fixture to customizt the session class"""
from sqlalchemy.orm import sessionmaker # noqa: PLC0415

return sessionmaker(class_=db.CustomSession)


@pytest.fixture
def lti_v11_params():
return {
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/lms/db/__init___test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sqlalchemy import select

from lms.models import Event


class TestCustomSession:
def test_compile_query(self, db_session):
old_style_query = db_session.query(Event).filter_by(id=0)
new_style_query = select(Event).where(Event.id == 0)

old_style_query_compiled = db_session.compile_query(old_style_query)
new_style_query_complied = db_session.compile_query(new_style_query)

assert old_style_query_compiled == new_style_query_complied
assert old_style_query_compiled.startswith("SELECT event.id")

0 comments on commit 845ed50

Please sign in to comment.