From 7c2171652ba9328ff93f5f53d9a4f0f5a0b9f615 Mon Sep 17 00:00:00 2001 From: BobTheBuidler Date: Tue, 10 Oct 2023 03:05:25 +0000 Subject: [PATCH] fix: as_yielded --- a_sync/utils/iterators.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/a_sync/utils/iterators.py b/a_sync/utils/iterators.py index 5e1daed0..756322d7 100644 --- a/a_sync/utils/iterators.py +++ b/a_sync/utils/iterators.py @@ -18,6 +18,8 @@ async def exhaust_iterator(iterator: AsyncIterator[T], *, queue: Optional[asynci async def exhaust_iterators(iterators, *, queue: Optional[asyncio.Queue] = None) -> None: await asyncio.gather(*[exhaust_iterator(iterator, queue=queue) for iterator in iterators]) + if queue: + queue.put_nowait(_Done()) T0 = TypeVar('T0') T1 = TypeVar('T1') @@ -58,16 +60,20 @@ async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: def done_callback(t: asyncio.Task) -> None: if t.exception() and not next_fut.done(): next_fut.set_exception(t.exception()) + task.add_done_callback(done_callback) while not task.done(): next_fut = asyncio.get_event_loop().create_future() get_task = asyncio.create_task(coro=queue.get(), name=str(queue)) _chain_future(get_task, next_fut) - yield await next_fut - for next in queue.get_nowait(-1): - yield next - + for item in (await next_fut, *queue.get_nowait(-1)): + if isinstance(item, _Done): + return + yield item + if e := task.exception(): get_task.cancel() raise e +class _Done: + pass \ No newline at end of file