Skip to content

Commit

Permalink
Add missing connection timeout parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
decaz authored and mosquito committed Sep 4, 2019
1 parent 83bf56d commit 243e0c6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
17 changes: 11 additions & 6 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import asyncio
import logging
from functools import partial
from typing import Callable, Type
from typing import Callable, Type, TypeVar

from yarl import URL

import aiormq
from aiormq.tools import censor_url
from .channel import Channel
from .tools import CallbackCollection
from .types import TimeoutType

try:
from yarl import DEFAULT_PORTS
Expand Down Expand Up @@ -103,7 +104,7 @@ def _on_connection_close(self, connection, closing, *args, **kwargs):
self.close_callbacks(exc)
log.debug("Closing AMQP connection %r", connection)

async def connect(self, timeout: TimeoutError = None):
async def connect(self, timeout: TimeoutType = None):
""" Connect to AMQP server. This method should be called after
:func:`aio_pika.connection.Connection.__init__`
Expand Down Expand Up @@ -208,13 +209,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()


ConnectionType = TypeVar('ConnectionType', bound=Connection)


async def connect(
url: str = None, *, host: str = 'localhost', port: int = 5672,
login: str = 'guest', password: str = 'guest', virtualhost: str = '/',
ssl: bool = False, loop: asyncio.AbstractEventLoop = None,
ssl_options: dict = None, connection_class: Type[Connection] = Connection,
**kwargs
) -> Connection:
ssl_options: dict = None, timeout: TimeoutType = None,
connection_class: Type[ConnectionType] = Connection, **kwargs
) -> ConnectionType:

""" Make connection to the broker.
Expand Down Expand Up @@ -262,6 +266,7 @@ async def main():
:param virtualhost: virtualhost parameter. `'/'` by default
:param ssl: use SSL for connection. Should be used with addition kwargs.
:param ssl_options: A dict of values for the SSL connection.
:param timeout: connection timeout in seconds
:param loop:
Event loop (:func:`asyncio.get_event_loop()` when :class:`None`)
:param connection_class: Factory of a new connection
Expand Down Expand Up @@ -290,7 +295,7 @@ async def main():
)

connection = connection_class(url, loop=loop)
await connection.connect()
await connection.connect(timeout=timeout)
return connection


Expand Down
24 changes: 13 additions & 11 deletions aio_pika/robust_connection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
from functools import wraps
from logging import getLogger
from typing import Callable
from typing import Callable, Type

from aiormq.connection import parse_bool, parse_int
from .exceptions import CONNECTION_EXCEPTIONS
from .connection import Connection, connect
from .connection import Connection, connect, ConnectionType
from .tools import CallbackCollection
from .types import TimeoutType
from .robust_channel import RobustChannel


Expand Down Expand Up @@ -75,7 +76,7 @@ def add_reconnect_callback(self, callback: Callable[[], None]):

self._on_reconnect_callbacks.add(callback)

async def connect(self, timeout=None):
async def connect(self, timeout: TimeoutType = None):
while True:
try:
return await super().connect(timeout=timeout)
Expand Down Expand Up @@ -149,13 +150,13 @@ async def close(self, exc=asyncio.CancelledError):
return await super().close(exc)


async def connect_robust(url: str = None, *, host: str = 'localhost',
port: int = 5672, login: str = 'guest',
password: str = 'guest', virtualhost: str = '/',
ssl: bool = False, loop=None,
ssl_options: dict = None,
connection_class=RobustConnection,
**kwargs) -> Connection:
async def connect_robust(
url: str = None, *, host: str = 'localhost', port: int = 5672,
login: str = 'guest', password: str = 'guest', virtualhost: str = '/',
ssl: bool = False, loop: asyncio.AbstractEventLoop = None,
ssl_options: dict = None, timeout: TimeoutType = None,
connection_class: Type[ConnectionType] = RobustConnection, **kwargs
) -> ConnectionType:

""" Make robust connection to the broker.
Expand Down Expand Up @@ -207,6 +208,7 @@ async def main():
:param virtualhost: virtualhost parameter. `'/'` by default
:param ssl: use SSL for connection. Should be used with addition kwargs.
:param ssl_options: A dict of values for the SSL connection.
:param timeout: connection timeout in seconds
:param loop:
Event loop (:func:`asyncio.get_event_loop()` when :class:`None`)
:param connection_class: Factory of a new connection
Expand All @@ -222,7 +224,7 @@ async def main():
url=url, host=host, port=port, login=login,
password=password, virtualhost=virtualhost, ssl=ssl,
loop=loop, connection_class=connection_class,
ssl_options=ssl_options, **kwargs
ssl_options=ssl_options, timeout=timeout, **kwargs
)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, url, **kwargs):
self.url = URL(url)
self.kwargs = kwargs

async def connect(self):
async def connect(self, timeout=None):
return


Expand Down

0 comments on commit 243e0c6

Please sign in to comment.