From 6ad2ddca78ad34ca3d3fe41578dd18b44eaa43bc Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Thu, 29 Feb 2024 18:37:45 -0500 Subject: [PATCH] feat: better type hints everywhere (#141) * feat: ASyncMethodDescriptor and ASyncBoundMethod classes * feat: more type checking helper classes * fix(mypy): fix type errs --- a_sync/__init__.py | 7 +- a_sync/_bound.py | 192 +++++++++++--------- a_sync/_descriptor.py | 24 +++ a_sync/_helpers.py | 9 +- a_sync/_meta.py | 21 +-- a_sync/_typing.py | 15 +- a_sync/abstract.py | 17 +- a_sync/exceptions.py | 6 +- a_sync/modified.py | 16 +- a_sync/primitives/locks/semaphore.py | 2 +- a_sync/property.py | 262 ++++++++++++++++++++++----- tests/fixtures.py | 9 +- tests/test_base.py | 26 ++- tests/test_meta.py | 11 +- 14 files changed, 435 insertions(+), 182 deletions(-) create mode 100644 a_sync/_descriptor.py diff --git a/a_sync/__init__.py b/a_sync/__init__.py index b5731e89..f2b2e377 100644 --- a/a_sync/__init__.py +++ b/a_sync/__init__.py @@ -6,6 +6,7 @@ from a_sync.iter import ASyncIterable, ASyncIterator from a_sync.modifiers.semaphores import apply_semaphore from a_sync.primitives import * +from a_sync.property import ASyncCachedPropertyDescriptor, ASyncPropertyDescriptor, cached_property, property from a_sync.singleton import ASyncGenericSingleton from a_sync.task import TaskMapping as map from a_sync.task import TaskMapping, create_task @@ -19,7 +20,6 @@ # alias for backward-compatability, will be removed eventually, probably in 0.1.0 ASyncBase = ASyncGenericBase - __all__ = [ "all", "any", @@ -34,4 +34,9 @@ "ASyncIterator", "ASyncGenericSingleton", "TaskMapping", + # property + "cached_property", + "property", + "ASyncPropertyDescriptor", + "ASyncCachedPropertyDescriptor", ] diff --git a/a_sync/_bound.py b/a_sync/_bound.py index 68ba5952..0377ace6 100644 --- a/a_sync/_bound.py +++ b/a_sync/_bound.py @@ -1,111 +1,123 @@ # mypy: disable-error-code=valid-type # mypy: disable-error-code=misc import functools +import logging from inspect import isawaitable -from a_sync import _helpers +from a_sync import _helpers, _kwargs +from a_sync._descriptor import ASyncDescriptor from a_sync._typing import * -from a_sync.decorator import a_sync as unbound_a_sync -from a_sync.modified import ASyncFunction -from a_sync.property import (AsyncCachedPropertyDescriptor, - AsyncPropertyDescriptor) +from a_sync.modified import ASyncFunction, ASyncFunctionAsyncDefault, ASyncFunctionSyncDefault -if TYPE_CHECKING: - from a_sync.abstract import ASyncABC +logger = logging.getLogger(__name__) -def _clean_default_from_modifiers( - coro_fn: AsyncBoundMethod[P, T], # type: ignore [misc] - modifiers: ModifierKwargs, -): - # NOTE: We set the default here manually because the default set by the user will be used later in the code to determine whether to await. - force_await = None - if not asyncio.iscoroutinefunction(coro_fn) and not isinstance(coro_fn, ASyncFunction): - if 'default' not in modifiers or modifiers['default'] != 'async': - if 'default' in modifiers and modifiers['default'] == 'sync': - force_await = True - modifiers['default'] = 'async' - return modifiers, force_await +class ASyncMethodDescriptor(ASyncDescriptor[ASyncFunction[P, T]], Generic[O, P, T]): + _fget: ASyncFunction[Concatenate[O, P], T] + def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethod[P, T]": + if instance is None: + return self + try: + return instance.__dict__[self.field_name] + except KeyError: + from a_sync.abstract import ASyncABC + if self.default == "sync": + bound = ASyncBoundMethodSyncDefault(instance, self._fget, **self.modifiers) + elif self.default == "async": + bound = ASyncBoundMethodAsyncDefault(instance, self._fget, **self.modifiers) + elif isinstance(instance, ASyncABC) and instance.__a_sync_instance_should_await__: + bound = ASyncBoundMethodSyncDefault(instance, self._fget, **self.modifiers) + elif isinstance(instance, ASyncABC) and instance.__a_sync_instance_should_await__: + bound = ASyncBoundMethodAsyncDefault(instance, self._fget, **self.modifiers) + else: + bound = ASyncBoundMethod(instance, self._fget, **self.modifiers) + instance.__dict__[self.field_name] = bound + logger.debug("new bound method: %s", bound) + return bound + def __set__(self, instance, value): + raise RuntimeError(f"cannot set {self.field_name}, {self} is what you get. sorry.") + def __delete__(self, instance): + raise RuntimeError(f"cannot delete {self.field_name}, you're stuck with {self} forever. sorry.") - -def _wrap_bound_method( - coro_fn: AsyncBoundMethod[P, T], - **modifiers: Unpack[ModifierKwargs] -) -> AsyncBoundMethod[P, T]: - from a_sync.abstract import ASyncABC - - # First we unwrap the coro_fn and rewrap it so overriding flag kwargs are handled automagically. - if isinstance(coro_fn, ASyncFunction): - coro_fn = coro_fn.__wrapped__ - - modifiers, _force_await = _clean_default_from_modifiers(coro_fn, modifiers) - - wrapped_coro_fn: AsyncBoundMethod[P, T] = ASyncFunction(coro_fn, **modifiers) # type: ignore [arg-type, valid-type, misc] +class ASyncMethodDescriptorSyncDefault(ASyncMethodDescriptor[ASyncInstance, P, T]): + def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethodSyncDefault[P, T]": + if instance is None: + return self + try: + return instance.__dict__[self.field_name] + except KeyError: + bound = ASyncBoundMethodSyncDefault(instance, self._fget, **self.modifiers) + instance.__dict__[self.field_name] = bound + logger.debug("new bound method: %s", bound) + return bound - @functools.wraps(coro_fn) - def bound_a_sync_wrap(self: ASyncABC, *args: P.args, **kwargs: P.kwargs) -> T: # type: ignore [name-defined] - if not isinstance(self, ASyncABC): - raise RuntimeError(f"{self} must be an instance of a class that inherits from ASyncABC.") +class ASyncMethodDescriptorAsyncDefault(ASyncMethodDescriptor[ASyncInstance, P, T]): + def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethodAsyncDefault[P, T]": + if instance is None: + return self + try: + return instance.__dict__[self.field_name] + except KeyError: + bound = ASyncBoundMethodAsyncDefault(instance, self._fget, **self.modifiers) + instance.__dict__[self.field_name] = bound + logger.debug("new bound method: %s", bound) + return bound + +class ASyncBoundMethod(ASyncFunction[P, T]): + def __init__( + self, + instance: ASyncInstance, + unbound: AnyFn[Concatenate[ASyncInstance, P], T], + **modifiers: Unpack[ModifierKwargs], + ) -> None: + self.instance = instance + # First we unwrap the coro_fn and rewrap it so overriding flag kwargs are handled automagically. + if isinstance(unbound, ASyncFunction): + modifiers.update(unbound.modifiers) + unbound = unbound.__wrapped__ + if asyncio.iscoroutinefunction(unbound): + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + return await unbound(self.instance, *args, **kwargs) + else: + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + return unbound(self.instance, *args, **kwargs) + functools.update_wrapper(wrapped, unbound) + super().__init__(wrapped, **modifiers) + def __repr__(self) -> str: + return f"<{self.__class__.__name__} for function {self.__module__}.{self.instance.__class__.__name__}.{self.__name__} bound to {self.instance}>" + def __call__(self, *args, **kwargs): + logger.debug("calling %s", self) # This could either be a coroutine or a return value from an awaited coroutine, # depending on if an overriding flag kwarg was passed into the function call. - retval = coro = wrapped_coro_fn(self, *args, **kwargs) + retval = coro = super().__call__(*args, **kwargs) if not isawaitable(retval): # The coroutine was already awaited due to the use of an overriding flag kwarg. # We can return the value. return retval # type: ignore [return-value] # The awaitable was not awaited, so now we need to check the flag as defined on 'self' and await if appropriate. - return _helpers._await(coro) if self.__a_sync_should_await__(kwargs, force=_force_await) else coro # type: ignore [call-overload, return-value] - return bound_a_sync_wrap - -class _PropertyGetter(Awaitable[T]): - def __init__(self, coro: Awaitable[T], property: Union[AsyncPropertyDescriptor[T], AsyncCachedPropertyDescriptor[T]]): - self._coro = coro - self._property = property - def __repr__(self) -> str: - return f"<_PropertyGetter for {self._property}._get at {hex(id(self))}>" - def __await__(self) -> Generator[Any, None, T]: - return self._coro.__await__() - -@overload -def _wrap_property( - async_property: AsyncPropertyDescriptor[T], - **modifiers: Unpack[ModifierKwargs] -) -> AsyncPropertyDescriptor[T]:... - -@overload -def _wrap_property( - async_property: AsyncCachedPropertyDescriptor[T], - **modifiers: Unpack[ModifierKwargs] -) -> AsyncCachedPropertyDescriptor:... + return _helpers._await(coro) if self.should_await(kwargs) else coro # type: ignore [call-overload, return-value] + @functools.cached_property + def __bound_to_a_sync_instance__(self) -> bool: + from a_sync.abstract import ASyncABC + return isinstance(self.instance, ASyncABC) + def should_await(self, kwargs: dict) -> bool: + if flag := _kwargs.get_flag_name(kwargs): + return _kwargs.is_sync(flag, kwargs, pop_flag=True) # type: ignore [arg-type] + elif self.default: + return self.default == "sync" + elif self.__bound_to_a_sync_instance__: + return self.instance.__a_sync_should_await__(kwargs) + return asyncio.iscoroutinefunction(self.__wrapped__) -def _wrap_property( - async_property: Union[AsyncPropertyDescriptor[T], AsyncCachedPropertyDescriptor[T]], - **modifiers: Unpack[ModifierKwargs] -) -> Tuple[Property[T], HiddenMethod[T]]: - if not isinstance(async_property, (AsyncPropertyDescriptor, AsyncCachedPropertyDescriptor)): - raise TypeError(f"{async_property} must be one of: AsyncPropertyDescriptor, AsyncCachedPropertyDescriptor") - from a_sync.abstract import ASyncABC +class ASyncBoundMethodSyncDefault(ASyncBoundMethod[P, T]): + def __get__(self, instance: ASyncInstance, owner) -> ASyncFunctionSyncDefault[P, T]: + return super().__get__(instance, owner) + def __call__(self, *args, **kwargs) -> T: + return super().__call__(*args, **kwargs) - async_property.hidden_method_name = f"__{async_property.field_name}__" - - modifiers, _force_await = _clean_default_from_modifiers(async_property, modifiers) - - @unbound_a_sync(**modifiers) - async def _get(instance: ASyncABC) -> T: - return await async_property.__get__(instance, async_property) - - @functools.wraps(async_property) - def a_sync_method(self: ASyncABC, **kwargs) -> T: - if not isinstance(self, ASyncABC): - raise RuntimeError(f"{self} must be an instance of a class that inherits from ASyncABC.") - awaitable = _PropertyGetter(_get(self), async_property) - return _helpers._await(awaitable) if self.__a_sync_should_await__(kwargs, force=_force_await) else awaitable - - @property # type: ignore [misc] - @functools.wraps(async_property) - def a_sync_property(self: ASyncABC) -> T: - coro = getattr(self, async_property.hidden_method_name)(sync=False) - return _helpers._await(coro) if self.__a_sync_should_await__({}, force=_force_await) else coro - - return a_sync_property, a_sync_method +class ASyncBoundMethodAsyncDefault(ASyncBoundMethod[P, T]): + def __get__(self, instance: ASyncInstance, owner) -> ASyncFunctionAsyncDefault[P, T]: + return super().__get__(instance, owner) + def __call__(self, *args, **kwargs) -> Awaitable[T]: + return super().__call__(*args, **kwargs) diff --git a/a_sync/_descriptor.py b/a_sync/_descriptor.py new file mode 100644 index 00000000..e111e108 --- /dev/null +++ b/a_sync/_descriptor.py @@ -0,0 +1,24 @@ + +import functools + +from a_sync._typing import * +from a_sync.modified import ASyncFunction, ModifiedMixin, ModifierManager + +class ASyncDescriptor(ModifiedMixin, Generic[T]): + def __init__(self, _fget: UnboundMethod[ASyncInstance, P, T], field_name=None, **modifiers: ModifierKwargs): + if not callable(_fget): + raise ValueError(f'Unable to decorate {_fget}') + self.modifiers = ModifierManager(modifiers) + self._fn: UnboundMethod[ASyncInstance, P, T] = _fget + if isinstance(_fget, ASyncFunction): + self._fget = _fget + elif asyncio.iscoroutinefunction(_fget): + self._fget: AsyncUnboundMethod[ASyncInstance, P, T] = self.modifiers.apply_async_modifiers(_fget) + else: + self._fget = self._asyncify(_fget) + self.field_name = field_name or _fget.__name__ + functools.update_wrapper(self, _fget) + def __repr__(self) -> str: + return f"<{self.__class__.__name__} for {self._fn}>" + def __set_name__(self, owner, name): + self.field_name = name diff --git a/a_sync/_helpers.py b/a_sync/_helpers.py index 43f94dc4..64dc5580 100644 --- a/a_sync/_helpers.py +++ b/a_sync/_helpers.py @@ -10,10 +10,9 @@ from async_property.cached import \ AsyncCachedPropertyDescriptor # type: ignore [import] -from a_sync import _flags +from a_sync import _flags, exceptions from a_sync._typing import * -from a_sync.exceptions import (ASyncRuntimeError, KwargsUnsupportedError, - SyncModeInAsyncContextError) +from a_sync.modified import ASyncFunction def get_event_loop() -> asyncio.AbstractEventLoop: @@ -44,10 +43,12 @@ def _await(awaitable: Awaitable[T]) -> T: return get_event_loop().run_until_complete(awaitable) except RuntimeError as e: if str(e) == "This event loop is already running": - raise SyncModeInAsyncContextError from e + raise exceptions.SyncModeInAsyncContextError from e raise e def _asyncify(func: SyncFn[P, T], executor: Executor) -> CoroFn[P, T]: # type: ignore [misc] + if asyncio.iscoroutinefunction(func) or isinstance(func, ASyncFunction): + raise exceptions.FunctionNotSync(func) @functools.wraps(func) async def _asyncify_wrap(*args: P.args, **kwargs: P.kwargs) -> T: return await asyncio.futures.wrap_future( diff --git a/a_sync/_meta.py b/a_sync/_meta.py index 5d22ef4a..85db68ed 100644 --- a/a_sync/_meta.py +++ b/a_sync/_meta.py @@ -4,10 +4,12 @@ from abc import ABCMeta from typing import Any, Dict, Tuple -from a_sync import ENVIRONMENT_VARIABLES, _bound, modifiers +from a_sync import ENVIRONMENT_VARIABLES, modifiers +from a_sync._bound import ASyncMethodDescriptor from a_sync.future import _ASyncFutureWrappedFn # type: ignore [attr-defined] from a_sync.modified import ASyncFunction, ModifiedMixin -from a_sync.property import PropertyDescriptor +from a_sync.property import ASyncPropertyDescriptor, ASyncCachedPropertyDescriptor +from a_sync.primitives.locks.semaphore import Semaphore logger = logging.getLogger(__name__) @@ -30,7 +32,7 @@ def __new__(cls, new_class_name, bases, attrs): elif "__" in attr_name: logger.debug(f"`%s.%s` incluldes a double-underscore, skipping", new_class_name, attr_name) continue - elif isinstance(attr_value, _ASyncFutureWrappedFn): + elif isinstance(attr_value, (_ASyncFutureWrappedFn, Semaphore)): logger.debug(f"`%s.%s` is a %s, skipping", new_class_name, attr_name, attr_value.__class__.__name__) continue logger.debug(f"inspecting `{new_class_name}.{attr_name}` of type {attr_value.__class__.__name__}") @@ -46,23 +48,20 @@ def __new__(cls, new_class_name, bases, attrs): else: logger.debug(f"I did not find any modifiers") logger.debug(f"full modifier set for `{new_class_name}.{attr_name}`: {fn_modifiers}") - if isinstance(attr_value, PropertyDescriptor): + if isinstance(attr_value, (ASyncPropertyDescriptor, ASyncCachedPropertyDescriptor)): # Wrap property logger.debug(f"`{attr_name} is a property, now let's wrap it") - wrapped, hidden = _bound._wrap_property(attr_value, **fn_modifiers) - attrs[attr_name] = wrapped - logger.debug(f"`{attr_name}` is now `{wrapped}`") logger.debug(f"since `{attr_name}` is a property, we will add a hidden dundermethod so you can still access it both sync and async") - attrs[attr_value.hidden_method_name] = hidden - logger.debug(f"`{new_class_name}.{attr_value.hidden_method_name}` is now {hidden}") + attrs[attr_value.hidden_method_name] = attr_value.hidden_method_descriptor + logger.debug(f"`{new_class_name}.{attr_value.hidden_method_name}` is now {attr_value.hidden_method_descriptor}") elif isinstance(attr_value, ASyncFunction): - attrs[attr_name] = _bound._wrap_bound_method(attr_value, **fn_modifiers) + attrs[attr_name] = ASyncMethodDescriptor(attr_value, **fn_modifiers) else: raise NotImplementedError(attr_name, attr_value) elif callable(attr_value): # NOTE We will need to improve this logic if somebody needs to use it with classmethods or staticmethods. - attrs[attr_name] = _bound._wrap_bound_method(attr_value, **fn_modifiers) + attrs[attr_name] = ASyncMethodDescriptor(attr_value, **fn_modifiers) else: logger.debug(f"`{new_class_name}.{attr_name}` is not callable, we will take no action with it") return super(ASyncMeta, cls).__new__(cls, new_class_name, bases, attrs) diff --git a/a_sync/_typing.py b/a_sync/_typing.py index bfec8215..fc694539 100644 --- a/a_sync/_typing.py +++ b/a_sync/_typing.py @@ -12,10 +12,14 @@ if TYPE_CHECKING: from a_sync.abstract import ASyncABC + ASyncInstance = TypeVar("ASyncInstance", bound=ASyncABC) +else: + ASyncInstance = TypeVar("ASyncInstance", bound=object) T = TypeVar("T") K = TypeVar("K") V = TypeVar("V") +O = TypeVar("O", bound=object) E = TypeVar('E', bound=Exception) P = ParamSpec("P") @@ -23,15 +27,16 @@ MaybeAwaitable = Union[Awaitable[T], T] -Property = Callable[["ASyncABC"], T] -HiddenMethod = Callable[["ASyncABC", Dict[str, bool]], T] -AsyncBoundMethod = Callable[Concatenate["ASyncABC", P], Awaitable[T]] -BoundMethod = Callable[Concatenate["ASyncABC", P], T] - CoroFn = Callable[P, Awaitable[T]] SyncFn = Callable[P, T] AnyFn = Union[CoroFn[P, T], SyncFn[P, T]] +AsyncUnboundMethod = Callable[Concatenate[ASyncInstance, P], Awaitable[T]] +SyncUnboundMethod = Callable[Concatenate[ASyncInstance, P], T] +UnboundMethod = Union[AsyncUnboundMethod[ASyncInstance, P, T], SyncUnboundMethod[ASyncInstance, P, T]] + +Property = Callable[[object], T] + AsyncDecorator = Callable[[CoroFn[P, T]], CoroFn[P, T]] AllToAsyncDecorator = Callable[[AnyFn[P, T]], CoroFn[P, T]] diff --git a/a_sync/abstract.py b/a_sync/abstract.py index 04e89b00..c5426e05 100644 --- a/a_sync/abstract.py +++ b/a_sync/abstract.py @@ -16,36 +16,35 @@ class ASyncABC(metaclass=ASyncMeta): # Concrete Methods (overridable) # ################################## - def __a_sync_should_await__(self, kwargs: dict, force: Optional[Literal[True]] = None) -> bool: + def __a_sync_should_await__(self, kwargs: dict) -> bool: """Returns a boolean that indicates whether methods of 'instance' should be called as sync or async methods.""" try: # Defer to kwargs always - return self.__should_await_from_kwargs(kwargs) + return self.__a_sync_should_await_from_kwargs__(kwargs) except exceptions.NoFlagsFound: # No flag found in kwargs, check for a flag attribute. - return force if force else self.__should_await_from_instance + return self.__a_sync_instance_should_await__ @functools.cached_property - def __should_await_from_instance(self) -> bool: + def __a_sync_instance_should_await__(self) -> bool: """ You can override this if you want. If you want to be able to hotswap instance modes, you can redefine this as a non-cached property. """ return _flags.negate_if_necessary(self.__a_sync_flag_name__, self.__a_sync_flag_value__) - def __should_await_from_kwargs(self, kwargs: dict) -> bool: + def __a_sync_should_await_from_kwargs__(self, kwargs: dict) -> bool: """You can override this if you want.""" if flag := _kwargs.get_flag_name(kwargs): - return _kwargs.is_sync(flag, kwargs, pop_flag=True) - else: - raise NoFlagsFound("kwargs", kwargs.keys()) + return _kwargs.is_sync(flag, kwargs, pop_flag=True) # type: ignore [arg-type] + raise NoFlagsFound("kwargs", kwargs.keys()) @classmethod def __a_sync_instance_will_be_sync__(cls, args: tuple, kwargs: dict) -> bool: """You can override this if you want.""" logger.debug("checking `%s.%s.__init__` signature against provided kwargs to determine a_sync mode for the new instance", cls.__module__, cls.__name__) if flag := _kwargs.get_flag_name(kwargs): - sync = _kwargs.is_sync(flag, kwargs) + sync = _kwargs.is_sync(flag, kwargs) # type: ignore [arg-type] logger.debug("kwargs indicate the new instance created with args %s %s is %ssynchronous", args, kwargs, 'a' if sync is False else '') return sync logger.debug("No valid flags found in kwargs, checking class definition for defined default") diff --git a/a_sync/exceptions.py b/a_sync/exceptions.py index 165098d3..df0a5dce 100644 --- a/a_sync/exceptions.py +++ b/a_sync/exceptions.py @@ -52,7 +52,11 @@ class ImproperFunctionType(ValueError): class FunctionNotAsync(ImproperFunctionType): def __init__(self, fn): - super().__init__(f"'coro_fn' must be a coroutine function defined with 'async def'. You passed {fn}.") + super().__init__(f"`coro_fn` must be a coroutine function defined with `async def`. You passed {fn}.") + +class FunctionNotSync(ImproperFunctionType): + def __init__(self, fn): + super().__init__(f"`func` must be a coroutine function defined with `def`. You passed {fn}.") class KwargsUnsupportedError(ValueError): def __init__(self): diff --git a/a_sync/modified.py b/a_sync/modified.py index 19c09d9a..3aca45af 100644 --- a/a_sync/modified.py +++ b/a_sync/modified.py @@ -34,13 +34,13 @@ def __init__(self, fn: AnyFn[P, T], **modifiers: Unpack[ModifierKwargs]) -> None functools.update_wrapper(self, self.__wrapped__) @overload - def __call__(self, *args: P.args, sync: Literal[True] = True, **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T:... @overload - def __call__(self, *args: P.args, sync: Literal[False] = False, **kwargs: P.kwargs) -> Awaitable[T]:... + def __call__(self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs) -> Awaitable[T]:... @overload - def __call__(self, *args: P.args, asynchronous: Literal[False] = False, **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs) -> T:... @overload - def __call__(self, *args: P.args, asynchronous: Literal[True] = True, **kwargs: P.kwargs) -> Awaitable[T]:... + def __call__(self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs) -> Awaitable[T]:... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeAwaitable[T]: return self.fn(*args, **kwargs) @@ -136,13 +136,13 @@ def __call__(self, func: AnyFn[P, T]) -> ASyncFunction[P, T]: # type: ignore [o class ASyncFunctionSyncDefault(ASyncFunction[P, T]): @overload - def __call__(self, *args: P.args, sync: Literal[True] = True, **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T:... @overload - def __call__(self, *args: P.args, sync: Literal[False] = False, **kwargs: P.kwargs) -> Awaitable[T]:... + def __call__(self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs) -> Awaitable[T]:... @overload - def __call__(self, *args: P.args, asynchronous: Literal[False] = False, **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs) -> T:... @overload - def __call__(self, *args: P.args, asynchronous: Literal[True] = True, **kwargs: P.kwargs) -> Awaitable[T]:... + def __call__(self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs) -> Awaitable[T]:... @overload def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeAwaitable[T]: diff --git a/a_sync/primitives/locks/semaphore.py b/a_sync/primitives/locks/semaphore.py index d775de76..3970c946 100644 --- a/a_sync/primitives/locks/semaphore.py +++ b/a_sync/primitives/locks/semaphore.py @@ -19,7 +19,7 @@ def __init__(self, value: int, name=None, **kwargs) -> None: self._decorated: Set[str] = set() # Dank new functionality - def __call__(self, fn: Callable[P, T]) -> Callable[P, T]: + def __call__(self, fn: CoroFn[P, T]) -> CoroFn[P, T]: return self.decorate(fn) # type: ignore [arg-type, return-value] def __repr__(self) -> str: diff --git a/a_sync/property.py b/a_sync/property.py index 393a80b0..3bedd37b 100644 --- a/a_sync/property.py +++ b/a_sync/property.py @@ -1,84 +1,264 @@ -import asyncio +import functools +import logging import async_property as ap # type: ignore [import] -from a_sync import config +from a_sync import _helpers, config, exceptions +from a_sync._bound import ASyncBoundMethodAsyncDefault, ASyncMethodDescriptorAsyncDefault +from a_sync._descriptor import ASyncDescriptor from a_sync._typing import * -from a_sync.modified import ModifiedMixin -from a_sync.modifiers.manager import ModifierManager - - -class PropertyDescriptor(ModifiedMixin, Generic[T]): - def __init__(self, _fget: Callable[..., T], field_name=None, **modifiers: ModifierKwargs): - if not callable(_fget): - raise ValueError(f'Unable to decorate {_fget}') - self.modifiers = ModifierManager(modifiers) - self._fn = _fget - _fget = self.modifiers.apply_async_modifiers(_fget) if asyncio.iscoroutinefunction(_fget) else self._asyncify(_fget) - super().__init__(_fget, field_name=field_name) # type: ignore [call-arg] - def __repr__(self) -> str: - return f"<{self.__class__.__module__}.{self.__class__.__name__} for {self._fn} at {hex(id(self))}>" - -class AsyncPropertyDescriptor(PropertyDescriptor[T], ap.base.AsyncPropertyDescriptor): - pass - -class AsyncCachedPropertyDescriptor(PropertyDescriptor[T], ap.cached.AsyncCachedPropertyDescriptor): + + +logger = logging.getLogger(__name__) + +class _ASyncPropertyDescriptorBase(ASyncDescriptor[T]): + _fget: Property[T] + def __init__(self, _fget: Property[Awaitable[T]], field_name=None, **modifiers: config.ModifierKwargs): + super().__init__(_fget, field_name, **modifiers) + self.hidden_method_name = f"__{self.field_name}__" + hidden_modifiers = dict(self.modifiers) + hidden_modifiers["default"] = "async" + self.hidden_method_descriptor = HiddenMethodDescriptor(self.get, self.hidden_method_name, **hidden_modifiers) + async def get(self, instance: object) -> T: + return await super().__get__(instance, None) + def __get__(self, instance: object, owner) -> T: + awaitable = super().__get__(instance, owner) + # if the user didn't specify a default behavior, we will defer to the instance + if _is_a_sync_instance(instance): + should_await = self.default == "sync" if self.default else instance.__a_sync_instance_should_await__ + else: + should_await = self.default == "sync" if self.default else not asyncio.get_event_loop().is_running() + return _helpers._await(awaitable) if should_await else awaitable + +class ASyncPropertyDescriptor(_ASyncPropertyDescriptorBase[T], ap.base.AsyncPropertyDescriptor): pass +class property(ASyncPropertyDescriptor[T]):... + +class ASyncPropertyDescriptorSyncDefault(property[T]): + """This is a helper class used for type checking. You will not run into any instance of this in prod.""" + +class ASyncPropertyDescriptorAsyncDefault(property[T]): + """This is a helper class used for type checking. You will not run into any instance of this in prod.""" + def __get__(self, instance, owner) -> Awaitable[T]: + return super().__get__(instance, owner) + + +ASyncPropertyDecorator = Callable[[Property[T]], property[T]] +ASyncPropertyDecoratorSyncDefault = Callable[[Property[T]], ASyncPropertyDescriptorSyncDefault[T]] +ASyncPropertyDecoratorAsyncDefault = Callable[[Property[T]], ASyncPropertyDescriptorAsyncDefault[T]] + +@overload +def a_sync_property( # type: ignore [misc] + func: Literal[None], + default: Literal["sync"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncPropertyDecoratorSyncDefault[T]:... + +@overload +def a_sync_property( # type: ignore [misc] + func: Literal[None], + default: Literal["sync"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncPropertyDecoratorSyncDefault[T]:... + +@overload +def a_sync_property( # type: ignore [misc] + func: Literal[None], + default: Literal["async"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncPropertyDecoratorAsyncDefault[T]:... @overload def a_sync_property( # type: ignore [misc] func: Literal[None], default: DefaultMode = config.DEFAULT_MODE, **modifiers: Unpack[ModifierKwargs], -) -> Callable[[Property[T]], AsyncPropertyDescriptor[T]]:... +) -> ASyncPropertyDecorator[T]:... + +@overload +def a_sync_property( # type: ignore [misc] + default: Literal["sync"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncPropertyDecoratorSyncDefault[T]:... + +@overload +def a_sync_property( # type: ignore [misc] + default: Literal["async"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncPropertyDecoratorAsyncDefault[T]:... + +@overload +def a_sync_property( # type: ignore [misc] + func: Property[T], + default: Literal["sync"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncPropertyDescriptorSyncDefault[T]:... + +@overload +def a_sync_property( # type: ignore [misc] + func: Property[T], + default: Literal["async"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncPropertyDescriptorAsyncDefault[T]:... @overload def a_sync_property( # type: ignore [misc] func: Property[T], default: DefaultMode = config.DEFAULT_MODE, **modifiers: Unpack[ModifierKwargs], -) -> AsyncPropertyDescriptor[T]:... +) -> ASyncPropertyDescriptor[T]:... def a_sync_property( # type: ignore [misc] func: Union[Property[T], DefaultMode] = None, - default: DefaultMode = config.DEFAULT_MODE, **modifiers: Unpack[ModifierKwargs], ) -> Union[ - AsyncPropertyDescriptor[T], - Callable[[Property[T]], AsyncPropertyDescriptor[T]], + ASyncPropertyDescriptor[T], + ASyncPropertyDescriptorSyncDefault[T], + ASyncPropertyDescriptorAsyncDefault[T], + ASyncPropertyDecorator[T], + ASyncPropertyDecoratorSyncDefault[T], + ASyncPropertyDecoratorAsyncDefault[T], ]: - if func in ['sync', 'async']: - modifiers['default'] = func - func = None - def modifier_wrap(func: Property[T]) -> AsyncPropertyDescriptor[T]: - return AsyncPropertyDescriptor(func, **modifiers) - return modifier_wrap if func is None else modifier_wrap(func) # type: ignore [arg-type] - + func, modifiers = _parse_args(func, modifiers) + if modifiers.get("default") == "sync": + descriptor_class = ASyncPropertyDescriptorSyncDefault + elif modifiers.get("default") == "async": + descriptor_class = ASyncPropertyDescriptorAsyncDefault + else: + descriptor_class = property + decorator = functools.partial(descriptor_class, **modifiers) + return decorator if func is None else decorator(func) + + +class ASyncCachedPropertyDescriptor(_ASyncPropertyDescriptorBase[T], ap.cached.AsyncCachedPropertyDescriptor): + __slots__ = "_fset", "_fdel", "__async_property__" + def __init__(self, _fget, _fset=None, _fdel=None, field_name=None, **modifiers: Unpack[ModifierKwargs]): + super().__init__(_fget, field_name, **modifiers) + self._check_method_sync(_fset, 'setter') + self._check_method_sync(_fdel, 'deleter') + self._fset = _fset + self._fdel = _fdel + +class cached_property(ASyncCachedPropertyDescriptor[T]):... + +class ASyncCachedPropertyDescriptorSyncDefault(cached_property[T]): + """This is a helper class used for type checking. You will not run into any instance of this in prod.""" + +class ASyncCachedPropertyDescriptorAsyncDefault(cached_property[T]): + """This is a helper class used for type checking. You will not run into any instance of this in prod.""" + +ASyncCachedPropertyDecorator = Callable[[Property[T]], cached_property[T]] +ASyncCachedPropertyDecoratorSyncDefault = Callable[[Property[T]], ASyncCachedPropertyDescriptorSyncDefault[T]] +ASyncCachedPropertyDecoratorAsyncDefault = Callable[[Property[T]], ASyncCachedPropertyDescriptorAsyncDefault[T]] + +@overload +def a_sync_cached_property( # type: ignore [misc] + func: Literal[None], + default: Literal["sync"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncCachedPropertyDecoratorSyncDefault[T]:... + +@overload +def a_sync_cached_property( # type: ignore [misc] + func: Literal[None], + default: Literal["async"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncCachedPropertyDecoratorAsyncDefault[T]:... @overload def a_sync_cached_property( # type: ignore [misc] func: Literal[None], default: DefaultMode, **modifiers: Unpack[ModifierKwargs], -) -> Callable[[Property[T]], AsyncCachedPropertyDescriptor[T]]:... +) -> ASyncCachedPropertyDecorator[T]:... + +@overload +def a_sync_cached_property( # type: ignore [misc] + default: Literal["sync"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncCachedPropertyDecoratorSyncDefault[T]:... + +@overload +def a_sync_cached_property( # type: ignore [misc] + default: Literal["async"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncCachedPropertyDecoratorAsyncDefault[T]:... +@overload +def a_sync_cached_property( # type: ignore [misc] + func: Property[T], + default: Literal["sync"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncCachedPropertyDescriptorSyncDefault[T]:... + +@overload +def a_sync_cached_property( # type: ignore [misc] + func: Property[T], + default: Literal["async"], + **modifiers: Unpack[ModifierKwargs], +) -> ASyncCachedPropertyDescriptorAsyncDefault[T]:... + @overload def a_sync_cached_property( # type: ignore [misc] func: Property[T], default: DefaultMode = config.DEFAULT_MODE, **modifiers: Unpack[ModifierKwargs], -) -> AsyncCachedPropertyDescriptor[T]:... +) -> ASyncCachedPropertyDescriptor[T]:... def a_sync_cached_property( # type: ignore [misc] func: Optional[Property[T]] = None, - default: DefaultMode = config.DEFAULT_MODE, **modifiers: Unpack[ModifierKwargs], ) -> Union[ - AsyncCachedPropertyDescriptor[T], - Callable[[Property[T]], AsyncCachedPropertyDescriptor[T]], + ASyncCachedPropertyDescriptor[T], + ASyncCachedPropertyDescriptorSyncDefault[T], + ASyncCachedPropertyDescriptorAsyncDefault[T], + ASyncCachedPropertyDecorator[T], + ASyncCachedPropertyDecoratorSyncDefault[T], + ASyncCachedPropertyDecoratorAsyncDefault[T], ]: - def modifier_wrap(func: Property[T]) -> AsyncCachedPropertyDescriptor[T]: - return AsyncCachedPropertyDescriptor(func, **modifiers) - return modifier_wrap if func is None else modifier_wrap(func) + func, modifiers = _parse_args(func, modifiers) + if modifiers.get("default") == "sync": + descriptor_class = ASyncCachedPropertyDescriptorSyncDefault + elif modifiers.get("default") == "sync": + descriptor_class = ASyncCachedPropertyDescriptorAsyncDefault + else: + descriptor_class = ASyncCachedPropertyDescriptor + decorator = functools.partial(descriptor_class, **modifiers) + return decorator if func is None else decorator(func) + + +class HiddenMethod(ASyncBoundMethodAsyncDefault[ASyncInstance, T]): + def should_await(self, kwargs: dict) -> bool: + try: + return self.instance.__a_sync_should_await_from_kwargs__(kwargs) + except exceptions.NoFlagsFound: + return False + +class HiddenMethodDescriptor(ASyncMethodDescriptorAsyncDefault[ASyncInstance, P, T]): + def __get__(self, instance: ASyncInstance, owner) -> HiddenMethod[ASyncInstance, T]: + if instance is None: + return self + try: + return instance.__dict__[self.field_name] + except KeyError: + bound = HiddenMethod(instance, self._fget, **self.modifiers) + instance.__dict__[self.field_name] = bound + logger.debug("new hidden method: %s", bound) + return bound + +def _is_a_sync_instance(instance: object) -> bool: + try: + return instance.__dict__["__is_a_sync_instance__"] + except KeyError: + from a_sync.abstract import ASyncABC + is_a_sync = isinstance(instance, ASyncABC) + instance.__dict__["__is_a_sync_instance__"] = is_a_sync + return is_a_sync + +def _parse_args(func: Union[None, DefaultMode, Property[T]], modifiers: ModifierKwargs) -> Tuple[Optional[Property[T]], ModifierKwargs]: + if func in ['sync', 'async']: + modifiers['default'] = func + func = None + return func, modifiers diff --git a/tests/fixtures.py b/tests/fixtures.py index 05ceda02..6fa3bc43 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -14,7 +14,7 @@ increment = pytest.mark.parametrize('i', range(10)) class TestClass(ASyncBase): - def __init__(self, v: int, sync: bool): + def __init__(self, v: int, sync: bool = False): self.v = v self.sync = sync @@ -94,21 +94,20 @@ class TestSingletonMeta(TestClass, metaclass=ASyncSingletonMeta): pass class TestSemaphore(ASyncBase): - semaphore=1 # NOTE: this is detected propely by undecorated test_fn but not the properties + #semaphore=1 # NOTE: this is detected propely by undecorated test_fn but not the properties def __init__(self, v: int, sync: bool): self.v = v self.sync = sync # spec on class and function both working - #@a_sync.a_sync(semaphore=1) + @a_sync.a_sync(semaphore=1) async def test_fn(self) -> int: await asyncio.sleep(1) return self.v - # spec on class, function, property all working - @a_sync.aka.property#('async', semaphore=1) + @a_sync.aka.property('async', semaphore=1) async def test_property(self) -> int: await asyncio.sleep(1) return self.v * 2 diff --git a/tests/test_base.py b/tests/test_base.py index 43231122..6e59b7da 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -4,6 +4,7 @@ import pytest +from a_sync._bound import ASyncBoundMethodAsyncDefault from a_sync._meta import ASyncMeta from tests.fixtures import TestClass, TestInheritor, TestMeta, increment, TestSync @@ -28,9 +29,17 @@ def test_base_sync(cls: type, i: int): assert isinstance(val, int) # Can we access hidden methods for properties? - assert sync_instance.__test_property__() == i * 2 + getter = sync_instance.__test_property__ + assert isinstance(getter, ASyncBoundMethodAsyncDefault), getter + getter_coro = getter() + assert asyncio.iscoroutine(getter_coro), getter_coro + assert asyncio.get_event_loop().run_until_complete(getter_coro) == i * 2 start = time.time() - assert sync_instance.__test_cached_property__() == i * 3 + getter = sync_instance.__test_cached_property__ + assert isinstance(getter, ASyncBoundMethodAsyncDefault), getter + getter_coro = getter() + assert asyncio.iscoroutine(getter_coro), getter_coro + assert asyncio.get_event_loop().run_until_complete(getter_coro) == i * 3 # Can we override them too? assert asyncio.get_event_loop().run_until_complete(sync_instance.__test_cached_property__(sync=False)) == i * 3 duration = time.time() - start @@ -65,8 +74,17 @@ async def test_base_async(cls: type, i: int): async_instance.test_fn(sync=True) # Can we access hidden methods for properties? - assert await async_instance.__test_property__() == i * 2 - assert await async_instance.__test_cached_property__() == i * 3 + getter = async_instance.__test_property__ + assert isinstance(getter, ASyncBoundMethodAsyncDefault), getter + getter_coro = getter() + assert asyncio.iscoroutine(getter_coro), getter_coro + assert await getter_coro == i * 2 + + getter = async_instance.__test_cached_property__ + assert isinstance(getter, ASyncBoundMethodAsyncDefault), getter + getter_coro = getter() + assert asyncio.iscoroutine(getter_coro), getter_coro + assert await getter_coro == i * 3 # Can we override them too? with pytest.raises(RuntimeError): async_instance.__test_cached_property__(sync=True) diff --git a/tests/test_meta.py b/tests/test_meta.py index 3151df4d..ee2d6f5f 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -5,6 +5,7 @@ from a_sync._meta import ASyncMeta from a_sync.base import ASyncGenericBase +from a_sync.property import HiddenMethod from a_sync.singleton import ASyncGenericSingleton from tests.fixtures import TestSingleton, TestSingletonMeta, increment @@ -30,9 +31,15 @@ def test_singleton_meta_sync(cls: type, i: int): assert isinstance(val, int) # Can we access hidden methods for properties? - assert sync_instance.__test_property__() == 0 + getter = sync_instance.__test_property__ + assert isinstance(getter, HiddenMethod), getter + getter_coro = getter() + assert asyncio.get_event_loop().run_until_complete(getter_coro) == 0 start = time.time() - assert sync_instance.__test_cached_property__() == 0 + getter = sync_instance.__test_property__ + assert isinstance(getter, HiddenMethod), getter + getter_coro = getter() + assert asyncio.get_event_loop().run_until_complete(getter_coro) == 0 # Can we override them too? assert asyncio.get_event_loop().run_until_complete(sync_instance.__test_cached_property__(sync=False)) == 0 duration = time.time() - start