Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: implement library installed check #217

Merged
merged 5 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ async def update(
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
id_attribute: str | InstrumentedAttribute[Any] | None = None,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
) -> ModelT: ...

async def update_many(
Expand Down Expand Up @@ -1109,7 +1111,6 @@ async def count(
"""
with wrap_sqlalchemy_exception():
statement = self.statement if statement is None else statement
fragment = self.get_id_attribute_value(self.model_type)
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
Expand All @@ -1119,7 +1120,7 @@ async def count(
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
statement = statement.add_criteria(
lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None),
lambda s: s.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None),
)
results = await self._execute(statement, uniquify=loader_options_have_wildcard)
return cast(int, results.scalar_one())
Expand All @@ -1133,6 +1134,8 @@ async def update(
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
id_attribute: str | InstrumentedAttribute[Any] | None = None,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
) -> ModelT:
"""Update instance with the attribute values present on `data`.

Expand All @@ -1152,6 +1155,8 @@ async def update(
:class:`SQLAlchemyAsyncRepository.auto_commit <SQLAlchemyAsyncRepository>`
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
load: Set relationships to be loaded
execution_options: Set default execution options

Returns:
The updated instance.
Expand All @@ -1165,7 +1170,7 @@ async def update(
id_attribute=id_attribute,
)
# this will raise for not found, and will put the item in the session
await self.get(item_id, id_attribute=id_attribute)
await self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options)
# this will merge the inbound data to the instance we just put in the session
instance = await self._attach_to_session(data, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
Expand Down
11 changes: 8 additions & 3 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def update(
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
id_attribute: str | InstrumentedAttribute[Any] | None = None,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
) -> ModelT: ...

def update_many(
Expand Down Expand Up @@ -1110,7 +1112,6 @@ def count(
"""
with wrap_sqlalchemy_exception():
statement = self.statement if statement is None else statement
fragment = self.get_id_attribute_value(self.model_type)
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
Expand All @@ -1120,7 +1121,7 @@ def count(
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
statement = statement.add_criteria(
lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None),
lambda s: s.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None),
)
results = self._execute(statement, uniquify=loader_options_have_wildcard)
return cast(int, results.scalar_one())
Expand All @@ -1134,6 +1135,8 @@ def update(
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
id_attribute: str | InstrumentedAttribute[Any] | None = None,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
) -> ModelT:
"""Update instance with the attribute values present on `data`.

Expand All @@ -1153,6 +1156,8 @@ def update(
:class:`SQLAlchemyAsyncRepository.auto_commit <SQLAlchemyAsyncRepository>`
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
load: Set relationships to be loaded
execution_options: Set default execution options

Returns:
The updated instance.
Expand All @@ -1166,7 +1171,7 @@ def update(
id_attribute=id_attribute,
)
# this will raise for not found, and will put the item in the session
self.get(item_id, id_attribute=id_attribute)
self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options)
# this will merge the inbound data to the instance we just put in the session
instance = self._attach_to_session(data, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
Expand Down
2 changes: 2 additions & 0 deletions advanced_alchemy/repository/memory/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,8 @@ async def update(
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
id_attribute: str | InstrumentedAttribute[Any] | None = None,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
) -> ModelT:
self._find_or_raise_not_found(self.__collection__().key(data))
return self.__collection__().update(data)
Expand Down
2 changes: 2 additions & 0 deletions advanced_alchemy/repository/memory/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ def update(
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
id_attribute: str | InstrumentedAttribute[Any] | None = None,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
) -> ModelT:
self._find_or_raise_not_found(self.__collection__().key(data))
return self.__collection__().update(data)
Expand Down
2 changes: 2 additions & 0 deletions advanced_alchemy/service/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ async def update(
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
)

async def update_many(
Expand Down
2 changes: 2 additions & 0 deletions advanced_alchemy/service/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ def update(
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
)

def update_many(
Expand Down
8 changes: 5 additions & 3 deletions advanced_alchemy/service/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from advanced_alchemy.filters import LimitOffset
from advanced_alchemy.repository.typing import ModelOrRowMappingT
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service.typing import ( # type: ignore[attr-defined]
from advanced_alchemy.service.typing import (
MSGSPEC_INSTALLED,
PYDANTIC_INSTALLED,
BaseModel,
ModelDTOT,
Struct,
Expand Down Expand Up @@ -185,7 +187,7 @@ def to_schema(
offset=limit_offset.offset,
total=total,
)
if issubclass(schema_type, Struct):
if MSGSPEC_INSTALLED and issubclass(schema_type, Struct):
if not isinstance(data, Sequence):
return cast(
"ModelDTOT",
Expand Down Expand Up @@ -221,7 +223,7 @@ def to_schema(
total=total,
)

if issubclass(schema_type, BaseModel):
if PYDANTIC_INSTALLED and issubclass(schema_type, BaseModel):
if not isinstance(data, Sequence):
return cast("ModelDTOT", TypeAdapter(schema_type).validate_python(data, from_attributes=True)) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType,reportAttributeAccessIssue,reportCallIssue]
limit_offset = find_filter(LimitOffset, filters=filters)
Expand Down
24 changes: 24 additions & 0 deletions advanced_alchemy/service/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from typing import (
Any,
Final,
Generic,
Protocol,
TypeVar,
Expand All @@ -21,6 +22,9 @@

try:
from msgspec import Struct, convert # pyright: ignore[reportAssignmentType,reportUnusedImport]

MSGSPEC_INSTALLED: Final[bool] = True

except ImportError: # pragma: nocover

class Struct(Protocol): # type: ignore[no-redef] # pragma: nocover
Expand All @@ -30,10 +34,14 @@ def convert(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef] # noqa:
"""Placeholder implementation"""
return {}

MSGSPEC_INSTALLED: Final[bool] = False # type: ignore # pyright: ignore[reportConstantRedefinition,reportGeneralTypeIssues] # noqa: PGH003


try:
from pydantic import BaseModel # pyright: ignore[reportAssignmentType]
from pydantic.type_adapter import TypeAdapter # pyright: ignore[reportUnusedImport, reportAssignmentType]

PYDANTIC_INSTALLED: Final[bool] = True
except ImportError: # pragma: nocover

class BaseModel(Protocol): # type: ignore[no-redef] # pragma: nocover
Expand All @@ -51,10 +59,26 @@ def validate_python(self, data: Any, *args: Any, **kwargs: Any) -> T: # pragma:
"""Stub"""
return cast("T", data)

PYDANTIC_INSTALLED: Final[bool] = False # type: ignore # pyright: ignore[reportConstantRedefinition,reportGeneralTypeIssues] # noqa: PGH003

ModelDictT: TypeAlias = "dict[str, Any] | ModelT"
ModelDictListT: TypeAlias = "list[ModelT | dict[str, Any]] | list[dict[str, Any]]"
FilterTypeT = TypeVar("FilterTypeT", bound="StatementFilter")
ModelDTOT = TypeVar("ModelDTOT", bound="Struct | BaseModel")
PydanticModelDTOT = TypeVar("PydanticModelDTOT", bound="BaseModel")
StructModelDTOT = TypeVar("StructModelDTOT", bound="Struct")

__all__ = (
"ModelDictT",
"ModelDictListT",
"FilterTypeT",
"ModelDTOT",
"PydanticModelDTOT",
"StructModelDTOT",
"PYDANTIC_INSTALLED",
"MSGSPEC_INSTALLED",
"BaseModel",
"TypeAdapter",
"Struct",
"convert",
)
4 changes: 2 additions & 2 deletions tests/fixtures/bigint/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class SlugBookSyncService(SQLAlchemySyncRepositoryService[BigIntSlugBook]):
match_fields = ["title"]

def __init__(self, **repo_kwargs: Any) -> None:
self.repository = SlugBookSyncRepository(**repo_kwargs)
self.repository: SlugBookSyncRepository = self.repository_type(**repo_kwargs) # pyright: ignore

def to_model(self, data: BigIntSlugBook | dict[str, Any], operation: str | None = None) -> BigIntSlugBook:
if isinstance(data, dict) and "slug" not in data and operation == "create":
Expand All @@ -254,7 +254,7 @@ class SlugBookAsyncMockService(SQLAlchemyAsyncRepositoryService[BigIntSlugBook])
match_fields = ["title"]

def __init__(self, **repo_kwargs: Any) -> None:
self.repository = SlugBookAsyncMockRepository(**repo_kwargs)
self.repository: SlugBookAsyncMockRepository = self.repository_type(**repo_kwargs) # pyright: ignore

async def to_model(self, data: BigIntSlugBook | dict[str, Any], operation: str | None = None) -> BigIntSlugBook:
if isinstance(data, dict) and "slug" not in data and operation == "create":
Expand Down
Loading