Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Client websocket.py #179

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"sse-starlette>=1.6.1",
"pydantic-settings>=2.5.2",
"uvicorn>=0.23.1",
"websockets>=15.0.1",
]

[project.optional-dependencies]
Expand Down
86 changes: 86 additions & 0 deletions src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator

import anyio
from pydantic import ValidationError
from websockets.asyncio.client import connect as ws_connect
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from websockets.typing import Subprotocol

import mcp.types as types

logger = logging.getLogger(__name__)


@asynccontextmanager
async def websocket_client(url: str) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
MemoryObjectSendStream[types.JSONRPCMessage],
],
None,
]:
"""
WebSocket client transport for MCP, symmetrical to the server version.

Connects to 'url' using the 'mcp' subprotocol, then yields:
(read_stream, write_stream)

- read_stream: As you read from this stream, you'll receive either valid
JSONRPCMessage objects or Exception objects (when validation fails).
- write_stream: Write JSONRPCMessage objects to this stream to send them
over the WebSocket to the server.
"""

# Create two in-memory streams:
# - One for incoming messages (read_stream, written by ws_reader)
# - One for outgoing messages (write_stream, read by ws_writer)
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

# Connect using websockets, requesting the "mcp" subprotocol
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
async def ws_reader():
"""
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
and sends them into read_stream_writer.
"""
try:
async with read_stream_writer:
async for raw_text in ws:
try:
message = types.JSONRPCMessage.model_validate_json(raw_text)
await read_stream_writer.send(message)
except ValidationError as exc:
# If JSON parse or model validation fails, send the exception
await read_stream_writer.send(exc)
except (anyio.ClosedResourceError, Exception):
await ws.close()

async def ws_writer():
"""
Reads JSON-RPC messages from write_stream_reader and sends them to the server.
"""
try:
async with write_stream_reader:
async for message in write_stream_reader:
# Convert to a dict, then to JSON
msg_dict = message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
await ws.send(json.dumps(msg_dict))
except (anyio.ClosedResourceError, Exception):
await ws.close()

async with anyio.create_task_group() as tg:
# Start reader and writer tasks
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)

# Yield the receive/send streams
yield (read_stream, write_stream)

# Once the caller's 'async with' block exits, we shut down
tg.cancel_scope.cancel()
227 changes: 227 additions & 0 deletions tests/shared/test_ws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import multiprocessing
import socket
import time
from typing import AsyncGenerator, Generator

import anyio
import pytest
import uvicorn
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import WebSocketRoute

from mcp.client.session import ClientSession
from mcp.client.websocket import websocket_client
from mcp.server import Server
from mcp.server.websocket import websocket_server
from mcp.shared.exceptions import McpError
from mcp.types import (
EmptyResult,
ErrorData,
InitializeResult,
ReadResourceResult,
TextContent,
TextResourceContents,
Tool,
)

SERVER_NAME = "test_server_for_WS"


@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


@pytest.fixture
def server_url(server_port: int) -> str:
return f"ws://127.0.0.1:{server_port}"


# Test server implementation
class ServerTest(Server):
def __init__(self):
super().__init__(SERVER_NAME)

@self.read_resource()
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
if uri.scheme == "foobar":
return f"Read {uri.host}"
elif uri.scheme == "slow":
# Simulate a slow resource
await anyio.sleep(2.0)
return f"Slow response from {uri.host}"

raise McpError(
error=ErrorData(
code=404, message="OOPS! no resource with that URI was found"
)
)

@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="test_tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
)
]

@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
return [TextContent(type="text", text=f"Called {name}")]


# Test fixtures
def make_server_app() -> Starlette:
"""Create test Starlette app with WebSocket transport"""
server = ServerTest()

async def handle_ws(websocket):
async with websocket_server(
websocket.scope, websocket.receive, websocket.send
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)

app = Starlette(
routes=[
WebSocketRoute("/ws", endpoint=handle_ws),
]
)

return app


def run_server(server_port: int) -> None:
app = make_server_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting server on {server_port}")
server.run()

# Give server time to start
while not server.started:
print("waiting for server to start")
time.sleep(0.5)


@pytest.fixture()
def server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process(
target=run_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting process")
proc.start()

# Wait for server to be running
max_attempts = 20
attempt = 0
print("waiting for server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
"Server failed to start after {} attempts".format(max_attempts)
)

yield

print("killing server")
# Signal the server to stop
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("server process failed to terminate")


@pytest.fixture()
async def initialized_ws_client_session(
server, server_url: str
) -> AsyncGenerator[ClientSession, None]:
"""Create and initialize a WebSocket client session"""
async with websocket_client(server_url + "/ws") as streams:
async with ClientSession(*streams) as session:
# Test initialization
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == SERVER_NAME

# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)

yield session


# Tests
@pytest.mark.anyio
async def test_ws_client_basic_connection(server: None, server_url: str) -> None:
"""Test the WebSocket connection establishment"""
async with websocket_client(server_url + "/ws") as streams:
async with ClientSession(*streams) as session:
# Test initialization
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == SERVER_NAME

# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)


@pytest.mark.anyio
async def test_ws_client_happy_request_and_response(
initialized_ws_client_session: ClientSession,
) -> None:
"""Test a successful request and response via WebSocket"""
result = await initialized_ws_client_session.read_resource("foobar://example")
assert isinstance(result, ReadResourceResult)
assert isinstance(result.contents, list)
assert len(result.contents) > 0
assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Read example"


@pytest.mark.anyio
async def test_ws_client_exception_handling(
initialized_ws_client_session: ClientSession,
) -> None:
"""Test exception handling in WebSocket communication"""
with pytest.raises(McpError) as exc_info:
await initialized_ws_client_session.read_resource("unknown://example")
assert exc_info.value.error.code == 404


@pytest.mark.anyio
async def test_ws_client_timeout(
initialized_ws_client_session: ClientSession,
) -> None:
"""Test timeout handling in WebSocket communication"""
# Set a very short timeout to trigger a timeout exception
with pytest.raises(TimeoutError):
with anyio.fail_after(0.1): # 100ms timeout
await initialized_ws_client_session.read_resource("slow://example")

# Now test that we can still use the session after a timeout
with anyio.fail_after(5): # Longer timeout to allow completion
result = await initialized_ws_client_session.read_resource("foobar://example")
assert isinstance(result, ReadResourceResult)
assert isinstance(result.contents, list)
assert len(result.contents) > 0
assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Read example"
Loading