Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async #16

Open
theobouwman opened this issue Nov 10, 2023 · 1 comment
Open

Async #16

theobouwman opened this issue Nov 10, 2023 · 1 comment

Comments

@theobouwman
Copy link

Is it possible to use this with async SQLAlchemy?

@waza-ari
Copy link

It is possible, I adapted it within my project to work with async SQLAlchemy. A few changes are required though. Just a summary:

core/config.py

from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
    # DB Settings
    database_url: str = Field(description="The Database URL")

    model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")


settings = Settings()

core/db.py

from contextlib import asynccontextmanager
from contextvars import ContextVar
from typing import Any, AsyncGenerator, Optional
from uuid import uuid4

from sqlalchemy.ext.asyncio import (AsyncAttrs, AsyncSession,
                                    async_scoped_session, create_async_engine)
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from starlette.middleware.base import (BaseHTTPMiddleware,
                                       RequestResponseEndpoint)
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp

from .config import settings


class Base(AsyncAttrs, DeclarativeBase):
    pass


ENGINE_ARGUMENTS = {
    "connect_args": {"connect_timeout": 10, "options": "-c timezone=UTC"},
    "pool_pre_ping": True,
    "pool_size": 60,
}
SESSION_ARGUMENTS = {
    "autocommit": False,
    "autoflush": True,
    "expire_on_commit": False,
    "class_": AsyncSession,
}


class Database:
    """Setup and contain our database connection.

    This is used to be able to setup the database in an uniform way while allowing easy testing and session management.

    Session management is done using ``scoped_session`` with a special scopefunc, because we cannot use
    threading.local(). Contextvar does the right thing with respect to asyncio and behaves similar to threading.local().
    We only store a random string in the contextvar and let scoped session do the heavy lifting. This allows us to
    easily start a new session or get the existing one using the scoped_session mechanics.
    """

    def __init__(self) -> None:
        self.request_context: ContextVar[str] = ContextVar(
            "request_context", default=""
        )
        self.engine = create_async_engine(settings.database_url, **ENGINE_ARGUMENTS)

        self.session_factory = sessionmaker(bind=self.engine, **SESSION_ARGUMENTS)

        self.scoped_session = async_scoped_session(
            self.session_factory, self._scopefunc
        )

    def _scopefunc(self) -> Optional[str]:
        scope_str = self.request_context.get()
        print(f"Scopefunc: {scope_str}")
        return scope_str

    @property
    def session(self) -> AsyncSession:
        return self.scoped_session()

    @asynccontextmanager
    async def database_scope(self, **kwargs: Any) -> AsyncGenerator["Database", None]:
        """Create a new database session (scope).

        This creates a new database session to handle all the database connection from a single scope (request or workflow).
        This method should typically only been called in request middleware or at the start of workflows.

        Args:
            ``**kwargs``: Optional session kw args for this session
        """
        token = self.request_context.set(str(uuid4()))
        self.scoped_session(**kwargs)
        yield self
        await self.scoped_session.remove()
        self.request_context.reset(token)


class DBSessionMiddleware(BaseHTTPMiddleware):
    def __init__(self, app: ASGIApp, database: Database):
        super().__init__(app)
        self.database = database

    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:
        async with self.database.database_scope():
            response = await call_next(request)
        return response

Adding the middleware in main.py doesn't really change. Keep in mind that there may be issues with lazy loading in asyncio, so for relations you may want to define eager loading if you use them in Pydantic nested models. One example from my budget model, which has a relationship to a category:

from __future__ import annotations

import uuid
from typing import TYPE_CHECKING

from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship

from ..core import Base
from .mixins.base_model import BaseModel

if TYPE_CHECKING:
    from .assignment import Assignment
    from .category import Category


class Budget(Base, BaseModel):

    __tablename__ = "budget"
    __mapper_args__ = {"eager_defaults": True}

    assignments: Mapped[list["Assignment"]] = relationship(
        back_populates="budget", lazy="selectin"
    )
    amount: Mapped[float] = mapped_column(nullable=False)
    name: Mapped[str] = mapped_column(nullable=False)
    category_id: Mapped[uuid.UUID | None] = mapped_column(
        ForeignKey("category.id"), nullable=True
    )
    category: Mapped["Category"] = relationship(
        "Category", back_populates="budgets", lazy="selectin"
    )
    description: Mapped[str | None] = mapped_column(nullable=True)
    amount: Mapped[float] = mapped_column(nullable=False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants