Skip to content

Commit

Permalink
feat: better type hints everywhere (#141)
Browse files Browse the repository at this point in the history
* feat: ASyncMethodDescriptor and ASyncBoundMethod classes

* feat: more type checking helper classes

* fix(mypy): fix type errs
  • Loading branch information
BobTheBuidler authored Feb 29, 2024
1 parent a7f4c92 commit 6ad2ddc
Show file tree
Hide file tree
Showing 14 changed files with 435 additions and 182 deletions.
7 changes: 6 additions & 1 deletion a_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from a_sync.iter import ASyncIterable, ASyncIterator
from a_sync.modifiers.semaphores import apply_semaphore
from a_sync.primitives import *
from a_sync.property import ASyncCachedPropertyDescriptor, ASyncPropertyDescriptor, cached_property, property
from a_sync.singleton import ASyncGenericSingleton
from a_sync.task import TaskMapping as map
from a_sync.task import TaskMapping, create_task
Expand All @@ -19,7 +20,6 @@
# alias for backward-compatability, will be removed eventually, probably in 0.1.0
ASyncBase = ASyncGenericBase


__all__ = [
"all",
"any",
Expand All @@ -34,4 +34,9 @@
"ASyncIterator",
"ASyncGenericSingleton",
"TaskMapping",
# property
"cached_property",
"property",
"ASyncPropertyDescriptor",
"ASyncCachedPropertyDescriptor",
]
192 changes: 102 additions & 90 deletions a_sync/_bound.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,123 @@
# mypy: disable-error-code=valid-type
# mypy: disable-error-code=misc
import functools
import logging
from inspect import isawaitable

from a_sync import _helpers
from a_sync import _helpers, _kwargs
from a_sync._descriptor import ASyncDescriptor
from a_sync._typing import *
from a_sync.decorator import a_sync as unbound_a_sync
from a_sync.modified import ASyncFunction
from a_sync.property import (AsyncCachedPropertyDescriptor,
AsyncPropertyDescriptor)
from a_sync.modified import ASyncFunction, ASyncFunctionAsyncDefault, ASyncFunctionSyncDefault

if TYPE_CHECKING:
from a_sync.abstract import ASyncABC

logger = logging.getLogger(__name__)

def _clean_default_from_modifiers(
coro_fn: AsyncBoundMethod[P, T], # type: ignore [misc]
modifiers: ModifierKwargs,
):
# NOTE: We set the default here manually because the default set by the user will be used later in the code to determine whether to await.
force_await = None
if not asyncio.iscoroutinefunction(coro_fn) and not isinstance(coro_fn, ASyncFunction):
if 'default' not in modifiers or modifiers['default'] != 'async':
if 'default' in modifiers and modifiers['default'] == 'sync':
force_await = True
modifiers['default'] = 'async'
return modifiers, force_await
class ASyncMethodDescriptor(ASyncDescriptor[ASyncFunction[P, T]], Generic[O, P, T]):
_fget: ASyncFunction[Concatenate[O, P], T]
def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethod[P, T]":
if instance is None:
return self
try:
return instance.__dict__[self.field_name]
except KeyError:
from a_sync.abstract import ASyncABC
if self.default == "sync":
bound = ASyncBoundMethodSyncDefault(instance, self._fget, **self.modifiers)
elif self.default == "async":
bound = ASyncBoundMethodAsyncDefault(instance, self._fget, **self.modifiers)
elif isinstance(instance, ASyncABC) and instance.__a_sync_instance_should_await__:
bound = ASyncBoundMethodSyncDefault(instance, self._fget, **self.modifiers)
elif isinstance(instance, ASyncABC) and instance.__a_sync_instance_should_await__:
bound = ASyncBoundMethodAsyncDefault(instance, self._fget, **self.modifiers)
else:
bound = ASyncBoundMethod(instance, self._fget, **self.modifiers)
instance.__dict__[self.field_name] = bound
logger.debug("new bound method: %s", bound)
return bound
def __set__(self, instance, value):
raise RuntimeError(f"cannot set {self.field_name}, {self} is what you get. sorry.")
def __delete__(self, instance):
raise RuntimeError(f"cannot delete {self.field_name}, you're stuck with {self} forever. sorry.")


def _wrap_bound_method(
coro_fn: AsyncBoundMethod[P, T],
**modifiers: Unpack[ModifierKwargs]
) -> AsyncBoundMethod[P, T]:
from a_sync.abstract import ASyncABC

# First we unwrap the coro_fn and rewrap it so overriding flag kwargs are handled automagically.
if isinstance(coro_fn, ASyncFunction):
coro_fn = coro_fn.__wrapped__

modifiers, _force_await = _clean_default_from_modifiers(coro_fn, modifiers)

wrapped_coro_fn: AsyncBoundMethod[P, T] = ASyncFunction(coro_fn, **modifiers) # type: ignore [arg-type, valid-type, misc]
class ASyncMethodDescriptorSyncDefault(ASyncMethodDescriptor[ASyncInstance, P, T]):
def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethodSyncDefault[P, T]":
if instance is None:
return self
try:
return instance.__dict__[self.field_name]
except KeyError:
bound = ASyncBoundMethodSyncDefault(instance, self._fget, **self.modifiers)
instance.__dict__[self.field_name] = bound
logger.debug("new bound method: %s", bound)
return bound

@functools.wraps(coro_fn)
def bound_a_sync_wrap(self: ASyncABC, *args: P.args, **kwargs: P.kwargs) -> T: # type: ignore [name-defined]
if not isinstance(self, ASyncABC):
raise RuntimeError(f"{self} must be an instance of a class that inherits from ASyncABC.")
class ASyncMethodDescriptorAsyncDefault(ASyncMethodDescriptor[ASyncInstance, P, T]):
def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethodAsyncDefault[P, T]":
if instance is None:
return self
try:
return instance.__dict__[self.field_name]
except KeyError:
bound = ASyncBoundMethodAsyncDefault(instance, self._fget, **self.modifiers)
instance.__dict__[self.field_name] = bound
logger.debug("new bound method: %s", bound)
return bound

class ASyncBoundMethod(ASyncFunction[P, T]):
def __init__(
self,
instance: ASyncInstance,
unbound: AnyFn[Concatenate[ASyncInstance, P], T],
**modifiers: Unpack[ModifierKwargs],
) -> None:
self.instance = instance
# First we unwrap the coro_fn and rewrap it so overriding flag kwargs are handled automagically.
if isinstance(unbound, ASyncFunction):
modifiers.update(unbound.modifiers)
unbound = unbound.__wrapped__
if asyncio.iscoroutinefunction(unbound):
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
return await unbound(self.instance, *args, **kwargs)
else:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
return unbound(self.instance, *args, **kwargs)
functools.update_wrapper(wrapped, unbound)
super().__init__(wrapped, **modifiers)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} for function {self.__module__}.{self.instance.__class__.__name__}.{self.__name__} bound to {self.instance}>"
def __call__(self, *args, **kwargs):
logger.debug("calling %s", self)
# This could either be a coroutine or a return value from an awaited coroutine,
# depending on if an overriding flag kwarg was passed into the function call.
retval = coro = wrapped_coro_fn(self, *args, **kwargs)
retval = coro = super().__call__(*args, **kwargs)
if not isawaitable(retval):
# The coroutine was already awaited due to the use of an overriding flag kwarg.
# We can return the value.
return retval # type: ignore [return-value]
# The awaitable was not awaited, so now we need to check the flag as defined on 'self' and await if appropriate.
return _helpers._await(coro) if self.__a_sync_should_await__(kwargs, force=_force_await) else coro # type: ignore [call-overload, return-value]
return bound_a_sync_wrap

class _PropertyGetter(Awaitable[T]):
def __init__(self, coro: Awaitable[T], property: Union[AsyncPropertyDescriptor[T], AsyncCachedPropertyDescriptor[T]]):
self._coro = coro
self._property = property
def __repr__(self) -> str:
return f"<_PropertyGetter for {self._property}._get at {hex(id(self))}>"
def __await__(self) -> Generator[Any, None, T]:
return self._coro.__await__()

@overload
def _wrap_property(
async_property: AsyncPropertyDescriptor[T],
**modifiers: Unpack[ModifierKwargs]
) -> AsyncPropertyDescriptor[T]:...

@overload
def _wrap_property(
async_property: AsyncCachedPropertyDescriptor[T],
**modifiers: Unpack[ModifierKwargs]
) -> AsyncCachedPropertyDescriptor:...
return _helpers._await(coro) if self.should_await(kwargs) else coro # type: ignore [call-overload, return-value]
@functools.cached_property
def __bound_to_a_sync_instance__(self) -> bool:
from a_sync.abstract import ASyncABC
return isinstance(self.instance, ASyncABC)
def should_await(self, kwargs: dict) -> bool:
if flag := _kwargs.get_flag_name(kwargs):
return _kwargs.is_sync(flag, kwargs, pop_flag=True) # type: ignore [arg-type]
elif self.default:
return self.default == "sync"
elif self.__bound_to_a_sync_instance__:
return self.instance.__a_sync_should_await__(kwargs)
return asyncio.iscoroutinefunction(self.__wrapped__)

def _wrap_property(
async_property: Union[AsyncPropertyDescriptor[T], AsyncCachedPropertyDescriptor[T]],
**modifiers: Unpack[ModifierKwargs]
) -> Tuple[Property[T], HiddenMethod[T]]:
if not isinstance(async_property, (AsyncPropertyDescriptor, AsyncCachedPropertyDescriptor)):
raise TypeError(f"{async_property} must be one of: AsyncPropertyDescriptor, AsyncCachedPropertyDescriptor")

from a_sync.abstract import ASyncABC
class ASyncBoundMethodSyncDefault(ASyncBoundMethod[P, T]):
def __get__(self, instance: ASyncInstance, owner) -> ASyncFunctionSyncDefault[P, T]:
return super().__get__(instance, owner)
def __call__(self, *args, **kwargs) -> T:
return super().__call__(*args, **kwargs)

async_property.hidden_method_name = f"__{async_property.field_name}__"

modifiers, _force_await = _clean_default_from_modifiers(async_property, modifiers)

@unbound_a_sync(**modifiers)
async def _get(instance: ASyncABC) -> T:
return await async_property.__get__(instance, async_property)

@functools.wraps(async_property)
def a_sync_method(self: ASyncABC, **kwargs) -> T:
if not isinstance(self, ASyncABC):
raise RuntimeError(f"{self} must be an instance of a class that inherits from ASyncABC.")
awaitable = _PropertyGetter(_get(self), async_property)
return _helpers._await(awaitable) if self.__a_sync_should_await__(kwargs, force=_force_await) else awaitable

@property # type: ignore [misc]
@functools.wraps(async_property)
def a_sync_property(self: ASyncABC) -> T:
coro = getattr(self, async_property.hidden_method_name)(sync=False)
return _helpers._await(coro) if self.__a_sync_should_await__({}, force=_force_await) else coro

return a_sync_property, a_sync_method
class ASyncBoundMethodAsyncDefault(ASyncBoundMethod[P, T]):
def __get__(self, instance: ASyncInstance, owner) -> ASyncFunctionAsyncDefault[P, T]:
return super().__get__(instance, owner)
def __call__(self, *args, **kwargs) -> Awaitable[T]:
return super().__call__(*args, **kwargs)
24 changes: 24 additions & 0 deletions a_sync/_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

import functools

from a_sync._typing import *
from a_sync.modified import ASyncFunction, ModifiedMixin, ModifierManager

class ASyncDescriptor(ModifiedMixin, Generic[T]):
def __init__(self, _fget: UnboundMethod[ASyncInstance, P, T], field_name=None, **modifiers: ModifierKwargs):
if not callable(_fget):
raise ValueError(f'Unable to decorate {_fget}')
self.modifiers = ModifierManager(modifiers)
self._fn: UnboundMethod[ASyncInstance, P, T] = _fget
if isinstance(_fget, ASyncFunction):
self._fget = _fget
elif asyncio.iscoroutinefunction(_fget):
self._fget: AsyncUnboundMethod[ASyncInstance, P, T] = self.modifiers.apply_async_modifiers(_fget)
else:
self._fget = self._asyncify(_fget)
self.field_name = field_name or _fget.__name__
functools.update_wrapper(self, _fget)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} for {self._fn}>"
def __set_name__(self, owner, name):
self.field_name = name
9 changes: 5 additions & 4 deletions a_sync/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from async_property.cached import \
AsyncCachedPropertyDescriptor # type: ignore [import]

from a_sync import _flags
from a_sync import _flags, exceptions
from a_sync._typing import *
from a_sync.exceptions import (ASyncRuntimeError, KwargsUnsupportedError,
SyncModeInAsyncContextError)
from a_sync.modified import ASyncFunction


def get_event_loop() -> asyncio.AbstractEventLoop:
Expand Down Expand Up @@ -44,10 +43,12 @@ def _await(awaitable: Awaitable[T]) -> T:
return get_event_loop().run_until_complete(awaitable)
except RuntimeError as e:
if str(e) == "This event loop is already running":
raise SyncModeInAsyncContextError from e
raise exceptions.SyncModeInAsyncContextError from e
raise e

def _asyncify(func: SyncFn[P, T], executor: Executor) -> CoroFn[P, T]: # type: ignore [misc]
if asyncio.iscoroutinefunction(func) or isinstance(func, ASyncFunction):
raise exceptions.FunctionNotSync(func)
@functools.wraps(func)
async def _asyncify_wrap(*args: P.args, **kwargs: P.kwargs) -> T:
return await asyncio.futures.wrap_future(
Expand Down
21 changes: 10 additions & 11 deletions a_sync/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from abc import ABCMeta
from typing import Any, Dict, Tuple

from a_sync import ENVIRONMENT_VARIABLES, _bound, modifiers
from a_sync import ENVIRONMENT_VARIABLES, modifiers
from a_sync._bound import ASyncMethodDescriptor
from a_sync.future import _ASyncFutureWrappedFn # type: ignore [attr-defined]
from a_sync.modified import ASyncFunction, ModifiedMixin
from a_sync.property import PropertyDescriptor
from a_sync.property import ASyncPropertyDescriptor, ASyncCachedPropertyDescriptor
from a_sync.primitives.locks.semaphore import Semaphore

logger = logging.getLogger(__name__)

Expand All @@ -30,7 +32,7 @@ def __new__(cls, new_class_name, bases, attrs):
elif "__" in attr_name:
logger.debug(f"`%s.%s` incluldes a double-underscore, skipping", new_class_name, attr_name)
continue
elif isinstance(attr_value, _ASyncFutureWrappedFn):
elif isinstance(attr_value, (_ASyncFutureWrappedFn, Semaphore)):
logger.debug(f"`%s.%s` is a %s, skipping", new_class_name, attr_name, attr_value.__class__.__name__)
continue
logger.debug(f"inspecting `{new_class_name}.{attr_name}` of type {attr_value.__class__.__name__}")
Expand All @@ -46,23 +48,20 @@ def __new__(cls, new_class_name, bases, attrs):
else:
logger.debug(f"I did not find any modifiers")
logger.debug(f"full modifier set for `{new_class_name}.{attr_name}`: {fn_modifiers}")
if isinstance(attr_value, PropertyDescriptor):
if isinstance(attr_value, (ASyncPropertyDescriptor, ASyncCachedPropertyDescriptor)):
# Wrap property
logger.debug(f"`{attr_name} is a property, now let's wrap it")
wrapped, hidden = _bound._wrap_property(attr_value, **fn_modifiers)
attrs[attr_name] = wrapped
logger.debug(f"`{attr_name}` is now `{wrapped}`")
logger.debug(f"since `{attr_name}` is a property, we will add a hidden dundermethod so you can still access it both sync and async")
attrs[attr_value.hidden_method_name] = hidden
logger.debug(f"`{new_class_name}.{attr_value.hidden_method_name}` is now {hidden}")
attrs[attr_value.hidden_method_name] = attr_value.hidden_method_descriptor
logger.debug(f"`{new_class_name}.{attr_value.hidden_method_name}` is now {attr_value.hidden_method_descriptor}")
elif isinstance(attr_value, ASyncFunction):
attrs[attr_name] = _bound._wrap_bound_method(attr_value, **fn_modifiers)
attrs[attr_name] = ASyncMethodDescriptor(attr_value, **fn_modifiers)
else:
raise NotImplementedError(attr_name, attr_value)

elif callable(attr_value):
# NOTE We will need to improve this logic if somebody needs to use it with classmethods or staticmethods.
attrs[attr_name] = _bound._wrap_bound_method(attr_value, **fn_modifiers)
attrs[attr_name] = ASyncMethodDescriptor(attr_value, **fn_modifiers)
else:
logger.debug(f"`{new_class_name}.{attr_name}` is not callable, we will take no action with it")
return super(ASyncMeta, cls).__new__(cls, new_class_name, bases, attrs)
Expand Down
15 changes: 10 additions & 5 deletions a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,31 @@

if TYPE_CHECKING:
from a_sync.abstract import ASyncABC
ASyncInstance = TypeVar("ASyncInstance", bound=ASyncABC)
else:
ASyncInstance = TypeVar("ASyncInstance", bound=object)

T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
O = TypeVar("O", bound=object)
E = TypeVar('E', bound=Exception)
P = ParamSpec("P")

Numeric = Union[int, float, Decimal]

MaybeAwaitable = Union[Awaitable[T], T]

Property = Callable[["ASyncABC"], T]
HiddenMethod = Callable[["ASyncABC", Dict[str, bool]], T]
AsyncBoundMethod = Callable[Concatenate["ASyncABC", P], Awaitable[T]]
BoundMethod = Callable[Concatenate["ASyncABC", P], T]

CoroFn = Callable[P, Awaitable[T]]
SyncFn = Callable[P, T]
AnyFn = Union[CoroFn[P, T], SyncFn[P, T]]

AsyncUnboundMethod = Callable[Concatenate[ASyncInstance, P], Awaitable[T]]
SyncUnboundMethod = Callable[Concatenate[ASyncInstance, P], T]
UnboundMethod = Union[AsyncUnboundMethod[ASyncInstance, P, T], SyncUnboundMethod[ASyncInstance, P, T]]

Property = Callable[[object], T]

AsyncDecorator = Callable[[CoroFn[P, T]], CoroFn[P, T]]

AllToAsyncDecorator = Callable[[AnyFn[P, T]], CoroFn[P, T]]
Expand Down
Loading

0 comments on commit 6ad2ddc

Please sign in to comment.