Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add global timeout for all operations within connection #259

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .exchange import Exchange, ExchangeType
from .message import IncomingMessage
from .queue import Queue
from .tools import OPERATION_TIMEOUT
from .transaction import Transaction
from .types import ReturnCallbackType, CloseCallbackType, TimeoutType

Expand Down Expand Up @@ -107,6 +108,12 @@ def channel(self) -> aiormq.Channel:
def number(self):
return self.channel.number if self._channel else None

def _get_operation_timeout(self, timeout: TimeoutType):
return (
self._connection.operation_timeout if timeout is OPERATION_TIMEOUT
else timeout
)

def __str__(self):
return "{0}".format(
self.number or "Not initialized channel"
Expand Down Expand Up @@ -159,12 +166,14 @@ async def _create_channel(self) -> aiormq.Channel:
channel_number=self._channel_number,
)

async def initialize(self, timeout: TimeoutType = None) -> None:
async def initialize(self,
timeout: TimeoutType = OPERATION_TIMEOUT) -> None:
if self._channel is not None:
raise RuntimeError("Can't initialize channel")

self._channel = await asyncio.wait_for(
self._create_channel(), timeout=timeout
self._create_channel(),
timeout=self._get_operation_timeout(timeout)
)

self._delivery_tag = 0
Expand All @@ -190,7 +199,7 @@ async def declare_exchange(
self, name: str, type: Union[ExchangeType, str] = ExchangeType.DIRECT,
durable: bool = None, auto_delete: bool = False,
internal: bool = False, passive: bool = False, arguments: dict = None,
timeout: TimeoutType = None
timeout: TimeoutType = OPERATION_TIMEOUT
) -> Exchange:
"""
Declare an exchange.
Expand Down Expand Up @@ -228,7 +237,7 @@ async def declare_queue(
self, name: str = None, *, durable: bool = None,
exclusive: bool = False, passive: bool = False,
auto_delete: bool = False, arguments: dict = None,
timeout: TimeoutType = None
timeout: TimeoutType = OPERATION_TIMEOUT
) -> Queue:

"""
Expand All @@ -248,7 +257,7 @@ async def declare_queue(
"""

queue = self.QUEUE_CLASS(
connection=self,
connection=self._connection,
channel=self.channel,
name=name,
durable=durable,
Expand All @@ -264,7 +273,7 @@ async def declare_queue(

async def set_qos(
self, prefetch_count: int = 0, prefetch_size: int = 0,
global_: bool = False, timeout: TimeoutType = None,
global_: bool = False, timeout: TimeoutType = OPERATION_TIMEOUT,
all_channels: bool = None
) -> aiormq.spec.Basic.QosOk:
if all_channels is not None:
Expand All @@ -277,11 +286,11 @@ async def set_qos(
prefetch_size=prefetch_size,
global_=global_
),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)

async def queue_delete(
self, queue_name: str, timeout: TimeoutType = None,
self, queue_name: str, timeout: TimeoutType = OPERATION_TIMEOUT,
if_unused: bool = False, if_empty: bool = False, nowait: bool = False
) -> aiormq.spec.Queue.DeleteOk:

Expand All @@ -292,11 +301,11 @@ async def queue_delete(
if_empty=if_empty,
nowait=nowait,
),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)

async def exchange_delete(
self, exchange_name: str, timeout: TimeoutType = None,
self, exchange_name: str, timeout: TimeoutType = OPERATION_TIMEOUT,
if_unused: bool = False, nowait: bool = False
) -> aiormq.spec.Exchange.DeleteOk:

Expand All @@ -306,15 +315,15 @@ async def exchange_delete(
if_unused=if_unused,
nowait=nowait,
),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)

def transaction(self) -> Transaction:
if self._publisher_confirms:
raise RuntimeError("Cannot create transaction when publisher "
"confirms are enabled")

return Transaction(self._channel)
return Transaction(connection=self._connection, channel=self._channel)

async def flow(self, active: bool = True) -> aiormq.spec.Channel.FlowOk:
return await self.channel.flow(active=active)
Expand Down
12 changes: 10 additions & 2 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def _parse_kwargs(cls, kwargs):
result[key] = parser(kwargs.get(key, default))
return result

def __init__(self, url, loop=None, **kwargs):
def __init__(self, url, operation_timeout: TimeoutType = None, loop=None,
**kwargs):
self.loop = loop or asyncio.get_event_loop()
self.url = URL(url)
self.operation_timeout = operation_timeout

self.kwargs = self._parse_kwargs(kwargs or self.url.query)

Expand Down Expand Up @@ -219,6 +221,7 @@ async def connect(
login: str = 'guest', password: str = 'guest', virtualhost: str = '/',
ssl: bool = False, loop: asyncio.AbstractEventLoop = None,
ssl_options: dict = None, timeout: TimeoutType = None,
operation_timeout: TimeoutType = None,
connection_class: Type[ConnectionType] = Connection,
client_properties: dict = None, **kwargs
) -> ConnectionType:
Expand Down Expand Up @@ -291,6 +294,7 @@ async def main():
:param ssl: use SSL for connection. Should be used with addition kwargs.
:param ssl_options: A dict of values for the SSL connection.
:param timeout: connection timeout in seconds
:param operation_timeout: execution timeout in seconds
:param loop:
Event loop (:func:`asyncio.get_event_loop()` when :class:`None`)
:param connection_class: Factory of a new connection
Expand Down Expand Up @@ -318,7 +322,11 @@ async def main():
query=kw
)

connection = connection_class(url, loop=loop)
connection = connection_class(
url,
operation_timeout=operation_timeout,
loop=loop
)

await connection.connect(
timeout=timeout, client_properties=client_properties, loop=loop
Expand Down
28 changes: 18 additions & 10 deletions aio_pika/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import aiormq
from .message import Message
from .tools import OPERATION_TIMEOUT
from .types import ExchangeType as ExchangeType_, TimeoutType


Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(self, connection, channel: aiormq.Channel, name: str,
if not arguments:
arguments = {}

self._connection = connection
self._channel = channel
self.__type = type.value if isinstance(type, ExchangeType) else type
self.name = name
Expand All @@ -52,6 +54,12 @@ def channel(self) -> aiormq.Channel:

return self._channel

def _get_operation_timeout(self, timeout: TimeoutType):
return (
self._connection.operation_timeout if timeout is OPERATION_TIMEOUT
else timeout
)

def __str__(self):
return self.name

Expand All @@ -61,7 +69,7 @@ def __repr__(self):
)

async def declare(
self, timeout: TimeoutType = None
self, timeout: TimeoutType = OPERATION_TIMEOUT
) -> aiormq.spec.Exchange.DeclareOk:
return await asyncio.wait_for(self.channel.exchange_declare(
self.name,
Expand All @@ -71,7 +79,7 @@ async def declare(
internal=self.internal,
passive=self.passive,
arguments=self.arguments,
), timeout=timeout)
), timeout=self._get_operation_timeout(timeout))

@staticmethod
def _get_exchange_name(exchange: ExchangeType_):
Expand All @@ -85,7 +93,7 @@ def _get_exchange_name(exchange: ExchangeType_):

async def bind(
self, exchange: ExchangeType_, routing_key: str = '', *,
arguments: dict = None, timeout: TimeoutType = None
arguments: dict = None, timeout: TimeoutType = OPERATION_TIMEOUT
) -> aiormq.spec.Exchange.BindOk:

""" A binding can also be a relationship between two exchanges.
Expand Down Expand Up @@ -133,12 +141,12 @@ async def bind(
destination=self.name,
routing_key=routing_key,
source=self._get_exchange_name(exchange),
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

async def unbind(
self, exchange: ExchangeType_, routing_key: str = '',
arguments: dict = None, timeout: TimeoutType = None
arguments: dict = None, timeout: TimeoutType = OPERATION_TIMEOUT
) -> aiormq.spec.Exchange.UnbindOk:

""" Remove exchange-to-exchange binding for this
Expand All @@ -163,12 +171,12 @@ async def unbind(
destination=self.name,
routing_key=routing_key,
source=self._get_exchange_name(exchange),
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

async def publish(
self, message: Message, routing_key, *, mandatory: bool = True,
immediate: bool = False, timeout: TimeoutType = None
immediate: bool = False, timeout: TimeoutType = OPERATION_TIMEOUT
) -> Optional[aiormq.types.ConfirmationFrameType]:

""" Publish the message to the queue. `aio-pika` uses
Expand Down Expand Up @@ -197,11 +205,11 @@ async def publish(
properties=message.properties,
mandatory=mandatory,
immediate=immediate
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

async def delete(
self, if_unused: bool = False, timeout: TimeoutType = None
self, if_unused: bool = False, timeout: TimeoutType = OPERATION_TIMEOUT
) -> aiormq.spec.Exchange.DeleteOk:

""" Delete the queue
Expand All @@ -213,7 +221,7 @@ async def delete(
log.info("Deleting %r", self)
return await asyncio.wait_for(
self.channel.exchange_delete(self.name, if_unused=if_unused),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)


Expand Down
Loading