Skip to content

Commit

Permalink
feat: type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Oct 10, 2023
1 parent 37b7aea commit 43152d3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
20 changes: 17 additions & 3 deletions a_sync/utils/as_completed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@
KT = TypeVar('KT')
VT = TypeVar('VT')

def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any):
@overload
def as_completed(fs: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float], return_exceptions: bool, aiter = False, tqdm: bool, **tqdm_kwargs: Any) -> Iterator[Awaitable[Tuple[KT, VT]]]:
...
@overload
def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float], return_exceptions: bool, aiter = False, tqdm: bool, **tqdm_kwargs: Any) -> Iterator[Awaitable[T]]:
...
@overload
def as_completed(fs: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float], return_exceptions: bool, aiter = True, tqdm: bool, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[KT, VT]]:
...
@overload
def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float], return_exceptions: bool, aiter = True, tqdm: bool, **tqdm_kwargs: Any) -> ASyncIterator[T]:
...
def as_completed(fs, *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any):
if return_exceptions:
raise NotImplementedError
return (
Expand All @@ -22,9 +34,11 @@ def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None,
)

@overload
def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter = True, tqdm: bool, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[KT, VT]]:...
def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter = True, tqdm: bool, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[KT, VT]]:
...
@overload
def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter = False, tqdm: bool, **tqdm_kwargs: Any) -> Iterator[Awaitable[Tuple[KT, VT]]]:...
def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter = False, tqdm: bool, **tqdm_kwargs: Any) -> Iterator[Awaitable[Tuple[KT, VT]]]:
...
def as_completed_mapping(mapping: Mapping[KT, Awaitable[VT]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Union[Iterator[Awaitable[Tuple[KT, VT]]], ASyncIterator[Tuple[KT, VT]]]:
return as_completed([__mapping_wrap(k, v) for k, v in mapping.items()], timeout=timeout, return_exceptions=return_exceptions, aiter=aiter, tqdm=tqdm, **tqdm_kwargs)

Expand Down
9 changes: 8 additions & 1 deletion a_sync/utils/gather.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import asyncio
from typing import Any, Awaitable, Dict, List, Mapping, TypeVar
from typing import (Any, Awaitable, Dict, List, Mapping, TypeVar, Union,
overload)

from tqdm.asyncio import tqdm_asyncio

Expand All @@ -10,7 +11,13 @@
KT = TypeVar('KT')
VT = TypeVar('VT')

@overload
async def gather(*awaitables: Mapping[KT, Awaitable[VT]], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Dict[KT, VT]:
...
@overload
async def gather(*awaitables: Awaitable[T], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> List[T]:
...
async def gather(*awaitables: Union[Awaitable[T], Mapping[KT, Awaitable[VT]]], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Union[List[T], Dict[KT, VT]]:
return await (
gather_mapping(awaitables[0], return_exceptions=return_exceptions, tqdm=tqdm, **tqdm_kwargs) if _is_mapping(awaitables)
else tqdm_asyncio.gather(*awaitables, return_exceptions=return_exceptions, **tqdm_kwargs) if tqdm
Expand Down

0 comments on commit 43152d3

Please sign in to comment.