Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to protocol version 2024-10-07 #19

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mcp_python/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from pydantic import AnyUrl

from mcp_python.shared.session import BaseSession
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
CallToolResult,
ClientCapabilities,
ClientNotification,
Expand Down Expand Up @@ -49,7 +50,7 @@ async def initialize(self) -> InitializeResult:
InitializeRequest(
method="initialize",
params=InitializeRequestParams(
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(
sampling=None, experimental=None
),
Expand All @@ -60,7 +61,7 @@ async def initialize(self) -> InitializeResult:
InitializeResult,
)

if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION:
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(
"Unsupported protocol version from the server: "
f"{result.protocolVersion}"
Expand Down
2 changes: 1 addition & 1 deletion mcp_python/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def decorator(func: Callable[[], Awaitable[list[Resource]]]):
async def handler(_: Any):
resources = await func()
return ServerResult(
ListResourcesResult(resources=resources, resourceTemplates=None)
ListResourcesResult(resources=resources)
)

self.request_handlers[ListResourcesRequest] = handler
Expand Down
4 changes: 2 additions & 2 deletions mcp_python/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
BaseSession,
RequestResponder,
)
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification,
ClientRequest,
CreateMessageResult,
Expand Down Expand Up @@ -67,7 +67,7 @@ async def _received_request(
await responder.respond(
ServerResult(
InitializeResult(
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=Implementation(
name=self._init_options.server_name,
Expand Down
4 changes: 3 additions & 1 deletion mcp_python/shared/version.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
SUPPORTED_PROTOCOL_VERSION = 1
from mcp_python.types import LATEST_PROTOCOL_VERSION

SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION]
70 changes: 59 additions & 11 deletions mcp_python/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
not separate types in the schema.
"""

LATEST_PROTOCOL_VERSION = "2024-10-07"

ProgressToken = str | int
Cursor = str


class RequestParams(BaseModel):
Expand Down Expand Up @@ -64,6 +66,14 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")


class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""


class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
"""Base class for JSON-RPC notifications."""

Expand All @@ -83,6 +93,14 @@ class Result(BaseModel):
"""


class PaginatedResult(Result):
nextCursor: Cursor | None = None
"""
An opaque token representing the pagination position after the last returned result.
If present, there may be more results available.
"""


RequestId = str | int


Expand Down Expand Up @@ -115,6 +133,7 @@ class JSONRPCResponse(BaseModel):
INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601
INVALID_PARAMS = -32602
INTERNAL_ERROR = -32603


class ErrorData(BaseModel):
Expand Down Expand Up @@ -191,7 +210,7 @@ class ServerCapabilities(BaseModel):
class InitializeRequestParams(RequestParams):
"""Parameters for the initialize request."""

protocolVersion: Literal[1]
protocolVersion: str | int
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lenient version parsing for now.

"""The latest version of the Model Context Protocol that the client supports."""
capabilities: ClientCapabilities
clientInfo: Implementation
Expand All @@ -211,7 +230,7 @@ class InitializeRequest(Request):
class InitializeResult(Result):
"""After receiving an initialize request from the client, the server sends this."""

protocolVersion: Literal[1]
protocolVersion: str | int
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lenient version parsing for now.

"""The version of the Model Context Protocol that the server wants to use."""
capabilities: ServerCapabilities
serverInfo: Implementation
Expand Down Expand Up @@ -265,7 +284,7 @@ class ProgressNotification(Notification):
params: ProgressNotificationParams


class ListResourcesRequest(Request):
class ListResourcesRequest(PaginatedRequest):
"""Sent from the client to request a list of resources the server has."""

method: Literal["resources/list"]
Expand All @@ -277,6 +296,10 @@ class Resource(BaseModel):

uri: AnyUrl
"""The URI of this resource."""
name: str
"""A human-readable name for this resource."""
description: str | None = None
"""A description of what this resource represents."""
mimeType: str | None = None
"""The MIME type of this resource, if known."""
model_config = ConfigDict(extra="allow")
Expand All @@ -290,7 +313,7 @@ class ResourceTemplate(BaseModel):
A URI template (according to RFC 6570) that can be used to construct resource
URIs.
"""
name: str | None = None
name: str
"""A human-readable name for the type of resource this template refers to."""
description: str | None = None
"""A human-readable description of what this template is for."""
Expand All @@ -302,11 +325,23 @@ class ResourceTemplate(BaseModel):
model_config = ConfigDict(extra="allow")


class ListResourcesResult(Result):
class ListResourcesResult(PaginatedResult):
"""The server's response to a resources/list request from the client."""

resourceTemplates: list[ResourceTemplate] | None = None
resources: list[Resource] | None = None
resources: list[Resource]


class ListResourceTemplatesRequest(PaginatedRequest):
"""Sent from the client to request a list of resource templates the server has."""

method: Literal["resources/templates/list"]
params: RequestParams | None = None


class ListResourceTemplatesResult(PaginatedResult):
"""The server's response to a resources/templates/list request from the client."""

resourceTemplates: list[ResourceTemplate]


class ReadResourceRequestParams(RequestParams):
Expand Down Expand Up @@ -430,7 +465,7 @@ class ResourceUpdatedNotification(Notification):
params: ResourceUpdatedNotificationParams


class ListPromptsRequest(Request):
class ListPromptsRequest(PaginatedRequest):
"""Sent from the client to request a list of prompts and prompt templates."""

method: Literal["prompts/list"]
Expand Down Expand Up @@ -461,7 +496,7 @@ class Prompt(BaseModel):
model_config = ConfigDict(extra="allow")


class ListPromptsResult(Result):
class ListPromptsResult(PaginatedResult):
"""The server's response to a prompts/list request from the client."""

prompts: list[Prompt]
Expand Down Expand Up @@ -526,7 +561,17 @@ class GetPromptResult(Result):
messages: list[SamplingMessage]


class ListToolsRequest(Request):
class PromptListChangedNotification(Notification):
"""
An optional notification from the server to the client, informing it that the list
of prompts it offers has changed.
"""

method: Literal["notifications/prompts/list_changed"]
params: NotificationParams | None = None


class ListToolsRequest(PaginatedRequest):
"""Sent from the client to request a list of tools the server has."""

method: Literal["tools/list"]
Expand All @@ -545,7 +590,7 @@ class Tool(BaseModel):
model_config = ConfigDict(extra="allow")


class ListToolsResult(Result):
class ListToolsResult(PaginatedResult):
"""The server's response to a tools/list request from the client."""

tools: list[Tool]
Expand Down Expand Up @@ -742,6 +787,7 @@ class ClientRequest(
| GetPromptRequest
| ListPromptsRequest
| ListResourcesRequest
| ListResourceTemplatesRequest
| ReadResourceRequest
| SubscribeRequest
| UnsubscribeRequest
Expand Down Expand Up @@ -771,6 +817,7 @@ class ServerNotification(
| ResourceUpdatedNotification
| ResourceListChangedNotification
| ToolListChangedNotification
| PromptListChangedNotification
]
):
pass
Expand All @@ -784,6 +831,7 @@ class ServerResult(
| GetPromptResult
| ListPromptsResult
| ListResourcesResult
| ListResourceTemplatesResult
| ReadResourceResult
| CallToolResult
| ListToolsResult
Expand Down
5 changes: 3 additions & 2 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mcp_python.client.session import ClientSession
from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification,
ClientRequest,
Implementation,
Expand Down Expand Up @@ -41,7 +42,7 @@ async def mock_server():

result = ServerResult(
InitializeResult(
protocolVersion=1,
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(
logging=None,
resources=None,
Expand Down Expand Up @@ -88,7 +89,7 @@ async def listen_session():

# Assert the result
assert isinstance(result, InitializeResult)
assert result.protocolVersion == 1
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
assert isinstance(result.capabilities, ServerCapabilities)
assert result.serverInfo == Implementation(name="mock-server", version="0.1.0")

Expand Down
11 changes: 8 additions & 3 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from mcp_python.types import ClientRequest, JSONRPCMessage, JSONRPCRequest
from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
ClientRequest,
JSONRPCMessage,
JSONRPCRequest,
)


def test_jsonrpc_request():
Expand All @@ -7,7 +12,7 @@ def test_jsonrpc_request():
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": 1,
"protocolVersion": LATEST_PROTOCOL_VERSION,
"capabilities": {"batch": None, "sampling": None},
"clientInfo": {"name": "mcp_python", "version": "0.1.0"},
},
Expand All @@ -21,4 +26,4 @@ def test_jsonrpc_request():
assert request.root.id == 1
assert request.root.method == "initialize"
assert request.root.params is not None
assert request.root.params["protocolVersion"] == 1
assert request.root.params["protocolVersion"] == LATEST_PROTOCOL_VERSION