From 55b41a537bb029bd82ab903f2f56c567c8979e6d Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 13 Jun 2024 15:43:41 +0000 Subject: [PATCH 1/5] fix: implement library installed check --- advanced_alchemy/service/_util.py | 8 +++++--- advanced_alchemy/service/typing.py | 26 ++++++++++++++++++++++++++ tests/fixtures/bigint/services.py | 4 ++-- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/advanced_alchemy/service/_util.py b/advanced_alchemy/service/_util.py index d609186f..7b7fd4ca 100644 --- a/advanced_alchemy/service/_util.py +++ b/advanced_alchemy/service/_util.py @@ -22,7 +22,9 @@ from advanced_alchemy.filters import LimitOffset from advanced_alchemy.repository.typing import ModelOrRowMappingT from advanced_alchemy.service.pagination import OffsetPagination -from advanced_alchemy.service.typing import ( # type: ignore[attr-defined] +from advanced_alchemy.service.typing import ( + MSGSPEC_INSTALLED, + PYDANTIC_INSTALLED, BaseModel, ModelDTOT, Struct, @@ -185,7 +187,7 @@ def to_schema( offset=limit_offset.offset, total=total, ) - if issubclass(schema_type, Struct): + if MSGSPEC_INSTALLED and issubclass(schema_type, Struct): if not isinstance(data, Sequence): return cast( "ModelDTOT", @@ -221,7 +223,7 @@ def to_schema( total=total, ) - if issubclass(schema_type, BaseModel): + if PYDANTIC_INSTALLED and issubclass(schema_type, BaseModel): if not isinstance(data, Sequence): return cast("ModelDTOT", TypeAdapter(schema_type).validate_python(data, from_attributes=True)) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType,reportAttributeAccessIssue,reportCallIssue] limit_offset = find_filter(LimitOffset, filters=filters) diff --git a/advanced_alchemy/service/typing.py b/advanced_alchemy/service/typing.py index 78f4f9e9..cd68325e 100644 --- a/advanced_alchemy/service/typing.py +++ b/advanced_alchemy/service/typing.py @@ -8,10 +8,12 @@ from typing import ( Any, + Final, Generic, Protocol, TypeVar, cast, + runtime_checkable, ) from typing_extensions import TypeAlias @@ -21,8 +23,12 @@ try: from msgspec import Struct, convert # pyright: ignore[reportAssignmentType,reportUnusedImport] + + MSGSPEC_INSTALLED: Final[bool] = True + except ImportError: # pragma: nocover + @runtime_checkable class Struct(Protocol): # type: ignore[no-redef] # pragma: nocover """Placeholder Implementation""" @@ -30,12 +36,17 @@ def convert(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef] # noqa: """Placeholder implementation""" return {} + MSGSPEC_INSTALLED: Final[bool] = False # type: ignore # pyright: ignore[reportConstantRedefinition,reportGeneralTypeIssues] # noqa: PGH003 + try: from pydantic import BaseModel # pyright: ignore[reportAssignmentType] from pydantic.type_adapter import TypeAdapter # pyright: ignore[reportUnusedImport, reportAssignmentType] + + PYDANTIC_INSTALLED: Final[bool] = True except ImportError: # pragma: nocover + @runtime_checkable class BaseModel(Protocol): # type: ignore[no-redef] # pragma: nocover """Placeholder Implementation""" @@ -58,3 +69,18 @@ def validate_python(self, data: Any, *args: Any, **kwargs: Any) -> T: # pragma: ModelDTOT = TypeVar("ModelDTOT", bound="Struct | BaseModel") PydanticModelDTOT = TypeVar("PydanticModelDTOT", bound="BaseModel") StructModelDTOT = TypeVar("StructModelDTOT", bound="Struct") + +__all__ = ( + "ModelDictT", + "ModelDictListT", + "FilterTypeT", + "ModelDTOT", + "PydanticModelDTOT", + "StructModelDTOT", + "PYDANTIC_INSTALLED", + "MSGSPEC_INSTALLED", + "BaseModel", + "TypeAdapter", + "Struct", + "convert", +) diff --git a/tests/fixtures/bigint/services.py b/tests/fixtures/bigint/services.py index 7ae71556..37260787 100644 --- a/tests/fixtures/bigint/services.py +++ b/tests/fixtures/bigint/services.py @@ -237,7 +237,7 @@ class SlugBookSyncService(SQLAlchemySyncRepositoryService[BigIntSlugBook]): match_fields = ["title"] def __init__(self, **repo_kwargs: Any) -> None: - self.repository = SlugBookSyncRepository(**repo_kwargs) + self.repository: SlugBookSyncRepository = self.repository_type(**repo_kwargs) # pyright: ignore def to_model(self, data: BigIntSlugBook | dict[str, Any], operation: str | None = None) -> BigIntSlugBook: if isinstance(data, dict) and "slug" not in data and operation == "create": @@ -254,7 +254,7 @@ class SlugBookAsyncMockService(SQLAlchemyAsyncRepositoryService[BigIntSlugBook]) match_fields = ["title"] def __init__(self, **repo_kwargs: Any) -> None: - self.repository = SlugBookAsyncMockRepository(**repo_kwargs) + self.repository: SlugBookAsyncMockRepository = self.repository_type(**repo_kwargs) # pyright: ignore async def to_model(self, data: BigIntSlugBook | dict[str, Any], operation: str | None = None) -> BigIntSlugBook: if isinstance(data, dict) and "slug" not in data and operation == "create": From a0e84e25e8a49718336531c2f709bc498ad9a648 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 13 Jun 2024 15:50:03 +0000 Subject: [PATCH 2/5] fix: populate load and execution_options for update --- advanced_alchemy/repository/_async.py | 8 +++++++- advanced_alchemy/repository/_sync.py | 8 +++++++- advanced_alchemy/repository/memory/_async.py | 2 ++ advanced_alchemy/repository/memory/_sync.py | 2 ++ advanced_alchemy/service/_async.py | 2 ++ advanced_alchemy/service/_sync.py | 2 ++ 6 files changed, 22 insertions(+), 2 deletions(-) diff --git a/advanced_alchemy/repository/_async.py b/advanced_alchemy/repository/_async.py index c03323d6..a9deb3ee 100644 --- a/advanced_alchemy/repository/_async.py +++ b/advanced_alchemy/repository/_async.py @@ -218,6 +218,8 @@ async def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: ... async def update_many( @@ -1133,6 +1135,8 @@ async def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: """Update instance with the attribute values present on `data`. @@ -1152,6 +1156,8 @@ async def update( :class:`SQLAlchemyAsyncRepository.auto_commit ` id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `id`, but can reference any surrogate or candidate key for the table. + load: Set relationships to be loaded + execution_options: Set default execution options Returns: The updated instance. @@ -1165,7 +1171,7 @@ async def update( id_attribute=id_attribute, ) # this will raise for not found, and will put the item in the session - await self.get(item_id, id_attribute=id_attribute) + await self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options) # this will merge the inbound data to the instance we just put in the session instance = await self._attach_to_session(data, strategy="merge") await self._flush_or_commit(auto_commit=auto_commit) diff --git a/advanced_alchemy/repository/_sync.py b/advanced_alchemy/repository/_sync.py index 0ac59978..4cfef38f 100644 --- a/advanced_alchemy/repository/_sync.py +++ b/advanced_alchemy/repository/_sync.py @@ -219,6 +219,8 @@ def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: ... def update_many( @@ -1134,6 +1136,8 @@ def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: """Update instance with the attribute values present on `data`. @@ -1153,6 +1157,8 @@ def update( :class:`SQLAlchemyAsyncRepository.auto_commit ` id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `id`, but can reference any surrogate or candidate key for the table. + load: Set relationships to be loaded + execution_options: Set default execution options Returns: The updated instance. @@ -1166,7 +1172,7 @@ def update( id_attribute=id_attribute, ) # this will raise for not found, and will put the item in the session - self.get(item_id, id_attribute=id_attribute) + self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options) # this will merge the inbound data to the instance we just put in the session instance = self._attach_to_session(data, strategy="merge") self._flush_or_commit(auto_commit=auto_commit) diff --git a/advanced_alchemy/repository/memory/_async.py b/advanced_alchemy/repository/memory/_async.py index 51d83386..3fc7fb67 100644 --- a/advanced_alchemy/repository/memory/_async.py +++ b/advanced_alchemy/repository/memory/_async.py @@ -535,6 +535,8 @@ async def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: self._find_or_raise_not_found(self.__collection__().key(data)) return self.__collection__().update(data) diff --git a/advanced_alchemy/repository/memory/_sync.py b/advanced_alchemy/repository/memory/_sync.py index 044b0031..d85eec5b 100644 --- a/advanced_alchemy/repository/memory/_sync.py +++ b/advanced_alchemy/repository/memory/_sync.py @@ -536,6 +536,8 @@ def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: self._find_or_raise_not_found(self.__collection__().key(data)) return self.__collection__().update(data) diff --git a/advanced_alchemy/service/_async.py b/advanced_alchemy/service/_async.py index 017820db..9f393d28 100644 --- a/advanced_alchemy/service/_async.py +++ b/advanced_alchemy/service/_async.py @@ -503,6 +503,8 @@ async def update( auto_expunge=auto_expunge, auto_refresh=auto_refresh, id_attribute=id_attribute, + load=load, + execution_options=execution_options, ) async def update_many( diff --git a/advanced_alchemy/service/_sync.py b/advanced_alchemy/service/_sync.py index e78cfd51..a4610f5b 100644 --- a/advanced_alchemy/service/_sync.py +++ b/advanced_alchemy/service/_sync.py @@ -504,6 +504,8 @@ def update( auto_expunge=auto_expunge, auto_refresh=auto_refresh, id_attribute=id_attribute, + load=load, + execution_options=execution_options, ) def update_many( From 9a5ab7375d43b2bc740bfa7ac5d3ed214a4380e1 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 13 Jun 2024 15:57:47 +0000 Subject: [PATCH 3/5] fix: remove runtime checkable --- advanced_alchemy/service/typing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/advanced_alchemy/service/typing.py b/advanced_alchemy/service/typing.py index cd68325e..a811219a 100644 --- a/advanced_alchemy/service/typing.py +++ b/advanced_alchemy/service/typing.py @@ -13,7 +13,6 @@ Protocol, TypeVar, cast, - runtime_checkable, ) from typing_extensions import TypeAlias @@ -28,7 +27,6 @@ except ImportError: # pragma: nocover - @runtime_checkable class Struct(Protocol): # type: ignore[no-redef] # pragma: nocover """Placeholder Implementation""" @@ -46,7 +44,6 @@ def convert(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef] # noqa: PYDANTIC_INSTALLED: Final[bool] = True except ImportError: # pragma: nocover - @runtime_checkable class BaseModel(Protocol): # type: ignore[no-redef] # pragma: nocover """Placeholder Implementation""" @@ -62,6 +59,7 @@ def validate_python(self, data: Any, *args: Any, **kwargs: Any) -> T: # pragma: """Stub""" return cast("T", data) + PYDANTIC_INSTALLED: Final[bool] = False # type: ignore # pyright: ignore[reportConstantRedefinition,reportGeneralTypeIssues] # noqa: PGH003 ModelDictT: TypeAlias = "dict[str, Any] | ModelT" ModelDictListT: TypeAlias = "list[ModelT | dict[str, Any]] | list[dict[str, Any]]" From 94049cf11db0bd6cf0d60c6f2b5d3504e4b22979 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 13 Jun 2024 16:05:56 +0000 Subject: [PATCH 4/5] feat: `count` doesn't not expect to have an argument here. --- advanced_alchemy/repository/_async.py | 3 +-- advanced_alchemy/repository/_sync.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/advanced_alchemy/repository/_async.py b/advanced_alchemy/repository/_async.py index a9deb3ee..8d8de5af 100644 --- a/advanced_alchemy/repository/_async.py +++ b/advanced_alchemy/repository/_async.py @@ -1111,7 +1111,6 @@ async def count( """ with wrap_sqlalchemy_exception(): statement = self.statement if statement is None else statement - fragment = self.get_id_attribute_value(self.model_type) loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, @@ -1121,7 +1120,7 @@ async def count( statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) statement = statement.add_criteria( - lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None), + lambda s: s.with_only_columns(sql_func.count(), maintain_column_froms=True).order_by(None), ) results = await self._execute(statement, uniquify=loader_options_have_wildcard) return cast(int, results.scalar_one()) diff --git a/advanced_alchemy/repository/_sync.py b/advanced_alchemy/repository/_sync.py index 4cfef38f..670fe34f 100644 --- a/advanced_alchemy/repository/_sync.py +++ b/advanced_alchemy/repository/_sync.py @@ -1112,7 +1112,6 @@ def count( """ with wrap_sqlalchemy_exception(): statement = self.statement if statement is None else statement - fragment = self.get_id_attribute_value(self.model_type) loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, @@ -1122,7 +1121,7 @@ def count( statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) statement = statement.add_criteria( - lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None), + lambda s: s.with_only_columns(sql_func.count(), maintain_column_froms=True).order_by(None), ) results = self._execute(statement, uniquify=loader_options_have_wildcard) return cast(int, results.scalar_one()) From 23e98bd589120b12636d27cd31d54c1179e0b2d2 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 13 Jun 2024 16:20:59 +0000 Subject: [PATCH 5/5] fix: align count syntax --- advanced_alchemy/repository/_async.py | 2 +- advanced_alchemy/repository/_sync.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/advanced_alchemy/repository/_async.py b/advanced_alchemy/repository/_async.py index 8d8de5af..d1e7726a 100644 --- a/advanced_alchemy/repository/_async.py +++ b/advanced_alchemy/repository/_async.py @@ -1120,7 +1120,7 @@ async def count( statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) statement = statement.add_criteria( - lambda s: s.with_only_columns(sql_func.count(), maintain_column_froms=True).order_by(None), + lambda s: s.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None), ) results = await self._execute(statement, uniquify=loader_options_have_wildcard) return cast(int, results.scalar_one()) diff --git a/advanced_alchemy/repository/_sync.py b/advanced_alchemy/repository/_sync.py index 670fe34f..ca3becca 100644 --- a/advanced_alchemy/repository/_sync.py +++ b/advanced_alchemy/repository/_sync.py @@ -1121,7 +1121,7 @@ def count( statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) statement = statement.add_criteria( - lambda s: s.with_only_columns(sql_func.count(), maintain_column_froms=True).order_by(None), + lambda s: s.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None), ) results = self._execute(statement, uniquify=loader_options_have_wildcard) return cast(int, results.scalar_one())