diff --git a/README.md b/README.md index 161d23b..a2b1073 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ pip install fastcrud Or, if using poetry: ```sh - poetry add fastcrud +poetry add fastcrud ```

Usage

@@ -141,14 +141,10 @@ async def lifespan(app: FastAPI): # FastAPI app app = FastAPI(lifespan=lifespan) -# CRUD operations setup -crud = FastCRUD(Item) - # CRUD router setup item_router = crud_router( session=get_session, model=Item, - crud=crud, create_schema=ItemCreateSchema, update_schema=ItemUpdateSchema, path="/items", diff --git a/docs/advanced/crud.md b/docs/advanced/crud.md index 70f0853..4404ac5 100644 --- a/docs/advanced/crud.md +++ b/docs/advanced/crud.md @@ -65,6 +65,74 @@ item_count = await item_crud.count( ) ``` +## Using `get_joined` and `get_multi_joined` for multiple models + +To facilitate complex data relationships, `get_joined` and `get_multi_joined` can be configured to handle joins with multiple models. This is achieved using the `joins_config` parameter, where you can specify a list of `JoinConfig` instances, each representing a distinct join configuration. + +#### Example: Joining User, Tier, and Department Models + +Consider a scenario where you want to retrieve users along with their associated tier and department information. Here's how you can achieve this using `get_multi_joined`. + +Start by creating a list of the multiple models to be joined: + +```python hl_lines="1 3-10 12-19" title="Join Configurations" +from fastcrud import JoinConfig + +joins_config = [ + JoinConfig( + model=Tier, + join_on=User.tier_id == Tier.id, + join_prefix="tier_", + schema_to_select=TierSchema, + join_type="left", + ), + + JoinConfig( + model=Department, + join_on=User.department_id == Department.id, + join_prefix="dept_", + schema_to_select=DepartmentSchema, + join_type="inner", + ) +] + +users = await user_crud.get_multi_joined( + db=session, + schema_to_select=UserSchema, + joins_config=joins_config, + offset=0, + limit=10, + sort_columns='username', + sort_orders='asc' +) +``` + +Then just pass this list to joins_config: + +```python hl_lines="10" title="Passing to get_multi_joined" +from fastcrud import JoinConfig + +joins_config = [ + ... +] + +users = await user_crud.get_multi_joined( + db=session, + schema_to_select=UserSchema, + joins_config=joins_config, + offset=0, + limit=10, + sort_columns='username', + sort_orders='asc' +) +``` + +In this example, users are joined with the `Tier` and `Department` models. The `join_on` parameter specifies the condition for the join, `join_prefix` assigns a prefix to columns from the joined models (to avoid naming conflicts), and `join_type` determines whether it's a left or inner join. + +!!! WARNING + + If both single join parameters and `joins_config` are used simultaneously, an error will be raised. + ## Conclusion The advanced features of FastCRUD, such as `allow_multiple` and support for advanced filters, empower developers to efficiently manage database records with complex conditions. By leveraging these capabilities, you can build more dynamic, robust, and scalable FastAPI applications that effectively interact with your data model. diff --git a/docs/usage/crud.md b/docs/usage/crud.md index d462f71..2e4438c 100644 --- a/docs/usage/crud.md +++ b/docs/usage/crud.md @@ -219,18 +219,19 @@ items = await item_crud.get_multi(db, offset=0, limit=10, sort_columns=['name'], ```python get_joined( - db: AsyncSession, - join_model: type[ModelType], - join_prefix: Optional[str] = None, - join_on: Optional[Union[Join, None]] = None, - schema_to_select: Optional[type[BaseModel]] = None, - join_schema_to_select: Optional[type[BaseModel]] = None, - join_type: str = "left", - **kwargs: Any + db: AsyncSession, + join_model: Optional[type[DeclarativeBase]] = None, + join_prefix: Optional[str] = None, + join_on: Optional[Union[Join, BinaryExpression]] = None, + schema_to_select: Optional[type[BaseModel]] = None, + join_schema_to_select: Optional[type[BaseModel]] = None, + join_type: str = "left", + joins_config: Optional[list[JoinConfig]] = None, + **kwargs: Any, ) -> Optional[dict[str, Any]] ``` -**Purpose**: To fetch a single record while performing a join operation with another model. +**Purpose**: To fetch a single record with one or multiple joins on other models. **Usage Example**: Fetches order details for a specific order by joining with the Customer table, selecting specific columns as defined in OrderSchema and CustomerSchema. ```python @@ -248,9 +249,9 @@ order_details = await order_crud.get_joined( ```python get_multi_joined( db: AsyncSession, - join_model: type[ModelType], + join_model: Optional[type[ModelType]] = None, join_prefix: Optional[str] = None, - join_on: Optional[Join] = None, + join_on: Optional[Any] = None, schema_to_select: Optional[type[BaseModel]] = None, join_schema_to_select: Optional[type[BaseModel]] = None, join_type: str = "left", @@ -259,7 +260,8 @@ get_multi_joined( sort_columns: Optional[Union[str, list[str]]] = None, sort_orders: Optional[Union[str, list[str]]] = None, return_as_model: bool = False, - **kwargs: Any + joins_config: Optional[list[JoinConfig]] = None, + **kwargs: Any, ) -> dict[str, Any] ``` diff --git a/fastcrud/__init__.py b/fastcrud/__init__.py index 1d3ada5..6ba8e15 100644 --- a/fastcrud/__init__.py +++ b/fastcrud/__init__.py @@ -1,5 +1,6 @@ from .crud.fast_crud import FastCRUD from .endpoint.endpoint_creator import EndpointCreator from .endpoint.crud_router import crud_router +from .crud.helper import JoinConfig -__all__ = ["FastCRUD", "EndpointCreator", "crud_router"] +__all__ = ["FastCRUD", "EndpointCreator", "crud_router", "JoinConfig"] diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index 379273d..a9aa444 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -16,6 +16,7 @@ _extract_matching_columns_from_schema, _auto_detect_join_condition, _add_column_with_prefix, + JoinConfig, ) ModelType = TypeVar("ModelType", bound=DeclarativeBase) @@ -536,17 +537,19 @@ async def get_multi( async def get_joined( self, db: AsyncSession, - join_model: type[ModelType], + join_model: Optional[type[DeclarativeBase]] = None, join_prefix: Optional[str] = None, - join_on: Optional[Union[Join, None]] = None, + join_on: Optional[Union[Join, BinaryExpression]] = None, schema_to_select: Optional[type[BaseModel]] = None, join_schema_to_select: Optional[type[BaseModel]] = None, join_type: str = "left", + joins_config: Optional[list[JoinConfig]] = None, **kwargs: Any, ) -> Optional[dict[str, Any]]: """ - Fetches a single record with a join on another model. If 'join_on' is not provided, the method attempts - to automatically detect the join condition using foreign key relationships. Advanced filters supported: + Fetches a single record with one or multiple joins on other models. If 'join_on' is not provided, the method attempts + to automatically detect the join condition using foreign key relationships. For multiple joins, use 'joins_config' to + specify each join configuration. Advanced filters supported: '__gt' (greater than), '__lt' (less than), '__gte' (greater than or equal to), @@ -557,16 +560,21 @@ async def get_joined( db: The SQLAlchemy async session. join_model: The model to join with. join_prefix: Optional prefix to be added to all columns of the joined model. If None, no prefix is added. - join_on: SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is - auto-detected based on foreign keys. + join_on: SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is auto-detected based on foreign keys. schema_to_select: Pydantic schema for selecting specific columns from the primary model. Required if `return_as_model` is True. join_schema_to_select: Pydantic schema for selecting specific columns from the joined model. join_type: Specifies the type of join operation to perform. Can be "left" for a left outer join or "inner" for an inner join. + joins_config: A list of JoinConfig instances, each specifying a model to join with, join condition, optional prefix for column names, schema for selecting specific columns, and the type of join. This parameter enables support for multiple joins. **kwargs: Filters to apply to the primary model query, supporting advanced comparison operators for refined searching. Returns: A dictionary representing the joined record, or None if no record matches the criteria. + Raises: + ValueError: If both single join parameters and 'joins_config' are used simultaneously. + ArgumentError: If any provided model in 'joins_config' is not recognized or invalid. + NoResultFound: If no record matches the criteria with the provided filters. + Examples: Simple example: Joining User and Tier models without explicitly providing join_on ```python @@ -607,6 +615,32 @@ async def get_joined( ) ``` + Example of using 'joins_config' for multiple joins: + ```python + from fastcrud import JoinConfig + + result = await crud_user.get_joined( + db=session, + schema_to_select=UserSchema, + joins_config=[ + JoinConfig( + model=Tier, + join_on=User.tier_id == Tier.id, + join_prefix="tier_", + schema_to_select=TierSchema, + join_type="left", + ), + JoinConfig( + model=Department, + join_on=User.department_id == Department.id, + join_prefix="dept_", + schema_to_select=DepartmentSchema, + join_type="inner", + ) + ] + ) + ``` + Return example: prefix added, no schema_to_select or join_schema_to_select ```python { @@ -629,37 +663,52 @@ async def get_joined( } ``` """ - if join_on is None: - join_on = _auto_detect_join_condition(self.model, join_model) + if joins_config and ( + join_model or join_prefix or join_on or join_schema_to_select + ): + raise ValueError( + "Cannot use both single join parameters and joinsConfig simultaneously." + ) + elif not joins_config and not join_model: + raise ValueError("You need one of join_model or joins_config.") primary_select = _extract_matching_columns_from_schema( model=self.model, schema=schema_to_select ) - join_select = [] - - if join_schema_to_select: - columns = _extract_matching_columns_from_schema( - model=join_model, schema=join_schema_to_select + stmt: Select = select(*primary_select) + + join_definitions = joins_config if joins_config else [] + if join_model: + join_definitions.append( + JoinConfig( + model=join_model, + join_on=join_on, + join_prefix=join_prefix, + schema_to_select=join_schema_to_select, + join_type=join_type, + ) ) - else: - columns = inspect(join_model).c - - for column in columns: - labeled_column = _add_column_with_prefix(column, join_prefix) - if f"{join_prefix}{column.name}" not in [ - col.name for col in primary_select - ]: - join_select.append(labeled_column) - - if join_type == "left": - stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on) - elif join_type == "inner": - stmt = select(*primary_select, *join_select).join(join_model, join_on) - else: - raise ValueError( - f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid." + + for join in join_definitions: + join_select = _extract_matching_columns_from_schema( + join.model, join.schema_to_select ) + if join.join_prefix: + join_select = [ + _add_column_with_prefix(column, join.join_prefix) + for column in join_select + ] + + if join.join_type == "left": + stmt = stmt.outerjoin(join.model, join.join_on).add_columns( + *join_select + ) + elif join.join_type == "inner": + stmt = stmt.join(join.model, join.join_on).add_columns(*join_select) + else: + raise ValueError(f"Unsupported join type: {join.join_type}.") + filters = self._parse_filters(**kwargs) if filters: stmt = stmt.filter(*filters) @@ -675,9 +724,9 @@ async def get_joined( async def get_multi_joined( self, db: AsyncSession, - join_model: type[ModelType], + join_model: Optional[type[ModelType]] = None, join_prefix: Optional[str] = None, - join_on: Optional[Join] = None, + join_on: Optional[Any] = None, schema_to_select: Optional[type[BaseModel]] = None, join_schema_to_select: Optional[type[BaseModel]] = None, join_type: str = "left", @@ -686,6 +735,7 @@ async def get_multi_joined( sort_columns: Optional[Union[str, list[str]]] = None, sort_orders: Optional[Union[str, list[str]]] = None, return_as_model: bool = False, + joins_config: Optional[list[JoinConfig]] = None, **kwargs: Any, ) -> dict[str, Any]: """ @@ -710,6 +760,7 @@ async def get_multi_joined( sort_columns: A single column name or a list of column names on which to apply sorting. sort_orders: A single sort order ('asc' or 'desc') or a list of sort orders corresponding to the columns in sort_columns. If not provided, defaults to 'asc' for each column. return_as_model: If True, converts the fetched data to Pydantic models based on schema_to_select. Defaults to False. + joins_config: List of JoinConfig instances for specifying multiple joins. Each instance defines a model to join with, join condition, optional prefix for column names, schema for selecting specific columns, and join type. **kwargs: Filters to apply to the primary query, including advanced comparison operators for refined searching. Returns: @@ -717,6 +768,7 @@ async def get_multi_joined( Raises: ValueError: If limit or offset is negative, or if schema_to_select is required but not provided or invalid. + Also if both 'joins_config' and any of the single join parameters are provided or none of 'joins_config' and 'join_model' is provided. Examples: Fetching multiple User records joined with Tier records, using left join, returning raw data: @@ -790,40 +842,91 @@ async def get_multi_joined( return_as_model=True ) ``` + + Example using 'joins_config' for multiple joins: + ```python + from fastcrud import JoinConfig + + users = await crud_user.get_multi_joined( + db=session, + schema_to_select=UserSchema, + joins_config=[ + JoinConfig( + model=Tier, + join_on=User.tier_id == Tier.id, + join_prefix="tier_", + schema_to_select=TierSchema, + join_type="left", + ), + JoinConfig( + model=Department, + join_on=User.department_id == Department.id, + join_prefix="dept_", + schema_to_select=DepartmentSchema, + join_type="inner", + ) + ], + offset=0, + limit=10, + sort_columns='username', + sort_orders='asc' + ) + ``` """ + if joins_config and ( + join_model or join_prefix or join_on or join_schema_to_select + ): + raise ValueError( + "Cannot use both single join parameters and joinsConfig simultaneously." + ) + elif not joins_config and not join_model: + raise ValueError("You need one of join_model or joins_config.") + if limit < 0 or offset < 0: raise ValueError("Limit and offset must be non-negative.") - if join_on is None: - join_on = _auto_detect_join_condition(self.model, join_model) + joins: list[JoinConfig] = [] + if join_model is not None: + joins.append( + JoinConfig( + model=join_model, + join_on=join_on + or _auto_detect_join_condition(self.model, join_model), + join_prefix=join_prefix, + schema_to_select=join_schema_to_select, + join_type=join_type, + ) + ) + elif joins_config: + joins.extend(joins_config) primary_select = _extract_matching_columns_from_schema( model=self.model, schema=schema_to_select ) - join_select = [] + stmt: Select = select(*primary_select) - if join_schema_to_select: - columns = _extract_matching_columns_from_schema( - model=join_model, schema=join_schema_to_select - ) - else: - columns = inspect(join_model).c - - for column in columns: - labeled_column = _add_column_with_prefix(column, join_prefix) - if f"{join_prefix}{column.name}" not in [ - col.name for col in primary_select - ]: - join_select.append(labeled_column) - - if join_type == "left": - stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on) - elif join_type == "inner": - stmt = select(*primary_select, *join_select).join(join_model, join_on) - else: - raise ValueError( - f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid." - ) + for join in joins: + if join.schema_to_select: + join_select = _extract_matching_columns_from_schema( + join.model, join.schema_to_select + ) + else: + join_select = inspect(join.model).c + + if join.join_prefix: + join_select = [ + _add_column_with_prefix(column, join.join_prefix) + for column in join_select + ] + + if join.join_type == "left": + stmt = stmt.outerjoin(join.model, join.join_on).add_columns( + *join_select + ) + elif join.join_type == "inner": + stmt = stmt.join(join.model, join.join_on).add_columns(*join_select) + else: + raise ValueError(f"Unsupported join type: {join.join_type}.") filters = self._parse_filters(**kwargs) if filters: @@ -835,7 +938,7 @@ async def get_multi_joined( stmt = stmt.offset(offset).limit(limit) result = await db.execute(stmt) - data = result.mappings().all() + data = [dict(row) for row in result.mappings().all()] if return_as_model and schema_to_select: data = [schema_to_select.model_construct(**row) for row in data] diff --git a/fastcrud/crud/helper.py b/fastcrud/crud/helper.py index 562d121..308df74 100644 --- a/fastcrud/crud/helper.py +++ b/fastcrud/crud/helper.py @@ -1,4 +1,4 @@ -from typing import Any, Union, Optional +from typing import Any, Union, Optional, NamedTuple from sqlalchemy import inspect from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeBase @@ -9,6 +9,14 @@ from pydantic import BaseModel +class JoinConfig(NamedTuple): + model: Any + join_on: Any + join_prefix: Optional[str] = None + schema_to_select: Optional[type[BaseModel]] = None + join_type: str = "left" + + def _extract_matching_columns_from_schema( model: type[DeclarativeBase], schema: Optional[Union[type[BaseModel], list]] ) -> list[Any]: diff --git a/pyproject.toml b/pyproject.toml index 1cb1f91..786f2fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fastcrud" -version = "0.8.1" +version = "0.9.0" description = "FastCRUD is a Python package for FastAPI, offering robust async CRUD operations and flexible endpoint creation utilities." authors = ["Igor Benav "] license = "MIT" diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py index 820b9e1..c2a38c0 100644 --- a/tests/sqlalchemy/conftest.py +++ b/tests/sqlalchemy/conftest.py @@ -1,11 +1,11 @@ +from typing import Optional + import pytest import pytest_asyncio - from sqlalchemy import Column, Integer, String, ForeignKey, Boolean, DateTime from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker, DeclarativeBase, relationship from pydantic import BaseModel, ConfigDict - from fastapi import FastAPI from fastapi.testclient import TestClient @@ -17,12 +17,23 @@ class Base(DeclarativeBase): pass +class CategoryModel(Base): + __tablename__ = "category" + tests = relationship("ModelTest", back_populates="category") + id = Column(Integer, primary_key=True) + name = Column(String, unique=True) + + class ModelTest(Base): __tablename__ = "test" id = Column(Integer, primary_key=True) name = Column(String) tier_id = Column(Integer, ForeignKey("tier.id")) + category_id = Column( + Integer, ForeignKey("category.id"), nullable=True, default=None + ) tier = relationship("TierModel", back_populates="tests") + category = relationship("CategoryModel", back_populates="tests") is_deleted = Column(Boolean, default=False) deleted_at = Column(DateTime, nullable=True, default=None) @@ -38,12 +49,14 @@ class CreateSchemaTest(BaseModel): model_config = ConfigDict(extra="forbid") name: str tier_id: int + category_id: Optional[int] = None class ReadSchemaTest(BaseModel): id: int name: str tier_id: int + category_id: Optional[int] class UpdateSchemaTest(BaseModel): @@ -62,6 +75,11 @@ class TierDeleteSchemaTest(BaseModel): pass +class CategorySchemaTest(BaseModel): + id: Optional[int] = None + name: str + + async_engine = create_async_engine( "sqlite+aiosqlite:///:memory:", echo=True, future=True ) @@ -95,17 +113,17 @@ async def async_session() -> AsyncSession: @pytest.fixture(scope="function") def test_data() -> list[dict]: return [ - {"id": 1, "name": "Charlie", "tier_id": 1}, - {"id": 2, "name": "Alice", "tier_id": 2}, - {"id": 3, "name": "Bob", "tier_id": 1}, - {"id": 4, "name": "David", "tier_id": 2}, - {"id": 5, "name": "Eve", "tier_id": 1}, - {"id": 6, "name": "Frank", "tier_id": 2}, - {"id": 7, "name": "Grace", "tier_id": 1}, - {"id": 8, "name": "Hannah", "tier_id": 2}, - {"id": 9, "name": "Ivan", "tier_id": 1}, - {"id": 10, "name": "Judy", "tier_id": 2}, - {"id": 11, "name": "Alice", "tier_id": 1}, + {"id": 1, "name": "Charlie", "tier_id": 1, "category_id": 1}, + {"id": 2, "name": "Alice", "tier_id": 2, "category_id": 1}, + {"id": 3, "name": "Bob", "tier_id": 1, "category_id": 2}, + {"id": 4, "name": "David", "tier_id": 2, "category_id": 1}, + {"id": 5, "name": "Eve", "tier_id": 1, "category_id": 1}, + {"id": 6, "name": "Frank", "tier_id": 2, "category_id": 2}, + {"id": 7, "name": "Grace", "tier_id": 1, "category_id": 2}, + {"id": 8, "name": "Hannah", "tier_id": 2, "category_id": 1}, + {"id": 9, "name": "Ivan", "tier_id": 1, "category_id": 1}, + {"id": 10, "name": "Judy", "tier_id": 2, "category_id": 2}, + {"id": 11, "name": "Alice", "tier_id": 1, "category_id": 1}, ] @@ -114,6 +132,11 @@ def test_data_tier() -> list[dict]: return [{"id": 1, "name": "Premium"}, {"id": 2, "name": "Basic"}] +@pytest.fixture(scope="function") +def test_data_category() -> list[dict]: + return [{"id": 1, "name": "Tech"}, {"id": 2, "name": "Health"}] + + @pytest.fixture def test_model(): return ModelTest diff --git a/tests/sqlalchemy/crud/test_get_joined.py b/tests/sqlalchemy/crud/test_get_joined.py index b0d7380..273c481 100644 --- a/tests/sqlalchemy/crud/test_get_joined.py +++ b/tests/sqlalchemy/crud/test_get_joined.py @@ -1,11 +1,13 @@ import pytest from sqlalchemy import and_ -from fastcrud.crud.fast_crud import FastCRUD +from fastcrud import FastCRUD, JoinConfig from ...sqlalchemy.conftest import ( ModelTest, TierModel, CreateSchemaTest, TierSchemaTest, + CategoryModel, + CategorySchemaTest, ) @@ -174,3 +176,49 @@ async def test_count_with_advanced_filters(async_session, test_model, test_data) count_lt = await crud.count(async_session, id__lt=10) assert count_lt > 0, "Should count records with ID less than 10" + + +@pytest.mark.asyncio +async def test_get_joined_multiple_models( + async_session, test_data, test_data_tier, test_data_category +): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for category_item in test_data_category: + async_session.add(CategoryModel(**category_item)) + await async_session.commit() + + for user_item in test_data: + user_item_modified = user_item.copy() + user_item_modified["category_id"] = 1 + async_session.add(ModelTest(**user_item_modified)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + result = await crud.get_joined( + db=async_session, + joins_config=[ + JoinConfig( + model=TierModel, + join_prefix="tier_", + schema_to_select=TierSchemaTest, + join_on=ModelTest.tier_id == TierModel.id, + join_type="left", + ), + JoinConfig( + model=CategoryModel, + join_prefix="category_", + schema_to_select=CategorySchemaTest, + join_on=ModelTest.category_id == CategoryModel.id, + join_type="left", + ), + ], + schema_to_select=CreateSchemaTest, + ) + + assert result is not None + assert "name" in result + assert "tier_name" in result + assert "category_name" in result diff --git a/tests/sqlalchemy/crud/test_get_multi_joined.py b/tests/sqlalchemy/crud/test_get_multi_joined.py index 3e15aa5..677f8d9 100644 --- a/tests/sqlalchemy/crud/test_get_multi_joined.py +++ b/tests/sqlalchemy/crud/test_get_multi_joined.py @@ -1,11 +1,13 @@ import pytest -from fastcrud.crud.fast_crud import FastCRUD +from fastcrud import FastCRUD, JoinConfig from ...sqlalchemy.conftest import ( ModelTest, TierModel, CreateSchemaTest, TierSchemaTest, ReadSchemaTest, + CategoryModel, + CategorySchemaTest, ) @@ -67,7 +69,6 @@ async def test_get_multi_joined_sorting(async_session, test_data, test_data_tier @pytest.mark.asyncio async def test_get_multi_joined_filtering(async_session, test_data, test_data_tier): - # Assuming there's a user with a specific name in test_data specific_user_name = "Charlie" for tier_item in test_data_tier: async_session.add(TierModel(**tier_item)) @@ -84,7 +85,7 @@ async def test_get_multi_joined_filtering(async_session, test_data, test_data_ti join_prefix="tier_", schema_to_select=CreateSchemaTest, join_schema_to_select=TierSchemaTest, - name=specific_user_name, # Filter based on ModelTest attribute + name=specific_user_name, offset=0, limit=10, ) @@ -256,3 +257,50 @@ async def test_get_multi_joined_advanced_filtering( assert all( item["id"] > 5 for item in advanced_filter_result["data"] ), "All fetched records should meet the advanced filter condition" + + +@pytest.mark.asyncio +async def test_get_multi_joined_with_additional_join_model( + async_session, test_data, test_data_tier, test_data_category +): + for category_item in test_data_category: + async_session.add(CategoryModel(**category_item)) + await async_session.commit() + + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + result = await crud.get_multi_joined( + db=async_session, + joins_config=[ + JoinConfig( + model=TierModel, + join_prefix="tier_", + schema_to_select=TierSchemaTest, + join_on=ModelTest.tier_id == TierModel.id, + join_type="left", + ), + JoinConfig( + model=CategoryModel, + join_prefix="category_", + schema_to_select=CategorySchemaTest, + join_on=ModelTest.category_id == CategoryModel.id, + join_type="left", + ), + ], + schema_to_select=ReadSchemaTest, + offset=0, + limit=10, + ) + + assert len(result["data"]) == min(10, len(test_data)) + assert result["total_count"] == len(test_data) + assert all( + "tier_name" in item and "category_name" in item for item in result["data"] + ) diff --git a/tests/sqlmodel/conftest.py b/tests/sqlmodel/conftest.py index 89c4538..176d8c7 100644 --- a/tests/sqlmodel/conftest.py +++ b/tests/sqlmodel/conftest.py @@ -15,12 +15,21 @@ from fastcrud.endpoint.crud_router import crud_router +class CategoryModel(SQLModel, table=True): + __tablename__ = "category" + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + tests: list["ModelTest"] = Relationship(back_populates="category") + + class ModelTest(SQLModel, table=True): __tablename__ = "test" - id: int = Field(primary_key=True) + id: Optional[int] = Field(default=None, primary_key=True) name: str - tier_id: int = Field(foreign_key="tier.id") + tier_id: int = Field(default=None, foreign_key="tier.id") + category_id: Optional[int] = Field(default=None, foreign_key="category.id") tier: "TierModel" = Relationship(back_populates="tests") + category: "CategoryModel" = Relationship(back_populates="tests") is_deleted: bool = Field(default=False) deleted_at: Optional[datetime] = Field(default=None) @@ -112,6 +121,11 @@ def test_data_tier() -> list[dict]: return [{"id": 1, "name": "Premium"}, {"id": 2, "name": "Basic"}] +@pytest.fixture(scope="function") +def test_data_category() -> list[dict]: + return [{"id": 1, "name": "Tech"}, {"id": 2, "name": "Health"}] + + @pytest.fixture def test_model(): return ModelTest diff --git a/tests/sqlmodel/crud/test_get_joined.py b/tests/sqlmodel/crud/test_get_joined.py index b0d7380..273c481 100644 --- a/tests/sqlmodel/crud/test_get_joined.py +++ b/tests/sqlmodel/crud/test_get_joined.py @@ -1,11 +1,13 @@ import pytest from sqlalchemy import and_ -from fastcrud.crud.fast_crud import FastCRUD +from fastcrud import FastCRUD, JoinConfig from ...sqlalchemy.conftest import ( ModelTest, TierModel, CreateSchemaTest, TierSchemaTest, + CategoryModel, + CategorySchemaTest, ) @@ -174,3 +176,49 @@ async def test_count_with_advanced_filters(async_session, test_model, test_data) count_lt = await crud.count(async_session, id__lt=10) assert count_lt > 0, "Should count records with ID less than 10" + + +@pytest.mark.asyncio +async def test_get_joined_multiple_models( + async_session, test_data, test_data_tier, test_data_category +): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for category_item in test_data_category: + async_session.add(CategoryModel(**category_item)) + await async_session.commit() + + for user_item in test_data: + user_item_modified = user_item.copy() + user_item_modified["category_id"] = 1 + async_session.add(ModelTest(**user_item_modified)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + result = await crud.get_joined( + db=async_session, + joins_config=[ + JoinConfig( + model=TierModel, + join_prefix="tier_", + schema_to_select=TierSchemaTest, + join_on=ModelTest.tier_id == TierModel.id, + join_type="left", + ), + JoinConfig( + model=CategoryModel, + join_prefix="category_", + schema_to_select=CategorySchemaTest, + join_on=ModelTest.category_id == CategoryModel.id, + join_type="left", + ), + ], + schema_to_select=CreateSchemaTest, + ) + + assert result is not None + assert "name" in result + assert "tier_name" in result + assert "category_name" in result diff --git a/tests/sqlmodel/crud/test_get_multi_joined.py b/tests/sqlmodel/crud/test_get_multi_joined.py index 3e15aa5..0c1d812 100644 --- a/tests/sqlmodel/crud/test_get_multi_joined.py +++ b/tests/sqlmodel/crud/test_get_multi_joined.py @@ -1,11 +1,13 @@ import pytest -from fastcrud.crud.fast_crud import FastCRUD +from fastcrud import FastCRUD, JoinConfig from ...sqlalchemy.conftest import ( ModelTest, TierModel, CreateSchemaTest, TierSchemaTest, ReadSchemaTest, + CategoryModel, + CategorySchemaTest, ) @@ -256,3 +258,50 @@ async def test_get_multi_joined_advanced_filtering( assert all( item["id"] > 5 for item in advanced_filter_result["data"] ), "All fetched records should meet the advanced filter condition" + + +@pytest.mark.asyncio +async def test_get_multi_joined_with_additional_join_model( + async_session, test_data, test_data_tier, test_data_category +): + for category_item in test_data_category: + async_session.add(CategoryModel(**category_item)) + await async_session.commit() + + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + result = await crud.get_multi_joined( + db=async_session, + joins_config=[ + JoinConfig( + model=TierModel, + join_prefix="tier_", + schema_to_select=TierSchemaTest, + join_on=ModelTest.tier_id == TierModel.id, + join_type="left", + ), + JoinConfig( + model=CategoryModel, + join_prefix="category_", + schema_to_select=CategorySchemaTest, + join_on=ModelTest.category_id == CategoryModel.id, + join_type="left", + ), + ], + schema_to_select=ReadSchemaTest, + offset=0, + limit=10, + ) + + assert len(result["data"]) == min(10, len(test_data)) + assert result["total_count"] == len(test_data) + assert all( + "tier_name" in item and "category_name" in item for item in result["data"] + )