Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Github workflows for ruff and pyright #17

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/check-format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: ruff

on:
push:
branches: [main]
pull_request:

jobs:
format:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version-file: ".python-version"

- name: Install the project
run: uv sync --frozen --all-extras --dev

- name: Run ruff format check
run: uv run --frozen ruff check .
29 changes: 29 additions & 0 deletions .github/workflows/check-types.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: typecheck

on:
push:
branches: [main]
pull_request:

jobs:
typecheck:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version-file: ".python-version"

- name: Install the project
run: uv sync --frozen --all-extras --dev

- name: Run pyright
run: uv run --frozen pyright
19 changes: 14 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@ jobs:

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version-file: ".python-version"

- name: Install the project
run: uv sync --frozen --all-extras --dev

- run: pip install .
- run: pip install -U pytest trio
- run: pytest
- name: Run pytest
run: uv run --frozen pytest
4 changes: 3 additions & 1 deletion mcp_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
ReadResourceResult,
Resource,
ResourceUpdatedNotification,
Role as SamplingRole,
SamplingMessage,
ServerCapabilities,
ServerNotification,
Expand All @@ -49,6 +48,9 @@
Tool,
UnsubscribeRequest,
)
from .types import (
Role as SamplingRole,
)

__all__ = [
"CallToolRequest",
Expand Down
3 changes: 2 additions & 1 deletion mcp_python/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ async def initialize(self) -> InitializeResult:

if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION:
raise RuntimeError(
f"Unsupported protocol version from the server: {result.protocolVersion}"
"Unsupported protocol version from the server: "
f"{result.protocolVersion}"
)

await self.send_notification(
Expand Down
24 changes: 19 additions & 5 deletions mcp_python/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ def remove_request_params(url: str) -> str:


@asynccontextmanager
async def sse_client(url: str, headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5):
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
):
"""
Client transport for SSE.

`sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
Expand Down Expand Up @@ -67,7 +73,10 @@ async def sse_reader(
or url_parsed.scheme
!= endpoint_parsed.scheme
):
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
error_msg = (
"Endpoint origin does not match "
f"connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)

Expand Down Expand Up @@ -104,11 +113,16 @@ async def post_writer(endpoint_url: str):
logger.debug(f"Sending client message: {message}")
response = await client.post(
endpoint_url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
json=message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(
f"Client message sent successfully: {response.status_code}"
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
Expand Down
3 changes: 2 additions & 1 deletion mcp_python/client/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class StdioServerParameters(BaseModel):
@asynccontextmanager
async def stdio_client(server: StdioServerParameters):
"""
Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout.
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
Expand Down
57 changes: 33 additions & 24 deletions mcp_python/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ def __init__(self, name: str):

def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""

def pkg_version(package: str) -> str:
try:
from importlib.metadata import version

return version(package)
except Exception:
return "unknown"
Expand All @@ -69,16 +71,17 @@ def pkg_version(package: str) -> str:
)

def get_capabilities(self) -> ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object."""
def get_capability(req_type: type) -> dict[str, Any] | None:
return {} if req_type in self.request_handlers else None
"""Convert existing handlers to a ServerCapabilities object."""

return ServerCapabilities(
prompts=get_capability(ListPromptsRequest),
resources=get_capability(ListResourcesRequest),
tools=get_capability(ListPromptsRequest),
logging=get_capability(SetLevelRequest)
)
def get_capability(req_type: type) -> dict[str, Any] | None:
return {} if req_type in self.request_handlers else None

return ServerCapabilities(
prompts=get_capability(ListPromptsRequest),
resources=get_capability(ListResourcesRequest),
tools=get_capability(ListPromptsRequest),
logging=get_capability(SetLevelRequest),
)

@property
def request_context(self) -> RequestContext:
Expand All @@ -87,7 +90,7 @@ def request_context(self) -> RequestContext:

def list_prompts(self):
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
logger.debug(f"Registering handler for PromptListRequest")
logger.debug("Registering handler for PromptListRequest")

async def handler(_: Any):
prompts = await func()
Expand All @@ -103,17 +106,19 @@ def get_prompt(self):
GetPromptRequest,
GetPromptResult,
ImageContent,
Role as Role,
SamplingMessage,
TextContent,
)
from mcp_python.types import (
Role as Role,
)

def decorator(
func: Callable[
[str, dict[str, str] | None], Awaitable[types.PromptResponse]
],
):
logger.debug(f"Registering handler for GetPromptRequest")
logger.debug("Registering handler for GetPromptRequest")

async def handler(req: GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments)
Expand Down Expand Up @@ -149,7 +154,7 @@ async def handler(req: GetPromptRequest):

def list_resources(self):
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
logger.debug(f"Registering handler for ListResourcesRequest")
logger.debug("Registering handler for ListResourcesRequest")

async def handler(_: Any):
resources = await func()
Expand All @@ -169,7 +174,7 @@ def read_resource(self):
)

def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
logger.debug(f"Registering handler for ReadResourceRequest")
logger.debug("Registering handler for ReadResourceRequest")

async def handler(req: ReadResourceRequest):
result = await func(req.params.uri)
Expand Down Expand Up @@ -204,7 +209,7 @@ def set_logging_level(self):
from mcp_python.types import EmptyResult

def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
logger.debug(f"Registering handler for SetLevelRequest")
logger.debug("Registering handler for SetLevelRequest")

async def handler(req: SetLevelRequest):
await func(req.params.level)
Expand All @@ -219,7 +224,7 @@ def subscribe_resource(self):
from mcp_python.types import EmptyResult

def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug(f"Registering handler for SubscribeRequest")
logger.debug("Registering handler for SubscribeRequest")

async def handler(req: SubscribeRequest):
await func(req.params.uri)
Expand All @@ -234,7 +239,7 @@ def unsubscribe_resource(self):
from mcp_python.types import EmptyResult

def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug(f"Registering handler for UnsubscribeRequest")
logger.debug("Registering handler for UnsubscribeRequest")

async def handler(req: UnsubscribeRequest):
await func(req.params.uri)
Expand All @@ -249,7 +254,7 @@ def call_tool(self):
from mcp_python.types import CallToolResult

def decorator(func: Callable[..., Awaitable[Any]]):
logger.debug(f"Registering handler for CallToolRequest")
logger.debug("Registering handler for CallToolRequest")

async def handler(req: CallToolRequest):
result = await func(req.params.name, **(req.params.arguments or {}))
Expand All @@ -264,7 +269,7 @@ def progress_notification(self):
def decorator(
func: Callable[[str | int, float, float | None], Awaitable[None]],
):
logger.debug(f"Registering handler for ProgressNotification")
logger.debug("Registering handler for ProgressNotification")

async def handler(req: ProgressNotification):
await func(
Expand All @@ -286,7 +291,7 @@ def decorator(
Awaitable[Completion | None],
],
):
logger.debug(f"Registering handler for CompleteRequest")
logger.debug("Registering handler for CompleteRequest")

async def handler(req: CompleteRequest):
completion = await func(req.params.ref, req.params.argument)
Expand All @@ -307,10 +312,12 @@ async def run(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions
initialization_options: types.InitializationOptions,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(read_stream, write_stream, initialization_options) as session:
async with ServerSession(
read_stream, write_stream, initialization_options
) as session:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")

Expand Down Expand Up @@ -359,14 +366,16 @@ async def run(

handler = self.notification_handlers[type(notify)]
logger.debug(
f"Dispatching notification of type {type(notify).__name__}"
f"Dispatching notification of type "
f"{type(notify).__name__}"
)

try:
await handler(notify)
except Exception as err:
logger.error(
f"Uncaught exception in notification handler: {err}"
f"Uncaught exception in notification handler: "
f"{err}"
)

for warning in w:
Expand Down
18 changes: 15 additions & 3 deletions mcp_python/server/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import importlib.metadata
import logging
import sys
import importlib.metadata

import anyio

from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions
from mcp_python.server.stdio import stdio_server
from mcp_python.server.types import InitializationOptions
from mcp_python.types import ServerCapabilities

if not sys.warnoptions:
Expand All @@ -30,7 +31,18 @@ async def receive_loop(session: ServerSession):
async def main():
version = importlib.metadata.version("mcp_python")
async with stdio_server() as (read_stream, write_stream):
async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream:
async with (
ServerSession(
read_stream,
write_stream,
InitializationOptions(
server_name="mcp_python",
server_version=version,
capabilities=ServerCapabilities(),
),
) as session,
write_stream,
):
await receive_loop(session)


Expand Down
Loading