Skip to content

Commit

Permalink
max reconnect attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikita Kharlov authored and harlov committed Nov 15, 2019
1 parent 6498575 commit d95a3cc
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 6 deletions.
2 changes: 1 addition & 1 deletion aio_pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ async def main():
query=kw
)

connection = connection_class(url, loop=loop)
connection = connection_class(url, loop=loop, **kwargs)

await connection.connect(
timeout=timeout, client_properties=client_properties
Expand Down
5 changes: 5 additions & 0 deletions aio_pika/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
pass


class MaxReconnectAttemptsReached(Exception):
pass


__all__ = (
'AMQPChannelError',
'AMQPConnectionError',
Expand All @@ -53,6 +57,7 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
'DuplicateConsumerTag',
'IncompatibleProtocolError',
'InvalidFrameError',
'MaxReconnectAttemptsReached',
'MessageProcessError',
'MethodNotImplemented',
'ProbableAuthenticationError',
Expand Down
20 changes: 19 additions & 1 deletion aio_pika/robust_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Type

from aiormq.connection import parse_bool, parse_int
from .exceptions import CONNECTION_EXCEPTIONS
from .exceptions import CONNECTION_EXCEPTIONS, MaxReconnectAttemptsReached
from .connection import Connection, connect, ConnectionType
from .tools import CallbackCollection
from .types import TimeoutType
Expand All @@ -29,6 +29,7 @@ class RobustConnection(Connection):

CHANNEL_CLASS = RobustChannel
KWARGS_TYPES = (
('max_reconnect_attempts', parse_int, '0'),
('reconnect_interval', parse_int, '5'),
('fail_fast', parse_bool, '1'),
)
Expand All @@ -43,7 +44,9 @@ def __init__(self, url, loop=None, **kwargs):
self.fail_fast = self.kwargs['fail_fast']

self.__channels = set()
self._reconnect_attempt = None
self._reconnect_callbacks = CallbackCollection()
self._stop_callbacks = CallbackCollection()
self._closed = False

@property
Expand Down Expand Up @@ -77,6 +80,9 @@ def add_reconnect_callback(self, callback: Callable[[], None]):

self._reconnect_callbacks.add(callback)

def add_stop_callback(self, callback: Callable[[Exception], None]):
self._stop_callbacks.add(callback)

async def connect(self, timeout: TimeoutType = None, **kwargs):
if kwargs:
# Store connect kwargs for reconnects
Expand Down Expand Up @@ -104,6 +110,16 @@ async def reconnect(self):
if self.is_closed:
return

if self.kwargs['max_reconnect_attempts'] > 0:
if self._reconnect_attempt is None:
self._reconnect_attempt = 1
else:
self._reconnect_attempt += 1

if self._reconnect_attempt > self.kwargs['max_reconnect_attempts']:
await self.close(MaxReconnectAttemptsReached())
return

try:
await super().connect()
except CONNECTION_EXCEPTIONS:
Expand Down Expand Up @@ -131,6 +147,7 @@ def channel(self, channel_number: int = None,
return channel

async def _on_reconnect(self):
self._reconnect_attempt = None
for number, channel in self._channels.items():
try:
await channel.on_reconnect(self, number)
Expand All @@ -151,6 +168,7 @@ async def close(self, exc=asyncio.CancelledError):
return

self._closed = True
self._stop_callbacks(exc)

if self.connection is None:
return
Expand Down
44 changes: 40 additions & 4 deletions tests/test_amqp_robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aiormq import ChannelLockedResource

from aio_pika import connect_robust, Message
from aio_pika.exceptions import MaxReconnectAttemptsReached
from aio_pika.robust_channel import RobustChannel
from aio_pika.robust_connection import RobustConnection
from aio_pika.robust_queue import RobustQueue
Expand All @@ -27,6 +28,7 @@ def __init__(self, *, loop, shost='127.0.0.1', sport,
self.src_port = sport
self.dst_host = dhost
self.dst_port = dport
self._run_task = None
self.connections = set()

async def _pipe(self, reader: asyncio.StreamReader,
Expand Down Expand Up @@ -54,12 +56,19 @@ async def handle_client(self, creader: asyncio.StreamReader,
])

async def start(self):
return await asyncio.start_server(
self._run_task = await asyncio.start_server(
self.handle_client,
host=self.src_host,
port=self.src_port,
loop=self.loop,
)
return self._run_task

async def stop(self):
assert self._run_task is not None
self._run_task.close()
await self.disconnect()
self._run_task = None

async def disconnect(self):
tasks = list()
Expand All @@ -72,7 +81,8 @@ async def close(writer):
writer = self.connections.pop() # type: asyncio.StreamWriter
tasks.append(self.loop.create_task(close(writer)))

await asyncio.wait(tasks)
if tasks:
await asyncio.wait(tasks)


class TestCase(AMQPTestCase):
Expand All @@ -84,7 +94,7 @@ def get_unused_port() -> int:
sock.close()
return port

async def create_connection(self, cleanup=True):
async def create_connection(self, cleanup=True, max_reconnect_attempts=0):
self.proxy = Proxy(
dhost=AMQP_URL.host,
dport=AMQP_URL.port,
Expand All @@ -98,7 +108,11 @@ async def create_connection(self, cleanup=True):
self.proxy.src_host
).with_port(
self.proxy.src_port
).update_query(reconnect_interval=1)
).update_query(
reconnect_interval=1
).update_query(
max_reconnect_attempts=max_reconnect_attempts
)

client = await connect_robust(str(url), loop=self.loop)

Expand Down Expand Up @@ -210,6 +224,28 @@ async def reader():

assert len(shared) == 10

async def test_robust_reconnect_max_attempts(self):
client = await self.create_connection(max_reconnect_attempts=2)
self.assertIsInstance(client, RobustConnection)

first_close = asyncio.Future()
stopped = asyncio.Future()

def stop_callback(exc):
assert isinstance(exc, MaxReconnectAttemptsReached)
stopped.set_result(True)

def close_callback(f):
first_close.set_result(True)

client.add_stop_callback(stop_callback)
client.connection.closing.add_done_callback(close_callback)
await self.proxy.stop()
await first_close
# 1 interval before first try and 2 after attempts
await asyncio.wait_for(stopped,
timeout=client.reconnect_interval * 3 + 0.1)

async def test_channel_locked_resource2(self):
ch1 = await self.create_channel()
ch2 = await self.create_channel()
Expand Down

0 comments on commit d95a3cc

Please sign in to comment.