From 3ba566f30b798692b853a903e114e7943b9d1928 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Thu, 14 Nov 2024 20:20:34 +0800 Subject: [PATCH 1/2] handle connection close on CancelledError to prevent stale results --- asyncmy/connection.pyx | 20 +++++++++++++++++--- tests/test_pool.py | 17 +++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/asyncmy/connection.pyx b/asyncmy/connection.pyx index 4768a43..4f02098 100644 --- a/asyncmy/connection.pyx +++ b/asyncmy/connection.pyx @@ -296,6 +296,7 @@ class Connection: self._connected = False self._reader: Optional[StreamReader] = None self._writer: Optional[StreamWriter] = None + self._close_reason = None def _create_ssl_ctx(self, sslp): if isinstance(sslp, ssl.SSLContext): @@ -616,7 +617,11 @@ class Connection: """ buff = bytearray() while True: - packet_header = await self._read_bytes(4) + try: + packet_header = await self._read_bytes(4) + except asyncio.CancelledError: + self._close_on_cancel() + raise btrl, btrh, packet_number = HBB.unpack(packet_header) bytes_to_read = btrl + (btrh << 16) if packet_number != self._next_seq_id: @@ -631,7 +636,11 @@ class Connection: % (packet_number, self._next_seq_id) ) self._next_seq_id = (self._next_seq_id + 1) % 256 - recv_data = await self._read_bytes(bytes_to_read) + try: + recv_data = await self._read_bytes(bytes_to_read) + except asyncio.CancelledError: + self._close_on_cancel() + raise buff.extend(recv_data) # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html if bytes_to_read == 0xFFFFFF: @@ -708,7 +717,7 @@ class Connection: :raise ValueError: If no username was specified. """ if not self._connected: - raise errors.InterfaceError(0, "Not connected") + raise errors.InterfaceError(0, self._close_reason or "Not connected") # If the last query was unbuffered, make sure it finishes before # sending new commands @@ -1028,6 +1037,11 @@ class Connection: def get_server_info(self): return self.server_version + def _close_on_cancel(self): + self.close() + self._close_reason = "Cancelled during execution" + self._connected = False + Warning = errors.Warning Error = errors.Error InterfaceError = errors.InterfaceError diff --git a/tests/test_pool.py b/tests/test_pool.py index 992dba5..4bcb7a4 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,3 +1,4 @@ +import asyncio import pytest from asyncmy.connection import Connection @@ -30,3 +31,19 @@ async def test_acquire(pool): await pool.release(conn) assert pool.freesize == 1 assert pool.size == 1 + + +@pytest.mark.asyncio +async def test_cancel_execute(pool, event_loop): + async def run(index): + async with pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"SELECT {index}") + ret = await cursor.fetchone() + assert ret == (index,) + task = event_loop.create_task(run(1)) + await asyncio.sleep(0) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + await run(2) From 292d7373808e951e51e893c7ade07861d914c620 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Thu, 14 Nov 2024 21:28:16 +0800 Subject: [PATCH 2/2] code format: tests/test_pool.py --- tests/test_pool.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_pool.py b/tests/test_pool.py index 4bcb7a4..a3a765a 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,4 +1,5 @@ import asyncio + import pytest from asyncmy.connection import Connection @@ -41,6 +42,7 @@ async def run(index): await cursor.execute(f"SELECT {index}") ret = await cursor.fetchone() assert ret == (index,) + task = event_loop.create_task(run(1)) await asyncio.sleep(0) task.cancel()