diff --git a/a_sync/utils/as_completed.py b/a_sync/utils/as_completed.py index 53ca83c7..97ef3205 100644 --- a/a_sync/utils/as_completed.py +++ b/a_sync/utils/as_completed.py @@ -11,7 +11,19 @@ KT = TypeVar('KT') VT = TypeVar('VT') -def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any): +@overload +def as_completed(fs: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float], return_exceptions: bool, aiter = False, tqdm: bool, **tqdm_kwargs: Any) -> Iterator[Awaitable[Tuple[KT, VT]]]: + ... +@overload +def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float], return_exceptions: bool, aiter = False, tqdm: bool, **tqdm_kwargs: Any) -> Iterator[Awaitable[T]]: + ... +@overload +def as_completed(fs: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float], return_exceptions: bool, aiter = True, tqdm: bool, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[KT, VT]]: + ... +@overload +def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float], return_exceptions: bool, aiter = True, tqdm: bool, **tqdm_kwargs: Any) -> ASyncIterator[T]: + ... +def as_completed(fs, *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any): if return_exceptions: raise NotImplementedError return ( @@ -22,9 +34,11 @@ def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None, ) @overload -def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter = True, tqdm: bool, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[KT, VT]]:... +def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter = True, tqdm: bool, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[KT, VT]]: + ... @overload -def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter = False, tqdm: bool, **tqdm_kwargs: Any) -> Iterator[Awaitable[Tuple[KT, VT]]]:... +def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter = False, tqdm: bool, **tqdm_kwargs: Any) -> Iterator[Awaitable[Tuple[KT, VT]]]: + ... def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Union[Iterator[Awaitable[Tuple[KT, VT]]], ASyncIterator[Tuple[KT, VT]]]: return as_completed([__mapping_wrap(k, v) for k, v in mapping.items()], timeout=timeout, return_exceptions=return_exceptions, aiter=aiter, tqdm=tqdm, **tqdm_kwargs) diff --git a/a_sync/utils/gather.py b/a_sync/utils/gather.py index b749cca2..6593e55a 100644 --- a/a_sync/utils/gather.py +++ b/a_sync/utils/gather.py @@ -1,6 +1,7 @@ import asyncio -from typing import Any, Awaitable, Dict, List, Mapping, TypeVar +from typing import (Any, Awaitable, Dict, List, Mapping, TypeVar, Union, + overload) from tqdm.asyncio import tqdm_asyncio @@ -10,7 +11,13 @@ KT = TypeVar('KT') VT = TypeVar('VT') +@overload +async def gather(*awaitables: Mapping[KT, Awaitable[VT]], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Dict[KT, VT]: + ... +@overload async def gather(*awaitables: Awaitable[T], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> List[T]: + ... +async def gather(*awaitables: Union[Awaitable[T], Mapping[KT, Awaitable[VT]]], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Union[List[T], Dict[KT, VT]]: return await ( gather_mapping(awaitables[0], return_exceptions=return_exceptions, tqdm=tqdm, **tqdm_kwargs) if _is_mapping(awaitables) else tqdm_asyncio.gather(*awaitables, return_exceptions=return_exceptions, **tqdm_kwargs) if tqdm