Skip to content

Commit

Permalink
fix: TaskMapper aiter when loader is running
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Feb 14, 2024
1 parent 2d33281 commit aea14bd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
4 changes: 4 additions & 0 deletions a_sync/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def __init__(self):
err += f"Check your traceback to determine which, then try calling asynchronously instead with one of the following kwargs:\n"
err += f"{_flags.VIABLE_FLAGS}"

class MappingIsEmptyError(Exception):
def __init__(self):
super().__init__("TaskMapping does not contain anything to yield")

class MappingNotEmptyError(Exception):
def __init__(self):
super().__init__("TaskMapping already contains some data. In order to use `map`, you need a fresh one.")
28 changes: 19 additions & 9 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ def __init__(self, coro_fn: MappingFn[K, P, V] = None, *iterables: AnyIterable[K
self._coro_fn_kwargs = coro_fn_kwargs
self._name = name
if iterables:
self._aiterable = self._tasks_for_iterables(*iterables)
self._loader = create_task(exhaust_iterator(self._aiterable))
self._loader = create_task(exhaust_iterator(self._tasks_for_iterables(*iterables)))
else:
self._aiterable = None
self._loader = None
def __setitem__(self, item: Any, value: Any) -> None:
raise NotImplementedError("You cannot manually set items in a TaskMapping")
Expand All @@ -53,14 +51,24 @@ def __await__(self) -> Generator[Any, None, Dict[K, V]]:
async def __aiter__(self) -> Union[AsyncIterator[Tuple[K, V]], AsyncIterator[K]]:
"""aiterate thru all key-task pairs, yielding the key-result pair as each task completes"""
yielded = set()
async for _ in self._aiterable:
async for key, value in self.yield_completed(pop=False):
yielded.add(key)
# if you inited the TaskMapping with some iterators, we will load those
if self._loader:
while not self._loader.done():
async for key, value in self.yield_completed(pop=False):
yielded.add(key)
if key not in yielded:
yield _yield(key, value, "both")
await asyncio.sleep(0)
# loader is already done by this point, but we need to check for exceptions
await self._loader
elif not self:
# if you didn't init the TaskMapping with iterators and you didn't start any tasks manually, we should fail
raise exceptions.MappingIsEmptyError
# if there are any tasks that still need to complete, we need to await them and yield them
if self:
async for key, value in as_completed(self, aiter=True):
if key not in yielded:
yield _yield(key, value, "both")
async for key, value in as_completed(self, aiter=True):
if key not in yielded:
yield _yield(key, value, "both")
async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]:
if self:
raise exceptions.MappingNotEmptyError
Expand All @@ -82,6 +90,8 @@ async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]:
async def _await(self) -> Dict[K, V]:
if self._loader:
await self._loader
if not self:
raise exceptions.MappingIsEmptyError
return await gather(self)
async def _tasks_for_iterables(self, *iterables) -> AsyncIterator["asyncio.Task[V]"]:
async for key in as_yielded(*[_yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
Expand Down

0 comments on commit aea14bd

Please sign in to comment.