Skip to content

Commit

Permalink
Add compile_query as a helper function in lms.db._util
Browse files Browse the repository at this point in the history
Change the approach on the previous commit, instead of a method on the
session, just use a helper method importable from lms.db.
  • Loading branch information
marcospri committed Dec 6, 2024
1 parent 845ed50 commit dab8f3e
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 50 deletions.
29 changes: 3 additions & 26 deletions lms/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from lms.db._columns import varchar_enum
from lms.db._locks import CouldNotAcquireLock, LockType, try_advisory_transaction_lock
from lms.db._text_search import full_text_match
from lms.db._util import compile_query

__all__ = ("Base", "create_engine", "varchar_enum")
__all__ = ("Base", "compile_query", "create_engine", "varchar_enum")


Base = declarative_base(
Expand All @@ -41,31 +42,7 @@ def create_engine(database_url):
return sqlalchemy.create_engine(database_url)


class CustomSession(Session):
"""Our own session object based on the default orm.Session."""

def compile_query(self, query: Query | Select, literal_binds: bool = True) -> str:
"""
Return the SQL representation of `query` for postgres.
:param literal_binds: Whether or not replace the query parameters by their values.
"""
if isinstance(query, Query):
# Support for SQLAlchemy 1.X style queryies, eg: db.query(Model).filter_by()
statement = query.statement
else:
# SQLALchemy 2.X style, eg: select(Model).where()
statement = query

return str(
statement.compile(
self.get_bind(),
compile_kwargs={"literal_binds": literal_binds},
)
)


SESSION = sessionmaker(class_=CustomSession)
SESSION = sessionmaker()


def _session(request): # pragma: no cover
Expand Down
24 changes: 24 additions & 0 deletions lms/db/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from sqlalchemy import Select
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Query


def compile_query(query: Query | Select, literal_binds: bool = True) -> str:
"""
Return the SQL representation of `query` for postgres.
:param literal_binds: Whether or not replace the query parameters by their values.
"""
if isinstance(query, Query):
# Support for SQLAlchemy 1.X style queryies, eg: db.query(Model).filter_by()
statement = query.statement
else:
# SQLALchemy 2.X style, eg: select(Model).where()
statement = query

return str(
statement.compile(
dialect=postgresql.dialect(),
compile_kwargs={"literal_binds": literal_binds},
)
)
9 changes: 0 additions & 9 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pyramid import testing
from pyramid.request import apply_request_extensions

from lms import db
from lms.models import ApplicationSettings, LTIParams
from lms.models.lti_role import Role, RoleScope, RoleType
from lms.product import Product
Expand All @@ -19,14 +18,6 @@
TEST_SETTINGS["database_url"] = environ["DATABASE_URL"]


@pytest.fixture(scope="session")
def db_sessionfactory():
"""Overwrite h-testkit default fixture to customizt the session class"""
from sqlalchemy.orm import sessionmaker # noqa: PLC0415

return sessionmaker(class_=db.CustomSession)


@pytest.fixture
def lti_v11_params():
return {
Expand Down
15 changes: 0 additions & 15 deletions tests/unit/lms/db/__init___test.py

This file was deleted.

20 changes: 20 additions & 0 deletions tests/unit/lms/db/_util_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from sqlalchemy import select

from lms.db import compile_query
from lms.models import Event


def test_compile_query(db_session):
old_style_query = db_session.query(Event).filter_by(id=0)
new_style_query = select(Event).where(Event.id == 0)

old_style_query_compiled = compile_query(old_style_query)
new_style_query_complied = compile_query(new_style_query)

assert old_style_query_compiled == new_style_query_complied
assert (
old_style_query_compiled
== """SELECT event.id, event.timestamp, event.type_id, event.application_instance_id, event.course_id, event.assignment_id, event.grouping_id
FROM event
WHERE event.id = 0""" # noqa: W291
)

0 comments on commit dab8f3e

Please sign in to comment.