Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle connection close on CancelledError to prevent stale results #108

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions asyncmy/connection.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ class Connection:
self._connected = False
self._reader: Optional[StreamReader] = None
self._writer: Optional[StreamWriter] = None
self._close_reason = None

def _create_ssl_ctx(self, sslp):
if isinstance(sslp, ssl.SSLContext):
Expand Down Expand Up @@ -616,7 +617,11 @@ class Connection:
"""
buff = bytearray()
while True:
packet_header = await self._read_bytes(4)
try:
packet_header = await self._read_bytes(4)
except asyncio.CancelledError:
self._close_on_cancel()
raise
btrl, btrh, packet_number = HBB.unpack(packet_header)
bytes_to_read = btrl + (btrh << 16)
if packet_number != self._next_seq_id:
Expand All @@ -631,7 +636,11 @@ class Connection:
% (packet_number, self._next_seq_id)
)
self._next_seq_id = (self._next_seq_id + 1) % 256
recv_data = await self._read_bytes(bytes_to_read)
try:
recv_data = await self._read_bytes(bytes_to_read)
except asyncio.CancelledError:
self._close_on_cancel()
raise
buff.extend(recv_data)
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
if bytes_to_read == 0xFFFFFF:
Expand Down Expand Up @@ -708,7 +717,7 @@ class Connection:
:raise ValueError: If no username was specified.
"""
if not self._connected:
raise errors.InterfaceError(0, "Not connected")
raise errors.InterfaceError(0, self._close_reason or "Not connected")

# If the last query was unbuffered, make sure it finishes before
# sending new commands
Expand Down Expand Up @@ -1028,6 +1037,11 @@ class Connection:
def get_server_info(self):
return self.server_version

def _close_on_cancel(self):
self.close()
self._close_reason = "Cancelled during execution"
self._connected = False

Warning = errors.Warning
Error = errors.Error
InterfaceError = errors.InterfaceError
Expand Down
19 changes: 19 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import pytest

from asyncmy.connection import Connection
Expand Down Expand Up @@ -30,3 +32,20 @@ async def test_acquire(pool):
await pool.release(conn)
assert pool.freesize == 1
assert pool.size == 1


@pytest.mark.asyncio
async def test_cancel_execute(pool, event_loop):
async def run(index):
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"SELECT {index}")
ret = await cursor.fetchone()
assert ret == (index,)

task = event_loop.create_task(run(1))
await asyncio.sleep(0)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
await run(2)