-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: better type hints everywhere (#141)
* feat: ASyncMethodDescriptor and ASyncBoundMethod classes * feat: more type checking helper classes * fix(mypy): fix type errs
- Loading branch information
1 parent
a7f4c92
commit 6ad2ddc
Showing
14 changed files
with
435 additions
and
182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,111 +1,123 @@ | ||
# mypy: disable-error-code=valid-type | ||
# mypy: disable-error-code=misc | ||
import functools | ||
import logging | ||
from inspect import isawaitable | ||
|
||
from a_sync import _helpers | ||
from a_sync import _helpers, _kwargs | ||
from a_sync._descriptor import ASyncDescriptor | ||
from a_sync._typing import * | ||
from a_sync.decorator import a_sync as unbound_a_sync | ||
from a_sync.modified import ASyncFunction | ||
from a_sync.property import (AsyncCachedPropertyDescriptor, | ||
AsyncPropertyDescriptor) | ||
from a_sync.modified import ASyncFunction, ASyncFunctionAsyncDefault, ASyncFunctionSyncDefault | ||
|
||
if TYPE_CHECKING: | ||
from a_sync.abstract import ASyncABC | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def _clean_default_from_modifiers( | ||
coro_fn: AsyncBoundMethod[P, T], # type: ignore [misc] | ||
modifiers: ModifierKwargs, | ||
): | ||
# NOTE: We set the default here manually because the default set by the user will be used later in the code to determine whether to await. | ||
force_await = None | ||
if not asyncio.iscoroutinefunction(coro_fn) and not isinstance(coro_fn, ASyncFunction): | ||
if 'default' not in modifiers or modifiers['default'] != 'async': | ||
if 'default' in modifiers and modifiers['default'] == 'sync': | ||
force_await = True | ||
modifiers['default'] = 'async' | ||
return modifiers, force_await | ||
class ASyncMethodDescriptor(ASyncDescriptor[ASyncFunction[P, T]], Generic[O, P, T]): | ||
_fget: ASyncFunction[Concatenate[O, P], T] | ||
def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethod[P, T]": | ||
if instance is None: | ||
return self | ||
try: | ||
return instance.__dict__[self.field_name] | ||
except KeyError: | ||
from a_sync.abstract import ASyncABC | ||
if self.default == "sync": | ||
bound = ASyncBoundMethodSyncDefault(instance, self._fget, **self.modifiers) | ||
elif self.default == "async": | ||
bound = ASyncBoundMethodAsyncDefault(instance, self._fget, **self.modifiers) | ||
elif isinstance(instance, ASyncABC) and instance.__a_sync_instance_should_await__: | ||
bound = ASyncBoundMethodSyncDefault(instance, self._fget, **self.modifiers) | ||
elif isinstance(instance, ASyncABC) and instance.__a_sync_instance_should_await__: | ||
bound = ASyncBoundMethodAsyncDefault(instance, self._fget, **self.modifiers) | ||
else: | ||
bound = ASyncBoundMethod(instance, self._fget, **self.modifiers) | ||
instance.__dict__[self.field_name] = bound | ||
logger.debug("new bound method: %s", bound) | ||
return bound | ||
def __set__(self, instance, value): | ||
raise RuntimeError(f"cannot set {self.field_name}, {self} is what you get. sorry.") | ||
def __delete__(self, instance): | ||
raise RuntimeError(f"cannot delete {self.field_name}, you're stuck with {self} forever. sorry.") | ||
|
||
|
||
def _wrap_bound_method( | ||
coro_fn: AsyncBoundMethod[P, T], | ||
**modifiers: Unpack[ModifierKwargs] | ||
) -> AsyncBoundMethod[P, T]: | ||
from a_sync.abstract import ASyncABC | ||
|
||
# First we unwrap the coro_fn and rewrap it so overriding flag kwargs are handled automagically. | ||
if isinstance(coro_fn, ASyncFunction): | ||
coro_fn = coro_fn.__wrapped__ | ||
|
||
modifiers, _force_await = _clean_default_from_modifiers(coro_fn, modifiers) | ||
|
||
wrapped_coro_fn: AsyncBoundMethod[P, T] = ASyncFunction(coro_fn, **modifiers) # type: ignore [arg-type, valid-type, misc] | ||
class ASyncMethodDescriptorSyncDefault(ASyncMethodDescriptor[ASyncInstance, P, T]): | ||
def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethodSyncDefault[P, T]": | ||
if instance is None: | ||
return self | ||
try: | ||
return instance.__dict__[self.field_name] | ||
except KeyError: | ||
bound = ASyncBoundMethodSyncDefault(instance, self._fget, **self.modifiers) | ||
instance.__dict__[self.field_name] = bound | ||
logger.debug("new bound method: %s", bound) | ||
return bound | ||
|
||
@functools.wraps(coro_fn) | ||
def bound_a_sync_wrap(self: ASyncABC, *args: P.args, **kwargs: P.kwargs) -> T: # type: ignore [name-defined] | ||
if not isinstance(self, ASyncABC): | ||
raise RuntimeError(f"{self} must be an instance of a class that inherits from ASyncABC.") | ||
class ASyncMethodDescriptorAsyncDefault(ASyncMethodDescriptor[ASyncInstance, P, T]): | ||
def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethodAsyncDefault[P, T]": | ||
if instance is None: | ||
return self | ||
try: | ||
return instance.__dict__[self.field_name] | ||
except KeyError: | ||
bound = ASyncBoundMethodAsyncDefault(instance, self._fget, **self.modifiers) | ||
instance.__dict__[self.field_name] = bound | ||
logger.debug("new bound method: %s", bound) | ||
return bound | ||
|
||
class ASyncBoundMethod(ASyncFunction[P, T]): | ||
def __init__( | ||
self, | ||
instance: ASyncInstance, | ||
unbound: AnyFn[Concatenate[ASyncInstance, P], T], | ||
**modifiers: Unpack[ModifierKwargs], | ||
) -> None: | ||
self.instance = instance | ||
# First we unwrap the coro_fn and rewrap it so overriding flag kwargs are handled automagically. | ||
if isinstance(unbound, ASyncFunction): | ||
modifiers.update(unbound.modifiers) | ||
unbound = unbound.__wrapped__ | ||
if asyncio.iscoroutinefunction(unbound): | ||
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: | ||
return await unbound(self.instance, *args, **kwargs) | ||
else: | ||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: | ||
return unbound(self.instance, *args, **kwargs) | ||
functools.update_wrapper(wrapped, unbound) | ||
super().__init__(wrapped, **modifiers) | ||
def __repr__(self) -> str: | ||
return f"<{self.__class__.__name__} for function {self.__module__}.{self.instance.__class__.__name__}.{self.__name__} bound to {self.instance}>" | ||
def __call__(self, *args, **kwargs): | ||
logger.debug("calling %s", self) | ||
# This could either be a coroutine or a return value from an awaited coroutine, | ||
# depending on if an overriding flag kwarg was passed into the function call. | ||
retval = coro = wrapped_coro_fn(self, *args, **kwargs) | ||
retval = coro = super().__call__(*args, **kwargs) | ||
if not isawaitable(retval): | ||
# The coroutine was already awaited due to the use of an overriding flag kwarg. | ||
# We can return the value. | ||
return retval # type: ignore [return-value] | ||
# The awaitable was not awaited, so now we need to check the flag as defined on 'self' and await if appropriate. | ||
return _helpers._await(coro) if self.__a_sync_should_await__(kwargs, force=_force_await) else coro # type: ignore [call-overload, return-value] | ||
return bound_a_sync_wrap | ||
|
||
class _PropertyGetter(Awaitable[T]): | ||
def __init__(self, coro: Awaitable[T], property: Union[AsyncPropertyDescriptor[T], AsyncCachedPropertyDescriptor[T]]): | ||
self._coro = coro | ||
self._property = property | ||
def __repr__(self) -> str: | ||
return f"<_PropertyGetter for {self._property}._get at {hex(id(self))}>" | ||
def __await__(self) -> Generator[Any, None, T]: | ||
return self._coro.__await__() | ||
|
||
@overload | ||
def _wrap_property( | ||
async_property: AsyncPropertyDescriptor[T], | ||
**modifiers: Unpack[ModifierKwargs] | ||
) -> AsyncPropertyDescriptor[T]:... | ||
|
||
@overload | ||
def _wrap_property( | ||
async_property: AsyncCachedPropertyDescriptor[T], | ||
**modifiers: Unpack[ModifierKwargs] | ||
) -> AsyncCachedPropertyDescriptor:... | ||
return _helpers._await(coro) if self.should_await(kwargs) else coro # type: ignore [call-overload, return-value] | ||
@functools.cached_property | ||
def __bound_to_a_sync_instance__(self) -> bool: | ||
from a_sync.abstract import ASyncABC | ||
return isinstance(self.instance, ASyncABC) | ||
def should_await(self, kwargs: dict) -> bool: | ||
if flag := _kwargs.get_flag_name(kwargs): | ||
return _kwargs.is_sync(flag, kwargs, pop_flag=True) # type: ignore [arg-type] | ||
elif self.default: | ||
return self.default == "sync" | ||
elif self.__bound_to_a_sync_instance__: | ||
return self.instance.__a_sync_should_await__(kwargs) | ||
return asyncio.iscoroutinefunction(self.__wrapped__) | ||
|
||
def _wrap_property( | ||
async_property: Union[AsyncPropertyDescriptor[T], AsyncCachedPropertyDescriptor[T]], | ||
**modifiers: Unpack[ModifierKwargs] | ||
) -> Tuple[Property[T], HiddenMethod[T]]: | ||
if not isinstance(async_property, (AsyncPropertyDescriptor, AsyncCachedPropertyDescriptor)): | ||
raise TypeError(f"{async_property} must be one of: AsyncPropertyDescriptor, AsyncCachedPropertyDescriptor") | ||
|
||
from a_sync.abstract import ASyncABC | ||
class ASyncBoundMethodSyncDefault(ASyncBoundMethod[P, T]): | ||
def __get__(self, instance: ASyncInstance, owner) -> ASyncFunctionSyncDefault[P, T]: | ||
return super().__get__(instance, owner) | ||
def __call__(self, *args, **kwargs) -> T: | ||
return super().__call__(*args, **kwargs) | ||
|
||
async_property.hidden_method_name = f"__{async_property.field_name}__" | ||
|
||
modifiers, _force_await = _clean_default_from_modifiers(async_property, modifiers) | ||
|
||
@unbound_a_sync(**modifiers) | ||
async def _get(instance: ASyncABC) -> T: | ||
return await async_property.__get__(instance, async_property) | ||
|
||
@functools.wraps(async_property) | ||
def a_sync_method(self: ASyncABC, **kwargs) -> T: | ||
if not isinstance(self, ASyncABC): | ||
raise RuntimeError(f"{self} must be an instance of a class that inherits from ASyncABC.") | ||
awaitable = _PropertyGetter(_get(self), async_property) | ||
return _helpers._await(awaitable) if self.__a_sync_should_await__(kwargs, force=_force_await) else awaitable | ||
|
||
@property # type: ignore [misc] | ||
@functools.wraps(async_property) | ||
def a_sync_property(self: ASyncABC) -> T: | ||
coro = getattr(self, async_property.hidden_method_name)(sync=False) | ||
return _helpers._await(coro) if self.__a_sync_should_await__({}, force=_force_await) else coro | ||
|
||
return a_sync_property, a_sync_method | ||
class ASyncBoundMethodAsyncDefault(ASyncBoundMethod[P, T]): | ||
def __get__(self, instance: ASyncInstance, owner) -> ASyncFunctionAsyncDefault[P, T]: | ||
return super().__get__(instance, owner) | ||
def __call__(self, *args, **kwargs) -> Awaitable[T]: | ||
return super().__call__(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
|
||
import functools | ||
|
||
from a_sync._typing import * | ||
from a_sync.modified import ASyncFunction, ModifiedMixin, ModifierManager | ||
|
||
class ASyncDescriptor(ModifiedMixin, Generic[T]): | ||
def __init__(self, _fget: UnboundMethod[ASyncInstance, P, T], field_name=None, **modifiers: ModifierKwargs): | ||
if not callable(_fget): | ||
raise ValueError(f'Unable to decorate {_fget}') | ||
self.modifiers = ModifierManager(modifiers) | ||
self._fn: UnboundMethod[ASyncInstance, P, T] = _fget | ||
if isinstance(_fget, ASyncFunction): | ||
self._fget = _fget | ||
elif asyncio.iscoroutinefunction(_fget): | ||
self._fget: AsyncUnboundMethod[ASyncInstance, P, T] = self.modifiers.apply_async_modifiers(_fget) | ||
else: | ||
self._fget = self._asyncify(_fget) | ||
self.field_name = field_name or _fget.__name__ | ||
functools.update_wrapper(self, _fget) | ||
def __repr__(self) -> str: | ||
return f"<{self.__class__.__name__} for {self._fn}>" | ||
def __set_name__(self, owner, name): | ||
self.field_name = name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.