From cc3af4aca8834d5decd92e1540ee800da815f03f Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Tue, 13 Feb 2024 09:21:38 -0500 Subject: [PATCH] feat: init TaskMapping with iterables (#120) --- a_sync/task.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/a_sync/task.py b/a_sync/task.py index 073645bd..f5d53c93 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -4,7 +4,7 @@ from a_sync._typing import * from a_sync import exceptions -from a_sync.utils.iterators import as_yielded +from a_sync.utils.iterators import as_yielded, exhaust_iterator from a_sync.utils.as_completed import as_completed @@ -20,10 +20,14 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until return task class TaskMapping(DefaultDict[K, "asyncio.Task[V]"]): - def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *, name: str = '', **coro_fn_kwargs: P.kwargs) -> None: + def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *iterables: AnyIterable[K], name: str = '', **coro_fn_kwargs: P.kwargs) -> None: self._coro_fn = coro_fn self._coro_fn_kwargs = coro_fn_kwargs self._name = name + if iterables: + self._loader = create_task(exhaust_iterator(self._tasks_for_iterables(*iterables))) + else: + self._loader = None def __setitem__(self, item: Any, value: Any) -> None: raise NotImplementedError("You cannot manually set items in a TaskMapping") def __getitem__(self, item: K) -> "asyncio.Task[V]": @@ -39,8 +43,7 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]": async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]: if self: raise exceptions.MappingNotEmptyError - async for key in as_yielded(*[_yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined] - self[key] # ensure task is running + async for _ in self._tasks_for_iterables(*iterables): async for key, value in self.yield_completed(pop=pop): yield _yield(key, value, yields) async for key, value in as_completed(self, aiter=True): @@ -51,7 +54,10 @@ async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]: if pop: task = self.pop(k) yield k, await task - + async def _tasks_for_iterables(self, *iterables) -> AsyncIterator["asyncio.Task[V]"]: + async for key in as_yielded(*[_yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined] + yield self[key] # ensure task is running + __persisted_tasks: Set["asyncio.Task[Any]"] = set()