Skip to content

Commit

Permalink
Merge pull request #615 from Darsstar/QueueIterator-raises-StopAsyncI…
Browse files Browse the repository at this point in the history
…terator

QueueIterator raises StopAsyncIteration when iterator/channel is closed.
  • Loading branch information
mosquito authored Nov 21, 2024
2 parents 9c0ab3b + 1acd87e commit 001dcce
Show file tree
Hide file tree
Showing 24 changed files with 680 additions and 382 deletions.
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[run]
omit = aio_pika/compat.py
branch = True

[report]
exclude_lines =
pragma: no cover
raise NotImplementedError
8 changes: 2 additions & 6 deletions aio_pika/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@
from .robust_queue import RobustQueue


try:
from importlib.metadata import Distribution
__version__ = Distribution.from_name("aio-pika").version
except ImportError:
import pkg_resources
__version__ = pkg_resources.get_distribution("aio-pika").version
from importlib.metadata import Distribution
__version__ = Distribution.from_name("aio-pika").version


__all__ = (
Expand Down
53 changes: 34 additions & 19 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import dataclasses
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand All @@ -9,16 +10,10 @@
from types import TracebackType
from typing import (
Any, AsyncContextManager, AsyncIterable, Awaitable, Callable, Dict,
Generator, Iterator, Mapping, Optional, Tuple, Type, TypeVar, Union,
overload,
Generator, Iterator, Literal, Mapping, Optional, Tuple, Type, TypedDict,
TypeVar, Union, overload,
)


if sys.version_info >= (3, 8):
from typing import Literal, TypedDict
else:
from typing_extensions import Literal, TypedDict

import aiormq.abc
from aiormq.abc import ExceptionType
from pamqp.common import Arguments, FieldValue
Expand Down Expand Up @@ -255,7 +250,10 @@ class AbstractQueue:
arguments: Arguments
passive: bool
declaration_result: aiormq.spec.Queue.DeclareOk
close_callbacks: CallbackCollection
close_callbacks: CallbackCollection[
AbstractQueue,
[Optional[BaseException]],
]

@abstractmethod
def __init__(
Expand Down Expand Up @@ -366,14 +364,14 @@ def iterator(self, **kwargs: Any) -> "AbstractQueueIterator":
raise NotImplementedError


class AbstractQueueIterator(AsyncIterable):
class AbstractQueueIterator(AsyncIterable[AbstractIncomingMessage]):
_amqp_queue: AbstractQueue
_queue: asyncio.Queue
_consumer_tag: ConsumerTag
_consume_kwargs: Dict[str, Any]

@abstractmethod
def close(self, *_: Any) -> Awaitable[Any]:
def close(self) -> Awaitable[Any]:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -511,8 +509,14 @@ class AbstractChannel(PoolInstance, ABC):
QUEUE_CLASS: Type[AbstractQueue]
EXCHANGE_CLASS: Type[AbstractExchange]

close_callbacks: CallbackCollection
return_callbacks: CallbackCollection
close_callbacks: CallbackCollection[
AbstractChannel,
[Optional[BaseException]],
]
return_callbacks: CallbackCollection[
AbstractChannel,
[AbstractIncomingMessage],
]
default_exchange: AbstractExchange

publisher_confirms: bool
Expand All @@ -531,6 +535,10 @@ def is_closed(self) -> bool:
def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[None]:
raise NotImplementedError

@abstractmethod
def closed(self) -> Awaitable[Literal[True]]:
raise NotImplementedError

@abstractmethod
async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:
raise NotImplementedError
Expand Down Expand Up @@ -718,7 +726,10 @@ def parse(self, value: Optional[str]) -> Any:
class AbstractConnection(PoolInstance, ABC):
PARAMETERS: Tuple[ConnectionParameter, ...]

close_callbacks: CallbackCollection
close_callbacks: CallbackCollection[
AbstractConnection,
[Optional[BaseException]],
]
connected: asyncio.Event
transport: Optional[UnderlayConnection]
kwargs: Mapping[str, Any]
Expand All @@ -741,6 +752,10 @@ def is_closed(self) -> bool:
async def close(self, exc: ExceptionType = asyncio.CancelledError) -> None:
raise NotImplementedError

@abstractmethod
def closed(self) -> Awaitable[Literal[True]]:
raise NotImplementedError

@abstractmethod
async def connect(self, timeout: TimeoutType = None) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -831,7 +846,7 @@ async def bind(


class AbstractRobustChannel(AbstractChannel):
reopen_callbacks: CallbackCollection
reopen_callbacks: CallbackCollection[AbstractRobustChannel, []]

@abstractmethod
def reopen(self) -> Awaitable[None]:
Expand Down Expand Up @@ -874,7 +889,7 @@ async def declare_queue(


class AbstractRobustConnection(AbstractConnection):
reconnect_callbacks: CallbackCollection
reconnect_callbacks: CallbackCollection[AbstractRobustConnection, []]

@property
@abstractmethod
Expand All @@ -896,10 +911,10 @@ def channel(


ChannelCloseCallback = Callable[
[AbstractChannel, Optional[BaseException]], Any,
[Optional[AbstractChannel], Optional[BaseException]], Any,
]
ConnectionCloseCallback = Callable[
[AbstractConnection, Optional[BaseException]], Any,
[Optional[AbstractConnection], Optional[BaseException]], Any,
]
ConnectionType = TypeVar("ConnectionType", bound=AbstractConnection)

Expand Down
47 changes: 38 additions & 9 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import asyncio
import contextlib
import warnings
from abc import ABC
from types import TracebackType
from typing import Any, AsyncContextManager, Generator, Optional, Type, Union
from typing import (
Any, AsyncContextManager, Awaitable, Generator, Literal, Optional, Type,
Union,
)
from warnings import warn

import aiormq
Expand Down Expand Up @@ -77,8 +81,9 @@ def __init__(

self._connection: AbstractConnection = connection

# That's means user closed channel instance explicitly
self._closed: bool = False
self._closed: asyncio.Future = (
asyncio.get_running_loop().create_future()
)

self._channel: Optional[UnderlayChannel] = None
self._channel_number = channel_number
Expand All @@ -89,6 +94,8 @@ def __init__(
self.publisher_confirms = publisher_confirms
self.on_return_raises = on_return_raises

self.close_callbacks.add(self._set_closed_callback)

@property
def is_initialized(self) -> bool:
"""Returns True when the channel has been opened
Expand All @@ -99,7 +106,7 @@ def is_initialized(self) -> bool:
def is_closed(self) -> bool:
"""Returns True when the channel has been closed from the broker
side or after the close() method has been called."""
if not self.is_initialized or self._closed:
if not self.is_initialized or self._closed.done():
return True
channel = self._channel
if channel is None:
Expand All @@ -119,8 +126,12 @@ async def close(
return

log.debug("Closing channel %r", self)
self._closed = True
await self._channel.close()
if not self._closed.done():
self._closed.set_result(True)

def closed(self) -> Awaitable[Literal[True]]:
return self._closed

async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:

Expand Down Expand Up @@ -174,12 +185,13 @@ async def _open(self) -> None:
await channel.close(e)
self._channel = None
raise
self._closed = False
if self._closed.done():
self._closed = asyncio.get_running_loop().create_future()

async def initialize(self, timeout: TimeoutType = None) -> None:
if self.is_initialized:
raise RuntimeError("Already initialized")
elif self._closed:
elif self._closed.done():
raise RuntimeError("Can't initialize closed channel")

await self._open()
Expand All @@ -197,7 +209,10 @@ async def _on_open(self) -> None:
type=ExchangeType.DIRECT,
)

async def _on_close(self, closing: asyncio.Future) -> None:
async def _on_close(
self,
closing: asyncio.Future
) -> Optional[BaseException]:
try:
exc = closing.exception()
except asyncio.CancelledError as e:
Expand All @@ -207,6 +222,16 @@ async def _on_close(self, closing: asyncio.Future) -> None:
if self._channel and self._channel.channel:
self._channel.channel.on_return_callbacks.discard(self._on_return)

return exc

async def _set_closed_callback(
self,
_: Optional[AbstractChannel],
exc: Optional[BaseException],
) -> None:
if not self._closed.done():
self._closed.set_result(True)

async def _on_initialized(self) -> None:
channel = await self.get_underlay_channel()
channel.on_return_callbacks.add(self._on_return)
Expand All @@ -219,7 +244,11 @@ async def reopen(self) -> None:
await self._open()

def __del__(self) -> None:
self._closed = True
with contextlib.suppress(AttributeError):
# might raise because an Exception was raised in __init__
if not self._closed.done():
self._closed.set_result(True)

self._channel = None

async def declare_exchange(
Expand Down
20 changes: 13 additions & 7 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
from ssl import SSLContext
from types import TracebackType
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any, Awaitable, Dict, Literal, Optional, Tuple, Type, TypeVar, Union
)


import aiormq.abc
from aiormq.connection import parse_int
Expand Down Expand Up @@ -39,11 +42,11 @@ class Connection(AbstractConnection):
),
)

_closed: bool
_closed: asyncio.Future

@property
def is_closed(self) -> bool:
return self._closed
return self._closed.done()

async def close(
self, exc: Optional[aiormq.abc.ExceptionType] = ConnectionClosed,
Expand All @@ -53,7 +56,11 @@ async def close(
if not transport:
return
await transport.close(exc)
self._closed = True
if not self._closed.done():
self._closed.set_result(True)

def closed(self) -> Awaitable[Literal[True]]:
return self._closed

@classmethod
def _parse_parameters(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -74,7 +81,7 @@ def __init__(
):
self.loop = loop or asyncio.get_event_loop()
self.transport = None
self._closed = False
self._closed = self.loop.create_future()
self._close_called = False

self.url = URL(url)
Expand Down Expand Up @@ -201,8 +208,7 @@ async def ready(self) -> None:
def __del__(self) -> None:
if (
self.is_closed or
self.loop.is_closed() or
not hasattr(self, "connection")
self.loop.is_closed()
):
return

Expand Down
8 changes: 4 additions & 4 deletions aio_pika/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def encode_expiration_timedelta(value: timedelta) -> str:
return str(int(value.total_seconds() * MILLISECONDS))


@encode_expiration.register(NoneType) # type: ignore
@encode_expiration.register(NoneType)
def encode_expiration_none(_: Any) -> None:
return None

Expand All @@ -62,7 +62,7 @@ def decode_expiration_str(t: str) -> float:
return float(t) / MILLISECONDS


@decode_expiration.register(NoneType) # type: ignore
@decode_expiration.register(NoneType)
def decode_expiration_none(_: Any) -> None:
return None

Expand All @@ -88,7 +88,7 @@ def encode_timestamp_timedelta(value: timedelta) -> datetime:
return datetime.now(tz=timezone.utc) + value


@encode_timestamp.register(NoneType) # type: ignore
@encode_timestamp.register(NoneType)
def encode_timestamp_none(_: Any) -> None:
return None

Expand All @@ -103,7 +103,7 @@ def decode_timestamp_datetime(value: datetime) -> datetime:
return value


@decode_timestamp.register(NoneType) # type: ignore
@decode_timestamp.register(NoneType)
def decode_timestamp_none(_: Any) -> None:
return None

Expand Down
6 changes: 3 additions & 3 deletions aio_pika/patterns/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
AbstractChannel, AbstractExchange, AbstractIncomingMessage, AbstractQueue,
ConsumerTag, DeliveryMode,
)
from aio_pika.message import Message, ReturnedMessage
from aio_pika.message import Message

from ..tools import create_task, ensure_awaitable
from .base import Base, CallbackType, Proxy, T
Expand Down Expand Up @@ -113,8 +113,8 @@ def exchange(self) -> AbstractExchange:

@staticmethod
def on_message_returned(
channel: AbstractChannel,
message: ReturnedMessage,
channel: Optional[AbstractChannel],
message: AbstractIncomingMessage,
) -> None:
log.warning(
"Message returned. Probably destination queue does not exists: %r",
Expand Down
Loading

0 comments on commit 001dcce

Please sign in to comment.