diff --git a/a_sync/utils/iterators.py b/a_sync/utils/iterators.py index 51d4e3d1..5e1daed0 100644 --- a/a_sync/utils/iterators.py +++ b/a_sync/utils/iterators.py @@ -11,13 +11,10 @@ T = TypeVar('T') async def exhaust_iterator(iterator: AsyncIterator[T], *, queue: Optional[asyncio.Queue] = None) -> None: - if queue: - async for thing in iterator: + async for thing in iterator: + if queue: logger.debug('putting %s from %s to queue %s', thing, iterator, queue) queue.put_nowait(thing) - else: - async for thing in iterator: - pass async def exhaust_iterators(iterators, *, queue: Optional[asyncio.Queue] = None) -> None: await asyncio.gather(*[exhaust_iterator(iterator, queue=queue) for iterator in iterators]) @@ -33,6 +30,8 @@ async def exhaust_iterators(iterators, *, queue: Optional[asyncio.Queue] = None) T8 = TypeVar('T8') T9 = TypeVar('T9') +@overload +async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]:... @overload async def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], iterator3: AsyncIterator[T3], iterator4: AsyncIterator[T4], iterator5: AsyncIterator[T5], iterator6: AsyncIterator[T6], iterator7: AsyncIterator[T7], iterator8: AsyncIterator[T8], iterator9: AsyncIterator[T9]) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]]:... @overload @@ -62,10 +61,13 @@ def done_callback(t: asyncio.Task) -> None: task.add_done_callback(done_callback) while not task.done(): next_fut = asyncio.get_event_loop().create_future() - _chain_future(asyncio.create_task(queue.get()), next_fut) + get_task = asyncio.create_task(coro=queue.get(), name=str(queue)) + _chain_future(get_task, next_fut) yield await next_fut for next in queue.get_nowait(-1): yield next + if e := task.exception(): + get_task.cancel() raise e