diff --git a/aio_pika/abc.py b/aio_pika/abc.py index 227356ea..0c9244ac 100644 --- a/aio_pika/abc.py +++ b/aio_pika/abc.py @@ -367,14 +367,14 @@ def iterator(self, **kwargs: Any) -> "AbstractQueueIterator": raise NotImplementedError -class AbstractQueueIterator(AsyncIterable): +class AbstractQueueIterator(AsyncIterable[AbstractIncomingMessage]): _amqp_queue: AbstractQueue - _queue: asyncio.Queue + _queue: asyncio.Queue[AbstractIncomingMessage] _consumer_tag: ConsumerTag _consume_kwargs: Dict[str, Any] @abstractmethod - def close(self, *_: Any) -> Awaitable[Any]: + def close(self) -> Awaitable[Any]: raise NotImplementedError @abstractmethod @@ -532,6 +532,10 @@ def is_closed(self) -> bool: def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[None]: raise NotImplementedError + @abstractmethod + async def wait(self) -> None: + raise NotImplementedError + @abstractmethod async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel: raise NotImplementedError @@ -742,6 +746,10 @@ def is_closed(self) -> bool: async def close(self, exc: ExceptionType = asyncio.CancelledError) -> None: raise NotImplementedError + @abstractmethod + async def wait(self) -> None: + raise NotImplementedError + @abstractmethod async def connect(self, timeout: TimeoutType = None) -> None: raise NotImplementedError diff --git a/aio_pika/channel.py b/aio_pika/channel.py index 7741b7af..15ab9da1 100644 --- a/aio_pika/channel.py +++ b/aio_pika/channel.py @@ -77,8 +77,7 @@ def __init__( self._connection: AbstractConnection = connection - # That's means user closed channel instance explicitly - self._closed: bool = False + self._closed: asyncio.Event = asyncio.Event() self._channel: Optional[UnderlayChannel] = None self._channel_number = channel_number @@ -89,6 +88,8 @@ def __init__( self.publisher_confirms = publisher_confirms self.on_return_raises = on_return_raises + self.close_callbacks.add(self._set_closed_callback) + @property def is_initialized(self) -> bool: """Returns True when the channel has been opened @@ -99,7 +100,7 @@ def is_initialized(self) -> bool: def is_closed(self) -> bool: """Returns True when the channel has been closed from the broker side or after the close() method has been called.""" - if not self.is_initialized or self._closed: + if not self.is_initialized or self._closed.is_set(): return True channel = self._channel if channel is None: @@ -119,8 +120,11 @@ async def close( return log.debug("Closing channel %r", self) - self._closed = True await self._channel.close() + self._closed.set() + + async def wait(self) -> None: + await self._closed.wait() async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel: @@ -174,12 +178,12 @@ async def _open(self) -> None: await channel.close(e) self._channel = None raise - self._closed = False + self._closed.clear() async def initialize(self, timeout: TimeoutType = None) -> None: if self.is_initialized: raise RuntimeError("Already initialized") - elif self._closed: + elif self._closed.is_set(): raise RuntimeError("Can't initialize closed channel") await self._open() @@ -197,7 +201,10 @@ async def _on_open(self) -> None: type=ExchangeType.DIRECT, ) - async def _on_close(self, closing: asyncio.Future) -> None: + async def _on_close( + self, + closing: asyncio.Future + ) -> Optional[BaseException]: try: exc = closing.exception() except asyncio.CancelledError as e: @@ -207,6 +214,14 @@ async def _on_close(self, closing: asyncio.Future) -> None: if self._channel and self._channel.channel: self._channel.channel.on_return_callbacks.discard(self._on_return) + return exc + + async def _set_closed_callback( + self, + _: AbstractChannel, exc: BaseException + ) -> None: + self._closed.set() + async def _on_initialized(self) -> None: channel = await self.get_underlay_channel() channel.on_return_callbacks.add(self._on_return) @@ -219,7 +234,7 @@ async def reopen(self) -> None: await self._open() def __del__(self) -> None: - self._closed = True + self._closed.set() self._channel = None async def declare_exchange( diff --git a/aio_pika/connection.py b/aio_pika/connection.py index fb49db4b..6976af7d 100644 --- a/aio_pika/connection.py +++ b/aio_pika/connection.py @@ -39,11 +39,11 @@ class Connection(AbstractConnection): ), ) - _closed: bool + _closed: asyncio.Event @property def is_closed(self) -> bool: - return self._closed + return self._closed.is_set() async def close( self, exc: Optional[aiormq.abc.ExceptionType] = ConnectionClosed, @@ -53,7 +53,10 @@ async def close( if not transport: return await transport.close(exc) - self._closed = True + self._closed.set() + + async def wait(self) -> None: + await self._closed.wait() @classmethod def _parse_parameters(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: @@ -74,7 +77,7 @@ def __init__( ): self.loop = loop or asyncio.get_event_loop() self.transport = None - self._closed = False + self._closed = asyncio.Event() self._close_called = False self.url = URL(url) @@ -201,8 +204,7 @@ async def ready(self) -> None: def __del__(self) -> None: if ( self.is_closed or - self.loop.is_closed() or - not hasattr(self, "connection") + self.loop.is_closed() ): return diff --git a/aio_pika/queue.py b/aio_pika/queue.py index 2f0f9d45..1981fcc7 100644 --- a/aio_pika/queue.py +++ b/aio_pika/queue.py @@ -421,7 +421,22 @@ class QueueIterator(AbstractQueueIterator): def consumer_tag(self) -> Optional[ConsumerTag]: return getattr(self, "_consumer_tag", None) - async def close(self, *_: Any) -> Any: + async def close(self) -> None: + await self._on_close(self._amqp_queue.channel, None) + self._closed.set() + + async def _set_closed_callback( + self, + _channel: AbstractChannel, + exc: Optional[BaseException] + ) -> None: + self._closed.set() + + async def _on_close( + self, + _channel: AbstractChannel, + _exc: Optional[BaseException] + ) -> None: log.debug("Cancelling queue iterator %r", self) if not hasattr(self, "_consumer_tag"): @@ -436,7 +451,7 @@ async def close(self, *_: Any) -> Any: consumer_tag = self._consumer_tag del self._consumer_tag - self._amqp_queue.close_callbacks.remove(self.close) + self._amqp_queue.close_callbacks.discard(self._on_close) await self._amqp_queue.cancel(consumer_tag) log.debug("Queue iterator %r closed", self) @@ -482,9 +497,14 @@ def __init__(self, queue: Queue, **kwargs: Any): self._consumer_tag: ConsumerTag self._amqp_queue: AbstractQueue = queue self._queue = asyncio.Queue() + self._closed = asyncio.Event() self._consume_kwargs = kwargs - self._amqp_queue.close_callbacks.add(self.close) + self._amqp_queue.close_callbacks.add(self._on_close, weak=True) + self._amqp_queue.close_callbacks.add( + self._set_closed_callback, + weak=True + ) async def on_message(self, message: AbstractIncomingMessage) -> None: await self._queue.put(message) @@ -513,22 +533,64 @@ async def __aexit__( async def __anext__(self) -> IncomingMessage: if not hasattr(self, "_consumer_tag"): await self.consume() - try: - return await asyncio.wait_for( - self._queue.get(), - timeout=self._consume_kwargs.get("timeout"), - ) - except asyncio.CancelledError: - timeout = self._consume_kwargs.get( - "timeout", - self.DEFAULT_CLOSE_TIMEOUT, + + if self._closed.is_set(): + raise StopAsyncIteration + + message = asyncio.create_task( + self._queue.get(), + name=f"waiting for message from {self}" + ) + closed_channel = asyncio.create_task( + self._amqp_queue.channel.wait(), + name=f"waiting for channel {self._amqp_queue.channel} to close " + f"before a message from {self}" + ) + closed = asyncio.create_task( + self._closed.wait(), + name=f"waiting for queue iterator to close " + f"before a message from {self}" + ) + + timeout = self._consume_kwargs.get("timeout") + sleep = asyncio.get_running_loop().create_future() + + if timeout is not None: + sleep = asyncio.create_task( + asyncio.sleep(timeout), + name=f"waiting for {self} to timeout after {timeout} seconds" ) + else: + timeout = self.DEFAULT_CLOSE_TIMEOUT + + pending = {message, closed_channel, closed, sleep} + + done, pending = await asyncio.wait( + pending, + return_when=asyncio.FIRST_COMPLETED + ) + + for task in pending: + task.cancel() + + await asyncio.wait(pending) + + if not message.cancelled(): + return message.result() + + if not closed.cancelled() or not closed_channel.cancelled(): + self._closed.set() + raise StopAsyncIteration + + if not sleep.cancelled(): log.info( "%r closing with timeout %d seconds", self, timeout, ) await asyncio.wait_for(self.close(), timeout=timeout) - raise + raise TimeoutError + + raise asyncio.CancelledError __all__ = ("Queue", "QueueIterator", "ConsumerTag") diff --git a/aio_pika/robust_channel.py b/aio_pika/robust_channel.py index 5ddf4ab9..8b62f51b 100644 --- a/aio_pika/robust_channel.py +++ b/aio_pika/robust_channel.py @@ -69,7 +69,8 @@ def __init__( self.reopen_callbacks: CallbackCollection = CallbackCollection(self) self.__restore_lock = asyncio.Lock() self.__restored = asyncio.Event() - self.close_callbacks.add(self.__close_callback) + + self.close_callbacks.remove(self._set_closed_callback) async def ready(self) -> None: await self._connection.ready() @@ -94,23 +95,31 @@ async def restore(self, channel: Any = None) -> None: await self.reopen() self.__restored.set() - async def __close_callback(self, _: Any, exc: BaseException) -> None: + async def _on_close( + self, + closing: asyncio.Future + ) -> Optional[BaseException]: + exc = await super()._on_close(closing) + if isinstance(exc, asyncio.CancelledError): # This happens only if the channel is forced to close from the # outside, for example, if the connection is closed. # Of course, here you need to exit from this function # as soon as possible and to avoid a recovery attempt. self.__restored.clear() - return + self._closed.set() + return exc in_restore_state = not self.__restored.is_set() self.__restored.clear() - if self._closed or in_restore_state: - return + if self._closed.is_set() or in_restore_state: + return exc await self.restore() + return exc + async def _open(self) -> None: await super()._open() await self.reopen_callbacks() diff --git a/aio_pika/robust_queue.py b/aio_pika/robust_queue.py index e7fdc1ba..91078d7c 100644 --- a/aio_pika/robust_queue.py +++ b/aio_pika/robust_queue.py @@ -151,6 +151,11 @@ def iterator(self, **kwargs: Any) -> AbstractQueueIterator: class RobustQueueIterator(QueueIterator): + def __init__(self, queue: Queue, **kwargs: Any): + super().__init__(queue, **kwargs) + + self._amqp_queue.close_callbacks.discard(self._set_closed_callback) + async def consume(self) -> None: while True: try: diff --git a/tests/test_amqp.py b/tests/test_amqp.py index 33ddb1ca..3eda497f 100644 --- a/tests/test_amqp.py +++ b/tests/test_amqp.py @@ -4,7 +4,7 @@ import time import uuid from datetime import datetime, timezone -from typing import Callable, Optional +from typing import Callable, Optional, List from unittest import mock import aiormq.exceptions @@ -23,6 +23,7 @@ ) from aio_pika.exchange import ExchangeType from aio_pika.message import ReturnedMessage +from aio_pika.queue import QueueIterator from tests import get_random_name @@ -1254,7 +1255,7 @@ async def publisher(): async def test_async_for_queue_context( self, event_loop, connection, declare_queue, - ): + ) -> None: channel2 = await self.create_channel(connection) queue = await declare_queue( @@ -1263,31 +1264,46 @@ async def test_async_for_queue_context( channel=channel2, ) - messages = 100 + messages: asyncio.Queue[bytes] = asyncio.Queue(100) + condition = asyncio.Condition() - async def publisher(): + async def publisher() -> None: channel1 = await self.create_channel(connection) - for i in range(messages): + for i in range(messages.maxsize): + body = str(i).encode() + await messages.put(body) await channel1.default_exchange.publish( - Message(body=str(i).encode()), routing_key=queue.name, + Message(body=body), routing_key=queue.name, ) - event_loop.create_task(publisher()) + async def consumer() -> None: + async with queue.iterator() as queue_iterator: + async for message in queue_iterator: + async with message.process(): + async with condition: + data.append(message.body) + messages.task_done() + condition.notify() + + async def application_stop_request() -> None: + async with condition: + await condition.wait_for(messages.full) + await messages.join() + await asyncio.sleep(1) + await connection.close() - count = 0 - data = list() + p = event_loop.create_task(publisher()) + c = event_loop.create_task(consumer()) + asr = event_loop.create_task(application_stop_request()) - async with queue.iterator() as queue_iterator: - async for message in queue_iterator: - async with message.process(): - count += 1 - data.append(message.body) + data: List[bytes] = list() - if count >= messages: - break + await asyncio.gather(p, c, asr) - assert data == list(map(lambda x: str(x).encode(), range(messages))) + assert data == list( + map(lambda x: str(x).encode(), range(messages.maxsize)) + ) async def test_async_with_connection( self, create_connection: Callable, @@ -1303,32 +1319,47 @@ async def test_async_with_connection( channel=channel, ) - messages = 100 + condition = asyncio.Condition() + messages: asyncio.Queue[bytes] = asyncio.Queue( + 100 + ) async def publisher(): channel1 = await self.create_channel(connection) - for i in range(messages): + for i in range(messages.maxsize): + body = str(i).encode() + await messages.put(body) await channel1.default_exchange.publish( - Message(body=str(i).encode()), routing_key=queue.name, + Message(body=body), routing_key=queue.name, ) - event_loop.create_task(publisher()) - - count = 0 data = list() - async with queue.iterator() as queue_iterator: - async for message in queue_iterator: - async with message.process(): - count += 1 - data.append(message.body) + async def consume_loop(): + async with queue.iterator() as queue_iterator: + async for message in queue_iterator: + async with message.process(): + async with condition: + data.append(message.body) + condition.notify() + messages.task_done() + + async def application_close_request(): + async with condition: + await condition.wait_for(messages.full) + await messages.join() + await asyncio.sleep(1) + await connection.close() + + p = event_loop.create_task(publisher()) + cl = event_loop.create_task(consume_loop()) + acr = event_loop.create_task(application_close_request()) - if count >= messages: - break + await asyncio.gather(p, cl, acr) assert data == list( - map(lambda x: str(x).encode(), range(messages)), + map(lambda x: str(x).encode(), range(messages.maxsize)), ) assert channel.is_closed @@ -1384,47 +1415,37 @@ async def test_channel_locked_resource( async def test_queue_iterator_close_was_called_twice( self, create_connection: Callable, event_loop, declare_queue, ): - future = event_loop.create_future() event = asyncio.Event() queue_name = get_random_name() + iterator: QueueIterator async def task_inner(): - nonlocal future nonlocal event + nonlocal iterator nonlocal create_connection - try: - connection = await create_connection() - - async with connection: - channel = await self.create_channel(connection) + connection = await create_connection() - queue = await declare_queue( - queue_name, channel=channel, cleanup=False, - ) + async with connection: + channel = await self.create_channel(connection) - async with queue.iterator() as q: - event.set() + queue = await declare_queue( + queue_name, channel=channel, cleanup=False, + ) - async for message in q: - with message.process(): - break + async with queue.iterator() as iterator: + event.set() - except asyncio.CancelledError as e: - future.set_exception(e) - raise + async for message in iterator: + with message.process(): + pytest.fail("who sent this message?") task = event_loop.create_task(task_inner()) await event.wait() - event_loop.call_soon(task.cancel) - - with pytest.raises(asyncio.CancelledError): - await task - - with pytest.raises(asyncio.CancelledError): - await future + await iterator.close() + await task async def test_queue_iterator_close_with_noack( self, @@ -1433,7 +1454,7 @@ async def test_queue_iterator_close_with_noack( add_cleanup: Callable, declare_queue, ): - messages = [] + messages = asyncio.Queue() queue_name = get_random_name("test_queue") body = get_random_name("test_body").encode() @@ -1454,7 +1475,7 @@ async def task_inner(): async with queue.iterator(no_ack=True) as q: async for message in q: - messages.append(message) + await messages.put(message) return async with await create_connection() as connection: @@ -1471,12 +1492,13 @@ async def task_inner(): task = event_loop.create_task(task_inner()) - await task + message = await messages.get() - assert messages - assert messages[0].body == body + assert message + assert message.body == body finally: + await task await queue.delete() async def test_passive_for_exchange(