Skip to content

Commit 08cfbe5

Browse files
committed
fix: improve error handling and request cancellation for issue #88
1 parent 827e494 commit 08cfbe5

File tree

2 files changed

+41
-20
lines changed

2 files changed

+41
-20
lines changed

src/mcp/shared/session.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from contextlib import AbstractAsyncContextManager
23
from datetime import timedelta
34
from typing import Any, Callable, Generic, TypeVar
@@ -273,19 +274,28 @@ async def _receive_loop(self) -> None:
273274
await self._incoming_message_stream_writer.send(responder)
274275

275276
elif isinstance(message.root, JSONRPCNotification):
276-
notification = self._receive_notification_type.model_validate(
277-
message.root.model_dump(
278-
by_alias=True, mode="json", exclude_none=True
277+
try:
278+
notification = self._receive_notification_type.model_validate(
279+
message.root.model_dump(
280+
by_alias=True, mode="json", exclude_none=True
281+
)
282+
)
283+
# Handle cancellation notifications
284+
if isinstance(notification.root, CancelledNotification):
285+
cancelled_id = notification.root.params.requestId
286+
if cancelled_id in self._in_flight:
287+
await self._in_flight[cancelled_id].cancel()
288+
else:
289+
await self._received_notification(notification)
290+
await self._incoming_message_stream_writer.send(
291+
notification
292+
)
293+
except Exception as e:
294+
# For other validation errors, log and continue
295+
logging.warning(
296+
f"Failed to validate notification: {e}. "
297+
f"Message was: {message.root}"
279298
)
280-
)
281-
# Handle cancellation notifications
282-
if isinstance(notification.root, CancelledNotification):
283-
cancelled_id = notification.root.params.requestId
284-
if cancelled_id in self._in_flight:
285-
await self._in_flight[cancelled_id].cancel()
286-
else:
287-
await self._received_notification(notification)
288-
await self._incoming_message_stream_writer.send(notification)
289299
else: # Response or error
290300
stream = self._response_streams.pop(message.root.id, None)
291301
if stream:

tests/issues/test_88_random_error.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,23 @@ async def test_notification_validation_error(tmp_path: Path):
3030

3131
server = Server(name="test")
3232
request_count = 0
33-
slow_request_complete = False
33+
slow_request_started = anyio.Event()
34+
slow_request_complete = anyio.Event()
3435

3536
@server.call_tool()
3637
async def slow_tool(
3738
name: str, arg
3839
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
39-
nonlocal request_count, slow_request_complete
40+
nonlocal request_count
4041
request_count += 1
4142

4243
if name == "slow":
44+
# Signal that slow request has started
45+
slow_request_started.set()
4346
# Long enough to ensure timeout
4447
await anyio.sleep(0.2)
45-
slow_request_complete = True
48+
# Signal completion
49+
slow_request_complete.set()
4650
return [TextContent(type="text", text=f"slow {request_count}")]
4751
elif name == "fast":
4852
# Fast enough to complete before timeout
@@ -71,16 +75,16 @@ async def client(read_stream, write_stream):
7175
# First call should work (fast operation)
7276
result = await session.call_tool("fast")
7377
assert result.content == [TextContent(type="text", text="fast 1")]
74-
assert not slow_request_complete
78+
assert not slow_request_complete.is_set()
7579

7680
# Second call should timeout (slow operation)
7781
with pytest.raises(McpError) as exc_info:
7882
await session.call_tool("slow")
7983
assert "Timed out while waiting" in str(exc_info.value)
8084

8185
# Wait for slow request to complete in the background
82-
await anyio.sleep(0.3)
83-
assert slow_request_complete
86+
with anyio.fail_after(1): # Timeout after 1 second
87+
await slow_request_complete.wait()
8488

8589
# Third call should work (fast operation),
8690
# proving server is still responsive
@@ -91,10 +95,17 @@ async def client(read_stream, write_stream):
9195
server_writer, server_reader = anyio.create_memory_object_stream(1)
9296
client_writer, client_reader = anyio.create_memory_object_stream(1)
9397

98+
server_ready = anyio.Event()
99+
100+
async def wrapped_server_handler(read_stream, write_stream):
101+
server_ready.set()
102+
await server_handler(read_stream, write_stream)
103+
94104
async with anyio.create_task_group() as tg:
95-
tg.start_soon(server_handler, server_reader, client_writer)
105+
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
96106
# Wait for server to start and initialize
97-
await anyio.sleep(0.1)
107+
with anyio.fail_after(1): # Timeout after 1 second
108+
await server_ready.wait()
98109
# Run client in a separate task to avoid cancellation
99110
async with anyio.create_task_group() as client_tg:
100111
client_tg.start_soon(client, client_reader, server_writer)

0 commit comments

Comments
 (0)