diff --git a/advanced_alchemy/repository/_async.py b/advanced_alchemy/repository/_async.py index c03323d6..d1e7726a 100644 --- a/advanced_alchemy/repository/_async.py +++ b/advanced_alchemy/repository/_async.py @@ -218,6 +218,8 @@ async def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: ... async def update_many( @@ -1109,7 +1111,6 @@ async def count( """ with wrap_sqlalchemy_exception(): statement = self.statement if statement is None else statement - fragment = self.get_id_attribute_value(self.model_type) loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, @@ -1119,7 +1120,7 @@ async def count( statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) statement = statement.add_criteria( - lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None), + lambda s: s.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None), ) results = await self._execute(statement, uniquify=loader_options_have_wildcard) return cast(int, results.scalar_one()) @@ -1133,6 +1134,8 @@ async def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: """Update instance with the attribute values present on `data`. @@ -1152,6 +1155,8 @@ async def update( :class:`SQLAlchemyAsyncRepository.auto_commit ` id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `id`, but can reference any surrogate or candidate key for the table. + load: Set relationships to be loaded + execution_options: Set default execution options Returns: The updated instance. @@ -1165,7 +1170,7 @@ async def update( id_attribute=id_attribute, ) # this will raise for not found, and will put the item in the session - await self.get(item_id, id_attribute=id_attribute) + await self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options) # this will merge the inbound data to the instance we just put in the session instance = await self._attach_to_session(data, strategy="merge") await self._flush_or_commit(auto_commit=auto_commit) diff --git a/advanced_alchemy/repository/_sync.py b/advanced_alchemy/repository/_sync.py index 0ac59978..ca3becca 100644 --- a/advanced_alchemy/repository/_sync.py +++ b/advanced_alchemy/repository/_sync.py @@ -219,6 +219,8 @@ def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: ... def update_many( @@ -1110,7 +1112,6 @@ def count( """ with wrap_sqlalchemy_exception(): statement = self.statement if statement is None else statement - fragment = self.get_id_attribute_value(self.model_type) loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, @@ -1120,7 +1121,7 @@ def count( statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) statement = statement.add_criteria( - lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None), + lambda s: s.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None), ) results = self._execute(statement, uniquify=loader_options_have_wildcard) return cast(int, results.scalar_one()) @@ -1134,6 +1135,8 @@ def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: """Update instance with the attribute values present on `data`. @@ -1153,6 +1156,8 @@ def update( :class:`SQLAlchemyAsyncRepository.auto_commit ` id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `id`, but can reference any surrogate or candidate key for the table. + load: Set relationships to be loaded + execution_options: Set default execution options Returns: The updated instance. @@ -1166,7 +1171,7 @@ def update( id_attribute=id_attribute, ) # this will raise for not found, and will put the item in the session - self.get(item_id, id_attribute=id_attribute) + self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options) # this will merge the inbound data to the instance we just put in the session instance = self._attach_to_session(data, strategy="merge") self._flush_or_commit(auto_commit=auto_commit) diff --git a/advanced_alchemy/repository/memory/_async.py b/advanced_alchemy/repository/memory/_async.py index 51d83386..3fc7fb67 100644 --- a/advanced_alchemy/repository/memory/_async.py +++ b/advanced_alchemy/repository/memory/_async.py @@ -535,6 +535,8 @@ async def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: self._find_or_raise_not_found(self.__collection__().key(data)) return self.__collection__().update(data) diff --git a/advanced_alchemy/repository/memory/_sync.py b/advanced_alchemy/repository/memory/_sync.py index 044b0031..d85eec5b 100644 --- a/advanced_alchemy/repository/memory/_sync.py +++ b/advanced_alchemy/repository/memory/_sync.py @@ -536,6 +536,8 @@ def update( auto_expunge: bool | None = None, auto_refresh: bool | None = None, id_attribute: str | InstrumentedAttribute[Any] | None = None, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, ) -> ModelT: self._find_or_raise_not_found(self.__collection__().key(data)) return self.__collection__().update(data) diff --git a/advanced_alchemy/service/_async.py b/advanced_alchemy/service/_async.py index 017820db..9f393d28 100644 --- a/advanced_alchemy/service/_async.py +++ b/advanced_alchemy/service/_async.py @@ -503,6 +503,8 @@ async def update( auto_expunge=auto_expunge, auto_refresh=auto_refresh, id_attribute=id_attribute, + load=load, + execution_options=execution_options, ) async def update_many( diff --git a/advanced_alchemy/service/_sync.py b/advanced_alchemy/service/_sync.py index e78cfd51..a4610f5b 100644 --- a/advanced_alchemy/service/_sync.py +++ b/advanced_alchemy/service/_sync.py @@ -504,6 +504,8 @@ def update( auto_expunge=auto_expunge, auto_refresh=auto_refresh, id_attribute=id_attribute, + load=load, + execution_options=execution_options, ) def update_many( diff --git a/advanced_alchemy/service/_util.py b/advanced_alchemy/service/_util.py index d609186f..7b7fd4ca 100644 --- a/advanced_alchemy/service/_util.py +++ b/advanced_alchemy/service/_util.py @@ -22,7 +22,9 @@ from advanced_alchemy.filters import LimitOffset from advanced_alchemy.repository.typing import ModelOrRowMappingT from advanced_alchemy.service.pagination import OffsetPagination -from advanced_alchemy.service.typing import ( # type: ignore[attr-defined] +from advanced_alchemy.service.typing import ( + MSGSPEC_INSTALLED, + PYDANTIC_INSTALLED, BaseModel, ModelDTOT, Struct, @@ -185,7 +187,7 @@ def to_schema( offset=limit_offset.offset, total=total, ) - if issubclass(schema_type, Struct): + if MSGSPEC_INSTALLED and issubclass(schema_type, Struct): if not isinstance(data, Sequence): return cast( "ModelDTOT", @@ -221,7 +223,7 @@ def to_schema( total=total, ) - if issubclass(schema_type, BaseModel): + if PYDANTIC_INSTALLED and issubclass(schema_type, BaseModel): if not isinstance(data, Sequence): return cast("ModelDTOT", TypeAdapter(schema_type).validate_python(data, from_attributes=True)) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType,reportAttributeAccessIssue,reportCallIssue] limit_offset = find_filter(LimitOffset, filters=filters) diff --git a/advanced_alchemy/service/typing.py b/advanced_alchemy/service/typing.py index 78f4f9e9..a811219a 100644 --- a/advanced_alchemy/service/typing.py +++ b/advanced_alchemy/service/typing.py @@ -8,6 +8,7 @@ from typing import ( Any, + Final, Generic, Protocol, TypeVar, @@ -21,6 +22,9 @@ try: from msgspec import Struct, convert # pyright: ignore[reportAssignmentType,reportUnusedImport] + + MSGSPEC_INSTALLED: Final[bool] = True + except ImportError: # pragma: nocover class Struct(Protocol): # type: ignore[no-redef] # pragma: nocover @@ -30,10 +34,14 @@ def convert(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef] # noqa: """Placeholder implementation""" return {} + MSGSPEC_INSTALLED: Final[bool] = False # type: ignore # pyright: ignore[reportConstantRedefinition,reportGeneralTypeIssues] # noqa: PGH003 + try: from pydantic import BaseModel # pyright: ignore[reportAssignmentType] from pydantic.type_adapter import TypeAdapter # pyright: ignore[reportUnusedImport, reportAssignmentType] + + PYDANTIC_INSTALLED: Final[bool] = True except ImportError: # pragma: nocover class BaseModel(Protocol): # type: ignore[no-redef] # pragma: nocover @@ -51,6 +59,7 @@ def validate_python(self, data: Any, *args: Any, **kwargs: Any) -> T: # pragma: """Stub""" return cast("T", data) + PYDANTIC_INSTALLED: Final[bool] = False # type: ignore # pyright: ignore[reportConstantRedefinition,reportGeneralTypeIssues] # noqa: PGH003 ModelDictT: TypeAlias = "dict[str, Any] | ModelT" ModelDictListT: TypeAlias = "list[ModelT | dict[str, Any]] | list[dict[str, Any]]" @@ -58,3 +67,18 @@ def validate_python(self, data: Any, *args: Any, **kwargs: Any) -> T: # pragma: ModelDTOT = TypeVar("ModelDTOT", bound="Struct | BaseModel") PydanticModelDTOT = TypeVar("PydanticModelDTOT", bound="BaseModel") StructModelDTOT = TypeVar("StructModelDTOT", bound="Struct") + +__all__ = ( + "ModelDictT", + "ModelDictListT", + "FilterTypeT", + "ModelDTOT", + "PydanticModelDTOT", + "StructModelDTOT", + "PYDANTIC_INSTALLED", + "MSGSPEC_INSTALLED", + "BaseModel", + "TypeAdapter", + "Struct", + "convert", +) diff --git a/tests/fixtures/bigint/services.py b/tests/fixtures/bigint/services.py index 7ae71556..37260787 100644 --- a/tests/fixtures/bigint/services.py +++ b/tests/fixtures/bigint/services.py @@ -237,7 +237,7 @@ class SlugBookSyncService(SQLAlchemySyncRepositoryService[BigIntSlugBook]): match_fields = ["title"] def __init__(self, **repo_kwargs: Any) -> None: - self.repository = SlugBookSyncRepository(**repo_kwargs) + self.repository: SlugBookSyncRepository = self.repository_type(**repo_kwargs) # pyright: ignore def to_model(self, data: BigIntSlugBook | dict[str, Any], operation: str | None = None) -> BigIntSlugBook: if isinstance(data, dict) and "slug" not in data and operation == "create": @@ -254,7 +254,7 @@ class SlugBookAsyncMockService(SQLAlchemyAsyncRepositoryService[BigIntSlugBook]) match_fields = ["title"] def __init__(self, **repo_kwargs: Any) -> None: - self.repository = SlugBookAsyncMockRepository(**repo_kwargs) + self.repository: SlugBookAsyncMockRepository = self.repository_type(**repo_kwargs) # pyright: ignore async def to_model(self, data: BigIntSlugBook | dict[str, Any], operation: str | None = None) -> BigIntSlugBook: if isinstance(data, dict) and "slug" not in data and operation == "create":