Skip to content

Commit

Permalink
chore: cleanup Generic types (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Mar 1, 2024
1 parent 21978a4 commit 330e762
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 22 deletions.
37 changes: 21 additions & 16 deletions a_sync/_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

logger = logging.getLogger(__name__)

class ASyncMethodDescriptor(ASyncDescriptor[ASyncFunction[P, T]], Generic[O, P, T]):
__wrapped__: AnyFn[Concatenate[O, P], T]
def __get__(self, instance: O, owner: Any) -> "ASyncBoundMethod[P, T]":
class ASyncMethodDescriptor(ASyncDescriptor[ASyncFunction[P, T]], Generic[I, P, T]):
__wrapped__: AnyFn[Concatenate[I, P], T]
def __get__(self, instance: I, owner: Any) -> "ASyncBoundMethod[I, P, T]":
if instance is None:
return self
try:
Expand All @@ -41,8 +41,8 @@ def __set__(self, instance, value):
def __delete__(self, instance):
raise RuntimeError(f"cannot delete {self.field_name}, you're stuck with {self} forever. sorry.")

class ASyncMethodDescriptorSyncDefault(ASyncMethodDescriptor[O, P, T]):
def __get__(self, instance: O, owner: Any) -> "ASyncBoundMethodSyncDefault[P, T]":
class ASyncMethodDescriptorSyncDefault(ASyncMethodDescriptor[I, P, T]):
def __get__(self, instance: I, owner: Any) -> "ASyncBoundMethodSyncDefault[I, P, T]":
if instance is None:
return self
try:
Expand All @@ -53,8 +53,8 @@ def __get__(self, instance: O, owner: Any) -> "ASyncBoundMethodSyncDefault[P, T]
logger.debug("new bound method: %s", bound)
return bound

class ASyncMethodDescriptorAsyncDefault(ASyncMethodDescriptor[O, P, T]):
def __get__(self, instance: O, owner: Any) -> "ASyncBoundMethodAsyncDefault[P, T]":
class ASyncMethodDescriptorAsyncDefault(ASyncMethodDescriptor[I, P, T]):
def __get__(self, instance: I, owner: Any) -> "ASyncBoundMethodAsyncDefault[I, T]":
if instance is None:
return self
try:
Expand All @@ -65,12 +65,12 @@ def __get__(self, instance: O, owner: Any) -> "ASyncBoundMethodAsyncDefault[P, T
logger.debug("new bound method: %s", bound)
return bound

class ASyncBoundMethod(ASyncFunction[P, T]):
class ASyncBoundMethod(ASyncFunction[P, T], Generic[I, P, T]):
__slots__ = "__self__",
def __init__(
self,
instance: O,
unbound: AnyFn[Concatenate[O, P], T],
instance: I,
unbound: AnyFn[Concatenate[I, P], T],
**modifiers: Unpack[ModifierKwargs],
) -> None:
self.__self__ = instance
Expand Down Expand Up @@ -98,6 +98,9 @@ def __call__(self, *args, **kwargs):
def __bound_to_a_sync_instance__(self) -> bool:
from a_sync.abstract import ASyncABC
return isinstance(self.__self__, ASyncABC)
@functools.cached_property
def __is_async_def__(self) -> bool:
return asyncio.iscoroutinefunction(self.__wrapped__)
def _should_await(self, kwargs: dict) -> bool:
if flag := _kwargs.get_flag_name(kwargs):
return _kwargs.is_sync(flag, kwargs, pop_flag=True) # type: ignore [arg-type]
Expand All @@ -106,17 +109,19 @@ def _should_await(self, kwargs: dict) -> bool:
elif self.__bound_to_a_sync_instance__:
self.__self__: "ASyncABC"
return self.__self__.__a_sync_should_await__(kwargs)
return asyncio.iscoroutinefunction(self.__wrapped__)
return self.__is_async_def__


class ASyncBoundMethodSyncDefault(ASyncBoundMethod[P, T]):
def __get__(self, instance: O, owner: Any) -> ASyncFunctionSyncDefault[P, T]:
class ASyncBoundMethodSyncDefault(ASyncBoundMethod[I, P, T]):
"""just a helper for your IDE's typing tools"""
def __get__(self, instance: Any, owner: Any) -> ASyncFunctionSyncDefault[P, T]:
return super().__get__(instance, owner)
def __call__(self, *args, **kwargs) -> T:
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return super().__call__(*args, **kwargs)

class ASyncBoundMethodAsyncDefault(ASyncBoundMethod[P, T]):
def __get__(self, instance: O, owner: Any) -> ASyncFunctionAsyncDefault[P, T]:
class ASyncBoundMethodAsyncDefault(ASyncBoundMethod[I, P, T]):
"""just a helper for your IDE's typing tools"""
def __get__(self, instance: I, owner: Any) -> ASyncFunctionAsyncDefault[P, T]:
return super().__get__(instance, owner)
def __call__(self, *args, **kwargs) -> Awaitable[T]:
return super().__call__(*args, **kwargs)
1 change: 0 additions & 1 deletion a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
K = TypeVar("K")
V = TypeVar("V")
I = TypeVar("I")
O = TypeVar("O", bound=object)
E = TypeVar('E', bound=Exception)
TYPE = TypeVar("TYPE", bound=Type)
P = ParamSpec("P")
Expand Down
10 changes: 5 additions & 5 deletions a_sync/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
self.hidden_method_name = f"__{self.field_name}__"
hidden_modifiers = dict(self.modifiers)
hidden_modifiers["default"] = "async"
self.hidden_method_descriptor = HiddenMethodDescriptor(self.get, self.hidden_method_name, **hidden_modifiers)
self.hidden_method_descriptor: HiddenMethodDescriptor[T] = HiddenMethodDescriptor(self.get, self.hidden_method_name, **hidden_modifiers)
if asyncio.iscoroutinefunction(_fget):
self._fget = self.__wrapped__
else:
Expand Down Expand Up @@ -238,8 +238,8 @@ def a_sync_cached_property( # type: ignore [misc]
decorator = functools.partial(descriptor_class, **modifiers)
return decorator if func is None else decorator(func)

class HiddenMethod(ASyncBoundMethodAsyncDefault[O, T]):
def __init__(self, instance: O, unbound: AnyFn[Concatenate[O, P], T], field_name: str, **modifiers: _helpers.ModifierKwargs) -> None:
class HiddenMethod(ASyncBoundMethodAsyncDefault[I, Tuple[()], T]):
def __init__(self, instance: I, unbound: AnyFn[Concatenate[I, P], T], field_name: str, **modifiers: _helpers.ModifierKwargs) -> None:
super().__init__(instance, unbound, **modifiers)
self.__name__ = field_name
def __repr__(self) -> str:
Expand All @@ -253,8 +253,8 @@ def _should_await(self, kwargs: dict) -> bool:
def __await__(self) -> Generator[Any, None, T]:
return self().__await__()

class HiddenMethodDescriptor(ASyncMethodDescriptorAsyncDefault[O, P, T]):
def __get__(self, instance: O, owner: Any) -> HiddenMethod[O, T]:
class HiddenMethodDescriptor(ASyncMethodDescriptorAsyncDefault[I, Tuple[()], T]):
def __get__(self, instance: I, owner: Any) -> HiddenMethod[I, T]:
if instance is None:
return self
try:
Expand Down

0 comments on commit 330e762

Please sign in to comment.