Skip to content

Commit

Permalink
QueueIterator raises StopAsyncIteration when channel is closed.
Browse files Browse the repository at this point in the history
  • Loading branch information
Darsstar committed Sep 13, 2024
1 parent eb5990e commit 296bab9
Show file tree
Hide file tree
Showing 15 changed files with 481 additions and 243 deletions.
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[run]
omit = aio_pika/compat.py
branch = True

[report]
exclude_lines =
pragma: no cover
raise NotImplementedError
12 changes: 10 additions & 2 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,14 @@ def iterator(self, **kwargs: Any) -> "AbstractQueueIterator":
raise NotImplementedError


class AbstractQueueIterator(AsyncIterable):
class AbstractQueueIterator(AsyncIterable[AbstractIncomingMessage]):
_amqp_queue: AbstractQueue
_queue: asyncio.Queue
_consumer_tag: ConsumerTag
_consume_kwargs: Dict[str, Any]

@abstractmethod
def close(self, *_: Any) -> Awaitable[Any]:
def close(self) -> Awaitable[Any]:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -531,6 +531,10 @@ def is_closed(self) -> bool:
def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[None]:
raise NotImplementedError

@abstractmethod
def closed(self) -> Awaitable[Literal[True]]:
raise NotImplementedError

@abstractmethod
async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:
raise NotImplementedError
Expand Down Expand Up @@ -741,6 +745,10 @@ def is_closed(self) -> bool:
async def close(self, exc: ExceptionType = asyncio.CancelledError) -> None:
raise NotImplementedError

@abstractmethod
def closed(self) -> Awaitable[Literal[True]]:
raise NotImplementedError

@abstractmethod
async def connect(self, timeout: TimeoutType = None) -> None:
raise NotImplementedError
Expand Down
46 changes: 37 additions & 9 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import asyncio
import contextlib
import warnings
from abc import ABC
from types import TracebackType
from typing import Any, AsyncContextManager, Generator, Optional, Type, Union
from typing import (
Any, AsyncContextManager, Awaitable, Generator, Literal, Optional, Type,
Union,
)
from warnings import warn

import aiormq
Expand Down Expand Up @@ -77,8 +81,9 @@ def __init__(

self._connection: AbstractConnection = connection

# That's means user closed channel instance explicitly
self._closed: bool = False
self._closed: asyncio.Future = (
asyncio.get_running_loop().create_future()
)

self._channel: Optional[UnderlayChannel] = None
self._channel_number = channel_number
Expand All @@ -89,6 +94,8 @@ def __init__(
self.publisher_confirms = publisher_confirms
self.on_return_raises = on_return_raises

self.close_callbacks.add(self._set_closed_callback)

@property
def is_initialized(self) -> bool:
"""Returns True when the channel has been opened
Expand All @@ -99,7 +106,7 @@ def is_initialized(self) -> bool:
def is_closed(self) -> bool:
"""Returns True when the channel has been closed from the broker
side or after the close() method has been called."""
if not self.is_initialized or self._closed:
if not self.is_initialized or self._closed.done():
return True
channel = self._channel
if channel is None:
Expand All @@ -119,8 +126,12 @@ async def close(
return

log.debug("Closing channel %r", self)
self._closed = True
await self._channel.close()
if not self._closed.done():
self._closed.set_result(True)

def closed(self) -> Awaitable[Literal[True]]:
return self._closed

async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:

Expand Down Expand Up @@ -174,12 +185,13 @@ async def _open(self) -> None:
await channel.close(e)
self._channel = None
raise
self._closed = False
if self._closed.done():
self._closed = asyncio.get_running_loop().create_future()

async def initialize(self, timeout: TimeoutType = None) -> None:
if self.is_initialized:
raise RuntimeError("Already initialized")
elif self._closed:
elif self._closed.done():
raise RuntimeError("Can't initialize closed channel")

await self._open()
Expand All @@ -197,7 +209,10 @@ async def _on_open(self) -> None:
type=ExchangeType.DIRECT,
)

async def _on_close(self, closing: asyncio.Future) -> None:
async def _on_close(
self,
closing: asyncio.Future
) -> Optional[BaseException]:
try:
exc = closing.exception()
except asyncio.CancelledError as e:
Expand All @@ -207,6 +222,15 @@ async def _on_close(self, closing: asyncio.Future) -> None:
if self._channel and self._channel.channel:
self._channel.channel.on_return_callbacks.discard(self._on_return)

return exc

async def _set_closed_callback(
self,
_: AbstractChannel, exc: BaseException
) -> None:
if not self._closed.done():
self._closed.set_result(True)

async def _on_initialized(self) -> None:
channel = await self.get_underlay_channel()
channel.on_return_callbacks.add(self._on_return)
Expand All @@ -219,7 +243,11 @@ async def reopen(self) -> None:
await self._open()

def __del__(self) -> None:
self._closed = True
with contextlib.suppress(AttributeError):
# might raise because an Exception was raised in __init__
if not self._closed.done():
self._closed.set_result(True)

self._channel = None

async def declare_exchange(
Expand Down
20 changes: 13 additions & 7 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
from ssl import SSLContext
from types import TracebackType
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any, Awaitable, Dict, Literal, Optional, Tuple, Type, TypeVar, Union
)


import aiormq.abc
from aiormq.connection import parse_int
Expand Down Expand Up @@ -39,11 +42,11 @@ class Connection(AbstractConnection):
),
)

_closed: bool
_closed: asyncio.Future

@property
def is_closed(self) -> bool:
return self._closed
return self._closed.done()

async def close(
self, exc: Optional[aiormq.abc.ExceptionType] = ConnectionClosed,
Expand All @@ -53,7 +56,11 @@ async def close(
if not transport:
return
await transport.close(exc)
self._closed = True
if not self._closed.done():
self._closed.set_result(True)

def closed(self) -> Awaitable[Literal[True]]:
return self._closed

@classmethod
def _parse_parameters(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -74,7 +81,7 @@ def __init__(
):
self.loop = loop or asyncio.get_event_loop()
self.transport = None
self._closed = False
self._closed = self.loop.create_future()
self._close_called = False

self.url = URL(url)
Expand Down Expand Up @@ -201,8 +208,7 @@ async def ready(self) -> None:
def __del__(self) -> None:
if (
self.is_closed or
self.loop.is_closed() or
not hasattr(self, "connection")
self.loop.is_closed()
):
return

Expand Down
Loading

0 comments on commit 296bab9

Please sign in to comment.