Skip to content

Commit

Permalink
Merge branch 'master' into better-task-mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Feb 13, 2024
2 parents 38d2026 + cc3af4a commit d81bc6c
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from a_sync._typing import *
from a_sync import exceptions
from a_sync.utils.as_completed import as_completed
from a_sync.utils.iterators import as_yielded
from a_sync.utils.gather import gather
from a_sync.utils.iterators import as_yielded, exhaust_iterator


logger = logging.getLogger(__name__)
Expand All @@ -21,10 +21,14 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until
return task

class TaskMapping(DefaultDict[K, "asyncio.Task[V]"]):
def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *, name: str = '', **coro_fn_kwargs: P.kwargs) -> None:
def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *iterables: AnyIterable[K], name: str = '', **coro_fn_kwargs: P.kwargs) -> None:
self._coro_fn = coro_fn
self._coro_fn_kwargs = coro_fn_kwargs
self._name = name
if iterables:
self._loader = create_task(exhaust_iterator(self._tasks_for_iterables(*iterables)))
else:
self._loader = None
def __setitem__(self, item: Any, value: Any) -> None:
raise NotImplementedError("You cannot manually set items in a TaskMapping")
def __getitem__(self, item: K) -> "asyncio.Task[V]":
Expand All @@ -49,8 +53,7 @@ async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]:
async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]:
if self:
raise exceptions.MappingNotEmptyError
async for key in as_yielded(*[_yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
self[key] # ensure task is running
async for _ in self._tasks_for_iterables(*iterables):
async for key, value in self.yield_completed(pop=pop):
yield _yield(key, value, yields)
async for key, value in as_completed(self, aiter=True):
Expand All @@ -65,7 +68,10 @@ async def _await(self) -> Dict[K, V]:
if self._loader:
await self._loader
return await gather(self)

async def _tasks_for_iterables(self, *iterables) -> AsyncIterator["asyncio.Task[V]"]:
async for key in as_yielded(*[_yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
yield self[key] # ensure task is running


__persisted_tasks: Set["asyncio.Task[Any]"] = set()

Expand Down

0 comments on commit d81bc6c

Please sign in to comment.