From f28bfd5d9f23fd400d753d433066a6ee2c822127 Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Tue, 13 Feb 2024 09:31:38 -0500 Subject: [PATCH] feat: yields init kwarg --- a_sync/task.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/a_sync/task.py b/a_sync/task.py index c96fcb6f..e5189e4b 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -4,6 +4,7 @@ from a_sync._typing import * from a_sync import exceptions +from a_sync.iter import ASyncIterable from a_sync.utils.as_completed import as_completed from a_sync.utils.gather import gather from a_sync.utils.iterators import as_yielded, exhaust_iterator @@ -20,11 +21,12 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until __persist(task) return task -class TaskMapping(DefaultDict[K, "asyncio.Task[V]"]): - def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *iterables: AnyIterable[K], name: str = '', **coro_fn_kwargs: P.kwargs) -> None: +class TaskMapping(ASyncIterable[K, V], DefaultDict[K, "asyncio.Task[V]"]): + def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *iterables: AnyIterable[K], name: str = '', yields: Literal['keys', 'both'] = 'both', **coro_fn_kwargs: P.kwargs) -> None: self._coro_fn = coro_fn self._coro_fn_kwargs = coro_fn_kwargs self._name = name + self._yields = yields if iterables: self._loader = create_task(exhaust_iterator(self._tasks_for_iterables(*iterables))) else: @@ -48,8 +50,8 @@ async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]: """aiterate thru all key-task pairs, yielding the key-result pair as each task completes""" if self._loader: await self._loader - async for k, v in as_completed(self, aiter=True): - yield k, v + async for key, value in as_completed(self, aiter=True): + yield _yield(key, value, self._yields) async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]: if self: raise exceptions.MappingNotEmptyError @@ -57,6 +59,8 @@ async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Litera async for key, value in self.yield_completed(pop=pop): yield _yield(key, value, yields) async for key, value in as_completed(self, aiter=True): + if pop: + self.pop(key) yield _yield(key, value, yields) async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]: for k, task in dict(self).items():