|
11 | 11 | from typing import List, Optional
|
12 | 12 | from urllib.parse import parse_qs, urlparse
|
13 | 13 |
|
| 14 | +import anyio |
14 | 15 | import httpx
|
15 | 16 | import pytest
|
16 | 17 | from httpx_sse import aconnect_sse
|
@@ -993,130 +994,132 @@ async def test_fastmcp_with_auth(
|
993 | 994 | def test_tool(x: int) -> str:
|
994 | 995 | return f"Result: {x}"
|
995 | 996 |
|
996 |
| - transport = StreamingASGITransport(app=mcp.starlette_app()) # pyright: ignore |
997 |
| - test_client = httpx.AsyncClient( |
998 |
| - transport=transport, base_url="http://mcptest.com" |
999 |
| - ) |
1000 |
| - # test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com") |
| 997 | + async with anyio.create_task_group() as task_group: |
| 998 | + transport = StreamingASGITransport(app=mcp.starlette_app(), task_group=task_group) # pyright: ignore |
| 999 | + test_client = httpx.AsyncClient( |
| 1000 | + transport=transport, base_url="http://mcptest.com" |
| 1001 | + ) |
| 1002 | + # test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com") |
1001 | 1003 |
|
1002 |
| - # Test metadata endpoint |
1003 |
| - response = await test_client.get("/.well-known/oauth-authorization-server") |
1004 |
| - assert response.status_code == 200 |
| 1004 | + # Test metadata endpoint |
| 1005 | + response = await test_client.get("/.well-known/oauth-authorization-server") |
| 1006 | + assert response.status_code == 200 |
1005 | 1007 |
|
1006 |
| - # Test that auth is required for protected endpoints |
1007 |
| - response = await test_client.get("/sse") |
1008 |
| - # TODO: we should return 401/403 depending on whether authn or authz fails |
1009 |
| - assert response.status_code == 403 |
| 1008 | + # Test that auth is required for protected endpoints |
| 1009 | + response = await test_client.get("/sse") |
| 1010 | + # TODO: we should return 401/403 depending on whether authn or authz fails |
| 1011 | + assert response.status_code == 403 |
1010 | 1012 |
|
1011 |
| - response = await test_client.post("/messages/") |
1012 |
| - # TODO: we should return 401/403 depending on whether authn or authz fails |
1013 |
| - assert response.status_code == 403, response.content |
| 1013 | + response = await test_client.post("/messages/") |
| 1014 | + # TODO: we should return 401/403 depending on whether authn or authz fails |
| 1015 | + assert response.status_code == 403, response.content |
1014 | 1016 |
|
1015 |
| - response = await test_client.post( |
1016 |
| - "/messages/", |
1017 |
| - headers={"Authorization": "invalid"}, |
1018 |
| - ) |
1019 |
| - assert response.status_code == 403 |
1020 |
| - |
1021 |
| - response = await test_client.post( |
1022 |
| - "/messages/", |
1023 |
| - headers={"Authorization": "Bearer invalid"}, |
1024 |
| - ) |
1025 |
| - assert response.status_code == 403 |
| 1017 | + response = await test_client.post( |
| 1018 | + "/messages/", |
| 1019 | + headers={"Authorization": "invalid"}, |
| 1020 | + ) |
| 1021 | + assert response.status_code == 403 |
1026 | 1022 |
|
1027 |
| - # now, become authenticated and try to go through the flow again |
1028 |
| - client_metadata = { |
1029 |
| - "redirect_uris": ["https://client.example.com/callback"], |
1030 |
| - "client_name": "Test Client", |
1031 |
| - } |
| 1023 | + response = await test_client.post( |
| 1024 | + "/messages/", |
| 1025 | + headers={"Authorization": "Bearer invalid"}, |
| 1026 | + ) |
| 1027 | + assert response.status_code == 403 |
1032 | 1028 |
|
1033 |
| - response = await test_client.post( |
1034 |
| - "/register", |
1035 |
| - json=client_metadata, |
1036 |
| - ) |
1037 |
| - assert response.status_code == 201 |
1038 |
| - client_info = response.json() |
| 1029 | + # now, become authenticated and try to go through the flow again |
| 1030 | + client_metadata = { |
| 1031 | + "redirect_uris": ["https://client.example.com/callback"], |
| 1032 | + "client_name": "Test Client", |
| 1033 | + } |
1039 | 1034 |
|
1040 |
| - # Request authorization using POST with form-encoded data |
1041 |
| - response = await test_client.post( |
1042 |
| - "/authorize", |
1043 |
| - data={ |
1044 |
| - "response_type": "code", |
1045 |
| - "client_id": client_info["client_id"], |
1046 |
| - "redirect_uri": "https://client.example.com/callback", |
1047 |
| - "code_challenge": pkce_challenge["code_challenge"], |
1048 |
| - "code_challenge_method": "S256", |
1049 |
| - "state": "test_state", |
1050 |
| - }, |
1051 |
| - ) |
1052 |
| - assert response.status_code == 302 |
| 1035 | + response = await test_client.post( |
| 1036 | + "/register", |
| 1037 | + json=client_metadata, |
| 1038 | + ) |
| 1039 | + assert response.status_code == 201 |
| 1040 | + client_info = response.json() |
1053 | 1041 |
|
1054 |
| - # Extract the authorization code from the redirect URL |
1055 |
| - redirect_url = response.headers["location"] |
1056 |
| - parsed_url = urlparse(redirect_url) |
1057 |
| - query_params = parse_qs(parsed_url.query) |
| 1042 | + # Request authorization using POST with form-encoded data |
| 1043 | + response = await test_client.post( |
| 1044 | + "/authorize", |
| 1045 | + data={ |
| 1046 | + "response_type": "code", |
| 1047 | + "client_id": client_info["client_id"], |
| 1048 | + "redirect_uri": "https://client.example.com/callback", |
| 1049 | + "code_challenge": pkce_challenge["code_challenge"], |
| 1050 | + "code_challenge_method": "S256", |
| 1051 | + "state": "test_state", |
| 1052 | + }, |
| 1053 | + ) |
| 1054 | + assert response.status_code == 302 |
1058 | 1055 |
|
1059 |
| - assert "code" in query_params |
1060 |
| - auth_code = query_params["code"][0] |
| 1056 | + # Extract the authorization code from the redirect URL |
| 1057 | + redirect_url = response.headers["location"] |
| 1058 | + parsed_url = urlparse(redirect_url) |
| 1059 | + query_params = parse_qs(parsed_url.query) |
1061 | 1060 |
|
1062 |
| - # Exchange the authorization code for tokens |
1063 |
| - response = await test_client.post( |
1064 |
| - "/token", |
1065 |
| - data={ |
1066 |
| - "grant_type": "authorization_code", |
1067 |
| - "client_id": client_info["client_id"], |
1068 |
| - "client_secret": client_info["client_secret"], |
1069 |
| - "code": auth_code, |
1070 |
| - "code_verifier": pkce_challenge["code_verifier"], |
1071 |
| - "redirect_uri": "https://client.example.com/callback", |
1072 |
| - }, |
1073 |
| - ) |
1074 |
| - assert response.status_code == 200 |
| 1061 | + assert "code" in query_params |
| 1062 | + auth_code = query_params["code"][0] |
1075 | 1063 |
|
1076 |
| - token_response = response.json() |
1077 |
| - assert "access_token" in token_response |
1078 |
| - authorization = f"Bearer {token_response['access_token']}" |
1079 |
| - |
1080 |
| - # Test the authenticated endpoint with valid token |
1081 |
| - async with aconnect_sse( |
1082 |
| - test_client, "GET", "/sse", headers={"Authorization": authorization} |
1083 |
| - ) as event_source: |
1084 |
| - assert event_source.response.status_code == 200 |
1085 |
| - events = event_source.aiter_sse() |
1086 |
| - sse = await events.__anext__() |
1087 |
| - assert sse.event == "endpoint" |
1088 |
| - assert sse.data.startswith("/messages/?session_id=") |
1089 |
| - messages_uri = sse.data |
1090 |
| - |
1091 |
| - # verify that we can now post to the /messages endpoint, and get a response |
1092 |
| - # on the /sse endpoint |
| 1064 | + # Exchange the authorization code for tokens |
1093 | 1065 | response = await test_client.post(
|
1094 |
| - messages_uri, |
1095 |
| - headers={"Authorization": authorization}, |
1096 |
| - content=JSONRPCRequest( |
1097 |
| - jsonrpc="2.0", |
1098 |
| - id="123", |
1099 |
| - method="initialize", |
1100 |
| - params={ |
1101 |
| - "protocolVersion": "2024-11-05", |
1102 |
| - "capabilities": { |
1103 |
| - "roots": {"listChanged": True}, |
1104 |
| - "sampling": {}, |
1105 |
| - }, |
1106 |
| - "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, |
1107 |
| - }, |
1108 |
| - ).model_dump_json(), |
1109 |
| - ) |
1110 |
| - assert response.status_code == 202 |
1111 |
| - assert response.content == b"Accepted" |
1112 |
| - |
1113 |
| - sse = await events.__anext__() |
1114 |
| - assert sse.event == "message" |
1115 |
| - sse_data = json.loads(sse.data) |
1116 |
| - assert sse_data["id"] == "123" |
1117 |
| - assert set(sse_data["result"]["capabilities"].keys()) == set( |
1118 |
| - ("experimental", "prompts", "resources", "tools") |
| 1066 | + "/token", |
| 1067 | + data={ |
| 1068 | + "grant_type": "authorization_code", |
| 1069 | + "client_id": client_info["client_id"], |
| 1070 | + "client_secret": client_info["client_secret"], |
| 1071 | + "code": auth_code, |
| 1072 | + "code_verifier": pkce_challenge["code_verifier"], |
| 1073 | + "redirect_uri": "https://client.example.com/callback", |
| 1074 | + }, |
1119 | 1075 | )
|
| 1076 | + assert response.status_code == 200 |
| 1077 | + |
| 1078 | + token_response = response.json() |
| 1079 | + assert "access_token" in token_response |
| 1080 | + authorization = f"Bearer {token_response['access_token']}" |
| 1081 | + |
| 1082 | + # Test the authenticated endpoint with valid token |
| 1083 | + async with aconnect_sse( |
| 1084 | + test_client, "GET", "/sse", headers={"Authorization": authorization} |
| 1085 | + ) as event_source: |
| 1086 | + assert event_source.response.status_code == 200 |
| 1087 | + events = event_source.aiter_sse() |
| 1088 | + sse = await events.__anext__() |
| 1089 | + assert sse.event == "endpoint" |
| 1090 | + assert sse.data.startswith("/messages/?session_id=") |
| 1091 | + messages_uri = sse.data |
| 1092 | + |
| 1093 | + # verify that we can now post to the /messages endpoint, and get a response |
| 1094 | + # on the /sse endpoint |
| 1095 | + response = await test_client.post( |
| 1096 | + messages_uri, |
| 1097 | + headers={"Authorization": authorization}, |
| 1098 | + content=JSONRPCRequest( |
| 1099 | + jsonrpc="2.0", |
| 1100 | + id="123", |
| 1101 | + method="initialize", |
| 1102 | + params={ |
| 1103 | + "protocolVersion": "2024-11-05", |
| 1104 | + "capabilities": { |
| 1105 | + "roots": {"listChanged": True}, |
| 1106 | + "sampling": {}, |
| 1107 | + }, |
| 1108 | + "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, |
| 1109 | + }, |
| 1110 | + ).model_dump_json(), |
| 1111 | + ) |
| 1112 | + assert response.status_code == 202 |
| 1113 | + assert response.content == b"Accepted" |
| 1114 | + |
| 1115 | + sse = await events.__anext__() |
| 1116 | + assert sse.event == "message" |
| 1117 | + sse_data = json.loads(sse.data) |
| 1118 | + assert sse_data["id"] == "123" |
| 1119 | + assert set(sse_data["result"]["capabilities"].keys()) == set( |
| 1120 | + ("experimental", "prompts", "resources", "tools") |
| 1121 | + ) |
| 1122 | + task_group.cancel_scope.cancel() |
1120 | 1123 |
|
1121 | 1124 |
|
1122 | 1125 | class TestAuthorizeEndpointErrors:
|
|
0 commit comments