Skip to content

Commit

Permalink
fix: pass exclude_if from gather to gather_mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Apr 23, 2024
1 parent e5eb08a commit fad0db1
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions a_sync/utils/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def gather(*args, **kwargs):
async def gather(
*awaitables: Mapping[K, Awaitable[V]],
return_exceptions: bool = False,
exclude_if: Optional[Excluder[T]] = None,
exclude_if: Optional[Excluder[V]] = None,
tqdm: bool = False,
**tqdm_kwargs: Any,
) -> Dict[K, V]:
Expand Down Expand Up @@ -74,16 +74,23 @@ async def gather(
results = await gather(mapping)
```
"""
is_mapping = _is_mapping(awaitables)
results = await (
gather_mapping(awaitables[0], return_exceptions=return_exceptions, tqdm=tqdm, **tqdm_kwargs) if _is_mapping(awaitables)
gather_mapping(awaitables[0], return_exceptions=return_exceptions, exclude_if=exclude_if, tqdm=tqdm, **tqdm_kwargs) if is_mapping
else tqdm_asyncio.gather(*(_exc_wrap(a) for a in awaitables) if return_exceptions else awaitables, **tqdm_kwargs) if tqdm
else asyncio.gather(*awaitables, return_exceptions=return_exceptions) # type: ignore [arg-type]
)
if exclude_if:
if exclude_if and not is_mapping:
results = [r for r in results if not exclude_if(r)]
return results

async def gather_mapping(mapping: Mapping[K, Awaitable[V]], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Dict[K, V]:
async def gather_mapping(
mapping: Mapping[K, Awaitable[V]],
return_exceptions: bool = False,
exclude_if: Optional[Excluder[V]] = None,
tqdm: bool = False,
**tqdm_kwargs: Any,
) -> Dict[K, V]:
"""
Concurrently awaits a mapping of awaitable objects and returns a dictionary of results.
Expand All @@ -106,7 +113,12 @@ async def gather_mapping(mapping: Mapping[K, Awaitable[V]], return_exceptions: b
results = await gather_mapping(mapping)
```
"""
results = {k: v async for k, v in as_completed_mapping(mapping, return_exceptions=return_exceptions, aiter=True, tqdm=tqdm, **tqdm_kwargs)}
return {k: results[k] for k in mapping.keys()} # return data in same order as input mapping
results = {
k: v
async for k, v in as_completed_mapping(mapping, return_exceptions=return_exceptions, aiter=True, tqdm=tqdm, **tqdm_kwargs)
if exclude_if is None or not exclude_if(v)
}
# return data in same order as input mapping
return {k: results[k] for k in mapping}

_is_mapping = lambda awaitables: len(awaitables) == 1 and isinstance(awaitables[0], Mapping)

0 comments on commit fad0db1

Please sign in to comment.