Skip to content

Commit

Permalink
feat: type check Protocols (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Mar 1, 2024
1 parent 771a369 commit 50d73a3
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 124 deletions.
58 changes: 26 additions & 32 deletions a_sync/_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,30 @@
from a_sync._typing import *
from a_sync.modified import ASyncFunction, ASyncFunctionAsyncDefault, ASyncFunctionSyncDefault

if TYPE_CHECKING:
from a_sync.abstract import ASyncABC

logger = logging.getLogger(__name__)

class ASyncMethodDescriptor(ASyncDescriptor[ASyncFunction[P, T]], Generic[O, P, T]):
wrapped: ASyncFunction[Concatenate[O, P], T]
def __get__(self, instance: O, owner) -> "ASyncBoundMethod[P, T]":
__wrapped__: AnyFn[Concatenate[O, P], T]
def __get__(self, instance: O, owner: Any) -> "ASyncBoundMethod[P, T]":
if instance is None:
return self
try:
return instance.__dict__[self.field_name]
except KeyError:
from a_sync.abstract import ASyncABC
if self.default == "sync":
bound = ASyncBoundMethodSyncDefault(instance, self.wrapped, **self.modifiers)
bound = ASyncBoundMethodSyncDefault(instance, self.__wrapped__, **self.modifiers)
elif self.default == "async":
bound = ASyncBoundMethodAsyncDefault(instance, self.wrapped, **self.modifiers)
bound = ASyncBoundMethodAsyncDefault(instance, self.__wrapped__, **self.modifiers)
elif isinstance(instance, ASyncABC) and instance.__a_sync_instance_should_await__:
bound = ASyncBoundMethodSyncDefault(instance, self.wrapped, **self.modifiers)
bound = ASyncBoundMethodSyncDefault(instance, self.__wrapped__, **self.modifiers)
elif isinstance(instance, ASyncABC) and instance.__a_sync_instance_should_await__:
bound = ASyncBoundMethodAsyncDefault(instance, self.wrapped, **self.modifiers)
bound = ASyncBoundMethodAsyncDefault(instance, self.__wrapped__, **self.modifiers)
else:
bound = ASyncBoundMethod(instance, self.wrapped, **self.modifiers)
bound = ASyncBoundMethod(instance, self.__wrapped__, **self.modifiers)
instance.__dict__[self.field_name] = bound
logger.debug("new bound method: %s", bound)
return bound
Expand All @@ -39,32 +41,32 @@ def __set__(self, instance, value):
def __delete__(self, instance):
raise RuntimeError(f"cannot delete {self.field_name}, you're stuck with {self} forever. sorry.")

class ASyncMethodDescriptorSyncDefault(ASyncMethodDescriptor[ASyncInstance, P, T]):
def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethodSyncDefault[P, T]":
class ASyncMethodDescriptorSyncDefault(ASyncMethodDescriptor[O, P, T]):
def __get__(self, instance: O, owner: Any) -> "ASyncBoundMethodSyncDefault[P, T]":
if instance is None:
return self
try:
return instance.__dict__[self.field_name]
except KeyError:
bound = ASyncBoundMethodSyncDefault(instance, self.wrapped, **self.modifiers)
bound = ASyncBoundMethodSyncDefault(instance, self.__wrapped__, **self.modifiers)
instance.__dict__[self.field_name] = bound
logger.debug("new bound method: %s", bound)
return bound

class ASyncMethodDescriptorAsyncDefault(ASyncMethodDescriptor[ASyncInstance, P, T]):
def __get__(self, instance: ASyncInstance, owner) -> "ASyncBoundMethodAsyncDefault[P, T]":
class ASyncMethodDescriptorAsyncDefault(ASyncMethodDescriptor[O, P, T]):
def __get__(self, instance: O, owner: Any) -> "ASyncBoundMethodAsyncDefault[P, T]":
if instance is None:
return self
try:
return instance.__dict__[self.field_name]
except KeyError:
bound = ASyncBoundMethodAsyncDefault(instance, self.wrapped, **self.modifiers)
bound = ASyncBoundMethodAsyncDefault(instance, self.__wrapped__, **self.modifiers)
instance.__dict__[self.field_name] = bound
logger.debug("new bound method: %s", bound)
return bound

class ASyncBoundMethod(ASyncFunction[P, T]):
__slots__ = "__unbound__", "__self__"
__slots__ = "__self__",
def __init__(
self,
instance: O,
Expand All @@ -75,54 +77,46 @@ def __init__(
# First we unwrap the coro_fn and rewrap it so overriding flag kwargs are handled automagically.
if isinstance(unbound, ASyncFunction):
modifiers.update(unbound.modifiers)
print(modifiers)
self.__unbound__ = unbound.wrapped
else:
self.__unbound__ = unbound
bound = self._bound_async if asyncio.iscoroutinefunction(self.__unbound__) else self._bound_sync
super().__init__(bound, **modifiers)
functools.update_wrapper(self, self.__unbound__)
unbound = unbound.__wrapped__
super().__init__(unbound, **modifiers)
functools.update_wrapper(self, unbound)
def __repr__(self) -> str:
instance_type = type(self.instance)
instance_type = type(self.__self__)
return f"<{self.__class__.__name__} for function {instance_type.__module__}.{instance_type.__name__}.{self.__name__} bound to {self.__self__}>"
def __call__(self, *args, **kwargs):
logger.debug("calling %s", self)
# This could either be a coroutine or a return value from an awaited coroutine,
# depending on if an overriding flag kwarg was passed into the function call.
retval = coro = super().__call__(*args, **kwargs)
retval = coro = super().__call__(self.__self__, *args, **kwargs)
if not isawaitable(retval):
# The coroutine was already awaited due to the use of an overriding flag kwarg.
# We can return the value.
return retval # type: ignore [return-value]
# The awaitable was not awaited, so now we need to check the flag as defined on 'self' and await if appropriate.
return _helpers._await(coro) if self.should_await(kwargs) else coro # type: ignore [call-overload, return-value]
return _helpers._await(coro) if self._should_await(kwargs) else coro # type: ignore [call-overload, return-value]
@functools.cached_property
def __bound_to_a_sync_instance__(self) -> bool:
from a_sync.abstract import ASyncABC
return isinstance(self.__self__, ASyncABC)
def should_await(self, kwargs: dict) -> bool:
def _should_await(self, kwargs: dict) -> bool:
if flag := _kwargs.get_flag_name(kwargs):
return _kwargs.is_sync(flag, kwargs, pop_flag=True) # type: ignore [arg-type]
elif self.default:
return self.default == "sync"
elif self.__bound_to_a_sync_instance__:
self.__self__: "ASyncABC"
return self.__self__.__a_sync_should_await__(kwargs)
return asyncio.iscoroutinefunction(self.wrapped)
def _bound_sync(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.__unbound__(self.__self__, *args, **kwargs)
async def _bound_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
return await self.__unbound__(self.__self__, *args, **kwargs)
return asyncio.iscoroutinefunction(self.__wrapped__)


class ASyncBoundMethodSyncDefault(ASyncBoundMethod[P, T]):
def __get__(self, instance: ASyncInstance, owner) -> ASyncFunctionSyncDefault[P, T]:
def __get__(self, instance: O, owner: Any) -> ASyncFunctionSyncDefault[P, T]:
return super().__get__(instance, owner)
def __call__(self, *args, **kwargs) -> T:
return super().__call__(*args, **kwargs)

class ASyncBoundMethodAsyncDefault(ASyncBoundMethod[P, T]):
def __get__(self, instance: ASyncInstance, owner) -> ASyncFunctionAsyncDefault[P, T]:
def __get__(self, instance: O, owner: Any) -> ASyncFunctionAsyncDefault[P, T]:
return super().__get__(instance, owner)
def __call__(self, *args, **kwargs) -> Awaitable[T]:
return super().__call__(*args, **kwargs)
13 changes: 7 additions & 6 deletions a_sync/_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@ class ASyncDescriptor(ModifiedMixin, Generic[T]):
__slots__ = "field_name", "_fget"
def __init__(
self,
_fget: AnyUnboundMethod[ASyncInstance, P, T],
_fget: AnyUnboundMethod[P, T],
field_name: Optional[str] = None,
**modifiers: ModifierKwargs,
) -> None:
if not callable(_fget):
raise ValueError(f'Unable to decorate {_fget}')
self.modifiers = ModifierManager(modifiers)
if isinstance(_fget, ASyncFunction):
self.wrapped = _fget
self.modifiers.update(_fget.modifiers)
self.__wrapped__ = _fget
elif asyncio.iscoroutinefunction(_fget):
self.wrapped: AsyncUnboundMethod[ASyncInstance, P, T] = self.modifiers.apply_async_modifiers(_fget)
self.__wrapped__: AsyncUnboundMethod[P, T] = self.modifiers.apply_async_modifiers(_fget)
else:
self.wrapped = self._asyncify(_fget)
self.__wrapped__ = _fget
self.field_name = field_name or _fget.__name__
functools.update_wrapper(self, _fget)
functools.update_wrapper(self, self.__wrapped__)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} for {self.wrapped}>"
return f"<{self.__class__.__name__} for {self.__wrapped__}>"
def __set_name__(self, owner, name):
self.field_name = name
38 changes: 14 additions & 24 deletions a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,11 @@

from typing_extensions import Concatenate, ParamSpec, Self, Unpack

if TYPE_CHECKING:
from a_sync.abstract import ASyncABC
ASyncInstance = TypeVar("ASyncInstance", bound=ASyncABC)
else:
ASyncInstance = TypeVar("ASyncInstance", bound=object)

T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
I = TypeVar("I")
O = TypeVar("O", bound=object)
E = TypeVar('E', bound=Exception)
TYPE = TypeVar("TYPE", bound=Type)
Expand All @@ -32,30 +28,24 @@
SyncFn = Callable[P, T]
AnyFn = Union[CoroFn[P, T], SyncFn[P, T]]

class CoroBoundMethod(Protocol[O, P, T]):
__self__: O
class CoroBoundMethod(Protocol[I, P, T]):
__self__: I
__call__: Callable[P, Awaitable[T]]
class SyncBoundMethod(Protocol[O, P, T]):
__self__: O
class SyncBoundMethod(Protocol[I, P, T]):
__self__: I
__call__: Callable[P, T]
AnyBoundMethod = Union[CoroBoundMethod[Any, P, T], SyncBoundMethod[Any, P, T]]

class CoroClassMethod(Protocol[TYPE, P, T]):
__self__: TYPE
__call__: Callable[P, Awaitable[T]]
class SyncClassMethod(Protocol[TYPE, P, T]):
__self__: TYPE
__call__: Callable[P, Awaitable[T]]
AnyClassMethod = Union[CoroClassMethod[type, P, T], SyncClassMethod[type, P, T]]

class AsyncUnboundMethod(Protocol[O, P, T]):
__get__: Callable[[O, None], CoroBoundMethod[O, P, T]]
class SyncUnboundMethod(Protocol[O, P, T]):
__get__: Callable[[O, None], SyncBoundMethod[O, P, T]]
AnyUnboundMethod = Union[AsyncUnboundMethod[O, P, T], SyncUnboundMethod[O, P, T]]
@runtime_checkable
class AsyncUnboundMethod(Protocol[P, T]):
__get__: Callable[[I, None], CoroBoundMethod[I, P, T]]
@runtime_checkable
class SyncUnboundMethod(Protocol[P, T]):
__get__: Callable[[I, None], SyncBoundMethod[I, P, T]]
AnyUnboundMethod = Union[AsyncUnboundMethod[P, T], SyncUnboundMethod[P, T]]

class AsyncPropertyGetter(CoroBoundMethod[object, tuple, T]):...
class PropertyGetter(SyncBoundMethod[object, tuple, T]):...
class AsyncPropertyGetter(CoroBoundMethod[Any, Tuple[()], T]):...
class PropertyGetter(SyncBoundMethod[Any, Tuple[()], T]):...
AnyPropertyGetter = Union[AsyncPropertyGetter[T], PropertyGetter[T]]

AsyncDecorator = Callable[[CoroFn[P, T]], CoroFn[P, T]]
Expand Down
34 changes: 7 additions & 27 deletions a_sync/modified.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(self, fn: SyncFn[P, T], **modifiers: Unpack[ModifierKwargs]) -> Non
def __init__(self, fn: AnyFn[P, T], **modifiers: Unpack[ModifierKwargs]) -> None:
_helpers._validate_wrapped_fn(fn)
self.modifiers = ModifierManager(modifiers)
self.wrapped = fn
functools.update_wrapper(self, self.wrapped)
self.__wrapped__ = fn
functools.update_wrapper(self, self.__wrapped__)

@overload
def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T:...
Expand Down Expand Up @@ -61,7 +61,7 @@ def _sync_default(self) -> bool:

@property
def _async_def(self) -> bool:
return asyncio.iscoroutinefunction(self.wrapped)
return asyncio.iscoroutinefunction(self.__wrapped__)

def _run_sync(self, kwargs: dict) -> bool:
if flag := _kwargs.get_flag_name(kwargs):
Expand All @@ -75,7 +75,7 @@ def _run_sync(self, kwargs: dict) -> bool:
def _asyncified(self) -> CoroFn[P, T]:
"""Turns 'self._fn' async and applies both sync and async modifiers."""
if self._async_def:
raise TypeError(f"Can only be applied to sync functions, not {self.wrapped}")
raise TypeError(f"Can only be applied to sync functions, not {self.__wrapped__}")
return self._asyncify(self._modified_fn) # type: ignore [arg-type]

@functools.cached_property
Expand All @@ -85,8 +85,8 @@ def _modified_fn(self) -> AnyFn[P, T]:
Applies async modifiers to 'self._fn' if 'self._fn' is a sync function.
"""
if self._async_def:
return self.modifiers.apply_async_modifiers(self.wrapped) # type: ignore [arg-type]
return self.modifiers.apply_sync_modifiers(self.wrapped) # type: ignore [return-value]
return self.modifiers.apply_async_modifiers(self.__wrapped__) # type: ignore [arg-type]
return self.modifiers.apply_sync_modifiers(self.__wrapped__) # type: ignore [return-value]

@functools.cached_property
def _async_wrap(self): # -> SyncFn[[CoroFn[P, T]], MaybeAwaitable[T]]:
Expand Down Expand Up @@ -167,34 +167,14 @@ class ASyncDecoratorSyncDefault(ASyncDecorator):
@overload
def __call__(self, func: AnyBoundMethod[P, T]) -> ASyncFunctionSyncDefault[P, T]: # type: ignore [override]
...
# classmethod matchers
# NOTE: these could potentially match improperly if you pass types around thru your functions but that is kinda rare and should be a non-issue for the purposes of this lib
@overload
def __call__(self, func: AnyClassMethod[P, T]) -> ASyncFunctionSyncDefault[P, T]: # type: ignore [override]
...
@overload
def __call__(self, func: CoroFn[P, T]) -> ASyncFunctionSyncDefault[P, T]: # type: ignore [override]
...
@overload
def __call__(self, func: SyncFn[P, T]) -> ASyncFunctionSyncDefault[P, T]: # type: ignore [override]
...
def __call__(self, func: AnyFn[P, T]) -> ASyncFunctionSyncDefault[P, T]: # type: ignore [override]
...

class ASyncDecoratorAsyncDefault(ASyncDecorator):
@overload
def __call__(self, func: AnyBoundMethod[P, T]) -> ASyncFunctionAsyncDefault[P, T]: # type: ignore [override]
...
# classmethod matchers
# NOTE: these could potentially match improperly if you pass types around thru your functions but that is kinda rare and should be a non-issue for the purposes of this lib
@overload
def __call__(self, func: AnyClassMethod[P, T]) -> ASyncFunctionAsyncDefault[P, T]: # type: ignore [override]
...
# regular functions
@overload
def __call__(self, func: CoroFn[P, T]) -> ASyncFunctionAsyncDefault[P, T]: # type: ignore [override]
...
@overload
def __call__(self, func: SyncFn[P, T]) -> ASyncFunctionAsyncDefault[P, T]: # type: ignore [override]
def __call__(self, func: AnyFn[P, T]) -> ASyncFunctionAsyncDefault[P, T]: # type: ignore [override]
...

15 changes: 9 additions & 6 deletions a_sync/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
logger = logging.getLogger(__name__)

class _ASyncPropertyDescriptorBase(ASyncDescriptor[T]):
wrapped: AsyncPropertyGetter[T]
__wrapped__: AsyncPropertyGetter[T]
__slots__ = "hidden_method_name", "hidden_method_descriptor", "_fget"
def __init__(
self,
Expand All @@ -26,10 +26,13 @@ def __init__(
hidden_modifiers = dict(self.modifiers)
hidden_modifiers["default"] = "async"
self.hidden_method_descriptor = HiddenMethodDescriptor(self.get, self.hidden_method_name, **hidden_modifiers)
self._fget = self.wrapped
if asyncio.iscoroutinefunction(_fget):
self._fget = self.__wrapped__
else:
self._fget = _helpers._asyncify(self.__wrapped__, self.modifiers.executor)
async def get(self, instance: object) -> T:
return await super().__get__(instance, None)
def __get__(self, instance: object, owner) -> T:
def __get__(self, instance: object, owner: Any) -> T:
awaitable = super().__get__(instance, owner)
# if the user didn't specify a default behavior, we will defer to the instance
if _is_a_sync_instance(instance):
Expand All @@ -48,7 +51,7 @@ class ASyncPropertyDescriptorSyncDefault(property[T]):

class ASyncPropertyDescriptorAsyncDefault(property[T]):
"""This is a helper class used for type checking. You will not run into any instance of this in prod."""
def __get__(self, instance, owner) -> Awaitable[T]:
def __get__(self, instance, owner: Any) -> Awaitable[T]:
return super().__get__(instance, owner)


Expand Down Expand Up @@ -248,8 +251,8 @@ def _should_await(self, kwargs: dict) -> bool:
except (AttributeError, exceptions.NoFlagsFound):
return False

class HiddenMethodDescriptor(ASyncMethodDescriptorAsyncDefault[ASyncInstance, P, T]):
def __get__(self, instance: ASyncInstance, owner) -> HiddenMethod[ASyncInstance, T]:
class HiddenMethodDescriptor(ASyncMethodDescriptorAsyncDefault[O, P, T]):
def __get__(self, instance: O, owner: Any) -> HiddenMethod[O, T]:
if instance is None:
return self
try:
Expand Down
Loading

0 comments on commit 50d73a3

Please sign in to comment.