diff --git a/aio_pika/channel.py b/aio_pika/channel.py index c49934a6..fc683e5d 100644 --- a/aio_pika/channel.py +++ b/aio_pika/channel.py @@ -106,6 +106,11 @@ def channel(self) -> aiormq.Channel: def number(self): return self.channel.number if self._channel else None + def _get_operation_timeout(self, timeout: TimeoutType = None): + if timeout is not None: + return timeout + return self._connection.operation_timeout + def __str__(self): return "{0}".format( self.number or "Not initialized channel" @@ -163,7 +168,8 @@ async def initialize(self, timeout: TimeoutType = None) -> None: raise RuntimeError("Can't initialize channel") self._channel = await asyncio.wait_for( - self._create_channel(), timeout=timeout + self._create_channel(), + timeout=self._get_operation_timeout(timeout) ) self._delivery_tag = 0 @@ -247,7 +253,7 @@ async def declare_queue( """ queue = self.QUEUE_CLASS( - connection=self, + connection=self._connection, channel=self.channel, name=name, durable=durable, @@ -271,7 +277,7 @@ async def set_qos( prefetch_count=prefetch_count, prefetch_size=prefetch_size ), - timeout=timeout + timeout=self._get_operation_timeout(timeout) ) async def queue_delete( @@ -286,7 +292,7 @@ async def queue_delete( if_empty=if_empty, nowait=nowait, ), - timeout=timeout + timeout=self._get_operation_timeout(timeout) ) async def exchange_delete( @@ -300,7 +306,7 @@ async def exchange_delete( if_unused=if_unused, nowait=nowait, ), - timeout=timeout + timeout=self._get_operation_timeout(timeout) ) def transaction(self) -> Transaction: @@ -308,7 +314,7 @@ def transaction(self) -> Transaction: raise RuntimeError("Cannot create transaction when publisher " "confirms are enabled") - return Transaction(self._channel) + return Transaction(connection=self._connection, channel=self._channel) async def flow(self, active: bool = True) -> aiormq.spec.Channel.FlowOk: return await self.channel.flow(active=active) diff --git a/aio_pika/connection.py b/aio_pika/connection.py index bfbca1fe..1a46fe7c 100644 --- a/aio_pika/connection.py +++ b/aio_pika/connection.py @@ -46,9 +46,11 @@ def _parse_kwargs(cls, kwargs): result[key] = parser(kwargs.get(key, default)) return result - def __init__(self, url, loop=None, **kwargs): + def __init__(self, url, operation_timeout: TimeoutType = None, loop=None, + **kwargs): self.loop = loop or asyncio.get_event_loop() self.url = URL(url) + self.operation_timeout = operation_timeout self.kwargs = self._parse_kwargs(kwargs or self.url.query) @@ -217,6 +219,7 @@ async def connect( login: str = 'guest', password: str = 'guest', virtualhost: str = '/', ssl: bool = False, loop: asyncio.AbstractEventLoop = None, ssl_options: dict = None, timeout: TimeoutType = None, + operation_timeout: TimeoutType = None, connection_class: Type[ConnectionType] = Connection, **kwargs ) -> ConnectionType: @@ -267,6 +270,7 @@ async def main(): :param ssl: use SSL for connection. Should be used with addition kwargs. :param ssl_options: A dict of values for the SSL connection. :param timeout: connection timeout in seconds + :param operation_timeout: execution timeout in seconds :param loop: Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param connection_class: Factory of a new connection @@ -294,7 +298,11 @@ async def main(): query=kw ) - connection = connection_class(url, loop=loop) + connection = connection_class( + url, + operation_timeout=operation_timeout, + loop=loop + ) await connection.connect(timeout=timeout) return connection diff --git a/aio_pika/exchange.py b/aio_pika/exchange.py index 8cc3f829..4c69a96a 100644 --- a/aio_pika/exchange.py +++ b/aio_pika/exchange.py @@ -36,6 +36,7 @@ def __init__(self, connection, channel: aiormq.Channel, name: str, if not arguments: arguments = {} + self._connection = connection self._channel = channel self.__type = type.value self.name = name @@ -52,6 +53,11 @@ def channel(self) -> aiormq.Channel: return self._channel + def _get_operation_timeout(self, timeout: TimeoutType = None): + if timeout is not None: + return timeout + return self._connection.operation_timeout + def __str__(self): return self.name @@ -71,7 +77,7 @@ async def declare( internal=self.internal, passive=self.passive, arguments=self.arguments, - ), timeout=timeout) + ), timeout=self._get_operation_timeout(timeout)) @staticmethod def _get_exchange_name(exchange: ExchangeType_): @@ -133,7 +139,7 @@ async def bind( destination=self.name, routing_key=routing_key, source=self._get_exchange_name(exchange), - ), timeout=timeout + ), timeout=self._get_operation_timeout(timeout) ) async def unbind( @@ -163,7 +169,7 @@ async def unbind( destination=self.name, routing_key=routing_key, source=self._get_exchange_name(exchange), - ), timeout=timeout + ), timeout=self._get_operation_timeout(timeout) ) async def publish( @@ -197,7 +203,7 @@ async def publish( properties=message.properties, mandatory=mandatory, immediate=immediate - ), timeout=timeout + ), timeout=self._get_operation_timeout(timeout) ) async def delete( @@ -213,7 +219,7 @@ async def delete( log.info("Deleting %r", self) return await asyncio.wait_for( self.channel.exchange_delete(self.name, if_unused=if_unused), - timeout=timeout + timeout=self._get_operation_timeout(timeout) ) diff --git a/aio_pika/queue.py b/aio_pika/queue.py index 8c34fcb2..19d2b894 100644 --- a/aio_pika/queue.py +++ b/aio_pika/queue.py @@ -9,7 +9,7 @@ from .exceptions import QueueEmpty from .exchange import Exchange -from aio_pika.types import ExchangeType as ExchangeType_ +from .types import ExchangeType as ExchangeType_, TimeoutType from .message import IncomingMessage from .tools import create_task, shield @@ -36,6 +36,7 @@ def __init__(self, connection, channel: aiormq.Channel, name, self.loop = connection.loop + self._connection = connection self._channel = channel self.name = name or '' self.durable = durable @@ -52,6 +53,11 @@ def channel(self) -> aiormq.Channel: raise RuntimeError("Channel not opened") return self._channel + def _get_operation_timeout(self, timeout: TimeoutType = None): + if timeout is not None: + return timeout + return self._connection.operation_timeout + def __str__(self): return "%s" % self.name @@ -70,7 +76,9 @@ def __repr__(self): self.arguments, ) - async def declare(self, timeout: int=None) -> aiormq.spec.Queue.DeclareOk: + async def declare( + self, timeout: TimeoutType = None + ) -> aiormq.spec.Queue.DeclareOk: """ Declare queue. :param timeout: execution timeout @@ -84,7 +92,7 @@ async def declare(self, timeout: int=None) -> aiormq.spec.Queue.DeclareOk: queue=self.name, durable=self.durable, exclusive=self.exclusive, auto_delete=self.auto_delete, arguments=self.arguments, passive=self.passive, - ), timeout=timeout + ), timeout=self._get_operation_timeout(timeout) ) # type: aiormq.spec.Queue.DeclareOk self.name = self.declaration_result.queue @@ -92,7 +100,7 @@ async def declare(self, timeout: int=None) -> aiormq.spec.Queue.DeclareOk: async def bind( self, exchange: ExchangeType_, routing_key: str=None, *, - arguments=None, timeout: int=None + arguments=None, timeout: TimeoutType = None ) -> aiormq.spec.Queue.BindOk: """ A binding is a relationship between an exchange and a queue. @@ -126,12 +134,12 @@ async def bind( exchange=Exchange._get_exchange_name(exchange), routing_key=routing_key, arguments=arguments - ), timeout=timeout + ), timeout=self._get_operation_timeout(timeout) ) async def unbind( self, exchange: ExchangeType_, routing_key: str=None, - arguments: dict=None, timeout: int=None + arguments: dict=None, timeout: TimeoutType = None ) -> aiormq.spec.Queue.UnbindOk: """ Remove binding from exchange for this :class:`Queue` instance @@ -159,13 +167,13 @@ async def unbind( exchange=Exchange._get_exchange_name(exchange), routing_key=routing_key, arguments=arguments - ), timeout=timeout + ), timeout=self._get_operation_timeout(timeout) ) async def consume( self, callback: Callable[[IncomingMessage], Any], no_ack: bool = False, exclusive: bool = False, arguments: dict = None, - consumer_tag=None, timeout=None + consumer_tag=None, timeout: TimeoutType = None ) -> ConsumerTag: """ Start to consuming the :class:`Queue`. @@ -203,11 +211,13 @@ async def consume( arguments=arguments, consumer_tag=consumer_tag, ), - timeout=timeout + timeout=self._get_operation_timeout(timeout) )).consumer_tag - async def cancel(self, consumer_tag: ConsumerTag, timeout=None, - nowait: bool=False) -> aiormq.spec.Basic.CancelOk: + async def cancel( + self, consumer_tag: ConsumerTag, timeout: TimeoutType = None, + nowait: bool=False + ) -> aiormq.spec.Basic.CancelOk: """ This method cancels a consumer. This does not affect already delivered messages, but it does mean the server will not send any more messages for that consumer. The client may receive an arbitrary number @@ -230,7 +240,7 @@ async def cancel(self, consumer_tag: ConsumerTag, timeout=None, consumer_tag=consumer_tag, nowait=nowait ), - timeout=timeout + timeout=self._get_operation_timeout(timeout) ) async def get( @@ -249,7 +259,7 @@ async def get( msg = await asyncio.wait_for(self.channel.basic_get( self.name, no_ack=no_ack - ), timeout=timeout + ), timeout=self._get_operation_timeout(timeout) ) # type: Optional[DeliveredMessage] if msg is None: @@ -260,7 +270,7 @@ async def get( return IncomingMessage(msg, no_ack=no_ack) async def purge( - self, no_wait=False, timeout=None + self, no_wait=False, timeout: TimeoutType = None ) -> aiormq.spec.Queue.PurgeOk: """ Purge all messages from the queue. @@ -275,11 +285,12 @@ async def purge( self.channel.queue_purge( self.name, nowait=no_wait, - ), timeout=timeout + ), timeout=self._get_operation_timeout(timeout) ) - async def delete(self, *, if_unused=True, if_empty=True, - timeout=None) -> aiormq.spec.Queue.DeclareOk: + async def delete( + self, *, if_unused=True, if_empty=True, timeout: TimeoutType = None + ) -> aiormq.spec.Queue.DeclareOk: """ Delete the queue. @@ -294,7 +305,7 @@ async def delete(self, *, if_unused=True, if_empty=True, return await asyncio.wait_for( self.channel.queue_delete( self.name, if_unused=if_unused, if_empty=if_empty - ), timeout=timeout + ), timeout=self._get_operation_timeout(timeout) ) def __aiter__(self) -> 'QueueIterator': diff --git a/aio_pika/robust_connection.py b/aio_pika/robust_connection.py index 29fc31c0..0dc9e34f 100644 --- a/aio_pika/robust_connection.py +++ b/aio_pika/robust_connection.py @@ -155,6 +155,7 @@ async def connect_robust( login: str = 'guest', password: str = 'guest', virtualhost: str = '/', ssl: bool = False, loop: asyncio.AbstractEventLoop = None, ssl_options: dict = None, timeout: TimeoutType = None, + operation_timeout: TimeoutType = None, connection_class: Type[ConnectionType] = RobustConnection, **kwargs ) -> ConnectionType: @@ -209,6 +210,7 @@ async def main(): :param ssl: use SSL for connection. Should be used with addition kwargs. :param ssl_options: A dict of values for the SSL connection. :param timeout: connection timeout in seconds + :param operation_timeout: execution timeout in seconds :param loop: Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param connection_class: Factory of a new connection @@ -224,7 +226,8 @@ async def main(): url=url, host=host, port=port, login=login, password=password, virtualhost=virtualhost, ssl=ssl, loop=loop, connection_class=connection_class, - ssl_options=ssl_options, timeout=timeout, **kwargs + ssl_options=ssl_options, timeout=timeout, + operation_timeout=operation_timeout, **kwargs ) ) diff --git a/aio_pika/robust_exchange.py b/aio_pika/robust_exchange.py index fb1e5ee1..5b29ca3a 100644 --- a/aio_pika/robust_exchange.py +++ b/aio_pika/robust_exchange.py @@ -5,6 +5,7 @@ from .exchange import Exchange, ExchangeType from .channel import Channel +from .types import TimeoutType log = getLogger(__name__) @@ -42,7 +43,7 @@ async def on_reconnect(self, channel: Channel): await self.bind(exchange, **kwargs) async def bind(self, exchange, routing_key: str='', *, - arguments=None, timeout: int=None): + arguments=None, timeout: TimeoutType = None): result = await super().bind( exchange, routing_key=routing_key, arguments=arguments, timeout=timeout @@ -55,7 +56,7 @@ async def bind(self, exchange, routing_key: str='', *, return result async def unbind(self, exchange, routing_key: str = '', - arguments: dict=None, timeout: int=None): + arguments: dict=None, timeout: TimeoutType = None): result = await super().unbind(exchange, routing_key, arguments=arguments, timeout=timeout) diff --git a/aio_pika/robust_queue.py b/aio_pika/robust_queue.py index ed1c9702..ab595aae 100644 --- a/aio_pika/robust_queue.py +++ b/aio_pika/robust_queue.py @@ -7,7 +7,7 @@ import aiormq from .channel import Channel -from aio_pika.types import ExchangeType as ExchangeType_ +from .types import ExchangeType as ExchangeType_, TimeoutType from .queue import Queue, ConsumerTag log = getLogger(__name__) @@ -56,7 +56,7 @@ async def on_reconnect(self, channel: Channel): await self.consume(consumer_tag=consumer_tag, **kwargs) async def bind(self, exchange: ExchangeType_, routing_key: str=None, *, - arguments=None, timeout: int=None): + arguments=None, timeout: TimeoutType = None): if routing_key is None: routing_key = self.name @@ -74,7 +74,7 @@ async def bind(self, exchange: ExchangeType_, routing_key: str=None, *, return result async def unbind(self, exchange: ExchangeType_, routing_key: str=None, - arguments: dict=None, timeout: int=None): + arguments: dict=None, timeout: TimeoutType = None): if routing_key is None: routing_key = self.name @@ -88,7 +88,8 @@ async def unbind(self, exchange: ExchangeType_, routing_key: str=None, async def consume(self, callback: FunctionType, no_ack: bool=False, exclusive: bool=False, arguments: dict=None, - consumer_tag=None, timeout=None) -> ConsumerTag: + consumer_tag=None, + timeout: TimeoutType = None) -> ConsumerTag: kwargs = dict( callback=callback, @@ -105,8 +106,8 @@ async def consume(self, callback: FunctionType, no_ack: bool=False, return consumer_tag - async def cancel(self, consumer_tag: ConsumerTag, timeout=None, - nowait: bool = False): + async def cancel(self, consumer_tag: ConsumerTag, + timeout: TimeoutType = None, nowait: bool = False): result = await super().cancel(consumer_tag, timeout, nowait) self._consumers.pop(consumer_tag, None) diff --git a/aio_pika/transaction.py b/aio_pika/transaction.py index 50c6fba0..3a1b9f4f 100644 --- a/aio_pika/transaction.py +++ b/aio_pika/transaction.py @@ -3,6 +3,8 @@ import aiormq +from .types import TimeoutType + class TransactionStates(Enum): created = 'created' @@ -15,8 +17,9 @@ class Transaction: def __str__(self): return self.state.value - def __init__(self, channel): - self.loop = channel.loop + def __init__(self, connection, channel): + self.loop = connection.loop + self._connection = connection self._channel = channel self.state = TransactionStates.created # type: TransactionStates @@ -30,24 +33,33 @@ def channel(self) -> aiormq.Channel: return self._channel - async def select(self, timeout=None) -> aiormq.spec.Tx.SelectOk: + def _get_operation_timeout(self, timeout: TimeoutType = None): + if timeout is not None: + return timeout + return self._connection.operation_timeout + + async def select(self, + timeout: TimeoutType = None) -> aiormq.spec.Tx.SelectOk: result = await asyncio.wait_for( - self.channel.tx_select(), timeout=timeout + self.channel.tx_select(), + timeout=self._get_operation_timeout(timeout) ) self.state = TransactionStates.started return result - async def rollback(self, timeout=None): + async def rollback(self, timeout: TimeoutType = None): result = await asyncio.wait_for( - self.channel.tx_rollback(), timeout=timeout + self.channel.tx_rollback(), + timeout=self._get_operation_timeout(timeout) ) self.state = TransactionStates.rolled_back return result - async def commit(self, timeout=None): + async def commit(self, timeout: TimeoutType = None): result = await asyncio.wait_for( - self.channel.tx_commit(), timeout=timeout + self.channel.tx_commit(), + timeout=self._get_operation_timeout(timeout) ) self.state = TransactionStates.commited