Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QueueIterator raises StopAsyncIteration when iterator/channel is closed. #615

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[run]
omit = aio_pika/compat.py
branch = True

[report]
exclude_lines =
pragma: no cover
raise NotImplementedError
8 changes: 2 additions & 6 deletions aio_pika/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@
from .robust_queue import RobustQueue


try:
from importlib.metadata import Distribution
__version__ = Distribution.from_name("aio-pika").version
except ImportError:
import pkg_resources
__version__ = pkg_resources.get_distribution("aio-pika").version
from importlib.metadata import Distribution
__version__ = Distribution.from_name("aio-pika").version


__all__ = (
Expand Down
53 changes: 34 additions & 19 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import dataclasses
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand All @@ -9,16 +10,10 @@
from types import TracebackType
from typing import (
Any, AsyncContextManager, AsyncIterable, Awaitable, Callable, Dict,
Generator, Iterator, Mapping, Optional, Tuple, Type, TypeVar, Union,
overload,
Generator, Iterator, Literal, Mapping, Optional, Tuple, Type, TypedDict,
TypeVar, Union, overload,
)


if sys.version_info >= (3, 8):
from typing import Literal, TypedDict
else:
from typing_extensions import Literal, TypedDict

import aiormq.abc
from aiormq.abc import ExceptionType
from pamqp.common import Arguments, FieldValue
Expand Down Expand Up @@ -255,7 +250,10 @@ class AbstractQueue:
arguments: Arguments
passive: bool
declaration_result: aiormq.spec.Queue.DeclareOk
close_callbacks: CallbackCollection
close_callbacks: CallbackCollection[
AbstractQueue,
[Optional[BaseException]],
]

@abstractmethod
def __init__(
Expand Down Expand Up @@ -366,14 +364,14 @@ def iterator(self, **kwargs: Any) -> "AbstractQueueIterator":
raise NotImplementedError


class AbstractQueueIterator(AsyncIterable):
class AbstractQueueIterator(AsyncIterable[AbstractIncomingMessage]):
_amqp_queue: AbstractQueue
_queue: asyncio.Queue
_consumer_tag: ConsumerTag
_consume_kwargs: Dict[str, Any]

@abstractmethod
def close(self, *_: Any) -> Awaitable[Any]:
def close(self) -> Awaitable[Any]:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -511,8 +509,14 @@ class AbstractChannel(PoolInstance, ABC):
QUEUE_CLASS: Type[AbstractQueue]
EXCHANGE_CLASS: Type[AbstractExchange]

close_callbacks: CallbackCollection
return_callbacks: CallbackCollection
close_callbacks: CallbackCollection[
AbstractChannel,
[Optional[BaseException]],
]
return_callbacks: CallbackCollection[
AbstractChannel,
[AbstractIncomingMessage],
]
default_exchange: AbstractExchange

publisher_confirms: bool
Expand All @@ -531,6 +535,10 @@ def is_closed(self) -> bool:
def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[None]:
raise NotImplementedError

@abstractmethod
def closed(self) -> Awaitable[Literal[True]]:
raise NotImplementedError

@abstractmethod
async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:
raise NotImplementedError
Expand Down Expand Up @@ -718,7 +726,10 @@ def parse(self, value: Optional[str]) -> Any:
class AbstractConnection(PoolInstance, ABC):
PARAMETERS: Tuple[ConnectionParameter, ...]

close_callbacks: CallbackCollection
close_callbacks: CallbackCollection[
AbstractConnection,
[Optional[BaseException]],
]
connected: asyncio.Event
transport: Optional[UnderlayConnection]
kwargs: Mapping[str, Any]
Expand All @@ -741,6 +752,10 @@ def is_closed(self) -> bool:
async def close(self, exc: ExceptionType = asyncio.CancelledError) -> None:
raise NotImplementedError

@abstractmethod
def closed(self) -> Awaitable[Literal[True]]:
raise NotImplementedError

@abstractmethod
async def connect(self, timeout: TimeoutType = None) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -831,7 +846,7 @@ async def bind(


class AbstractRobustChannel(AbstractChannel):
reopen_callbacks: CallbackCollection
reopen_callbacks: CallbackCollection[AbstractRobustChannel, []]

@abstractmethod
def reopen(self) -> Awaitable[None]:
Expand Down Expand Up @@ -874,7 +889,7 @@ async def declare_queue(


class AbstractRobustConnection(AbstractConnection):
reconnect_callbacks: CallbackCollection
reconnect_callbacks: CallbackCollection[AbstractRobustConnection, []]

@property
@abstractmethod
Expand All @@ -896,10 +911,10 @@ def channel(


ChannelCloseCallback = Callable[
[AbstractChannel, Optional[BaseException]], Any,
[Optional[AbstractChannel], Optional[BaseException]], Any,
]
ConnectionCloseCallback = Callable[
[AbstractConnection, Optional[BaseException]], Any,
[Optional[AbstractConnection], Optional[BaseException]], Any,
]
ConnectionType = TypeVar("ConnectionType", bound=AbstractConnection)

Expand Down
47 changes: 38 additions & 9 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import asyncio
import contextlib
import warnings
from abc import ABC
from types import TracebackType
from typing import Any, AsyncContextManager, Generator, Optional, Type, Union
from typing import (
Any, AsyncContextManager, Awaitable, Generator, Literal, Optional, Type,
Union,
)
from warnings import warn

import aiormq
Expand Down Expand Up @@ -77,8 +81,9 @@ def __init__(

self._connection: AbstractConnection = connection

# That's means user closed channel instance explicitly
self._closed: bool = False
self._closed: asyncio.Future = (
asyncio.get_running_loop().create_future()
)

self._channel: Optional[UnderlayChannel] = None
self._channel_number = channel_number
Expand All @@ -89,6 +94,8 @@ def __init__(
self.publisher_confirms = publisher_confirms
self.on_return_raises = on_return_raises

self.close_callbacks.add(self._set_closed_callback)

@property
def is_initialized(self) -> bool:
"""Returns True when the channel has been opened
Expand All @@ -99,7 +106,7 @@ def is_initialized(self) -> bool:
def is_closed(self) -> bool:
"""Returns True when the channel has been closed from the broker
side or after the close() method has been called."""
if not self.is_initialized or self._closed:
if not self.is_initialized or self._closed.done():
return True
channel = self._channel
if channel is None:
Expand All @@ -119,8 +126,12 @@ async def close(
return

log.debug("Closing channel %r", self)
self._closed = True
await self._channel.close()
if not self._closed.done():
self._closed.set_result(True)

def closed(self) -> Awaitable[Literal[True]]:
return self._closed

async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:

Expand Down Expand Up @@ -174,12 +185,13 @@ async def _open(self) -> None:
await channel.close(e)
self._channel = None
raise
self._closed = False
if self._closed.done():
self._closed = asyncio.get_running_loop().create_future()

async def initialize(self, timeout: TimeoutType = None) -> None:
if self.is_initialized:
raise RuntimeError("Already initialized")
elif self._closed:
elif self._closed.done():
raise RuntimeError("Can't initialize closed channel")

await self._open()
Expand All @@ -197,7 +209,10 @@ async def _on_open(self) -> None:
type=ExchangeType.DIRECT,
)

async def _on_close(self, closing: asyncio.Future) -> None:
async def _on_close(
self,
closing: asyncio.Future
) -> Optional[BaseException]:
try:
exc = closing.exception()
except asyncio.CancelledError as e:
Expand All @@ -207,6 +222,16 @@ async def _on_close(self, closing: asyncio.Future) -> None:
if self._channel and self._channel.channel:
self._channel.channel.on_return_callbacks.discard(self._on_return)

return exc

async def _set_closed_callback(
self,
_: Optional[AbstractChannel],
exc: Optional[BaseException],
) -> None:
if not self._closed.done():
self._closed.set_result(True)

async def _on_initialized(self) -> None:
channel = await self.get_underlay_channel()
channel.on_return_callbacks.add(self._on_return)
Expand All @@ -219,7 +244,11 @@ async def reopen(self) -> None:
await self._open()

def __del__(self) -> None:
self._closed = True
with contextlib.suppress(AttributeError):
# might raise because an Exception was raised in __init__
if not self._closed.done():
self._closed.set_result(True)

self._channel = None

async def declare_exchange(
Expand Down
20 changes: 13 additions & 7 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
from ssl import SSLContext
from types import TracebackType
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any, Awaitable, Dict, Literal, Optional, Tuple, Type, TypeVar, Union
)


import aiormq.abc
from aiormq.connection import parse_int
Expand Down Expand Up @@ -39,11 +42,11 @@ class Connection(AbstractConnection):
),
)

_closed: bool
_closed: asyncio.Future

@property
def is_closed(self) -> bool:
return self._closed
return self._closed.done()

async def close(
self, exc: Optional[aiormq.abc.ExceptionType] = ConnectionClosed,
Expand All @@ -53,7 +56,11 @@ async def close(
if not transport:
return
await transport.close(exc)
self._closed = True
if not self._closed.done():
self._closed.set_result(True)

def closed(self) -> Awaitable[Literal[True]]:
return self._closed

@classmethod
def _parse_parameters(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -74,7 +81,7 @@ def __init__(
):
self.loop = loop or asyncio.get_event_loop()
self.transport = None
self._closed = False
self._closed = self.loop.create_future()
self._close_called = False

self.url = URL(url)
Expand Down Expand Up @@ -201,8 +208,7 @@ async def ready(self) -> None:
def __del__(self) -> None:
if (
self.is_closed or
self.loop.is_closed() or
not hasattr(self, "connection")
Darsstar marked this conversation as resolved.
Show resolved Hide resolved
self.loop.is_closed()
):
return

Expand Down
8 changes: 4 additions & 4 deletions aio_pika/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def encode_expiration_timedelta(value: timedelta) -> str:
return str(int(value.total_seconds() * MILLISECONDS))


@encode_expiration.register(NoneType) # type: ignore
@encode_expiration.register(NoneType)
def encode_expiration_none(_: Any) -> None:
return None

Expand All @@ -62,7 +62,7 @@ def decode_expiration_str(t: str) -> float:
return float(t) / MILLISECONDS


@decode_expiration.register(NoneType) # type: ignore
@decode_expiration.register(NoneType)
def decode_expiration_none(_: Any) -> None:
return None

Expand All @@ -88,7 +88,7 @@ def encode_timestamp_timedelta(value: timedelta) -> datetime:
return datetime.now(tz=timezone.utc) + value


@encode_timestamp.register(NoneType) # type: ignore
@encode_timestamp.register(NoneType)
def encode_timestamp_none(_: Any) -> None:
return None

Expand All @@ -103,7 +103,7 @@ def decode_timestamp_datetime(value: datetime) -> datetime:
return value


@decode_timestamp.register(NoneType) # type: ignore
@decode_timestamp.register(NoneType)
def decode_timestamp_none(_: Any) -> None:
return None

Expand Down
6 changes: 3 additions & 3 deletions aio_pika/patterns/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
AbstractChannel, AbstractExchange, AbstractIncomingMessage, AbstractQueue,
ConsumerTag, DeliveryMode,
)
from aio_pika.message import Message, ReturnedMessage
from aio_pika.message import Message

from ..tools import create_task, ensure_awaitable
from .base import Base, CallbackType, Proxy, T
Expand Down Expand Up @@ -113,8 +113,8 @@ def exchange(self) -> AbstractExchange:

@staticmethod
def on_message_returned(
channel: AbstractChannel,
message: ReturnedMessage,
channel: Optional[AbstractChannel],
message: AbstractIncomingMessage,
) -> None:
log.warning(
"Message returned. Probably destination queue does not exists: %r",
Expand Down
Loading
Loading