Skip to content

Commit 4add482

Browse files
committed
Avoid asyncio dependency in tests
1 parent a855877 commit 4add482

File tree

2 files changed

+120
-114
lines changed

2 files changed

+120
-114
lines changed

tests/server/fastmcp/auth/streaming_asgi_transport.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
the connection is closed.
77
"""
88

9-
import asyncio
109
import typing
1110
from typing import Any, Dict, Tuple
1211

1312
import anyio
13+
import anyio.abc
1414
import anyio.streams.memory
15+
1516
from httpx._models import Request, Response
1617
from httpx._transports.base import AsyncBaseTransport
1718
from httpx._types import AsyncByteStream
@@ -41,6 +42,7 @@ class StreamingASGITransport(AsyncBaseTransport):
4142
def __init__(
4243
self,
4344
app: typing.Callable,
45+
task_group: anyio.abc.TaskGroup,
4446
raise_app_exceptions: bool = True,
4547
root_path: str = "",
4648
client: Tuple[str, int] = ("127.0.0.1", 123),
@@ -49,6 +51,7 @@ def __init__(
4951
self.raise_app_exceptions = raise_app_exceptions
5052
self.root_path = root_path
5153
self.client = client
54+
self.task_group = task_group
5255

5356
async def handle_async_request(
5457
self,
@@ -161,8 +164,8 @@ async def process_messages() -> None:
161164
response_complete.set()
162165

163166
# Create tasks for running the app and processing messages
164-
asyncio.create_task(run_app())
165-
asyncio.create_task(process_messages())
167+
self.task_group.start_soon(run_app)
168+
self.task_group.start_soon(process_messages)
166169

167170
# Wait for the initial response or timeout
168171
await initial_response_ready.wait()

tests/server/fastmcp/auth/test_auth_integration.py

+114-111
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import List, Optional
1212
from urllib.parse import parse_qs, urlparse
1313

14+
import anyio
1415
import httpx
1516
import pytest
1617
from httpx_sse import aconnect_sse
@@ -993,130 +994,132 @@ async def test_fastmcp_with_auth(
993994
def test_tool(x: int) -> str:
994995
return f"Result: {x}"
995996

996-
transport = StreamingASGITransport(app=mcp.starlette_app()) # pyright: ignore
997-
test_client = httpx.AsyncClient(
998-
transport=transport, base_url="http://mcptest.com"
999-
)
1000-
# test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com")
997+
async with anyio.create_task_group() as task_group:
998+
transport = StreamingASGITransport(app=mcp.starlette_app(), task_group=task_group) # pyright: ignore
999+
test_client = httpx.AsyncClient(
1000+
transport=transport, base_url="http://mcptest.com"
1001+
)
1002+
# test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com")
10011003

1002-
# Test metadata endpoint
1003-
response = await test_client.get("/.well-known/oauth-authorization-server")
1004-
assert response.status_code == 200
1004+
# Test metadata endpoint
1005+
response = await test_client.get("/.well-known/oauth-authorization-server")
1006+
assert response.status_code == 200
10051007

1006-
# Test that auth is required for protected endpoints
1007-
response = await test_client.get("/sse")
1008-
# TODO: we should return 401/403 depending on whether authn or authz fails
1009-
assert response.status_code == 403
1008+
# Test that auth is required for protected endpoints
1009+
response = await test_client.get("/sse")
1010+
# TODO: we should return 401/403 depending on whether authn or authz fails
1011+
assert response.status_code == 403
10101012

1011-
response = await test_client.post("/messages/")
1012-
# TODO: we should return 401/403 depending on whether authn or authz fails
1013-
assert response.status_code == 403, response.content
1013+
response = await test_client.post("/messages/")
1014+
# TODO: we should return 401/403 depending on whether authn or authz fails
1015+
assert response.status_code == 403, response.content
10141016

1015-
response = await test_client.post(
1016-
"/messages/",
1017-
headers={"Authorization": "invalid"},
1018-
)
1019-
assert response.status_code == 403
1020-
1021-
response = await test_client.post(
1022-
"/messages/",
1023-
headers={"Authorization": "Bearer invalid"},
1024-
)
1025-
assert response.status_code == 403
1017+
response = await test_client.post(
1018+
"/messages/",
1019+
headers={"Authorization": "invalid"},
1020+
)
1021+
assert response.status_code == 403
10261022

1027-
# now, become authenticated and try to go through the flow again
1028-
client_metadata = {
1029-
"redirect_uris": ["https://client.example.com/callback"],
1030-
"client_name": "Test Client",
1031-
}
1023+
response = await test_client.post(
1024+
"/messages/",
1025+
headers={"Authorization": "Bearer invalid"},
1026+
)
1027+
assert response.status_code == 403
10321028

1033-
response = await test_client.post(
1034-
"/register",
1035-
json=client_metadata,
1036-
)
1037-
assert response.status_code == 201
1038-
client_info = response.json()
1029+
# now, become authenticated and try to go through the flow again
1030+
client_metadata = {
1031+
"redirect_uris": ["https://client.example.com/callback"],
1032+
"client_name": "Test Client",
1033+
}
10391034

1040-
# Request authorization using POST with form-encoded data
1041-
response = await test_client.post(
1042-
"/authorize",
1043-
data={
1044-
"response_type": "code",
1045-
"client_id": client_info["client_id"],
1046-
"redirect_uri": "https://client.example.com/callback",
1047-
"code_challenge": pkce_challenge["code_challenge"],
1048-
"code_challenge_method": "S256",
1049-
"state": "test_state",
1050-
},
1051-
)
1052-
assert response.status_code == 302
1035+
response = await test_client.post(
1036+
"/register",
1037+
json=client_metadata,
1038+
)
1039+
assert response.status_code == 201
1040+
client_info = response.json()
10531041

1054-
# Extract the authorization code from the redirect URL
1055-
redirect_url = response.headers["location"]
1056-
parsed_url = urlparse(redirect_url)
1057-
query_params = parse_qs(parsed_url.query)
1042+
# Request authorization using POST with form-encoded data
1043+
response = await test_client.post(
1044+
"/authorize",
1045+
data={
1046+
"response_type": "code",
1047+
"client_id": client_info["client_id"],
1048+
"redirect_uri": "https://client.example.com/callback",
1049+
"code_challenge": pkce_challenge["code_challenge"],
1050+
"code_challenge_method": "S256",
1051+
"state": "test_state",
1052+
},
1053+
)
1054+
assert response.status_code == 302
10581055

1059-
assert "code" in query_params
1060-
auth_code = query_params["code"][0]
1056+
# Extract the authorization code from the redirect URL
1057+
redirect_url = response.headers["location"]
1058+
parsed_url = urlparse(redirect_url)
1059+
query_params = parse_qs(parsed_url.query)
10611060

1062-
# Exchange the authorization code for tokens
1063-
response = await test_client.post(
1064-
"/token",
1065-
data={
1066-
"grant_type": "authorization_code",
1067-
"client_id": client_info["client_id"],
1068-
"client_secret": client_info["client_secret"],
1069-
"code": auth_code,
1070-
"code_verifier": pkce_challenge["code_verifier"],
1071-
"redirect_uri": "https://client.example.com/callback",
1072-
},
1073-
)
1074-
assert response.status_code == 200
1061+
assert "code" in query_params
1062+
auth_code = query_params["code"][0]
10751063

1076-
token_response = response.json()
1077-
assert "access_token" in token_response
1078-
authorization = f"Bearer {token_response['access_token']}"
1079-
1080-
# Test the authenticated endpoint with valid token
1081-
async with aconnect_sse(
1082-
test_client, "GET", "/sse", headers={"Authorization": authorization}
1083-
) as event_source:
1084-
assert event_source.response.status_code == 200
1085-
events = event_source.aiter_sse()
1086-
sse = await events.__anext__()
1087-
assert sse.event == "endpoint"
1088-
assert sse.data.startswith("/messages/?session_id=")
1089-
messages_uri = sse.data
1090-
1091-
# verify that we can now post to the /messages endpoint, and get a response
1092-
# on the /sse endpoint
1064+
# Exchange the authorization code for tokens
10931065
response = await test_client.post(
1094-
messages_uri,
1095-
headers={"Authorization": authorization},
1096-
content=JSONRPCRequest(
1097-
jsonrpc="2.0",
1098-
id="123",
1099-
method="initialize",
1100-
params={
1101-
"protocolVersion": "2024-11-05",
1102-
"capabilities": {
1103-
"roots": {"listChanged": True},
1104-
"sampling": {},
1105-
},
1106-
"clientInfo": {"name": "ExampleClient", "version": "1.0.0"},
1107-
},
1108-
).model_dump_json(),
1109-
)
1110-
assert response.status_code == 202
1111-
assert response.content == b"Accepted"
1112-
1113-
sse = await events.__anext__()
1114-
assert sse.event == "message"
1115-
sse_data = json.loads(sse.data)
1116-
assert sse_data["id"] == "123"
1117-
assert set(sse_data["result"]["capabilities"].keys()) == set(
1118-
("experimental", "prompts", "resources", "tools")
1066+
"/token",
1067+
data={
1068+
"grant_type": "authorization_code",
1069+
"client_id": client_info["client_id"],
1070+
"client_secret": client_info["client_secret"],
1071+
"code": auth_code,
1072+
"code_verifier": pkce_challenge["code_verifier"],
1073+
"redirect_uri": "https://client.example.com/callback",
1074+
},
11191075
)
1076+
assert response.status_code == 200
1077+
1078+
token_response = response.json()
1079+
assert "access_token" in token_response
1080+
authorization = f"Bearer {token_response['access_token']}"
1081+
1082+
# Test the authenticated endpoint with valid token
1083+
async with aconnect_sse(
1084+
test_client, "GET", "/sse", headers={"Authorization": authorization}
1085+
) as event_source:
1086+
assert event_source.response.status_code == 200
1087+
events = event_source.aiter_sse()
1088+
sse = await events.__anext__()
1089+
assert sse.event == "endpoint"
1090+
assert sse.data.startswith("/messages/?session_id=")
1091+
messages_uri = sse.data
1092+
1093+
# verify that we can now post to the /messages endpoint, and get a response
1094+
# on the /sse endpoint
1095+
response = await test_client.post(
1096+
messages_uri,
1097+
headers={"Authorization": authorization},
1098+
content=JSONRPCRequest(
1099+
jsonrpc="2.0",
1100+
id="123",
1101+
method="initialize",
1102+
params={
1103+
"protocolVersion": "2024-11-05",
1104+
"capabilities": {
1105+
"roots": {"listChanged": True},
1106+
"sampling": {},
1107+
},
1108+
"clientInfo": {"name": "ExampleClient", "version": "1.0.0"},
1109+
},
1110+
).model_dump_json(),
1111+
)
1112+
assert response.status_code == 202
1113+
assert response.content == b"Accepted"
1114+
1115+
sse = await events.__anext__()
1116+
assert sse.event == "message"
1117+
sse_data = json.loads(sse.data)
1118+
assert sse_data["id"] == "123"
1119+
assert set(sse_data["result"]["capabilities"].keys()) == set(
1120+
("experimental", "prompts", "resources", "tools")
1121+
)
1122+
task_group.cancel_scope.cancel()
11201123

11211124

11221125
class TestAuthorizeEndpointErrors:

0 commit comments

Comments
 (0)