From fad0db1c2714cac53f86f0bd590dee470ad111d8 Mon Sep 17 00:00:00 2001 From: BobTheBuidler Date: Tue, 23 Apr 2024 01:12:33 +0000 Subject: [PATCH] fix: pass `exclude_if` from gather to gather_mapping --- a_sync/utils/gather.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/a_sync/utils/gather.py b/a_sync/utils/gather.py index 3401a797..e2132e61 100644 --- a/a_sync/utils/gather.py +++ b/a_sync/utils/gather.py @@ -20,7 +20,7 @@ async def gather(*args, **kwargs): async def gather( *awaitables: Mapping[K, Awaitable[V]], return_exceptions: bool = False, - exclude_if: Optional[Excluder[T]] = None, + exclude_if: Optional[Excluder[V]] = None, tqdm: bool = False, **tqdm_kwargs: Any, ) -> Dict[K, V]: @@ -74,16 +74,23 @@ async def gather( results = await gather(mapping) ``` """ + is_mapping = _is_mapping(awaitables) results = await ( - gather_mapping(awaitables[0], return_exceptions=return_exceptions, tqdm=tqdm, **tqdm_kwargs) if _is_mapping(awaitables) + gather_mapping(awaitables[0], return_exceptions=return_exceptions, exclude_if=exclude_if, tqdm=tqdm, **tqdm_kwargs) if is_mapping else tqdm_asyncio.gather(*(_exc_wrap(a) for a in awaitables) if return_exceptions else awaitables, **tqdm_kwargs) if tqdm else asyncio.gather(*awaitables, return_exceptions=return_exceptions) # type: ignore [arg-type] ) - if exclude_if: + if exclude_if and not is_mapping: results = [r for r in results if not exclude_if(r)] return results -async def gather_mapping(mapping: Mapping[K, Awaitable[V]], return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Dict[K, V]: +async def gather_mapping( + mapping: Mapping[K, Awaitable[V]], + return_exceptions: bool = False, + exclude_if: Optional[Excluder[V]] = None, + tqdm: bool = False, + **tqdm_kwargs: Any, +) -> Dict[K, V]: """ Concurrently awaits a mapping of awaitable objects and returns a dictionary of results. @@ -106,7 +113,12 @@ async def gather_mapping(mapping: Mapping[K, Awaitable[V]], return_exceptions: b results = await gather_mapping(mapping) ``` """ - results = {k: v async for k, v in as_completed_mapping(mapping, return_exceptions=return_exceptions, aiter=True, tqdm=tqdm, **tqdm_kwargs)} - return {k: results[k] for k in mapping.keys()} # return data in same order as input mapping + results = { + k: v + async for k, v in as_completed_mapping(mapping, return_exceptions=return_exceptions, aiter=True, tqdm=tqdm, **tqdm_kwargs) + if exclude_if is None or not exclude_if(v) + } + # return data in same order as input mapping + return {k: results[k] for k in mapping} _is_mapping = lambda awaitables: len(awaitables) == 1 and isinstance(awaitables[0], Mapping)