From 5b2f6c8d61919859bdfbe418981b2d983da355bf Mon Sep 17 00:00:00 2001 From: Dos Moonen Date: Thu, 14 Mar 2024 16:55:37 +0100 Subject: [PATCH] Prevent deadlock in RobustChannel.reopen() --- aio_pika/channel.py | 4 ++-- aio_pika/robust_channel.py | 4 ++-- tests/test_amqp_robust_proxy.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/aio_pika/channel.py b/aio_pika/channel.py index 7741b7af..453cb599 100644 --- a/aio_pika/channel.py +++ b/aio_pika/channel.py @@ -153,12 +153,12 @@ def __str__(self) -> str: return "{}".format(self.number or "Not initialized channel") async def _open(self) -> None: - await self._connection.ready() - transport = self._connection.transport if transport is None: raise ChannelInvalidStateError("No active transport in channel") + await transport.ready() + channel = await UnderlayChannel.create( transport.connection, self._on_close, diff --git a/aio_pika/robust_channel.py b/aio_pika/robust_channel.py index 5ddf4ab9..b73d9b65 100644 --- a/aio_pika/robust_channel.py +++ b/aio_pika/robust_channel.py @@ -111,8 +111,8 @@ async def __close_callback(self, _: Any, exc: BaseException) -> None: await self.restore() - async def _open(self) -> None: - await super()._open() + async def reopen(self) -> None: + await super().reopen() await self.reopen_callbacks() async def _on_open(self) -> None: diff --git a/tests/test_amqp_robust_proxy.py b/tests/test_amqp_robust_proxy.py index 0938d26d..c9eb2d01 100644 --- a/tests/test_amqp_robust_proxy.py +++ b/tests/test_amqp_robust_proxy.py @@ -14,7 +14,7 @@ import aio_pika from aio_pika.abc import AbstractRobustChannel -from aio_pika.exceptions import QueueEmpty +from aio_pika.exceptions import QueueEmpty, CONNECTION_EXCEPTIONS from aio_pika.message import Message from aio_pika.robust_channel import RobustChannel from aio_pika.robust_connection import RobustConnection @@ -565,6 +565,7 @@ async def test_channel_reconnect_after_5kb( assert messages + assert on_reconnect.is_set() await connection.close() await direct_connection.close() @@ -666,7 +667,7 @@ async def test_channel_reconnect_stairway( try: await channel.set_qos(prefetch_count=1) break - except aiormq.ChannelInvalidStateError: + except CONNECTION_EXCEPTIONS: await asyncio.sleep(0.1) continue @@ -689,5 +690,6 @@ async def test_channel_reconnect_stairway( assert messages + assert on_reconnect.is_set() await connection.close() await direct_connection.close()