diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 25e9436..04bbb1b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -525,7 +525,7 @@ async def _handle_message( async def _handle_request( self, - message: RequestResponder, + message: RequestResponder[types.ClientRequest, types.ServerResult], req: Any, session: ServerSession, lifespan_context: LifespanResultT, @@ -546,6 +546,7 @@ async def _handle_request( message.request_meta, session, lifespan_context, + message.request.root.headers or {}, ) ) response = await handler(req) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 0127753..2ce4be7 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -160,7 +160,8 @@ async def handle_post_message( logger.debug(f"Received JSON: {json}") try: - message = types.JSONRPCMessage.model_validate(json) + message_with_headers = {**json, "headers": dict(request.headers)} + message = types.JSONRPCMessage.model_validate(message_with_headers) logger.debug(f"Validated client message: {message}") except ValidationError as err: logger.error(f"Failed to parse message: {err}") diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index a45fdac..6140d9e 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Generic, TypeVar from mcp.shared.session import BaseSession @@ -14,3 +14,4 @@ class RequestContext(Generic[SessionT, LifespanContextT]): meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + headers: dict[str, str] = field(default_factory=dict) diff --git a/src/mcp/types.py b/src/mcp/types.py index 7d867bd..41f737f 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -74,6 +74,7 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]): method: MethodT params: RequestParamsT + headers: dict[str, str] | None = None model_config = ConfigDict(extra="allow")