diff --git a/aio_pika/abc.py b/aio_pika/abc.py index cadc91f1..1c5a23b5 100644 --- a/aio_pika/abc.py +++ b/aio_pika/abc.py @@ -1,13 +1,14 @@ import asyncio from abc import ABC, abstractmethod +from contextlib import asynccontextmanager from datetime import datetime, timedelta from enum import Enum, IntEnum, unique from functools import singledispatch from types import TracebackType from typing import ( - Any, AsyncContextManager, AsyncIterable, Awaitable, Callable, Dict, - FrozenSet, Generator, Iterator, MutableMapping, NamedTuple, Optional, Set, - Tuple, Type, TypeVar, Union, + Any, AsyncIterable, AsyncIterator, Awaitable, Callable, Dict, FrozenSet, + Generator, Iterator, MutableMapping, NamedTuple, Optional, Set, Tuple, Type, + TypeVar, Union, ) @@ -191,13 +192,14 @@ def channel(self) -> aiormq.abc.AbstractChannel: raise NotImplementedError @abstractmethod - def process( + @asynccontextmanager + async def process( self, requeue: bool = False, reject_on_redelivered: bool = False, ignore_processed: bool = False, - ) -> "AbstractProcessContext": - raise NotImplementedError + ) -> AsyncIterator["AbstractMessage"]: + yield self @abstractmethod async def ack(self, multiple: bool = False) -> None: @@ -220,21 +222,6 @@ def processed(self) -> bool: raise NotImplementedError -class AbstractProcessContext(AsyncContextManager): - @abstractmethod - async def __aenter__(self) -> AbstractIncomingMessage: - raise NotImplementedError - - @abstractmethod - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - raise NotImplementedError - - class AbstractQueue: channel: aiormq.abc.AbstractChannel name: str @@ -467,7 +454,7 @@ async def close(self, exc: Optional[ExceptionType] = None) -> Any: self.channel.closing.remove_done_callback(self.close_callback) self.channel.connection.closing.remove_done_callback( - self.close_callback + self.close_callback, ) @@ -871,7 +858,6 @@ def _get_exchange_name_from_str(value: str) -> str: "AbstractExchange", "AbstractIncomingMessage", "AbstractMessage", - "AbstractProcessContext", "AbstractQueue", "AbstractQueueIterator", "AbstractRobustChannel", diff --git a/aio_pika/message.py b/aio_pika/message.py index ceef0374..a2beab02 100644 --- a/aio_pika/message.py +++ b/aio_pika/message.py @@ -1,12 +1,12 @@ import json import time +from contextlib import asynccontextmanager from datetime import datetime, timedelta from functools import singledispatch from pprint import pformat -from types import TracebackType from typing import ( - Any, Callable, Dict, Iterable, Iterator, List, MutableMapping, Optional, - Type, TypeVar, Union, + Any, AsyncIterator, Callable, Dict, Iterable, Iterator, List, + MutableMapping, Optional, Type, TypeVar, Union, ) import aiormq @@ -15,8 +15,8 @@ from .abc import ( MILLISECONDS, ZERO_TIME, AbstractChannel, AbstractIncomingMessage, - AbstractMessage, AbstractProcessContext, DateType, DeliveryMode, - HeadersPythonValues, HeadersType, NoneType, + AbstractMessage, DateType, DeliveryMode, HeadersPythonValues, HeadersType, + NoneType, ) from .exceptions import MessageProcessError from .log import get_logger @@ -511,12 +511,13 @@ def __init__(self, message: DeliveredMessage, no_ack: bool = False): def channel(self) -> aiormq.abc.AbstractChannel: return self.__channel - def process( + @asynccontextmanager + async def process( self, requeue: bool = False, reject_on_redelivered: bool = False, ignore_processed: bool = False, - ) -> AbstractProcessContext: + ) -> AsyncIterator["Message"]: """ Context manager for processing the message >>> async def on_message_received(message: IncomingMessage): @@ -542,15 +543,39 @@ def process( :param ignore_processed: Do nothing if message already processed """ - return ProcessContext( - self, - requeue=requeue, - reject_on_redelivered=reject_on_redelivered, - ignore_processed=ignore_processed, - ) + # noinspection PyBroadException + try: + yield self + except BaseException: + if not ignore_processed or not self.processed: + if reject_on_redelivered and self.redelivered: + if not self.channel.is_closed: + log.info( + "Message %r was redelivered and will be rejected", + self, + ) + await self.reject(requeue=False) + raise + + log.warning( + "Message %r was redelivered and reject is not sent " + "since channel is closed", + self, + ) + raise + + if not self.channel.is_closed: + await self.reject(requeue=requeue) + raise + + log.warning("Reject is not sent since channel is closed") + raise + else: + if not ignore_processed or not self.processed: + await self.ack() async def ack(self, multiple: bool = False) -> None: - """ Send basic.ack is used for positive acknowledgements + """ Send Basic.Ack frame is used for positive acknowledgements .. note:: This method looks like a blocking-method, but actually it just @@ -581,7 +606,7 @@ async def ack(self, multiple: bool = False) -> None: async def reject(self, requeue: bool = False) -> None: """ When `requeue=True` the message will be returned to queue. - Otherwise message will be dropped. + Otherwise, message will be dropped. .. note:: This method looks like a blocking-method, but actually it just @@ -655,54 +680,4 @@ class ReturnedMessage(IncomingMessage): ReturnCallback = Callable[[AbstractChannel, ReturnedMessage], Any] -class ProcessContext(AbstractProcessContext): - def __init__( - self, - message: IncomingMessage, - *, - requeue: bool, - reject_on_redelivered: bool, - ignore_processed: bool - ): - self.message = message - self.requeue = requeue - self.reject_on_redelivered = reject_on_redelivered - self.ignore_processed = ignore_processed - - async def __aenter__(self) -> IncomingMessage: - return self.message - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - if not exc_type: - if not self.ignore_processed or not self.message.processed: - await self.message.ack() - - return - - if not self.ignore_processed or not self.message.processed: - if self.reject_on_redelivered and self.message.redelivered: - if not self.message.channel.is_closed: - log.info( - "Message %r was redelivered and will be rejected", - self.message, - ) - await self.message.reject(requeue=False) - return - log.warning( - "Message %r was redelivered and reject is not sent " - "since channel is closed", - self.message, - ) - else: - if not self.message.channel.is_closed: - await self.message.reject(requeue=self.requeue) - return - log.warning("Reject is not sent since channel is closed") - - __all__ = "Message", "IncomingMessage", "ReturnedMessage", diff --git a/aio_pika/pool.py b/aio_pika/pool.py index 93fc4080..e3b2e044 100644 --- a/aio_pika/pool.py +++ b/aio_pika/pool.py @@ -1,9 +1,10 @@ import abc import asyncio +from contextlib import asynccontextmanager from types import TracebackType from typing import ( - Any, AsyncContextManager, Awaitable, Callable, Coroutine, Generic, Optional, - Set, Tuple, Type, TypeVar, Union, + Any, AsyncIterator, Awaitable, Callable, Coroutine, Generic, Optional, Set, + Tuple, Type, TypeVar, Union, ) from aiormq.tools import awaitable @@ -72,11 +73,18 @@ def __init__( def is_closed(self) -> bool: return self.__closed - def acquire(self) -> "PoolItemContextManager[T]": + @asynccontextmanager + async def acquire(self) -> AsyncIterator[T]: if self.__closed: raise PoolInvalidStateError("acquire operation on closed pool") - return PoolItemContextManager[T](self) + item: T = await self._get() + + try: + yield item + finally: + if item is not None: + self.put(item) @property def _has_released(self) -> bool: @@ -141,25 +149,3 @@ async def __aexit__( return await asyncio.ensure_future(self.close()) - - -class PoolItemContextManager(Generic[T], AsyncContextManager): - __slots__ = "pool", "item" - - def __init__(self, pool: Pool): - self.pool = pool - self.item: T - - async def __aenter__(self) -> T: - # noinspection PyProtectedMember - self.item = await self.pool._get() - return self.item - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - if self.item is not None: - self.pool.put(self.item)