Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Initialization options that are passed to ServerSession #16

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions mcp_python/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ReadResourceResult,
Resource,
ResourceReference,
ServerCapabilities,
ServerResult,
SetLevelRequest,
SubscribeRequest,
Expand All @@ -40,7 +41,6 @@

logger = logging.getLogger(__name__)


request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
"request_ctx"
)
Expand All @@ -53,6 +53,33 @@ def __init__(self, name: str):
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
logger.info(f"Initializing server '{name}'")

def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""
def pkg_version(package: str) -> str:
try:
from importlib.metadata import version
return version(package)
except Exception:
return "unknown"

return types.InitializationOptions(
server_name=self.name,
server_version=pkg_version("mcp_python"),
capabilities=self.get_capabilities(),
)

def get_capabilities(self) -> ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object."""
def get_capability(req_type: type) -> dict[str, Any] | None:
return {} if req_type in self.request_handlers else None

return ServerCapabilities(
prompts=get_capability(ListPromptsRequest),
resources=get_capability(ListResourcesRequest),
tools=get_capability(ListPromptsRequest),
logging=get_capability(SetLevelRequest)
)

@property
def request_context(self) -> RequestContext:
"""If called outside of a request context, this will raise a LookupError."""
Expand Down Expand Up @@ -280,9 +307,10 @@ async def run(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(read_stream, write_stream) as session:
async with ServerSession(read_stream, write_stream, initialization_options) as session:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")

Expand Down
7 changes: 5 additions & 2 deletions mcp_python/server/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import sys

import importlib.metadata
import anyio

from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions
from mcp_python.server.stdio import stdio_server
from mcp_python.types import ServerCapabilities

if not sys.warnoptions:
import warnings
Expand All @@ -26,8 +28,9 @@ async def receive_loop(session: ServerSession):


async def main():
version = importlib.metadata.version("mcp_python")
async with stdio_server() as (read_stream, write_stream):
async with ServerSession(read_stream, write_stream) as session, write_stream:
async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream:
await receive_loop(session)


Expand Down
14 changes: 6 additions & 8 deletions mcp_python/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseSession,
RequestResponder,
)
from mcp_python.server.types import InitializationOptions
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
from mcp_python.types import (
ClientNotification,
Expand Down Expand Up @@ -52,9 +53,11 @@ def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
init_options: InitializationOptions
) -> None:
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options

async def _received_request(
self, responder: RequestResponder[ClientRequest, ServerResult]
Expand All @@ -66,15 +69,10 @@ async def _received_request(
ServerResult(
InitializeResult(
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
capabilities=ServerCapabilities(
logging=None,
resources=None,
tools=None,
experimental=None,
prompts={},
),
capabilities=self._init_options.capabilities,
serverInfo=Implementation(
name="mcp_python", version="0.1.0"
name=self._init_options.server_name,
version=self._init_options.server_version
),
)
)
Expand Down
9 changes: 8 additions & 1 deletion mcp_python/server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from dataclasses import dataclass
from typing import Literal

from mcp_python.types import Role
from pydantic import BaseModel
from mcp_python.types import Role, ServerCapabilities


@dataclass
Expand All @@ -25,3 +26,9 @@ class Message:
class PromptResponse:
messages: list[Message]
desc: str | None = None


class InitializationOptions(BaseModel):
server_name: str
server_version: str
capabilities: ServerCapabilities
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ target-version = "py38"

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]

[tool.uv]
dev-dependencies = [
"trio>=0.26.2",
]
33 changes: 32 additions & 1 deletion tests/server/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import pytest

from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions
from mcp_python.types import (
ClientNotification,
InitializedNotification,
JSONRPCMessage,
ServerCapabilities,
)


Expand All @@ -30,7 +33,7 @@ async def run_server():
nonlocal received_initialized

async with ServerSession(
client_to_server_receive, server_to_client_send
client_to_server_receive, server_to_client_send, InitializationOptions(server_name='mcp_python', server_version='0.1.0', capabilities=ServerCapabilities())
) as server_session:
async for message in server_session.incoming_messages:
if isinstance(message, Exception):
Expand All @@ -57,3 +60,31 @@ async def run_server():
pass

assert received_initialized


@pytest.mark.anyio
async def test_server_capabilities():
server = Server("test")

# Initially no capabilities
caps = server.get_capabilities()
assert caps.prompts is None
assert caps.resources is None

# Add a prompts handler
@server.list_prompts()
async def list_prompts():
return []

caps = server.get_capabilities()
assert caps.prompts == {}
assert caps.resources is None

# Add a resources handler
@server.list_resources()
async def list_resources():
return []

caps = server.get_capabilities()
assert caps.prompts == {}
assert caps.resources == {}