diff --git a/mcp_python/server/__init__.py b/mcp_python/server/__init__.py index 581f259..7801f64 100644 --- a/mcp_python/server/__init__.py +++ b/mcp_python/server/__init__.py @@ -32,6 +32,7 @@ ReadResourceResult, Resource, ResourceReference, + ServerCapabilities, ServerResult, SetLevelRequest, SubscribeRequest, @@ -40,7 +41,6 @@ logger = logging.getLogger(__name__) - request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar( "request_ctx" ) @@ -53,6 +53,33 @@ def __init__(self, name: str): self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} logger.info(f"Initializing server '{name}'") + def create_initialization_options(self) -> types.InitializationOptions: + """Create initialization options from this server instance.""" + def pkg_version(package: str) -> str: + try: + from importlib.metadata import version + return version(package) + except Exception: + return "unknown" + + return types.InitializationOptions( + server_name=self.name, + server_version=pkg_version("mcp_python"), + capabilities=self.get_capabilities(), + ) + + def get_capabilities(self) -> ServerCapabilities: + """Convert existing handlers to a ServerCapabilities object.""" + def get_capability(req_type: type) -> dict[str, Any] | None: + return {} if req_type in self.request_handlers else None + + return ServerCapabilities( + prompts=get_capability(ListPromptsRequest), + resources=get_capability(ListResourcesRequest), + tools=get_capability(ListPromptsRequest), + logging=get_capability(SetLevelRequest) + ) + @property def request_context(self) -> RequestContext: """If called outside of a request context, this will raise a LookupError.""" @@ -280,9 +307,10 @@ async def run( self, read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], + initialization_options: types.InitializationOptions ): with warnings.catch_warnings(record=True) as w: - async with ServerSession(read_stream, write_stream) as session: + async with ServerSession(read_stream, write_stream, initialization_options) as session: async for message in session.incoming_messages: logger.debug(f"Received message: {message}") diff --git a/mcp_python/server/__main__.py b/mcp_python/server/__main__.py index 907b453..efb7dd8 100644 --- a/mcp_python/server/__main__.py +++ b/mcp_python/server/__main__.py @@ -1,10 +1,12 @@ import logging import sys - +import importlib.metadata import anyio from mcp_python.server.session import ServerSession +from mcp_python.server.types import InitializationOptions from mcp_python.server.stdio import stdio_server +from mcp_python.types import ServerCapabilities if not sys.warnoptions: import warnings @@ -26,8 +28,9 @@ async def receive_loop(session: ServerSession): async def main(): + version = importlib.metadata.version("mcp_python") async with stdio_server() as (read_stream, write_stream): - async with ServerSession(read_stream, write_stream) as session, write_stream: + async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream: await receive_loop(session) diff --git a/mcp_python/server/session.py b/mcp_python/server/session.py index 687fda6..c64f799 100644 --- a/mcp_python/server/session.py +++ b/mcp_python/server/session.py @@ -10,6 +10,7 @@ BaseSession, RequestResponder, ) +from mcp_python.server.types import InitializationOptions from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION from mcp_python.types import ( ClientNotification, @@ -52,9 +53,11 @@ def __init__( self, read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], + init_options: InitializationOptions ) -> None: super().__init__(read_stream, write_stream, ClientRequest, ClientNotification) self._initialization_state = InitializationState.NotInitialized + self._init_options = init_options async def _received_request( self, responder: RequestResponder[ClientRequest, ServerResult] @@ -66,15 +69,10 @@ async def _received_request( ServerResult( InitializeResult( protocolVersion=SUPPORTED_PROTOCOL_VERSION, - capabilities=ServerCapabilities( - logging=None, - resources=None, - tools=None, - experimental=None, - prompts={}, - ), + capabilities=self._init_options.capabilities, serverInfo=Implementation( - name="mcp_python", version="0.1.0" + name=self._init_options.server_name, + version=self._init_options.server_version ), ) ) diff --git a/mcp_python/server/types.py b/mcp_python/server/types.py index 2993d84..1b56f24 100644 --- a/mcp_python/server/types.py +++ b/mcp_python/server/types.py @@ -5,7 +5,8 @@ from dataclasses import dataclass from typing import Literal -from mcp_python.types import Role +from pydantic import BaseModel +from mcp_python.types import Role, ServerCapabilities @dataclass @@ -25,3 +26,9 @@ class Message: class PromptResponse: messages: list[Message] desc: str | None = None + + +class InitializationOptions(BaseModel): + server_name: str + server_version: str + capabilities: ServerCapabilities diff --git a/pyproject.toml b/pyproject.toml index 4d2258b..208aece 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,3 +35,8 @@ target-version = "py38" [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] + +[tool.uv] +dev-dependencies = [ + "trio>=0.26.2", +] diff --git a/tests/server/test_session.py b/tests/server/test_session.py index eae7a77..01813a7 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -2,11 +2,14 @@ import pytest from mcp_python.client.session import ClientSession +from mcp_python.server import Server from mcp_python.server.session import ServerSession +from mcp_python.server.types import InitializationOptions from mcp_python.types import ( ClientNotification, InitializedNotification, JSONRPCMessage, + ServerCapabilities, ) @@ -30,7 +33,7 @@ async def run_server(): nonlocal received_initialized async with ServerSession( - client_to_server_receive, server_to_client_send + client_to_server_receive, server_to_client_send, InitializationOptions(server_name='mcp_python', server_version='0.1.0', capabilities=ServerCapabilities()) ) as server_session: async for message in server_session.incoming_messages: if isinstance(message, Exception): @@ -57,3 +60,31 @@ async def run_server(): pass assert received_initialized + + +@pytest.mark.anyio +async def test_server_capabilities(): + server = Server("test") + + # Initially no capabilities + caps = server.get_capabilities() + assert caps.prompts is None + assert caps.resources is None + + # Add a prompts handler + @server.list_prompts() + async def list_prompts(): + return [] + + caps = server.get_capabilities() + assert caps.prompts == {} + assert caps.resources is None + + # Add a resources handler + @server.list_resources() + async def list_resources(): + return [] + + caps = server.get_capabilities() + assert caps.prompts == {} + assert caps.resources == {}