Skip to content

Commit ff22f48

Browse files
jerome3o-anthropicSecretiveShelldsp-ant
authored
Add client handling for sampling, list roots, ping (#218)
Adds sampling and list roots callbacks to the ClientSession, allowing the client to handle requests from the server. Co-authored-by: TerminalMan <[email protected]> Co-authored-by: David Soria Parra <[email protected]>
1 parent 1066199 commit ff22f48

File tree

6 files changed

+256
-12
lines changed

6 files changed

+256
-12
lines changed

README.md

+13-1
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,21 @@ server_params = StdioServerParameters(
476476
env=None # Optional environment variables
477477
)
478478

479+
# Optional: create a sampling callback
480+
async def handle_sampling_message(message: types.CreateMessageRequestParams) -> types.CreateMessageResult:
481+
return types.CreateMessageResult(
482+
role="assistant",
483+
content=types.TextContent(
484+
type="text",
485+
text="Hello, world! from model",
486+
),
487+
model="gpt-3.5-turbo",
488+
stopReason="endTurn",
489+
)
490+
479491
async def run():
480492
async with stdio_client(server_params) as (read, write):
481-
async with ClientSession(read, write) as session:
493+
async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session:
482494
# Initialize the connection
483495
await session.initialize()
484496

src/mcp/client/session.py

+89-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,51 @@
11
from datetime import timedelta
2+
from typing import Any, Protocol
23

34
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4-
from pydantic import AnyUrl
5+
from pydantic import AnyUrl, TypeAdapter
56

67
import mcp.types as types
7-
from mcp.shared.session import BaseSession
8+
from mcp.shared.context import RequestContext
9+
from mcp.shared.session import BaseSession, RequestResponder
810
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
911

1012

13+
class SamplingFnT(Protocol):
14+
async def __call__(
15+
self,
16+
context: RequestContext["ClientSession", Any],
17+
params: types.CreateMessageRequestParams,
18+
) -> types.CreateMessageResult | types.ErrorData: ...
19+
20+
21+
class ListRootsFnT(Protocol):
22+
async def __call__(
23+
self, context: RequestContext["ClientSession", Any]
24+
) -> types.ListRootsResult | types.ErrorData: ...
25+
26+
27+
async def _default_sampling_callback(
28+
context: RequestContext["ClientSession", Any],
29+
params: types.CreateMessageRequestParams,
30+
) -> types.CreateMessageResult | types.ErrorData:
31+
return types.ErrorData(
32+
code=types.INVALID_REQUEST,
33+
message="Sampling not supported",
34+
)
35+
36+
37+
async def _default_list_roots_callback(
38+
context: RequestContext["ClientSession", Any],
39+
) -> types.ListRootsResult | types.ErrorData:
40+
return types.ErrorData(
41+
code=types.INVALID_REQUEST,
42+
message="List roots not supported",
43+
)
44+
45+
46+
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
47+
48+
1149
class ClientSession(
1250
BaseSession[
1351
types.ClientRequest,
@@ -22,6 +60,8 @@ def __init__(
2260
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
2361
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
2462
read_timeout_seconds: timedelta | None = None,
63+
sampling_callback: SamplingFnT | None = None,
64+
list_roots_callback: ListRootsFnT | None = None,
2565
) -> None:
2666
super().__init__(
2767
read_stream,
@@ -30,23 +70,34 @@ def __init__(
3070
types.ServerNotification,
3171
read_timeout_seconds=read_timeout_seconds,
3272
)
73+
self._sampling_callback = sampling_callback or _default_sampling_callback
74+
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
3375

3476
async def initialize(self) -> types.InitializeResult:
77+
sampling = (
78+
types.SamplingCapability() if self._sampling_callback is not None else None
79+
)
80+
roots = (
81+
types.RootsCapability(
82+
# TODO: Should this be based on whether we
83+
# _will_ send notifications, or only whether
84+
# they're supported?
85+
listChanged=True,
86+
)
87+
if self._list_roots_callback is not None
88+
else None
89+
)
90+
3591
result = await self.send_request(
3692
types.ClientRequest(
3793
types.InitializeRequest(
3894
method="initialize",
3995
params=types.InitializeRequestParams(
4096
protocolVersion=types.LATEST_PROTOCOL_VERSION,
4197
capabilities=types.ClientCapabilities(
42-
sampling=None,
98+
sampling=sampling,
4399
experimental=None,
44-
roots=types.RootsCapability(
45-
# TODO: Should this be based on whether we
46-
# _will_ send notifications, or only whether
47-
# they're supported?
48-
listChanged=True
49-
),
100+
roots=roots,
50101
),
51102
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
52103
),
@@ -243,3 +294,32 @@ async def send_roots_list_changed(self) -> None:
243294
)
244295
)
245296
)
297+
298+
async def _received_request(
299+
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
300+
) -> None:
301+
ctx = RequestContext[ClientSession, Any](
302+
request_id=responder.request_id,
303+
meta=responder.request_meta,
304+
session=self,
305+
lifespan_context=None,
306+
)
307+
308+
match responder.request.root:
309+
case types.CreateMessageRequest(params=params):
310+
with responder:
311+
response = await self._sampling_callback(ctx, params)
312+
client_response = ClientResponse.validate_python(response)
313+
await responder.respond(client_response)
314+
315+
case types.ListRootsRequest():
316+
with responder:
317+
response = await self._list_roots_callback(ctx)
318+
client_response = ClientResponse.validate_python(response)
319+
await responder.respond(client_response)
320+
321+
case types.PingRequest():
322+
with responder:
323+
return await responder.respond(
324+
types.ClientResult(root=types.EmptyResult())
325+
)

src/mcp/shared/memory.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import anyio
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111

12-
from mcp.client.session import ClientSession
12+
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
1313
from mcp.server import Server
1414
from mcp.types import JSONRPCMessage
1515

@@ -54,6 +54,8 @@ async def create_client_server_memory_streams() -> (
5454
async def create_connected_server_and_client_session(
5555
server: Server,
5656
read_timeout_seconds: timedelta | None = None,
57+
sampling_callback: SamplingFnT | None = None,
58+
list_roots_callback: ListRootsFnT | None = None,
5759
raise_exceptions: bool = False,
5860
) -> AsyncGenerator[ClientSession, None]:
5961
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -80,6 +82,8 @@ async def create_connected_server_and_client_session(
8082
read_stream=client_read,
8183
write_stream=client_write,
8284
read_timeout_seconds=read_timeout_seconds,
85+
sampling_callback=sampling_callback,
86+
list_roots_callback=list_roots_callback,
8387
) as client_session:
8488
await client_session.initialize()
8589
yield client_session
+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
from pydantic import FileUrl
3+
4+
from mcp.client.session import ClientSession
5+
from mcp.server.fastmcp.server import Context
6+
from mcp.shared.context import RequestContext
7+
from mcp.shared.memory import (
8+
create_connected_server_and_client_session as create_session,
9+
)
10+
from mcp.types import (
11+
ListRootsResult,
12+
Root,
13+
TextContent,
14+
)
15+
16+
17+
@pytest.mark.anyio
18+
async def test_list_roots_callback():
19+
from mcp.server.fastmcp import FastMCP
20+
21+
server = FastMCP("test")
22+
23+
callback_return = ListRootsResult(
24+
roots=[
25+
Root(
26+
uri=FileUrl("file://users/fake/test"),
27+
name="Test Root 1",
28+
),
29+
Root(
30+
uri=FileUrl("file://users/fake/test/2"),
31+
name="Test Root 2",
32+
),
33+
]
34+
)
35+
36+
async def list_roots_callback(
37+
context: RequestContext[ClientSession, None],
38+
) -> ListRootsResult:
39+
return callback_return
40+
41+
@server.tool("test_list_roots")
42+
async def test_list_roots(context: Context, message: str):
43+
roots = await context.session.list_roots()
44+
assert roots == callback_return
45+
return True
46+
47+
# Test with list_roots callback
48+
async with create_session(
49+
server._mcp_server, list_roots_callback=list_roots_callback
50+
) as client_session:
51+
# Make a request to trigger sampling callback
52+
result = await client_session.call_tool(
53+
"test_list_roots", {"message": "test message"}
54+
)
55+
assert result.isError is False
56+
assert isinstance(result.content[0], TextContent)
57+
assert result.content[0].text == "true"
58+
59+
# Test without list_roots callback
60+
async with create_session(server._mcp_server) as client_session:
61+
# Make a request to trigger sampling callback
62+
result = await client_session.call_tool(
63+
"test_list_roots", {"message": "test message"}
64+
)
65+
assert result.isError is True
66+
assert isinstance(result.content[0], TextContent)
67+
assert (
68+
result.content[0].text
69+
== "Error executing tool test_list_roots: List roots not supported"
70+
)
+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
3+
from mcp.client.session import ClientSession
4+
from mcp.shared.context import RequestContext
5+
from mcp.shared.memory import (
6+
create_connected_server_and_client_session as create_session,
7+
)
8+
from mcp.types import (
9+
CreateMessageRequestParams,
10+
CreateMessageResult,
11+
SamplingMessage,
12+
TextContent,
13+
)
14+
15+
16+
@pytest.mark.anyio
17+
async def test_sampling_callback():
18+
from mcp.server.fastmcp import FastMCP
19+
20+
server = FastMCP("test")
21+
22+
callback_return = CreateMessageResult(
23+
role="assistant",
24+
content=TextContent(
25+
type="text", text="This is a response from the sampling callback"
26+
),
27+
model="test-model",
28+
stopReason="endTurn",
29+
)
30+
31+
async def sampling_callback(
32+
context: RequestContext[ClientSession, None],
33+
params: CreateMessageRequestParams,
34+
) -> CreateMessageResult:
35+
return callback_return
36+
37+
@server.tool("test_sampling")
38+
async def test_sampling_tool(message: str):
39+
value = await server.get_context().session.create_message(
40+
messages=[
41+
SamplingMessage(
42+
role="user", content=TextContent(type="text", text=message)
43+
)
44+
],
45+
max_tokens=100,
46+
)
47+
assert value == callback_return
48+
return True
49+
50+
# Test with sampling callback
51+
async with create_session(
52+
server._mcp_server, sampling_callback=sampling_callback
53+
) as client_session:
54+
# Make a request to trigger sampling callback
55+
result = await client_session.call_tool(
56+
"test_sampling", {"message": "Test message for sampling"}
57+
)
58+
assert result.isError is False
59+
assert isinstance(result.content[0], TextContent)
60+
assert result.content[0].text == "true"
61+
62+
# Test without sampling callback
63+
async with create_session(server._mcp_server) as client_session:
64+
# Make a request to trigger sampling callback
65+
result = await client_session.call_tool(
66+
"test_sampling", {"message": "Test message for sampling"}
67+
)
68+
assert result.isError is True
69+
assert isinstance(result.content[0], TextContent)
70+
assert (
71+
result.content[0].text
72+
== "Error executing tool test_sampling: Sampling not supported"
73+
)

tests/client/test_stdio.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
import shutil
2+
13
import pytest
24

35
from mcp.client.stdio import StdioServerParameters, stdio_client
46
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
57

8+
tee: str = shutil.which("tee") # type: ignore
9+
610

711
@pytest.mark.anyio
12+
@pytest.mark.skipif(tee is None, reason="could not find tee command")
813
async def test_stdio_client():
9-
server_parameters = StdioServerParameters(command="/usr/bin/tee")
14+
server_parameters = StdioServerParameters(command=tee)
1015

1116
async with stdio_client(server_parameters) as (read_stream, write_stream):
1217
# Test sending and receiving messages

0 commit comments

Comments
 (0)