Skip to content

Commit

Permalink
QueueIterator raised StopAsyncIteration when channel is closed.
Browse files Browse the repository at this point in the history
  • Loading branch information
Darsstar committed Mar 19, 2024
1 parent d43ade7 commit 9820f08
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 102 deletions.
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[run]
omit = aio_pika/compat.py
branch = True

[report]
exclude_lines =
pragma: no cover
raise NotImplementedError
12 changes: 10 additions & 2 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,14 @@ def iterator(self, **kwargs: Any) -> "AbstractQueueIterator":
raise NotImplementedError


class AbstractQueueIterator(AsyncIterable):
class AbstractQueueIterator(AsyncIterable[AbstractIncomingMessage]):
_amqp_queue: AbstractQueue
_queue: asyncio.Queue
_consumer_tag: ConsumerTag
_consume_kwargs: Dict[str, Any]

@abstractmethod
def close(self, *_: Any) -> Awaitable[Any]:
def close(self) -> Awaitable[Any]:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -531,6 +531,10 @@ def is_closed(self) -> bool:
def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[None]:
raise NotImplementedError

@abstractmethod
def closed(self) -> Awaitable[Literal[True]]:
raise NotImplementedError

@abstractmethod
async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:
raise NotImplementedError
Expand Down Expand Up @@ -741,6 +745,10 @@ def is_closed(self) -> bool:
async def close(self, exc: ExceptionType = asyncio.CancelledError) -> None:
raise NotImplementedError

@abstractmethod
def closed(self) -> Awaitable[Literal[True]]:
raise NotImplementedError

@abstractmethod
async def connect(self, timeout: TimeoutType = None) -> None:
raise NotImplementedError
Expand Down
44 changes: 35 additions & 9 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import asyncio
import contextlib
import warnings
from abc import ABC
from types import TracebackType
from typing import Any, AsyncContextManager, Generator, Optional, Type, Union
from typing import (
Any, AsyncContextManager, Awaitable, Generator, Literal, Optional, Type,
Union,
)
from warnings import warn

import aiormq
Expand Down Expand Up @@ -77,8 +81,7 @@ def __init__(

self._connection: AbstractConnection = connection

# That's means user closed channel instance explicitly
self._closed: bool = False
self._closed: asyncio.Future[Literal[True]] = asyncio.get_running_loop().create_future()

self._channel: Optional[UnderlayChannel] = None
self._channel_number = channel_number
Expand All @@ -89,6 +92,8 @@ def __init__(
self.publisher_confirms = publisher_confirms
self.on_return_raises = on_return_raises

self.close_callbacks.add(self._set_closed_callback)

@property
def is_initialized(self) -> bool:
"""Returns True when the channel has been opened
Expand All @@ -99,7 +104,7 @@ def is_initialized(self) -> bool:
def is_closed(self) -> bool:
"""Returns True when the channel has been closed from the broker
side or after the close() method has been called."""
if not self.is_initialized or self._closed:
if not self.is_initialized or self._closed.done():
return True
channel = self._channel
if channel is None:
Expand All @@ -119,8 +124,12 @@ async def close(
return

log.debug("Closing channel %r", self)
self._closed = True
await self._channel.close()
if not self._closed.done():
self._closed.set_result(True)

def closed(self) -> Awaitable[Literal[True]]:
return self._closed

async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:

Expand Down Expand Up @@ -174,12 +183,13 @@ async def _open(self) -> None:
await channel.close(e)
self._channel = None
raise
self._closed = False
if self._closed.done():
self._closed = asyncio.get_running_loop().create_future()

async def initialize(self, timeout: TimeoutType = None) -> None:
if self.is_initialized:
raise RuntimeError("Already initialized")
elif self._closed:
elif self._closed.done():
raise RuntimeError("Can't initialize closed channel")

await self._open()
Expand All @@ -197,7 +207,10 @@ async def _on_open(self) -> None:
type=ExchangeType.DIRECT,
)

async def _on_close(self, closing: asyncio.Future) -> None:
async def _on_close(
self,
closing: asyncio.Future
) -> Optional[BaseException]:
try:
exc = closing.exception()
except asyncio.CancelledError as e:
Expand All @@ -207,6 +220,15 @@ async def _on_close(self, closing: asyncio.Future) -> None:
if self._channel and self._channel.channel:
self._channel.channel.on_return_callbacks.discard(self._on_return)

return exc

async def _set_closed_callback(
self,
_: AbstractChannel, exc: BaseException
) -> None:
if not self._closed.done():
self._closed.set_result(True)

async def _on_initialized(self) -> None:
channel = await self.get_underlay_channel()
channel.on_return_callbacks.add(self._on_return)
Expand All @@ -219,7 +241,11 @@ async def reopen(self) -> None:
await self._open()

def __del__(self) -> None:
self._closed = True
with contextlib.suppress(AttributeError):
# might raise because an Exception was raised in __init__
if not self._closed.done():
self._closed.set_result(True)

self._channel = None

async def declare_exchange(
Expand Down
20 changes: 13 additions & 7 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
from ssl import SSLContext
from types import TracebackType
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any, Awaitable, Dict, Literal, Optional, Tuple, Type, TypeVar, Union
)


import aiormq.abc
from aiormq.connection import parse_int
Expand Down Expand Up @@ -39,11 +42,11 @@ class Connection(AbstractConnection):
),
)

_closed: bool
_closed: asyncio.Future[Literal[True]]

@property
def is_closed(self) -> bool:
return self._closed
return self._closed.done()

async def close(
self, exc: Optional[aiormq.abc.ExceptionType] = ConnectionClosed,
Expand All @@ -53,7 +56,11 @@ async def close(
if not transport:
return
await transport.close(exc)
self._closed = True
if not self._closed.done():
self._closed.set_result(True)

def closed(self) -> Awaitable[Literal[True]]:
return self._closed

@classmethod
def _parse_parameters(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -74,7 +81,7 @@ def __init__(
):
self.loop = loop or asyncio.get_event_loop()
self.transport = None
self._closed = False
self._closed = self.loop.create_future()
self._close_called = False

self.url = URL(url)
Expand Down Expand Up @@ -201,8 +208,7 @@ async def ready(self) -> None:
def __del__(self) -> None:
if (
self.is_closed or
self.loop.is_closed() or
not hasattr(self, "connection")
self.loop.is_closed()
):
return

Expand Down
77 changes: 64 additions & 13 deletions aio_pika/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import sys
from functools import partial
from types import TracebackType
from typing import Any, Awaitable, Callable, Optional, Type, overload
from typing import (
Any, Awaitable, Callable, Optional, Type, cast, overload
)

import aiormq
from aiormq.abc import DeliveredMessage
Expand Down Expand Up @@ -92,7 +94,6 @@ async def declare(
""" Declare queue.
:param timeout: execution timeout
:param passive: Only check to see if the queue exists.
:return: :class:`None`
"""
log.debug("Declaring queue: %r", self)
Expand Down Expand Up @@ -421,7 +422,24 @@ class QueueIterator(AbstractQueueIterator):
def consumer_tag(self) -> Optional[ConsumerTag]:
return getattr(self, "_consumer_tag", None)

async def close(self, *_: Any) -> Any:
async def close(self) -> None:
await self._on_close(self._amqp_queue.channel, None)
if not self._closed.done():
self._closed.set_result(True)

async def _set_closed_callback(
self,
_channel: AbstractChannel,
exc: Optional[BaseException]
) -> None:
if not self._closed.done():
self._closed.set_result(True)

async def _on_close(
self,
_channel: AbstractChannel,
_exc: Optional[BaseException]
) -> None:
log.debug("Cancelling queue iterator %r", self)

if not hasattr(self, "_consumer_tag"):
Expand All @@ -436,7 +454,7 @@ async def close(self, *_: Any) -> Any:
consumer_tag = self._consumer_tag
del self._consumer_tag

self._amqp_queue.close_callbacks.remove(self.close)
self._amqp_queue.close_callbacks.discard(self._on_close)
await self._amqp_queue.cancel(consumer_tag)

log.debug("Queue iterator %r closed", self)
Expand Down Expand Up @@ -482,9 +500,14 @@ def __init__(self, queue: Queue, **kwargs: Any):
self._consumer_tag: ConsumerTag
self._amqp_queue: AbstractQueue = queue
self._queue = asyncio.Queue()
self._closed = asyncio.get_running_loop().create_future()
self._consume_kwargs = kwargs

self._amqp_queue.close_callbacks.add(self.close)
self._amqp_queue.close_callbacks.add(self._on_close, weak=True)
self._amqp_queue.close_callbacks.add(
self._set_closed_callback,
weak=True
)

async def on_message(self, message: AbstractIncomingMessage) -> None:
await self._queue.put(message)
Expand All @@ -511,24 +534,52 @@ async def __aexit__(
await self.close()

async def __anext__(self) -> IncomingMessage:
if self._closed.done():
raise StopAsyncIteration

if not hasattr(self, "_consumer_tag"):
await self.consume()

closed = self._closed
closed_channel = cast(
asyncio.Future[Literal[True]], self._amqp_queue.channel.closed()
)
message = asyncio.create_task(
self._queue.get(),
name=f"waiting for message from {self}"
)

timeout = self._consume_kwargs.get("timeout")
pending = {message, closed_channel, closed}

try:
return await asyncio.wait_for(
self._queue.get(),
timeout=self._consume_kwargs.get("timeout"),
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED, timeout=timeout
)
timed_out = not done
except asyncio.CancelledError:
timeout = self._consume_kwargs.get(
"timeout",
self.DEFAULT_CLOSE_TIMEOUT,
)
timed_out = False

if message.done():
return message.result()
else:
message.cancel()

if timed_out:
if timeout is None:
timeout = self.DEFAULT_CLOSE_TIMEOUT
log.info(
"%r closing with timeout %d seconds",
self, timeout,
)
await asyncio.wait_for(self.close(), timeout=timeout)
raise
raise TimeoutError
elif closed.done() or closed_channel.done():
if not self._closed.done():
self._closed.set_result(True)
raise StopAsyncIteration

raise asyncio.CancelledError


__all__ = ("Queue", "QueueIterator", "ConsumerTag")
20 changes: 15 additions & 5 deletions aio_pika/robust_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
self.reopen_callbacks: CallbackCollection = CallbackCollection(self)
self.__restore_lock = asyncio.Lock()
self.__restored = asyncio.Event()
self.close_callbacks.add(self.__close_callback)

self.close_callbacks.remove(self._set_closed_callback)

async def ready(self) -> None:
await self._connection.ready()
Expand All @@ -94,23 +95,32 @@ async def restore(self, channel: Any = None) -> None:
await self.reopen()
self.__restored.set()

async def __close_callback(self, _: Any, exc: BaseException) -> None:
async def _on_close(
self,
closing: asyncio.Future
) -> Optional[BaseException]:
exc = await super()._on_close(closing)

if isinstance(exc, asyncio.CancelledError):
# This happens only if the channel is forced to close from the
# outside, for example, if the connection is closed.
# Of course, here you need to exit from this function
# as soon as possible and to avoid a recovery attempt.
self.__restored.clear()
return
if not self._closed.done():
self._closed.set_result(True)
return exc

in_restore_state = not self.__restored.is_set()
self.__restored.clear()

if self._closed or in_restore_state:
return
if self._closed.done() or in_restore_state:
return exc

await self.restore()

return exc

async def reopen(self) -> None:
await super().reopen()
await self.reopen_callbacks()
Expand Down
Loading

0 comments on commit 9820f08

Please sign in to comment.