Skip to content

Commit

Permalink
Merge branch 'main' into docs-starter
Browse files Browse the repository at this point in the history
  • Loading branch information
Alc-Alc authored Jan 29, 2025
2 parents ca34237 + f60b713 commit 6039781
Show file tree
Hide file tree
Showing 10 changed files with 492 additions and 256 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs-preview.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
uses: actions/checkout@v4

- name: Download artifact
uses: dawidd6/action-download-artifact@v7
uses: dawidd6/action-download-artifact@v8
with:
workflow_conclusion: success
run_id: ${{ github.event.workflow_run.id }}
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.9.2"
rev: "v0.9.3"
hooks:
# Run the linter.
- id: ruff
Expand All @@ -32,7 +32,7 @@ repos:
- id: ruff-format
types_or: [ python, pyi ]
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.0
hooks:
- id: codespell
exclude: "uv.lock|examples/us_state_lookup.json"
Expand Down
13 changes: 10 additions & 3 deletions advanced_alchemy/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,15 @@ def __post_init__(self) -> None:

event.listen(Session, "before_flush", touch_updated_timestamp)

def __hash__(self) -> int:
return hash((self.__class__.__qualname__, self.bind_key))
def __hash__(self) -> int: # pragma: no cover
return hash(
(
self.__class__.__qualname__,
self.connection_string,
self.engine_config.__class__.__qualname__,
self.bind_key,
)
)

def __eq__(self, other: object) -> bool:
return self.__hash__() == other.__hash__()
Expand Down Expand Up @@ -250,7 +257,7 @@ def get_engine(self) -> EngineT:
del engine_config["json_serializer"]
return self.create_engine_callable(self.connection_string, **engine_config)

def create_session_maker(self) -> Callable[[], SessionT]:
def create_session_maker(self) -> Callable[[], SessionT]: # pragma: no cover
"""Get a session maker. If none exists yet, create one.
Returns:
Expand Down
6 changes: 1 addition & 5 deletions advanced_alchemy/extensions/starlette/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ def init_app(self, app: Starlette) -> None:
)

app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
app.add_event_handler("startup", self.on_startup) # pyright: ignore[reportUnknownMemberType]
app.add_event_handler("shutdown", self.on_shutdown) # pyright: ignore[reportUnknownMemberType]

async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
Expand Down Expand Up @@ -288,8 +286,6 @@ def init_app(self, app: Starlette) -> None:
)
_ = self.create_session_maker()
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
app.add_event_handler("startup", self.on_startup) # pyright: ignore[reportUnknownMemberType]
app.add_event_handler("shutdown", self.on_shutdown) # pyright: ignore[reportUnknownMemberType]

async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
Expand Down Expand Up @@ -381,5 +377,5 @@ async def on_shutdown(self) -> None: # pragma: no cover
if self.app is not None:
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, self.engine_key)
delattr(self.app.state, self.session_key)
delattr(self.app.state, self.session_maker_key)
delattr(self.app.state, self.session_key)
73 changes: 56 additions & 17 deletions advanced_alchemy/extensions/starlette/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@

import contextlib
from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Generator, Sequence, Union, overload
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Generator,
Sequence,
Union,
overload,
)

from starlette.applications import Starlette
from starlette.requests import Request # noqa: TC002
Expand Down Expand Up @@ -68,8 +77,33 @@ def init_app(self, app: Starlette) -> None:

app.state.advanced_alchemy = self

original_lifespan = app.router.lifespan_context

@asynccontextmanager
async def wrapped_lifespan(app: Starlette) -> AsyncGenerator[Any, None]: # pragma: no cover
async with self.lifespan(app), original_lifespan(app) as state:
yield state

app.router.lifespan_context = wrapped_lifespan

@asynccontextmanager
async def lifespan(self, app: Starlette) -> AsyncGenerator[Any, None]: # pragma: no cover
"""Context manager for lifespan events.
Args:
app: The starlette application.
Yields:
None
"""
await self.on_startup()
try:
yield
finally:
await self.on_shutdown()

@property
def app(self) -> Starlette:
def app(self) -> Starlette: # pragma: no cover
"""Returns the Starlette application instance.
Raises:
Expand All @@ -79,7 +113,7 @@ def app(self) -> Starlette:
Returns:
starlette.applications.Starlette: The Starlette application instance.
"""
if self._app is None:
if self._app is None: # pragma: no cover
msg = "Application not initialized. Did you forget to call init_app?"
raise ImproperConfigurationError(msg)

Expand Down Expand Up @@ -119,36 +153,38 @@ def get_config(self, key: str | None = None) -> Union[SQLAlchemyAsyncConfig, SQL
if key == "default" and len(self.config) == 1:
key = self.config[0].bind_key or "default"
config = self._mapped_configs.get(key)
if config is None:
if config is None: # pragma: no cover
msg = f"Config with key {key} not found"
raise ImproperConfigurationError(msg)
return config

def get_async_config(self, key: str | None = None) -> SQLAlchemyAsyncConfig:
"""Get the async config for the given key."""
config = self.get_config(key)
if not isinstance(config, SQLAlchemyAsyncConfig):
if not isinstance(config, SQLAlchemyAsyncConfig): # pragma: no cover
msg = "Expected an async config, but got a sync config"
raise ImproperConfigurationError(msg)
return config

def get_sync_config(self, key: str | None = None) -> SQLAlchemySyncConfig:
"""Get the sync config for the given key."""
config = self.get_config(key)
if not isinstance(config, SQLAlchemySyncConfig):
if not isinstance(config, SQLAlchemySyncConfig): # pragma: no cover
msg = "Expected a sync config, but got an async config"
raise ImproperConfigurationError(msg)
return config

@asynccontextmanager
async def with_async_session(self, key: str | None = None) -> AsyncGenerator[AsyncSession, None]:
async def with_async_session(
self, key: str | None = None
) -> AsyncGenerator[AsyncSession, None]: # pragma: no cover
"""Context manager for getting an async session."""
config = self.get_async_config(key)
async with config.get_session() as session:
yield session

@contextmanager
def with_sync_session(self, key: str | None = None) -> Generator[Session, None]:
def with_sync_session(self, key: str | None = None) -> Generator[Session, None]: # pragma: no cover
"""Context manager for getting a sync session."""
config = self.get_sync_config(key)
with config.get_session() as session:
Expand All @@ -164,7 +200,8 @@ def _get_session_from_request(request: Request, config: SQLAlchemySyncConfig) ->

@staticmethod
def _get_session_from_request(
request: Request, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig
request: Request,
config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, # pragma: no cover
) -> Session | AsyncSession: # pragma: no cover
"""Get the session for the given key."""
session = getattr(request.state, config.session_key, None)
Expand All @@ -173,22 +210,24 @@ def _get_session_from_request(
setattr(request.state, config.session_key, session)
return session

def get_session(self, request: Request, key: str | None = None) -> Session | AsyncSession:
def get_session(self, request: Request, key: str | None = None) -> Session | AsyncSession: # pragma: no cover
"""Get the session for the given key."""
config = self.get_config(key)
return self._get_session_from_request(request, config)

def get_async_session(self, request: Request, key: str | None = None) -> AsyncSession:
def get_async_session(self, request: Request, key: str | None = None) -> AsyncSession: # pragma: no cover
"""Get the async session for the given key."""
config = self.get_async_config(key)
return self._get_session_from_request(request, config)

def get_sync_session(self, request: Request, key: str | None = None) -> Session:
def get_sync_session(self, request: Request, key: str | None = None) -> Session: # pragma: no cover
"""Get the sync session for the given key."""
config = self.get_sync_config(key)
return self._get_session_from_request(request, config)

def provide_session(self, key: str | None = None) -> Callable[[Request], Session | AsyncSession]:
def provide_session(
self, key: str | None = None
) -> Callable[[Request], Session | AsyncSession]: # pragma: no cover
"""Get the session for the given key."""
config = self.get_config(key)

Expand All @@ -197,7 +236,7 @@ def _get_session(request: Request) -> Session | AsyncSession:

return _get_session

def provide_async_session(self, key: str | None = None) -> Callable[[Request], AsyncSession]:
def provide_async_session(self, key: str | None = None) -> Callable[[Request], AsyncSession]: # pragma: no cover
"""Get the async session for the given key."""
config = self.get_async_config(key)

Expand All @@ -206,7 +245,7 @@ def _get_session(request: Request) -> AsyncSession:

return _get_session

def provide_sync_session(self, key: str | None = None) -> Callable[[Request], Session]:
def provide_sync_session(self, key: str | None = None) -> Callable[[Request], Session]: # pragma: no cover
"""Get the sync session for the given key."""
config = self.get_sync_config(key)

Expand All @@ -220,12 +259,12 @@ def get_engine(self, key: str | None = None) -> Engine | AsyncEngine: # pragma:
config = self.get_config(key)
return config.get_engine()

def get_async_engine(self, key: str | None = None) -> AsyncEngine:
def get_async_engine(self, key: str | None = None) -> AsyncEngine: # pragma: no cover
"""Get the async engine for the given key."""
config = self.get_async_config(key)
return config.get_engine()

def get_sync_engine(self, key: str | None = None) -> Engine:
def get_sync_engine(self, key: str | None = None) -> Engine: # pragma: no cover
"""Get the sync engine for the given key."""
config = self.get_sync_config(key)
return config.get_engine()
Expand Down
26 changes: 26 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@
0.x Changelog
=============

.. changelog:: 0.30.3
:date: 2025-01-26

.. change:: add `wrap_exceptions` option to exception handler.
:type: feature
:pr: 363
:issue: 356

When `wrap_exceptions` is `False`, the original SQLAlchemy error message will be raised instead of the wrapped Repository error

Fixes #356 (Bug: `wrap_sqlalchemy_exception` masks db errors)

.. change:: simplify configuration hash
:type: feature
:pr: 366

The hashing method on the SQLAlchemy configs can be simplified. This should be enough to define a unique configuration.

.. change:: use `lifespan` context manager in Starlette and FastAPI
:type: bugfix
:pr: 368
:issue: 367

Modifies the Starlette and FastAPI integrations to use the `lifespan` context manager instead of the `startup`\`shutdown` hooks. If the application already has a lifespan set, it is wrapped so that both execute.


.. changelog:: 0.30.2
:date: 2025-01-21

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ maintainers = [
name = "advanced_alchemy"
readme = "README.md"
requires-python = ">=3.8"
version = "0.30.2"
version = "0.30.3"

[project.urls]
Changelog = "https://docs.advanced-alchemy.litestar.dev/latest/changelog"
Expand Down Expand Up @@ -168,7 +168,7 @@ test = [
allow_dirty = true
commit = true
commit_args = "--no-verify"
current_version = "0.30.2"
current_version = "0.30.3"
ignore_missing_files = false
ignore_missing_version = false
message = "chore(release): bump to v{new_version}"
Expand Down
41 changes: 40 additions & 1 deletion tests/unit/test_extensions/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from typing import TYPE_CHECKING, Callable, Generator, Literal, Union, cast
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Generator, Literal, Union, cast
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -356,3 +357,41 @@ def handler(
)
assert mock.call_args_list[0].kwargs["session"] is not mock.call_args_list[1].kwargs["session"]
assert mock.call_args_list[0].kwargs["engine"] is not mock.call_args_list[1].kwargs["engine"]


async def test_lifespan_startup_shutdown_called_fastapi(mocker: MockerFixture, app: FastAPI, config: AnyConfig) -> None:
mock_startup = mocker.patch.object(AdvancedAlchemy, "on_startup")
mock_shutdown = mocker.patch.object(AdvancedAlchemy, "on_shutdown")
_alchemy = AdvancedAlchemy(config, app=app)

with TestClient(app=app) as _client: # TestClient context manager triggers lifespan events
pass # App starts up and shuts down within this context

mock_startup.assert_called_once()
mock_shutdown.assert_called_once()


async def test_lifespan_with_custom_lifespan_fastapi(mocker: MockerFixture, app: FastAPI, config: AnyConfig) -> None:
mock_aa_startup = mocker.patch.object(AdvancedAlchemy, "on_startup")
mock_aa_shutdown = mocker.patch.object(AdvancedAlchemy, "on_shutdown")
mock_custom_startup = mocker.MagicMock()
mock_custom_shutdown = mocker.MagicMock()

@asynccontextmanager
async def custom_lifespan(app_in: FastAPI) -> AsyncGenerator[None, None]:
mock_custom_startup()
yield
mock_custom_shutdown()

app.router.lifespan_context = custom_lifespan # type: ignore[assignment] # Set a custom lifespan on the app
_alchemy = AdvancedAlchemy(config, app=app)

with TestClient(app=app) as _client: # TestClient context manager triggers lifespan events
pass # App starts up and shuts down within this context

mock_aa_startup.assert_called_once()
mock_aa_shutdown.assert_called_once()
mock_custom_startup.assert_called_once()
mock_custom_shutdown.assert_called_once()

# Optionally assert the order of calls if needed, e.g., using mocker.call_order
Loading

0 comments on commit 6039781

Please sign in to comment.