From 28d970d3b4a34082f1e2af7a56604ad15662c544 Mon Sep 17 00:00:00 2001 From: Dos Moonen Date: Fri, 5 Jan 2024 09:42:55 +0100 Subject: [PATCH] QueueIterator raised StopAsyncIteration when channel is closed. --- .coveragerc | 5 + aio_pika/abc.py | 12 ++- aio_pika/channel.py | 34 +++++-- aio_pika/connection.py | 14 +-- aio_pika/queue.py | 83 +++++++++++++--- aio_pika/robust_channel.py | 19 +++- aio_pika/robust_queue.py | 5 + poetry.lock | 10 +- pyproject.toml | 3 + tests/test_amqp.py | 199 +++++++++++++++++++++++++------------ 10 files changed, 284 insertions(+), 100 deletions(-) diff --git a/.coveragerc b/.coveragerc index 230110d8..f31b2d6a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,8 @@ [run] omit = aio_pika/compat.py branch = True + +[report] +exclude_lines = + pragma: no cover + raise NotImplementedError diff --git a/aio_pika/abc.py b/aio_pika/abc.py index ff81e38f..4975b352 100644 --- a/aio_pika/abc.py +++ b/aio_pika/abc.py @@ -366,14 +366,14 @@ def iterator(self, **kwargs: Any) -> "AbstractQueueIterator": raise NotImplementedError -class AbstractQueueIterator(AsyncIterable): +class AbstractQueueIterator(AsyncIterable[AbstractIncomingMessage]): _amqp_queue: AbstractQueue _queue: asyncio.Queue _consumer_tag: ConsumerTag _consume_kwargs: Dict[str, Any] @abstractmethod - def close(self, *_: Any) -> Awaitable[Any]: + def close(self) -> Awaitable[Any]: raise NotImplementedError @abstractmethod @@ -531,6 +531,10 @@ def is_closed(self) -> bool: def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[None]: raise NotImplementedError + @abstractmethod + async def wait(self) -> None: + raise NotImplementedError + @abstractmethod async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel: raise NotImplementedError @@ -741,6 +745,10 @@ def is_closed(self) -> bool: async def close(self, exc: ExceptionType = asyncio.CancelledError) -> None: raise NotImplementedError + @abstractmethod + async def wait(self) -> None: + raise NotImplementedError + @abstractmethod async def connect(self, timeout: TimeoutType = None) -> None: raise NotImplementedError diff --git a/aio_pika/channel.py b/aio_pika/channel.py index 7741b7af..f1fd3882 100644 --- a/aio_pika/channel.py +++ b/aio_pika/channel.py @@ -69,6 +69,9 @@ def __init__( confirmations (in pursuit of performance) """ + # init before any exception can be raised since we use it in __del__ + self._closed: asyncio.Event = asyncio.Event() + if not publisher_confirms and on_return_raises: raise RuntimeError( '"on_return_raises" not applicable ' @@ -77,9 +80,6 @@ def __init__( self._connection: AbstractConnection = connection - # That's means user closed channel instance explicitly - self._closed: bool = False - self._channel: Optional[UnderlayChannel] = None self._channel_number = channel_number @@ -89,6 +89,8 @@ def __init__( self.publisher_confirms = publisher_confirms self.on_return_raises = on_return_raises + self.close_callbacks.add(self._set_closed_callback) + @property def is_initialized(self) -> bool: """Returns True when the channel has been opened @@ -99,7 +101,7 @@ def is_initialized(self) -> bool: def is_closed(self) -> bool: """Returns True when the channel has been closed from the broker side or after the close() method has been called.""" - if not self.is_initialized or self._closed: + if not self.is_initialized or self._closed.is_set(): return True channel = self._channel if channel is None: @@ -119,8 +121,11 @@ async def close( return log.debug("Closing channel %r", self) - self._closed = True await self._channel.close() + self._closed.set() + + async def wait(self) -> None: + await self._closed.wait() async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel: @@ -174,12 +179,12 @@ async def _open(self) -> None: await channel.close(e) self._channel = None raise - self._closed = False + self._closed.clear() async def initialize(self, timeout: TimeoutType = None) -> None: if self.is_initialized: raise RuntimeError("Already initialized") - elif self._closed: + elif self._closed.is_set(): raise RuntimeError("Can't initialize closed channel") await self._open() @@ -197,7 +202,10 @@ async def _on_open(self) -> None: type=ExchangeType.DIRECT, ) - async def _on_close(self, closing: asyncio.Future) -> None: + async def _on_close( + self, + closing: asyncio.Future + ) -> Optional[BaseException]: try: exc = closing.exception() except asyncio.CancelledError as e: @@ -207,6 +215,14 @@ async def _on_close(self, closing: asyncio.Future) -> None: if self._channel and self._channel.channel: self._channel.channel.on_return_callbacks.discard(self._on_return) + return exc + + async def _set_closed_callback( + self, + _: AbstractChannel, exc: BaseException + ) -> None: + self._closed.set() + async def _on_initialized(self) -> None: channel = await self.get_underlay_channel() channel.on_return_callbacks.add(self._on_return) @@ -219,7 +235,7 @@ async def reopen(self) -> None: await self._open() def __del__(self) -> None: - self._closed = True + self._closed.set() self._channel = None async def declare_exchange( diff --git a/aio_pika/connection.py b/aio_pika/connection.py index fb49db4b..6976af7d 100644 --- a/aio_pika/connection.py +++ b/aio_pika/connection.py @@ -39,11 +39,11 @@ class Connection(AbstractConnection): ), ) - _closed: bool + _closed: asyncio.Event @property def is_closed(self) -> bool: - return self._closed + return self._closed.is_set() async def close( self, exc: Optional[aiormq.abc.ExceptionType] = ConnectionClosed, @@ -53,7 +53,10 @@ async def close( if not transport: return await transport.close(exc) - self._closed = True + self._closed.set() + + async def wait(self) -> None: + await self._closed.wait() @classmethod def _parse_parameters(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: @@ -74,7 +77,7 @@ def __init__( ): self.loop = loop or asyncio.get_event_loop() self.transport = None - self._closed = False + self._closed = asyncio.Event() self._close_called = False self.url = URL(url) @@ -201,8 +204,7 @@ async def ready(self) -> None: def __del__(self) -> None: if ( self.is_closed or - self.loop.is_closed() or - not hasattr(self, "connection") + self.loop.is_closed() ): return diff --git a/aio_pika/queue.py b/aio_pika/queue.py index 2f0f9d45..8ecd2d6c 100644 --- a/aio_pika/queue.py +++ b/aio_pika/queue.py @@ -2,7 +2,7 @@ import sys from functools import partial from types import TracebackType -from typing import Any, Awaitable, Callable, Optional, Type, overload +from typing import Any, Awaitable, Callable, Optional, Set, Type, overload import aiormq from aiormq.abc import DeliveredMessage @@ -421,7 +421,22 @@ class QueueIterator(AbstractQueueIterator): def consumer_tag(self) -> Optional[ConsumerTag]: return getattr(self, "_consumer_tag", None) - async def close(self, *_: Any) -> Any: + async def close(self) -> None: + await self._on_close(self._amqp_queue.channel, None) + self._closed.set() + + async def _set_closed_callback( + self, + _channel: AbstractChannel, + exc: Optional[BaseException] + ) -> None: + self._closed.set() + + async def _on_close( + self, + _channel: AbstractChannel, + _exc: Optional[BaseException] + ) -> None: log.debug("Cancelling queue iterator %r", self) if not hasattr(self, "_consumer_tag"): @@ -436,7 +451,7 @@ async def close(self, *_: Any) -> Any: consumer_tag = self._consumer_tag del self._consumer_tag - self._amqp_queue.close_callbacks.remove(self.close) + self._amqp_queue.close_callbacks.discard(self._on_close) await self._amqp_queue.cancel(consumer_tag) log.debug("Queue iterator %r closed", self) @@ -482,9 +497,14 @@ def __init__(self, queue: Queue, **kwargs: Any): self._consumer_tag: ConsumerTag self._amqp_queue: AbstractQueue = queue self._queue = asyncio.Queue() + self._closed = asyncio.Event() self._consume_kwargs = kwargs - self._amqp_queue.close_callbacks.add(self.close) + self._amqp_queue.close_callbacks.add(self._on_close, weak=True) + self._amqp_queue.close_callbacks.add( + self._set_closed_callback, + weak=True + ) async def on_message(self, message: AbstractIncomingMessage) -> None: await self._queue.put(message) @@ -513,22 +533,61 @@ async def __aexit__( async def __anext__(self) -> IncomingMessage: if not hasattr(self, "_consumer_tag"): await self.consume() + + if self._closed.is_set(): + raise StopAsyncIteration + + message = asyncio.create_task( + self._queue.get(), + name=f"waiting for message from {self}" + ) + closed_channel = asyncio.create_task( + self._amqp_queue.channel.wait(), + name=f"waiting for channel {self._amqp_queue.channel} to close " + f"before a message from {self}" + ) + closed = asyncio.create_task( + self._closed.wait(), + name=f"waiting for queue iterator to close " + f"before a message from {self}" + ) + + timeout = self._consume_kwargs.get("timeout") + + done: Set[asyncio.Task] = set() + pending = {message, closed_channel, closed} + try: - return await asyncio.wait_for( - self._queue.get(), - timeout=self._consume_kwargs.get("timeout"), + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED, timeout=timeout ) + cancelled = False except asyncio.CancelledError: - timeout = self._consume_kwargs.get( - "timeout", - self.DEFAULT_CLOSE_TIMEOUT, - ) + cancelled = True + + for task in pending: + task.cancel() + + await asyncio.wait(pending) + + if not done and not cancelled: + if timeout is None: + timeout = self.DEFAULT_CLOSE_TIMEOUT log.info( "%r closing with timeout %d seconds", self, timeout, ) await asyncio.wait_for(self.close(), timeout=timeout) - raise + raise TimeoutError + + if not message.cancelled(): + return message.result() + + if not closed.cancelled() or not closed_channel.cancelled(): + self._closed.set() + raise StopAsyncIteration + + raise asyncio.CancelledError __all__ = ("Queue", "QueueIterator", "ConsumerTag") diff --git a/aio_pika/robust_channel.py b/aio_pika/robust_channel.py index 5ddf4ab9..8b62f51b 100644 --- a/aio_pika/robust_channel.py +++ b/aio_pika/robust_channel.py @@ -69,7 +69,8 @@ def __init__( self.reopen_callbacks: CallbackCollection = CallbackCollection(self) self.__restore_lock = asyncio.Lock() self.__restored = asyncio.Event() - self.close_callbacks.add(self.__close_callback) + + self.close_callbacks.remove(self._set_closed_callback) async def ready(self) -> None: await self._connection.ready() @@ -94,23 +95,31 @@ async def restore(self, channel: Any = None) -> None: await self.reopen() self.__restored.set() - async def __close_callback(self, _: Any, exc: BaseException) -> None: + async def _on_close( + self, + closing: asyncio.Future + ) -> Optional[BaseException]: + exc = await super()._on_close(closing) + if isinstance(exc, asyncio.CancelledError): # This happens only if the channel is forced to close from the # outside, for example, if the connection is closed. # Of course, here you need to exit from this function # as soon as possible and to avoid a recovery attempt. self.__restored.clear() - return + self._closed.set() + return exc in_restore_state = not self.__restored.is_set() self.__restored.clear() - if self._closed or in_restore_state: - return + if self._closed.is_set() or in_restore_state: + return exc await self.restore() + return exc + async def _open(self) -> None: await super()._open() await self.reopen_callbacks() diff --git a/aio_pika/robust_queue.py b/aio_pika/robust_queue.py index e7fdc1ba..91078d7c 100644 --- a/aio_pika/robust_queue.py +++ b/aio_pika/robust_queue.py @@ -151,6 +151,11 @@ def iterator(self, **kwargs: Any) -> AbstractQueueIterator: class RobustQueueIterator(QueueIterator): + def __init__(self, queue: Queue, **kwargs: Any): + super().__init__(queue, **kwargs) + + self._amqp_queue.close_callbacks.discard(self._set_closed_callback) + async def consume(self) -> None: while True: try: diff --git a/poetry.lock b/poetry.lock index 5243c142..709fdfdd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiomisc" @@ -1314,13 +1314,13 @@ files = [ [[package]] name = "types-docutils" -version = "0.20.0.20240227" +version = "0.20.0.20240304" description = "Typing stubs for docutils" optional = false python-versions = ">=3.8" files = [ - {file = "types-docutils-0.20.0.20240227.tar.gz", hash = "sha256:7f2dbb02356024b5db3efd9df26b236da050ad2eada89872e5284b4a394b7761"}, - {file = "types_docutils-0.20.0.20240227-py3-none-any.whl", hash = "sha256:51c139502ba0add871392cbc37200a3a64096e61eeb6396727443ba6d38ae579"}, + {file = "types-docutils-0.20.0.20240304.tar.gz", hash = "sha256:c35ae35ca835a5aeead758df411cd46cfb7e7f19f2b223c413dae7e069d5b0be"}, + {file = "types_docutils-0.20.0.20240304-py3-none-any.whl", hash = "sha256:ef02f9d05f2b61500638b1358cdf3fbf975cc5dedaa825a2eb5ea71b7318a760"}, ] [[package]] @@ -1629,4 +1629,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "a8d27cbd8d26a79f324915969fae155576f7509726d0e5c2ce16e8a5eae2eb9d" +content-hash = "8f275c943e32fabe432322ce72fe24f7e9b0a29a20693d9714eb2e6fca4f9b31" diff --git a/pyproject.toml b/pyproject.toml index b9063e5f..378c1950 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,9 @@ types-setuptools = "^65.6.0.2" setuptools = "^69.0.3" testcontainers = "^3.7.1" +[tool.poetry.group.uvloop] +optional = true + [tool.poetry.group.uvloop.dependencies] uvloop = "^0.19" diff --git a/tests/test_amqp.py b/tests/test_amqp.py index 33ddb1ca..632228b8 100644 --- a/tests/test_amqp.py +++ b/tests/test_amqp.py @@ -4,7 +4,7 @@ import time import uuid from datetime import datetime, timezone -from typing import Callable, Optional +from typing import Callable, Optional, List from unittest import mock import aiormq.exceptions @@ -23,6 +23,7 @@ ) from aio_pika.exchange import ExchangeType from aio_pika.message import ReturnedMessage +from aio_pika.queue import QueueIterator from tests import get_random_name @@ -1254,7 +1255,7 @@ async def publisher(): async def test_async_for_queue_context( self, event_loop, connection, declare_queue, - ): + ) -> None: channel2 = await self.create_channel(connection) queue = await declare_queue( @@ -1263,31 +1264,46 @@ async def test_async_for_queue_context( channel=channel2, ) - messages = 100 + messages: asyncio.Queue[bytes] = asyncio.Queue(100) + condition = asyncio.Condition() - async def publisher(): + async def publisher() -> None: channel1 = await self.create_channel(connection) - for i in range(messages): + for i in range(messages.maxsize): + body = str(i).encode() + await messages.put(body) await channel1.default_exchange.publish( - Message(body=str(i).encode()), routing_key=queue.name, + Message(body=body), routing_key=queue.name, ) - event_loop.create_task(publisher()) + async def consumer() -> None: + async with queue.iterator() as queue_iterator: + async for message in queue_iterator: + async with message.process(): + async with condition: + data.append(message.body) + messages.task_done() + condition.notify() + + async def application_stop_request() -> None: + async with condition: + await condition.wait_for(messages.full) + await messages.join() + await asyncio.sleep(1) + await connection.close() - count = 0 - data = list() + p = event_loop.create_task(publisher()) + c = event_loop.create_task(consumer()) + asr = event_loop.create_task(application_stop_request()) - async with queue.iterator() as queue_iterator: - async for message in queue_iterator: - async with message.process(): - count += 1 - data.append(message.body) + data: List[bytes] = list() - if count >= messages: - break + await asyncio.gather(p, c, asr) - assert data == list(map(lambda x: str(x).encode(), range(messages))) + assert data == list( + map(lambda x: str(x).encode(), range(messages.maxsize)) + ) async def test_async_with_connection( self, create_connection: Callable, @@ -1303,32 +1319,47 @@ async def test_async_with_connection( channel=channel, ) - messages = 100 + condition = asyncio.Condition() + messages: asyncio.Queue[bytes] = asyncio.Queue( + 100 + ) async def publisher(): channel1 = await self.create_channel(connection) - for i in range(messages): + for i in range(messages.maxsize): + body = str(i).encode() + await messages.put(body) await channel1.default_exchange.publish( - Message(body=str(i).encode()), routing_key=queue.name, + Message(body=body), routing_key=queue.name, ) - event_loop.create_task(publisher()) - - count = 0 data = list() - async with queue.iterator() as queue_iterator: - async for message in queue_iterator: - async with message.process(): - count += 1 - data.append(message.body) + async def consume_loop(): + async with queue.iterator() as queue_iterator: + async for message in queue_iterator: + async with message.process(): + async with condition: + data.append(message.body) + condition.notify() + messages.task_done() + + async def application_close_request(): + async with condition: + await condition.wait_for(messages.full) + await messages.join() + await asyncio.sleep(1) + await connection.close() + + p = event_loop.create_task(publisher()) + cl = event_loop.create_task(consume_loop()) + acr = event_loop.create_task(application_close_request()) - if count >= messages: - break + await asyncio.gather(p, cl, acr) assert data == list( - map(lambda x: str(x).encode(), range(messages)), + map(lambda x: str(x).encode(), range(messages.maxsize)), ) assert channel.is_closed @@ -1384,47 +1415,37 @@ async def test_channel_locked_resource( async def test_queue_iterator_close_was_called_twice( self, create_connection: Callable, event_loop, declare_queue, ): - future = event_loop.create_future() event = asyncio.Event() queue_name = get_random_name() + iterator: QueueIterator async def task_inner(): - nonlocal future nonlocal event + nonlocal iterator nonlocal create_connection - try: - connection = await create_connection() - - async with connection: - channel = await self.create_channel(connection) + connection = await create_connection() - queue = await declare_queue( - queue_name, channel=channel, cleanup=False, - ) + async with connection: + channel = await self.create_channel(connection) - async with queue.iterator() as q: - event.set() + queue = await declare_queue( + queue_name, channel=channel, cleanup=False, + ) - async for message in q: - with message.process(): - break + async with queue.iterator() as iterator: + event.set() - except asyncio.CancelledError as e: - future.set_exception(e) - raise + async for message in iterator: + async with message.process(): + pytest.fail("who sent this message?") task = event_loop.create_task(task_inner()) await event.wait() - event_loop.call_soon(task.cancel) - - with pytest.raises(asyncio.CancelledError): - await task - - with pytest.raises(asyncio.CancelledError): - await future + await iterator.close() + await task async def test_queue_iterator_close_with_noack( self, @@ -1433,7 +1454,7 @@ async def test_queue_iterator_close_with_noack( add_cleanup: Callable, declare_queue, ): - messages = [] + messages: asyncio.Queue = asyncio.Queue() queue_name = get_random_name("test_queue") body = get_random_name("test_body").encode() @@ -1454,7 +1475,7 @@ async def task_inner(): async with queue.iterator(no_ack=True) as q: async for message in q: - messages.append(message) + await messages.put(message) return async with await create_connection() as connection: @@ -1471,14 +1492,70 @@ async def task_inner(): task = event_loop.create_task(task_inner()) - await task + message = await messages.get() - assert messages - assert messages[0].body == body + assert message + assert message.body == body finally: + await task await queue.delete() + async def test_queue_iterator_throws_cancelled_error( + self, + create_connection: Callable, + event_loop, + add_cleanup: Callable, + declare_queue, + ): + event_loop.set_debug(True) + queue_name = get_random_name("test_queue") + + connection = await create_connection() + + async with connection: + channel = await self.create_channel(connection) + + queue = await channel.declare_queue( + queue_name, + ) + + iterator = queue.iterator() + task = event_loop.create_task(iterator.__anext__()) + done, pending = await asyncio.wait({task}, timeout=1) + assert not done + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + async def test_queue_iterator_throws_timeout_error( + self, + create_connection: Callable, + event_loop, + add_cleanup: Callable, + declare_queue, + ): + event_loop.set_debug(True) + queue_name = get_random_name("test_queue") + + connection = await create_connection() + + async with connection: + channel = await self.create_channel(connection) + + queue = await channel.declare_queue( + queue_name, + ) + + iterator = queue.iterator(timeout=1) + task = event_loop.create_task(iterator.__anext__()) + done, pending = await asyncio.wait({task}, timeout=5) + assert done + + with pytest.raises(TimeoutError): + await task + async def test_passive_for_exchange( self, declare_exchange: Callable, connection, add_cleanup: Callable, ):