Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: mypy #95

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion a_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from a_sync import aliases
from a_sync.base import ASyncGenericBase
from a_sync.decorator import a_sync
from a_sync.future import ASyncFuture, future
from a_sync.future import ASyncFuture, future # type: ignore [attr-defined]
from a_sync.iter import ASyncIterable, ASyncIterator
from a_sync.modifiers.semaphores import apply_semaphore
from a_sync.primitives import *
Expand Down
20 changes: 14 additions & 6 deletions a_sync/modified.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def default(self) -> DefaultMode:
return self.modifiers.default


class ASyncFunction(Modified[T], Callable[P, T]):
class ASyncFunction(Modified[T], Callable[P, T], Generic[P, T]):
@overload
def __init__(self, fn: CoroFn[P, T], **modifiers: Unpack[ModifierKwargs]) -> None:...
@overload
Expand All @@ -34,10 +34,18 @@ def __init__(self, fn: AnyFn[P, T], **modifiers: Unpack[ModifierKwargs]) -> None
self.__name__ = fn.__name__
self.__wrapped__ = fn

@overload
def __call__(self, *args: P.args, sync: Literal[True] = True, **kwargs: P.kwargs) -> T:...
@overload
def __call__(self, *args: P.args, sync: Literal[False] = False, **kwargs: P.kwargs) -> Awaitable[T]:...
@overload
def __call__(self, *args: P.args, asynchronous: Literal[False] = False, **kwargs: P.kwargs) -> T:...
@overload
def __call__(self, *args: P.args, asynchronous: Literal[True] = True, **kwargs: P.kwargs) -> Awaitable[T]:...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeAwaitable[T]:
return self.fn(*args, **kwargs)

def __repr__(self) -> str:
async def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.__name__} at {hex(id(self))}>"

@functools.cached_property
Expand Down Expand Up @@ -66,7 +74,7 @@ def _run_sync(self, kwargs: dict):
def _asyncified(self) -> CoroFn[P, T]:
"""Turns 'self._fn' async and applies both sync and async modifiers."""
assert not self._async_def, f"Can only be applied to sync functions, not {self.__wrapped__}"
return self._asyncify(self._modified_fn)
return self._asyncify(self._modified_fn) # type: ignore [arg-type]

@functools.cached_property
def _modified_fn(self) -> AnyFn[P, T]:
Expand All @@ -75,8 +83,8 @@ def _modified_fn(self) -> AnyFn[P, T]:
Applies async modifiers to 'self._fn' if 'self._fn' is a sync function.
"""
if self._async_def:
return self.modifiers.apply_async_modifiers(self.__wrapped__)
return self.modifiers.apply_sync_modifiers(self.__wrapped__)
return self.modifiers.apply_async_modifiers(self.__wrapped__) # type: ignore [arg-type]
return self.modifiers.apply_sync_modifiers(self.__wrapped__) # type: ignore [return-value]

@functools.cached_property
def _async_wrap(self): # -> SyncFn[[CoroFn[P, T]], MaybeAwaitable[T]]:
Expand Down Expand Up @@ -114,5 +122,5 @@ def validate_inputs(self) -> None:
if self.default not in ['sync', 'async', None]:
raise ValueError(f"'default' must be either 'sync', 'async', or None. You passed {self.default}.")

def __call__(self, func: AnyFn[P, T]) -> ASyncFunction[P, T]:
def __call__(self, func: AnyFn[P, T]) -> ASyncFunction[P, T]: # type: ignore [override]
return ASyncFunction(func, **self.modifiers)
4 changes: 2 additions & 2 deletions a_sync/primitives/locks/counter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from collections import defaultdict
from time import time
from typing import Iterable, Optional
from typing import DefaultDict, Iterable, Optional

from a_sync.primitives._debug import _DebugDaemonMixin
from a_sync.primitives.locks.event import Event
Expand All @@ -20,7 +20,7 @@ class CounterLock(_DebugDaemonMixin):
def __init__(self, start_value: int = 0, name: Optional[str] = None):
self._name = name
self._value = start_value
self._events = defaultdict(Event)
self._events: DefaultDict[int, Event] = defaultdict(Event)
self.is_ready = lambda v: self._value >= v

async def wait_for(self, value: int) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion a_sync/utils/as_completed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
try:
from tqdm.asyncio import tqdm_asyncio
except ImportError as e:
class tqdm_asyncio:
class tqdm_asyncio: # type: ignore [no-redef]
def as_completed(*args, **kwargs):
raise ImportError("You must have tqdm installed to use this feature")

Expand Down
2 changes: 1 addition & 1 deletion a_sync/utils/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
try:
from tqdm.asyncio import tqdm_asyncio
except ImportError as e:
class tqdm_asyncio:
class tqdm_asyncio: # type: ignore [no-redef]
async def gather(*args, **kwargs):
raise ImportError("You must have tqdm installed in order to use this feature")

Expand Down
2 changes: 1 addition & 1 deletion a_sync/utils/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1],
async def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1]) -> AsyncIterator[Union[T0, T1]]:...
@overload
async def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], *iterators: AsyncIterator[T]) -> AsyncIterator[Union[T0, T1, T2, T]]:...
async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]:
async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: # type: ignore [misc]
queue = Queue()
task = asyncio.create_task(exhaust_iterators(iterators, queue=queue))
def done_callback(t: asyncio.Task) -> None:
Expand Down