Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace context manager classes to asynccontextmanager decorators #499

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 9 additions & 23 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from enum import Enum, IntEnum, unique
from functools import singledispatch
from types import TracebackType
from typing import (
Any, AsyncContextManager, AsyncIterable, Awaitable, Callable, Dict,
FrozenSet, Generator, Iterator, MutableMapping, NamedTuple, Optional, Set,
Tuple, Type, TypeVar, Union,
Any, AsyncIterable, AsyncIterator, Awaitable, Callable, Dict, FrozenSet,
Generator, Iterator, MutableMapping, NamedTuple, Optional, Set, Tuple, Type,
TypeVar, Union,
)


Expand Down Expand Up @@ -191,13 +192,14 @@ def channel(self) -> aiormq.abc.AbstractChannel:
raise NotImplementedError

@abstractmethod
def process(
@asynccontextmanager
async def process(
self,
requeue: bool = False,
reject_on_redelivered: bool = False,
ignore_processed: bool = False,
) -> "AbstractProcessContext":
raise NotImplementedError
) -> AsyncIterator["AbstractMessage"]:
yield self

@abstractmethod
async def ack(self, multiple: bool = False) -> None:
Expand All @@ -220,21 +222,6 @@ def processed(self) -> bool:
raise NotImplementedError


class AbstractProcessContext(AsyncContextManager):
@abstractmethod
async def __aenter__(self) -> AbstractIncomingMessage:
raise NotImplementedError

@abstractmethod
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
raise NotImplementedError


class AbstractQueue:
channel: aiormq.abc.AbstractChannel
name: str
Expand Down Expand Up @@ -467,7 +454,7 @@ async def close(self, exc: Optional[ExceptionType] = None) -> Any:

self.channel.closing.remove_done_callback(self.close_callback)
self.channel.connection.closing.remove_done_callback(
self.close_callback
self.close_callback,
)


Expand Down Expand Up @@ -871,7 +858,6 @@ def _get_exchange_name_from_str(value: str) -> str:
"AbstractExchange",
"AbstractIncomingMessage",
"AbstractMessage",
"AbstractProcessContext",
"AbstractQueue",
"AbstractQueueIterator",
"AbstractRobustChannel",
Expand Down
105 changes: 40 additions & 65 deletions aio_pika/message.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
import time
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from functools import singledispatch
from pprint import pformat
from types import TracebackType
from typing import (
Any, Callable, Dict, Iterable, Iterator, List, MutableMapping, Optional,
Type, TypeVar, Union,
Any, AsyncIterator, Callable, Dict, Iterable, Iterator, List,
MutableMapping, Optional, Type, TypeVar, Union,
)

import aiormq
Expand All @@ -15,8 +15,8 @@

from .abc import (
MILLISECONDS, ZERO_TIME, AbstractChannel, AbstractIncomingMessage,
AbstractMessage, AbstractProcessContext, DateType, DeliveryMode,
HeadersPythonValues, HeadersType, NoneType,
AbstractMessage, DateType, DeliveryMode, HeadersPythonValues, HeadersType,
NoneType,
)
from .exceptions import MessageProcessError
from .log import get_logger
Expand Down Expand Up @@ -511,12 +511,13 @@ def __init__(self, message: DeliveredMessage, no_ack: bool = False):
def channel(self) -> aiormq.abc.AbstractChannel:
return self.__channel

def process(
@asynccontextmanager
async def process(
self,
requeue: bool = False,
reject_on_redelivered: bool = False,
ignore_processed: bool = False,
) -> AbstractProcessContext:
) -> AsyncIterator["Message"]:
""" Context manager for processing the message

>>> async def on_message_received(message: IncomingMessage):
Expand All @@ -542,15 +543,39 @@ def process(
:param ignore_processed: Do nothing if message already processed

"""
return ProcessContext(
self,
requeue=requeue,
reject_on_redelivered=reject_on_redelivered,
ignore_processed=ignore_processed,
)
# noinspection PyBroadException
try:
yield self
except BaseException:
if not ignore_processed or not self.processed:
if reject_on_redelivered and self.redelivered:
if not self.channel.is_closed:
log.info(
"Message %r was redelivered and will be rejected",
self,
)
await self.reject(requeue=False)
raise

log.warning(
"Message %r was redelivered and reject is not sent "
"since channel is closed",
self,
)
raise

if not self.channel.is_closed:
await self.reject(requeue=requeue)
raise

log.warning("Reject is not sent since channel is closed")
raise
else:
if not ignore_processed or not self.processed:
await self.ack()

async def ack(self, multiple: bool = False) -> None:
""" Send basic.ack is used for positive acknowledgements
""" Send Basic.Ack frame is used for positive acknowledgements

.. note::
This method looks like a blocking-method, but actually it just
Expand Down Expand Up @@ -581,7 +606,7 @@ async def ack(self, multiple: bool = False) -> None:

async def reject(self, requeue: bool = False) -> None:
""" When `requeue=True` the message will be returned to queue.
Otherwise message will be dropped.
Otherwise, message will be dropped.

.. note::
This method looks like a blocking-method, but actually it just
Expand Down Expand Up @@ -655,54 +680,4 @@ class ReturnedMessage(IncomingMessage):
ReturnCallback = Callable[[AbstractChannel, ReturnedMessage], Any]


class ProcessContext(AbstractProcessContext):
def __init__(
self,
message: IncomingMessage,
*,
requeue: bool,
reject_on_redelivered: bool,
ignore_processed: bool
):
self.message = message
self.requeue = requeue
self.reject_on_redelivered = reject_on_redelivered
self.ignore_processed = ignore_processed

async def __aenter__(self) -> IncomingMessage:
return self.message

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if not exc_type:
if not self.ignore_processed or not self.message.processed:
await self.message.ack()

return

if not self.ignore_processed or not self.message.processed:
if self.reject_on_redelivered and self.message.redelivered:
if not self.message.channel.is_closed:
log.info(
"Message %r was redelivered and will be rejected",
self.message,
)
await self.message.reject(requeue=False)
return
log.warning(
"Message %r was redelivered and reject is not sent "
"since channel is closed",
self.message,
)
else:
if not self.message.channel.is_closed:
await self.message.reject(requeue=self.requeue)
return
log.warning("Reject is not sent since channel is closed")


__all__ = "Message", "IncomingMessage", "ReturnedMessage",
38 changes: 12 additions & 26 deletions aio_pika/pool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import abc
import asyncio
from contextlib import asynccontextmanager
from types import TracebackType
from typing import (
Any, AsyncContextManager, Awaitable, Callable, Coroutine, Generic, Optional,
Set, Tuple, Type, TypeVar, Union,
Any, AsyncIterator, Awaitable, Callable, Coroutine, Generic, Optional, Set,
Tuple, Type, TypeVar, Union,
)

from aiormq.tools import awaitable
Expand Down Expand Up @@ -72,11 +73,18 @@ def __init__(
def is_closed(self) -> bool:
return self.__closed

def acquire(self) -> "PoolItemContextManager[T]":
@asynccontextmanager
async def acquire(self) -> AsyncIterator[T]:
if self.__closed:
raise PoolInvalidStateError("acquire operation on closed pool")

return PoolItemContextManager[T](self)
item: T = await self._get()

try:
yield item
finally:
if item is not None:
self.put(item)

@property
def _has_released(self) -> bool:
Expand Down Expand Up @@ -141,25 +149,3 @@ async def __aexit__(
return

await asyncio.ensure_future(self.close())


class PoolItemContextManager(Generic[T], AsyncContextManager):
__slots__ = "pool", "item"

def __init__(self, pool: Pool):
self.pool = pool
self.item: T

async def __aenter__(self) -> T:
# noinspection PyProtectedMember
self.item = await self.pool._get()
return self.item

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if self.item is not None:
self.pool.put(self.item)