From 2ffc1de8461aa25473fa882cef73be1ff8b1d058 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Fri, 14 Feb 2025 09:22:45 +0000 Subject: [PATCH] Correctly close streams, better handle edge cases --- src/mcp_grafana/__init__.py | 15 +++--- src/mcp_grafana/transports/http.py | 78 ++++++++++++++++++------------ 2 files changed, 53 insertions(+), 40 deletions(-) diff --git a/src/mcp_grafana/__init__.py b/src/mcp_grafana/__init__.py index bce25e6..316cfda 100644 --- a/src/mcp_grafana/__init__.py +++ b/src/mcp_grafana/__init__.py @@ -4,7 +4,6 @@ import anyio import uvicorn from mcp.server import FastMCP -from starlette.requests import Request from .tools import add_tools @@ -18,14 +17,14 @@ class Transport(enum.StrEnum): class GrafanaMCP(FastMCP): async def run_http_async(self) -> None: from starlette.applications import Starlette - from starlette.routing import Route + from starlette.routing import Mount from .transports.http import handle_message - async def handle_http(request: Request): - async with handle_message( - request.scope, request.receive, request._send - ) as ( + async def handle_http(scope, receive, send): + if scope["type"] != "http": + raise ValueError("Expected HTTP request") + async with handle_message(scope, receive, send) as ( read_stream, write_stream, ): @@ -37,9 +36,7 @@ async def handle_http(request: Request): starlette_app = Starlette( debug=self.settings.debug, - routes=[ - Route("/mcp", endpoint=handle_http, methods=["POST"]), - ], + routes=[Mount("/", app=handle_http)], ) config = uvicorn.Config( diff --git a/src/mcp_grafana/transports/http.py b/src/mcp_grafana/transports/http.py index 32aa063..696563b 100644 --- a/src/mcp_grafana/transports/http.py +++ b/src/mcp_grafana/transports/http.py @@ -113,39 +113,55 @@ async def handle_message(scope: Scope, receive: Receive, send: Send): read_stream, read_stream_writer, write_stream, write_stream_reader = make_streams() async def handle_post_message(): - request = Request(scope, receive) try: - json = await request.json() - except JSONDecodeError as err: - logger.error(f"Failed to parse message: {err}") - response = Response("Could not parse message", status_code=400) - await response(scope, receive, send) - return - try: - client_message = types.JSONRPCMessage.model_validate(json) - logger.debug(f"Validated client message: {client_message}") - except ValidationError as err: - logger.error(f"Failed to parse message: {err}") - response = Response("Could not parse message", status_code=400) + request = Request(scope, receive) + if request.method != "POST": + response = Response("Method not allowed", status_code=405) + await response(scope, receive, send) + return + if scope["path"] != "/mcp": + response = Response("Not found", status_code=404) + await response(scope, receive, send) + return + try: + json = await request.json() + except JSONDecodeError as err: + logger.error(f"Failed to parse message: {err}") + response = Response("Could not parse message", status_code=400) + await response(scope, receive, send) + return + + try: + client_message = types.JSONRPCMessage.model_validate(json) + logger.debug(f"Validated client message: {client_message}") + except ValidationError as err: + logger.error(f"Failed to parse message: {err}") + response = Response("Could not parse message", status_code=400) + await response(scope, receive, send) + return + + # As part of the MCP spec we need to initialize first. + # In a stateful flow (e.g. stdio or sse transports) the client would + # send an initialize request to the server, and the server would send + # a response back to the client. In this case we're trying to be stateless, + # so we'll handle the initialization ourselves. + logger.debug("Initializing server") + await initialize(read_stream_writer, write_stream_reader) + + # Alright, now we can send the client message. + logger.debug("Sending client message") + await read_stream_writer.send(client_message) + + # Wait for the server's response, and forward it to the client. + server_message = await write_stream_reader.receive() + obj = server_message.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + response = JSONResponse(obj) await response(scope, receive, send) - return - - # As part of the MCP spec we need to initialize first. - # In a stateful flow (e.g. stdio or sse transports) the client would - # send an initialize request to the server, and the server would send - # a response back to the client. In this case we're trying to be stateless, - # so we'll handle the initialization ourselves. - logger.debug("Initializing server") - await initialize(read_stream_writer, write_stream_reader) - - # Alright, now we can send the client message. - logger.debug("Sending client message") - await read_stream_writer.send(client_message) - # Wait for the server's response, and forward it to the client. - server_message = await write_stream_reader.receive() - obj = server_message.model_dump(by_alias=True, mode="json", exclude_none=True) - response = JSONResponse(obj) - await response(scope, receive, send) + finally: + await read_stream_writer.aclose() + await write_stream_reader.aclose() async with anyio.create_task_group() as tg: tg.start_soon(handle_post_message)