Skip to content

Commit

Permalink
chore: fix mypy errs (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Feb 16, 2024
1 parent 378ebc5 commit 6f9efc3
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 52 deletions.
2 changes: 1 addition & 1 deletion a_sync/_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def _clean_default_from_modifiers(
coro_fn: AsyncBoundMethod[P, T], # type: ignore [misc]
modifiers: dict
modifiers: ModifierKwargs,
):
# NOTE: We set the default here manually because the default set by the user will be used later in the code to determine whether to await.
force_await = None
Expand Down
8 changes: 4 additions & 4 deletions a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from concurrent.futures._base import Executor
from decimal import Decimal
from typing import (TYPE_CHECKING, Any, AsyncIterable, AsyncIterator, Awaitable,
Callable, DefaultDict, Deque, Dict, Generator, Generic,
ItemsView, Iterable, Iterator, KeysView, List, Literal,
Optional, Protocol, Set, Tuple, Type, TypedDict, TypeVar,
Union, ValuesView, final, overload)
Callable, Coroutine, DefaultDict, Deque, Dict, Generator,
Generic, ItemsView, Iterable, Iterator, KeysView, List, Literal,
Mapping, Optional, Protocol, Set, Tuple, Type, TypedDict,
TypeVar, Union, ValuesView, final, overload)

from typing_extensions import Concatenate, ParamSpec, Self, Unpack

Expand Down
2 changes: 1 addition & 1 deletion a_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __a_sync_default_mode__(cls) -> bool: # type: ignore [override]
return sync

@classmethod
def __get_a_sync_flag_name_from_signature(cls) -> str:
def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]:
logger.debug("Searching for flags defined on %s.__init__", cls)
if cls.__name__ == "ASyncGenericBase":
logger.debug("There are no flags defined on the base class, this is expected. Skipping.")
Expand Down
2 changes: 1 addition & 1 deletion a_sync/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def wrap(self, aiterator: AsyncIterator[T]) -> "ASyncWrappedIterator[T]":
class ASyncWrappedIterable(ASyncIterable[T]):
def __init__(self, async_iterable: AsyncIterable[T]):
self.__aiterable = async_iterable
def __aiter__(self) -> AsyncIterable[T]:
def __aiter__(self) -> AsyncIterator[T]:
return self.__aiterable.__aiter__()

class ASyncWrappedIterator(ASyncIterator[T]):
Expand Down
2 changes: 1 addition & 1 deletion a_sync/primitives/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyn
if self.sync_mode:
fut = asyncio.ensure_future(self._exec_sync(fn, *args, **kwargs))
else:
fut = asyncio.futures.wrap_future(super().submit(fn, *args, **kwargs))
fut = asyncio.futures.wrap_future(super().submit(fn, *args, **kwargs)) # type: ignore [assignment]
self._start_debug_daemon(fut, fn, *args, **kwargs)
return fut
def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion a_sync/primitives/locks/semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, value: int, name=None, **kwargs) -> None:

# Dank new functionality
def __call__(self, fn: Callable[P, T]) -> Callable[P, T]:
return self.decorate(fn)
return self.decorate(fn) # type: ignore [arg-type, return-value]

def __repr__(self) -> str:
representation = f"<{self.__class__.__name__} name={self.name} value={self._value} waiters={len(self)}>"
Expand Down
2 changes: 1 addition & 1 deletion a_sync/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def a_sync_property( # type: ignore [misc]
func = None
def modifier_wrap(func: Property[T]) -> AsyncPropertyDescriptor[T]:
return AsyncPropertyDescriptor(func, **modifiers)
return modifier_wrap if func is None else modifier_wrap(func)
return modifier_wrap if func is None else modifier_wrap(func) # type: ignore [arg-type]


@overload
Expand Down
3 changes: 2 additions & 1 deletion a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, coro_fn: MappingFn[K, P, V] = None, *iterables: AnyIterable[K
self._coro_fn = coro_fn
self._coro_fn_kwargs = coro_fn_kwargs
self._name = name
self._loader: Optional["asyncio.Task[None]"]
if iterables:
self._loader = create_task(exhaust_iterator(self._tasks_for_iterables(*iterables)))
else:
Expand All @@ -48,7 +49,7 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]":
def __await__(self) -> Generator[Any, None, Dict[K, V]]:
"""await all tasks and returns a mapping with the results for each key"""
return self._await().__await__()
async def __aiter__(self) -> Union[AsyncIterator[Tuple[K, V]], AsyncIterator[K]]:
async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]:
"""aiterate thru all key-task pairs, yielding the key-result pair as each task completes"""
yielded = set()
# if you inited the TaskMapping with some iterators, we will load those
Expand Down
34 changes: 16 additions & 18 deletions a_sync/utils/as_completed.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@

import asyncio
from typing import (Any, AsyncIterator, Awaitable, Coroutine, Iterable,
Iterator, Literal, Mapping, Optional, Tuple, TypeVar,
Union, overload)

try:
from tqdm.asyncio import tqdm_asyncio
Expand All @@ -11,23 +8,20 @@ class tqdm_asyncio: # type: ignore [no-redef]
def as_completed(*args, **kwargs):
raise ImportError("You must have tqdm installed to use this feature")

from a_sync._typing import *
from a_sync.iter import ASyncIterator

T = TypeVar('T')
KT = TypeVar('KT')
VT = TypeVar('VT')

@overload
def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[False] = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Iterator[Coroutine[None, None, T]]:
...
@overload
def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[True] = True, tqdm: bool = False, **tqdm_kwargs: Any) -> ASyncIterator[T]:
...
@overload
def as_completed(fs: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[False] = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Iterator[Coroutine[None, None, Tuple[KT, VT]]]:
def as_completed(fs: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[False] = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Iterator[Coroutine[None, None, Tuple[K, V]]]:
...
@overload
def as_completed(fs: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[True] = True, tqdm: bool = False, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[KT, VT]]:
def as_completed(fs: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[True] = True, tqdm: bool = False, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[K, V]]:
...
def as_completed(fs, *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any):
"""
Expand All @@ -42,15 +36,15 @@ def as_completed(fs, *, timeout: Optional[float] = None, return_exceptions: bool
- Provides progress reporting using tqdm if 'tqdm' is set to True.
Args:
fs (Iterable[Awaitable[T] or Mapping[KT, Awaitable[VT]]]): The awaitables to await concurrently. It can be a list of individual awaitables or a mapping of awaitables.
fs (Iterable[Awaitable[T] or Mapping[K, Awaitable[V]]]): The awaitables to await concurrently. It can be a list of individual awaitables or a mapping of awaitables.
timeout (float, optional): The maximum time, in seconds, to wait for the completion of awaitables. Defaults to None (no timeout).
return_exceptions (bool, optional): If True, exceptions are returned as results instead of raising them. Defaults to False.
aiter (bool, optional): If True, returns an async iterator of results. Defaults to False.
tqdm (bool, optional): If True, enables progress reporting using tqdm. Defaults to False.
**tqdm_kwargs: Additional keyword arguments for tqdm if progress reporting is enabled.
Returns:
Iterator[Coroutine[None, None, T] or ASyncIterator[Tuple[KT, VT]]]: An iterator of results when awaiting individual awaitables or an async iterator when awaiting mappings.
Iterator[Coroutine[None, None, T] or ASyncIterator[Tuple[K, V]]]: An iterator of results when awaiting individual awaitables or an async iterator when awaiting mappings.
Examples:
Awaiting individual awaitables:
Expand Down Expand Up @@ -87,27 +81,27 @@ def as_completed(fs, *, timeout: Optional[float] = None, return_exceptions: bool
)

@overload
def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[True] = True, tqdm: bool = False, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[KT, VT]]:
def as_completed_mapping(mapping: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[True] = True, tqdm: bool = False, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[K, V]]:
...
@overload
def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[False] = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Iterator[Coroutine[None, None, Tuple[KT, VT]]]:
def as_completed_mapping(mapping: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[False] = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Iterator[Coroutine[None, None, Tuple[K, V]]]:
...
def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Union[Iterator[Coroutine[None, None, Tuple[KT, VT]]], ASyncIterator[Tuple[KT, VT]]]:
def as_completed_mapping(mapping: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Union[Iterator[Coroutine[None, None, Tuple[K, V]]], ASyncIterator[Tuple[K, V]]]:
"""
Concurrently awaits a mapping of awaitable objects and returns an iterator or async iterator of results.
This function is designed to await a mapping of awaitable objects, where each key-value pair represents a unique awaitable. It enables concurrent execution and gathers results into an iterator or an async iterator.
Args:
mapping (Mapping[KT, Awaitable[VT]]): A dictionary-like object where keys are of type KT and values are awaitable objects of type VT.
mapping (Mapping[K, Awaitable[V]]): A dictionary-like object where keys are of type K and values are awaitable objects of type V.
timeout (float, optional): The maximum time, in seconds, to wait for the completion of awaitables. Defaults to None (no timeout).
return_exceptions (bool, optional): If True, exceptions are returned as results instead of raising them. Defaults to False.
aiter (bool, optional): If True, returns an async iterator of results. Defaults to False.
tqdm (bool, optional): If True, enables progress reporting using tqdm. Defaults to False.
**tqdm_kwargs: Additional keyword arguments for tqdm if progress reporting is enabled.
Returns:
Union[Iterator[Coroutine[None, None, Tuple[KT, VT]]] or ASyncIterator[Tuple[KT, VT]]]: An iterator of results or an async iterator when awaiting mappings.
Union[Iterator[Coroutine[None, None, Tuple[K, V]]] or ASyncIterator[Tuple[K, V]]]: An iterator of results or an async iterator when awaiting mappings.
Example:
```
Expand All @@ -126,8 +120,12 @@ def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Option
async def __yield_as_completed(futs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None, return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> AsyncIterator[T]:
for fut in as_completed(futs, timeout=timeout, return_exceptions=return_exceptions, tqdm=tqdm, **tqdm_kwargs):
yield await fut

async def __mapping_wrap(k: KT, v: Awaitable[VT], return_exceptions: bool = False) -> VT:

@overload
async def __mapping_wrap(k: K, v: Awaitable[V], return_exceptions: Literal[True] = True) -> Union[V, Exception]:...
@overload
async def __mapping_wrap(k: K, v: Awaitable[V], return_exceptions: Literal[False] = False) -> V:...
async def __mapping_wrap(k: K, v: Awaitable[V], return_exceptions: bool = False) -> Union[V, Exception]:
try:
return k, await v
except Exception as e:
Expand Down
21 changes: 9 additions & 12 deletions a_sync/utils/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,17 @@ async def gather(*args, **kwargs):
from a_sync._typing import *
from a_sync.utils.as_completed import as_completed_mapping

T = TypeVar('T')
KT = TypeVar('KT')
VT = TypeVar('VT')

Excluder = Callable[[T], bool]

@overload
async def gather(
*awaitables: Mapping[KT, Awaitable[VT]],
*awaitables: Mapping[K, Awaitable[V]],
return_exceptions: bool = False,
exclude_if: Optional[Excluder[T]] = None,
tqdm: bool = False,
**tqdm_kwargs: Any,
) -> Dict[KT, VT]:
) -> Dict[K, V]:
...
@overload
async def gather(
Expand All @@ -38,12 +35,12 @@ async def gather(
) -> List[T]:
...
async def gather(
*awaitables: Union[Awaitable[T], Mapping[KT, Awaitable[VT]]],
*awaitables: Union[Awaitable[T], Mapping[K, Awaitable[V]]],
return_exceptions: bool = False,
exclude_if: Optional[Excluder[T]] = None,
tqdm: bool = False,
**tqdm_kwargs: Any,
) -> Union[List[T], Dict[KT, VT]]:
) -> Union[List[T], Dict[K, V]]:
"""
Concurrently awaits a list of awaitable objects or mappings of awaitables and returns the results.
Expand All @@ -55,13 +52,13 @@ async def gather(
- Provides progress reporting using tqdm if 'tqdm' is set to True.
Args:
*awaitables (Union[Awaitable[T], Mapping[KT, Awaitable[VT]]]): The awaitables to await concurrently. It can be a single awaitable or a mapping of awaitables.
*awaitables (Union[Awaitable[T], Mapping[K, Awaitable[V]]]): The awaitables to await concurrently. It can be a single awaitable or a mapping of awaitables.
return_exceptions (bool, optional): If True, exceptions are returned as results instead of raising them. Defaults to False.
tqdm (bool, optional): If True, enables progress reporting using tqdm. Defaults to False.
**tqdm_kwargs: Additional keyword arguments for tqdm if progress reporting is enabled.
Returns:
Union[List[T], Dict[KT, VT]]: A list of results when awaiting individual awaitables or a dictionary of results when awaiting mappings.
Union[List[T], Dict[K, V]]: A list of results when awaiting individual awaitables or a dictionary of results when awaiting mappings.
Examples:
Awaiting individual awaitables:
Expand All @@ -86,20 +83,20 @@ async def gather(
results = [r for r in results if not exclude_if(r)]
return results

async def gather_mapping(mapping: Mapping[KT, Awaitable[VT]], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Dict[KT, VT]:
async def gather_mapping(mapping: Mapping[K, Awaitable[V]], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Dict[K, V]:
"""
Concurrently awaits a mapping of awaitable objects and returns a dictionary of results.
This function is designed to await a mapping of awaitable objects, where each key-value pair represents a unique awaitable. It enables concurrent execution and gathers results into a dictionary.
Args:
mapping (Mapping[KT, Awaitable[VT]]): A dictionary-like object where keys are of type KT and values are awaitable objects of type VT.
mapping (Mapping[K, Awaitable[V]]): A dictionary-like object where keys are of type K and values are awaitable objects of type V.
return_exceptions (bool, optional): If True, exceptions are returned as results instead of raising them. Defaults to False.
tqdm (bool, optional): If True, enables progress reporting using tqdm. Defaults to False.
**tqdm_kwargs: Additional keyword arguments for tqdm if progress reporting is enabled.
Returns:
Dict[KT, VT]: A dictionary with keys corresponding to the keys of the input mapping and values containing the results of the corresponding awaitables.
Dict[K, V]: A dictionary with keys corresponding to the keys of the input mapping and values containing the results of the corresponding awaitables.
Example:
The 'results' dictionary will contain the awaited results, where keys match the keys in the 'mapping' and values contain the results of the corresponding awaitables.
Expand Down
Loading

0 comments on commit 6f9efc3

Please sign in to comment.