1
1
from datetime import timedelta
2
+ from typing import Any , Protocol
2
3
3
4
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
4
- from pydantic import AnyUrl
5
+ from pydantic import AnyUrl , TypeAdapter
5
6
6
7
import mcp .types as types
7
- from mcp .shared .session import BaseSession
8
+ from mcp .shared .context import RequestContext
9
+ from mcp .shared .session import BaseSession , RequestResponder
8
10
from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
9
11
10
12
13
+ class SamplingFnT (Protocol ):
14
+ async def __call__ (
15
+ self ,
16
+ context : RequestContext ["ClientSession" , Any ],
17
+ params : types .CreateMessageRequestParams ,
18
+ ) -> types .CreateMessageResult | types .ErrorData : ...
19
+
20
+
21
+ class ListRootsFnT (Protocol ):
22
+ async def __call__ (
23
+ self , context : RequestContext ["ClientSession" , Any ]
24
+ ) -> types .ListRootsResult | types .ErrorData : ...
25
+
26
+
27
+ async def _default_sampling_callback (
28
+ context : RequestContext ["ClientSession" , Any ],
29
+ params : types .CreateMessageRequestParams ,
30
+ ) -> types .CreateMessageResult | types .ErrorData :
31
+ return types .ErrorData (
32
+ code = types .INVALID_REQUEST ,
33
+ message = "Sampling not supported" ,
34
+ )
35
+
36
+
37
+ async def _default_list_roots_callback (
38
+ context : RequestContext ["ClientSession" , Any ],
39
+ ) -> types .ListRootsResult | types .ErrorData :
40
+ return types .ErrorData (
41
+ code = types .INVALID_REQUEST ,
42
+ message = "List roots not supported" ,
43
+ )
44
+
45
+
46
+ ClientResponse = TypeAdapter (types .ClientResult | types .ErrorData )
47
+
48
+
11
49
class ClientSession (
12
50
BaseSession [
13
51
types .ClientRequest ,
@@ -22,6 +60,8 @@ def __init__(
22
60
read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
23
61
write_stream : MemoryObjectSendStream [types .JSONRPCMessage ],
24
62
read_timeout_seconds : timedelta | None = None ,
63
+ sampling_callback : SamplingFnT | None = None ,
64
+ list_roots_callback : ListRootsFnT | None = None ,
25
65
) -> None :
26
66
super ().__init__ (
27
67
read_stream ,
@@ -30,23 +70,34 @@ def __init__(
30
70
types .ServerNotification ,
31
71
read_timeout_seconds = read_timeout_seconds ,
32
72
)
73
+ self ._sampling_callback = sampling_callback or _default_sampling_callback
74
+ self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
33
75
34
76
async def initialize (self ) -> types .InitializeResult :
77
+ sampling = (
78
+ types .SamplingCapability () if self ._sampling_callback is not None else None
79
+ )
80
+ roots = (
81
+ types .RootsCapability (
82
+ # TODO: Should this be based on whether we
83
+ # _will_ send notifications, or only whether
84
+ # they're supported?
85
+ listChanged = True ,
86
+ )
87
+ if self ._list_roots_callback is not None
88
+ else None
89
+ )
90
+
35
91
result = await self .send_request (
36
92
types .ClientRequest (
37
93
types .InitializeRequest (
38
94
method = "initialize" ,
39
95
params = types .InitializeRequestParams (
40
96
protocolVersion = types .LATEST_PROTOCOL_VERSION ,
41
97
capabilities = types .ClientCapabilities (
42
- sampling = None ,
98
+ sampling = sampling ,
43
99
experimental = None ,
44
- roots = types .RootsCapability (
45
- # TODO: Should this be based on whether we
46
- # _will_ send notifications, or only whether
47
- # they're supported?
48
- listChanged = True
49
- ),
100
+ roots = roots ,
50
101
),
51
102
clientInfo = types .Implementation (name = "mcp" , version = "0.1.0" ),
52
103
),
@@ -243,3 +294,32 @@ async def send_roots_list_changed(self) -> None:
243
294
)
244
295
)
245
296
)
297
+
298
+ async def _received_request (
299
+ self , responder : RequestResponder [types .ServerRequest , types .ClientResult ]
300
+ ) -> None :
301
+ ctx = RequestContext [ClientSession , Any ](
302
+ request_id = responder .request_id ,
303
+ meta = responder .request_meta ,
304
+ session = self ,
305
+ lifespan_context = None ,
306
+ )
307
+
308
+ match responder .request .root :
309
+ case types .CreateMessageRequest (params = params ):
310
+ with responder :
311
+ response = await self ._sampling_callback (ctx , params )
312
+ client_response = ClientResponse .validate_python (response )
313
+ await responder .respond (client_response )
314
+
315
+ case types .ListRootsRequest ():
316
+ with responder :
317
+ response = await self ._list_roots_callback (ctx )
318
+ client_response = ClientResponse .validate_python (response )
319
+ await responder .respond (client_response )
320
+
321
+ case types .PingRequest ():
322
+ with responder :
323
+ return await responder .respond (
324
+ types .ClientResult (root = types .EmptyResult ())
325
+ )
0 commit comments