diff --git a/aio_pika/connection.py b/aio_pika/connection.py index 03f7ce3d..bfbca1fe 100644 --- a/aio_pika/connection.py +++ b/aio_pika/connection.py @@ -1,7 +1,7 @@ import asyncio import logging from functools import partial -from typing import Callable, Type +from typing import Callable, Type, TypeVar from yarl import URL @@ -9,6 +9,7 @@ from aiormq.tools import censor_url from .channel import Channel from .tools import CallbackCollection +from .types import TimeoutType try: from yarl import DEFAULT_PORTS @@ -103,7 +104,7 @@ def _on_connection_close(self, connection, closing, *args, **kwargs): self.close_callbacks(exc) log.debug("Closing AMQP connection %r", connection) - async def connect(self, timeout: TimeoutError = None): + async def connect(self, timeout: TimeoutType = None): """ Connect to AMQP server. This method should be called after :func:`aio_pika.connection.Connection.__init__` @@ -208,13 +209,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() +ConnectionType = TypeVar('ConnectionType', bound=Connection) + + async def connect( url: str = None, *, host: str = 'localhost', port: int = 5672, login: str = 'guest', password: str = 'guest', virtualhost: str = '/', ssl: bool = False, loop: asyncio.AbstractEventLoop = None, - ssl_options: dict = None, connection_class: Type[Connection] = Connection, - **kwargs - ) -> Connection: + ssl_options: dict = None, timeout: TimeoutType = None, + connection_class: Type[ConnectionType] = Connection, **kwargs + ) -> ConnectionType: """ Make connection to the broker. @@ -262,6 +266,7 @@ async def main(): :param virtualhost: virtualhost parameter. `'/'` by default :param ssl: use SSL for connection. Should be used with addition kwargs. :param ssl_options: A dict of values for the SSL connection. + :param timeout: connection timeout in seconds :param loop: Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param connection_class: Factory of a new connection @@ -290,7 +295,7 @@ async def main(): ) connection = connection_class(url, loop=loop) - await connection.connect() + await connection.connect(timeout=timeout) return connection diff --git a/aio_pika/robust_connection.py b/aio_pika/robust_connection.py index 56e57f82..9feb15f6 100644 --- a/aio_pika/robust_connection.py +++ b/aio_pika/robust_connection.py @@ -1,12 +1,13 @@ import asyncio from functools import wraps from logging import getLogger -from typing import Callable +from typing import Callable, Type from aiormq.connection import parse_bool, parse_int from .exceptions import CONNECTION_EXCEPTIONS -from .connection import Connection, connect +from .connection import Connection, connect, ConnectionType from .tools import CallbackCollection +from .types import TimeoutType from .robust_channel import RobustChannel @@ -75,7 +76,7 @@ def add_reconnect_callback(self, callback: Callable[[], None]): self._on_reconnect_callbacks.add(callback) - async def connect(self, timeout=None): + async def connect(self, timeout: TimeoutType = None): while True: try: return await super().connect(timeout=timeout) @@ -149,13 +150,13 @@ async def close(self, exc=asyncio.CancelledError): return await super().close(exc) -async def connect_robust(url: str = None, *, host: str = 'localhost', - port: int = 5672, login: str = 'guest', - password: str = 'guest', virtualhost: str = '/', - ssl: bool = False, loop=None, - ssl_options: dict = None, - connection_class=RobustConnection, - **kwargs) -> Connection: +async def connect_robust( + url: str = None, *, host: str = 'localhost', port: int = 5672, + login: str = 'guest', password: str = 'guest', virtualhost: str = '/', + ssl: bool = False, loop: asyncio.AbstractEventLoop = None, + ssl_options: dict = None, timeout: TimeoutType = None, + connection_class: Type[ConnectionType] = RobustConnection, **kwargs +) -> ConnectionType: """ Make robust connection to the broker. @@ -207,6 +208,7 @@ async def main(): :param virtualhost: virtualhost parameter. `'/'` by default :param ssl: use SSL for connection. Should be used with addition kwargs. :param ssl_options: A dict of values for the SSL connection. + :param timeout: connection timeout in seconds :param loop: Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param connection_class: Factory of a new connection @@ -222,7 +224,7 @@ async def main(): url=url, host=host, port=port, login=login, password=password, virtualhost=virtualhost, ssl=ssl, loop=loop, connection_class=connection_class, - ssl_options=ssl_options, **kwargs + ssl_options=ssl_options, timeout=timeout, **kwargs ) ) diff --git a/tests/test_connect.py b/tests/test_connect.py index cb3e4d78..2e2b2ab6 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -32,7 +32,7 @@ def __init__(self, url, **kwargs): self.url = URL(url) self.kwargs = kwargs - async def connect(self): + async def connect(self, timeout=None): return