Skip to content

Commit

Permalink
QueueIterator raised StopAsyncIteration when channel is closed.
Browse files Browse the repository at this point in the history
  • Loading branch information
Darsstar committed Mar 4, 2024
1 parent a3ef44b commit 28d970d
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 100 deletions.
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[run]
omit = aio_pika/compat.py
branch = True

[report]
exclude_lines =
pragma: no cover
raise NotImplementedError
12 changes: 10 additions & 2 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,14 @@ def iterator(self, **kwargs: Any) -> "AbstractQueueIterator":
raise NotImplementedError


class AbstractQueueIterator(AsyncIterable):
class AbstractQueueIterator(AsyncIterable[AbstractIncomingMessage]):
_amqp_queue: AbstractQueue
_queue: asyncio.Queue
_consumer_tag: ConsumerTag
_consume_kwargs: Dict[str, Any]

@abstractmethod
def close(self, *_: Any) -> Awaitable[Any]:
def close(self) -> Awaitable[Any]:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -531,6 +531,10 @@ def is_closed(self) -> bool:
def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[None]:
raise NotImplementedError

@abstractmethod
async def wait(self) -> None:
raise NotImplementedError

@abstractmethod
async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:
raise NotImplementedError
Expand Down Expand Up @@ -741,6 +745,10 @@ def is_closed(self) -> bool:
async def close(self, exc: ExceptionType = asyncio.CancelledError) -> None:
raise NotImplementedError

@abstractmethod
async def wait(self) -> None:
raise NotImplementedError

@abstractmethod
async def connect(self, timeout: TimeoutType = None) -> None:
raise NotImplementedError
Expand Down
34 changes: 25 additions & 9 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __init__(
confirmations (in pursuit of performance)
"""

# init before any exception can be raised since we use it in __del__
self._closed: asyncio.Event = asyncio.Event()

if not publisher_confirms and on_return_raises:
raise RuntimeError(
'"on_return_raises" not applicable '
Expand All @@ -77,9 +80,6 @@ def __init__(

self._connection: AbstractConnection = connection

# That's means user closed channel instance explicitly
self._closed: bool = False

self._channel: Optional[UnderlayChannel] = None
self._channel_number = channel_number

Expand All @@ -89,6 +89,8 @@ def __init__(
self.publisher_confirms = publisher_confirms
self.on_return_raises = on_return_raises

self.close_callbacks.add(self._set_closed_callback)

@property
def is_initialized(self) -> bool:
"""Returns True when the channel has been opened
Expand All @@ -99,7 +101,7 @@ def is_initialized(self) -> bool:
def is_closed(self) -> bool:
"""Returns True when the channel has been closed from the broker
side or after the close() method has been called."""
if not self.is_initialized or self._closed:
if not self.is_initialized or self._closed.is_set():
return True
channel = self._channel
if channel is None:
Expand All @@ -119,8 +121,11 @@ async def close(
return

log.debug("Closing channel %r", self)
self._closed = True
await self._channel.close()
self._closed.set()

async def wait(self) -> None:
await self._closed.wait()

async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:

Expand Down Expand Up @@ -174,12 +179,12 @@ async def _open(self) -> None:
await channel.close(e)
self._channel = None
raise
self._closed = False
self._closed.clear()

async def initialize(self, timeout: TimeoutType = None) -> None:
if self.is_initialized:
raise RuntimeError("Already initialized")
elif self._closed:
elif self._closed.is_set():
raise RuntimeError("Can't initialize closed channel")

await self._open()
Expand All @@ -197,7 +202,10 @@ async def _on_open(self) -> None:
type=ExchangeType.DIRECT,
)

async def _on_close(self, closing: asyncio.Future) -> None:
async def _on_close(
self,
closing: asyncio.Future
) -> Optional[BaseException]:
try:
exc = closing.exception()
except asyncio.CancelledError as e:
Expand All @@ -207,6 +215,14 @@ async def _on_close(self, closing: asyncio.Future) -> None:
if self._channel and self._channel.channel:
self._channel.channel.on_return_callbacks.discard(self._on_return)

return exc

async def _set_closed_callback(
self,
_: AbstractChannel, exc: BaseException
) -> None:
self._closed.set()

async def _on_initialized(self) -> None:
channel = await self.get_underlay_channel()
channel.on_return_callbacks.add(self._on_return)
Expand All @@ -219,7 +235,7 @@ async def reopen(self) -> None:
await self._open()

def __del__(self) -> None:
self._closed = True
self._closed.set()
self._channel = None

async def declare_exchange(
Expand Down
14 changes: 8 additions & 6 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ class Connection(AbstractConnection):
),
)

_closed: bool
_closed: asyncio.Event

@property
def is_closed(self) -> bool:
return self._closed
return self._closed.is_set()

async def close(
self, exc: Optional[aiormq.abc.ExceptionType] = ConnectionClosed,
Expand All @@ -53,7 +53,10 @@ async def close(
if not transport:
return
await transport.close(exc)
self._closed = True
self._closed.set()

async def wait(self) -> None:
await self._closed.wait()

@classmethod
def _parse_parameters(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -74,7 +77,7 @@ def __init__(
):
self.loop = loop or asyncio.get_event_loop()
self.transport = None
self._closed = False
self._closed = asyncio.Event()
self._close_called = False

self.url = URL(url)
Expand Down Expand Up @@ -201,8 +204,7 @@ async def ready(self) -> None:
def __del__(self) -> None:
if (
self.is_closed or
self.loop.is_closed() or
not hasattr(self, "connection")
self.loop.is_closed()
):
return

Expand Down
83 changes: 71 additions & 12 deletions aio_pika/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from functools import partial
from types import TracebackType
from typing import Any, Awaitable, Callable, Optional, Type, overload
from typing import Any, Awaitable, Callable, Optional, Set, Type, overload

import aiormq
from aiormq.abc import DeliveredMessage
Expand Down Expand Up @@ -421,7 +421,22 @@ class QueueIterator(AbstractQueueIterator):
def consumer_tag(self) -> Optional[ConsumerTag]:
return getattr(self, "_consumer_tag", None)

async def close(self, *_: Any) -> Any:
async def close(self) -> None:
await self._on_close(self._amqp_queue.channel, None)
self._closed.set()

async def _set_closed_callback(
self,
_channel: AbstractChannel,
exc: Optional[BaseException]
) -> None:
self._closed.set()

async def _on_close(
self,
_channel: AbstractChannel,
_exc: Optional[BaseException]
) -> None:
log.debug("Cancelling queue iterator %r", self)

if not hasattr(self, "_consumer_tag"):
Expand All @@ -436,7 +451,7 @@ async def close(self, *_: Any) -> Any:
consumer_tag = self._consumer_tag
del self._consumer_tag

self._amqp_queue.close_callbacks.remove(self.close)
self._amqp_queue.close_callbacks.discard(self._on_close)
await self._amqp_queue.cancel(consumer_tag)

log.debug("Queue iterator %r closed", self)
Expand Down Expand Up @@ -482,9 +497,14 @@ def __init__(self, queue: Queue, **kwargs: Any):
self._consumer_tag: ConsumerTag
self._amqp_queue: AbstractQueue = queue
self._queue = asyncio.Queue()
self._closed = asyncio.Event()
self._consume_kwargs = kwargs

self._amqp_queue.close_callbacks.add(self.close)
self._amqp_queue.close_callbacks.add(self._on_close, weak=True)
self._amqp_queue.close_callbacks.add(
self._set_closed_callback,
weak=True
)

async def on_message(self, message: AbstractIncomingMessage) -> None:
await self._queue.put(message)
Expand Down Expand Up @@ -513,22 +533,61 @@ async def __aexit__(
async def __anext__(self) -> IncomingMessage:
if not hasattr(self, "_consumer_tag"):
await self.consume()

if self._closed.is_set():
raise StopAsyncIteration

message = asyncio.create_task(
self._queue.get(),
name=f"waiting for message from {self}"
)
closed_channel = asyncio.create_task(
self._amqp_queue.channel.wait(),
name=f"waiting for channel {self._amqp_queue.channel} to close "
f"before a message from {self}"
)
closed = asyncio.create_task(
self._closed.wait(),
name=f"waiting for queue iterator to close "
f"before a message from {self}"
)

timeout = self._consume_kwargs.get("timeout")

done: Set[asyncio.Task] = set()
pending = {message, closed_channel, closed}

try:
return await asyncio.wait_for(
self._queue.get(),
timeout=self._consume_kwargs.get("timeout"),
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED, timeout=timeout
)
cancelled = False
except asyncio.CancelledError:
timeout = self._consume_kwargs.get(
"timeout",
self.DEFAULT_CLOSE_TIMEOUT,
)
cancelled = True

for task in pending:
task.cancel()

await asyncio.wait(pending)

if not done and not cancelled:
if timeout is None:
timeout = self.DEFAULT_CLOSE_TIMEOUT
log.info(
"%r closing with timeout %d seconds",
self, timeout,
)
await asyncio.wait_for(self.close(), timeout=timeout)
raise
raise TimeoutError

if not message.cancelled():
return message.result()

if not closed.cancelled() or not closed_channel.cancelled():
self._closed.set()
raise StopAsyncIteration

raise asyncio.CancelledError


__all__ = ("Queue", "QueueIterator", "ConsumerTag")
19 changes: 14 additions & 5 deletions aio_pika/robust_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
self.reopen_callbacks: CallbackCollection = CallbackCollection(self)
self.__restore_lock = asyncio.Lock()
self.__restored = asyncio.Event()
self.close_callbacks.add(self.__close_callback)

self.close_callbacks.remove(self._set_closed_callback)

async def ready(self) -> None:
await self._connection.ready()
Expand All @@ -94,23 +95,31 @@ async def restore(self, channel: Any = None) -> None:
await self.reopen()
self.__restored.set()

async def __close_callback(self, _: Any, exc: BaseException) -> None:
async def _on_close(
self,
closing: asyncio.Future
) -> Optional[BaseException]:
exc = await super()._on_close(closing)

if isinstance(exc, asyncio.CancelledError):
# This happens only if the channel is forced to close from the
# outside, for example, if the connection is closed.
# Of course, here you need to exit from this function
# as soon as possible and to avoid a recovery attempt.
self.__restored.clear()
return
self._closed.set()
return exc

in_restore_state = not self.__restored.is_set()
self.__restored.clear()

if self._closed or in_restore_state:
return
if self._closed.is_set() or in_restore_state:
return exc

await self.restore()

return exc

async def _open(self) -> None:
await super()._open()
await self.reopen_callbacks()
Expand Down
5 changes: 5 additions & 0 deletions aio_pika/robust_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def iterator(self, **kwargs: Any) -> AbstractQueueIterator:


class RobustQueueIterator(QueueIterator):
def __init__(self, queue: Queue, **kwargs: Any):
super().__init__(queue, **kwargs)

self._amqp_queue.close_callbacks.discard(self._set_closed_callback)

async def consume(self) -> None:
while True:
try:
Expand Down
Loading

0 comments on commit 28d970d

Please sign in to comment.