Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HTTPBearer token is set, Auth button not shown on /api/docs #67

Merged
merged 27 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4ae5b2d
WIP: HTTPBearer token is set
Igoranze Oct 10, 2024
a611595
Fix: Authenticate refactor for GraphQL and other
Igoranze Oct 14, 2024
ee5085b
Fix: raise 403 on empty token
Igoranze Oct 14, 2024
23b0e16
Fix: Add test with token extractor
Igoranze Oct 14, 2024
bb5588b
Fix: changed bool to False auto_error
Igoranze Oct 14, 2024
8643e6c
Fix artefact package
pboers1988 Oct 14, 2024
8382c10
Linting
pboers1988 Oct 14, 2024
37433f3
Fix: Linting issues
Igoranze Oct 15, 2024
08c14b5
Fix: space removal black check .
Igoranze Oct 15, 2024
1239b56
Refactor: make it more abstract
Igoranze Oct 17, 2024
27e6a7f
Fix: linting
Igoranze Oct 17, 2024
effa3d8
Test: fix tests
Igoranze Oct 17, 2024
588231b
Fix: linting imports
Igoranze Oct 17, 2024
b3f4957
Fix: linting return type
Igoranze Oct 17, 2024
2e45752
Merge branch 'main' into auth-b-shown
Igoranze Oct 22, 2024
6079861
Add init to HTTPBearerExtractor
Igoranze Oct 22, 2024
f9e4f4c
Fix: Black --check
Igoranze Oct 22, 2024
a64b4ea
Fix typing decorator and removed unused code
Igoranze Oct 24, 2024
a7103f7
fix: black
Igoranze Oct 24, 2024
feac762
Update oauth2_lib/fastapi.py
Igoranze Oct 24, 2024
41b1cd0
Update oauth2_lib/fastapi.py
Igoranze Oct 24, 2024
576f944
Fix: ruff
Igoranze Oct 24, 2024
ce63c77
Merge branch 'auth-b-shown' of github.com:workfloworchestrator/oauth2…
Igoranze Oct 24, 2024
e163eac
Rollback Typing checks
Igoranze Oct 24, 2024
eeffe05
Add: typing ignore on decorator tests
Igoranze Oct 24, 2024
96ad751
bump version
Igoranze Oct 24, 2024
944473b
bump version
Igoranze Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 21 additions & 35 deletions oauth2_lib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from fastapi import HTTPException
from fastapi.requests import Request
from fastapi.security.http import HTTPBearer
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from httpx import AsyncClient, NetworkError
from pydantic import BaseModel
from starlette.requests import ClientDisconnect, HTTPConnection
Expand Down Expand Up @@ -126,7 +126,7 @@ class Authentication(ABC):
"""

@abstractmethod
async def authenticate(self, request: HTTPConnection, token: str | None = None) -> dict | None:
async def authenticate(self, request: Request, token: str | None = None) -> dict | None:
"""Authenticate the user."""
pass

Expand All @@ -142,17 +142,24 @@ async def extract(self, request: Request) -> str | None:
pass


class HttpBearerExtractor(IdTokenExtractor):
class HttpBearerExtractor(HTTPBearer, IdTokenExtractor):
"""Extracts bearer tokens using FastAPI's HTTPBearer.

Specifically designed for HTTP Authorization header token extraction.
"""

def __init__(self, auto_error: bool = False):
super().__init__(auto_error=auto_error)

async def __call__(self, request: Request) -> Optional[HTTPAuthorizationCredentials]:
"""Extract the Authorization header from the request."""
return await super().__call__(request)

async def extract(self, request: Request) -> str | None:
http_bearer = HTTPBearer(auto_error=False)
credential = await http_bearer(request)
"""Extract the token from the Authorization header in the request."""
http_auth_credentials = await super().__call__(request)

return credential.credentials if credential else None
return http_auth_credentials.credentials if http_auth_credentials else None


class OIDCAuth(Authentication):
Expand All @@ -168,11 +175,7 @@ def __init__(
resource_server_id: str,
resource_server_secret: str,
oidc_user_model_cls: type[OIDCUserModel],
id_token_extractor: IdTokenExtractor | None = None,
):
if not id_token_extractor:
self.id_token_extractor = HttpBearerExtractor()
Comment on lines -171 to -174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe keep the kwarg in __init__ and when a value is passed log a (deprecation) warning instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great spot by you (@Igoranze): if a custom id_token_extractor is passed then it's not actually assigned to self.id_token_extractor so authenticate() raises an AttributeError when trying to call it.

The only way this can have worked for others is if they override authenticate() as well. So, we don't need to worry about backwards compatibility


self.openid_url = openid_url
self.openid_config_url = openid_config_url
self.resource_server_id = resource_server_id
Expand All @@ -181,7 +184,7 @@ def __init__(

self.openid_config: OIDCConfig | None = None

async def authenticate(self, request: HTTPConnection, token: str | None = None) -> OIDCUserModel | None:
async def authenticate(self, request: Request, token: str | None = None) -> OIDCUserModel | None:
"""Return the OIDC user from OIDC introspect endpoint.

This is used as a security module in Fastapi projects
Expand All @@ -197,33 +200,16 @@ async def authenticate(self, request: HTTPConnection, token: str | None = None)
if not oauth2lib_settings.OAUTH2_ACTIVE:
return None

async with AsyncClient(http1=True, verify=HTTPX_SSL_CONTEXT) as async_client:
await self.check_openid_config(async_client)

# Handle WebSocket requests separately only to check for token presence.
if isinstance(request, WebSocket):
if token is None:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Not authenticated",
)
token_or_extracted_id_token = token
else:
request = cast(Request, request)

if await self.is_bypassable_request(request):
return None
if await self.is_bypassable_request(request):
return None

if token is None:
extracted_id_token = await self.id_token_extractor.extract(request)
if not extracted_id_token:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")
if not token:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you did this because the new WebSocket implementation passes the token in a header. Nice!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct :) And it no longer servers only for the websocket since that should not matter


token_or_extracted_id_token = extracted_id_token
else:
token_or_extracted_id_token = token
async with AsyncClient(http1=True, verify=HTTPX_SSL_CONTEXT) as async_client:
await self.check_openid_config(async_client)

user_info: OIDCUserModel = await self.userinfo(async_client, token_or_extracted_id_token)
user_info: OIDCUserModel = await self.userinfo(async_client, token)
logger.debug("OIDCUserModel object.", user_info=user_info)
return user_info

Expand Down
100 changes: 67 additions & 33 deletions oauth2_lib/strawberry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# limitations under the License.
from collections.abc import Callable
from enum import StrEnum, auto
from typing import Any
from functools import wraps
from typing import Any, TypeVar, cast

import asyncstdlib
import strawberry
Expand All @@ -23,15 +24,17 @@
from strawberry import BasePermission
from strawberry.fastapi import BaseContext
from strawberry.types import Info
from strawberry.types.fields.resolver import StrawberryResolver
from strawberry.types.info import RootValueType

from oauth2_lib.fastapi import AuthManager, OIDCUserModel
from oauth2_lib.fastapi import AuthManager, HttpBearerExtractor, OIDCUserModel
from oauth2_lib.settings import oauth2lib_settings

logger = structlog.get_logger(__name__)


F = TypeVar("F", bound=Callable[..., Any])


class OauthContext(BaseContext):
def __init__(
self,
Expand All @@ -56,7 +59,11 @@ async def get_current_user(self) -> OIDCUserModel | None:
return None

try:
return await self.auth_manager.authentication.authenticate(self.request)
http_bearer_extractor = HttpBearerExtractor(auto_error=False)
http_authorization_credentials = await http_bearer_extractor(self.request)

token = http_authorization_credentials.credentials if http_authorization_credentials else None
return await self.auth_manager.authentication.authenticate(self.request, token)
except HTTPException as exc:
logger.debug("User is not authenticated", status_code=exc.status_code, detail=exc.detail)
return None
Expand Down Expand Up @@ -197,48 +204,75 @@ async def has_permission(self, source: Any, info: OauthInfo, **kwargs) -> bool:

def authenticated_field(
description: str,
resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None,
deprecation_reason: str | None = None,
permission_classes: list[type[BasePermission]] | None = None,
) -> Any:
permissions = permission_classes if permission_classes else []
return strawberry.field(
description=description,
resolver=resolver, # type: ignore
deprecation_reason=deprecation_reason,
permission_classes=[IsAuthenticatedForQuery, IsAuthorizedForQuery] + permissions,
)
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

permissions = permission_classes if permission_classes else []
return cast(
F,
strawberry.field(
description=description,
resolver=wrapper,
deprecation_reason=deprecation_reason,
permission_classes=[IsAuthenticatedForQuery, IsAuthorizedForQuery] + permissions,
),
)

return decorator


def authenticated_mutation_field(
description: str,
resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None,
deprecation_reason: str | None = None,
permission_classes: list[type[BasePermission]] | None = None,
) -> Any:
permissions = permission_classes if permission_classes else []
return strawberry.field(
description=description,
resolver=resolver, # type: ignore
deprecation_reason=deprecation_reason,
permission_classes=[IsAuthenticatedForMutation, IsAuthorizedForMutation] + permissions,
)
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

permissions = permission_classes if permission_classes else []
return cast(
F,
strawberry.field(
description=description,
resolver=wrapper,
deprecation_reason=deprecation_reason,
permission_classes=[IsAuthenticatedForMutation, IsAuthorizedForMutation] + permissions,
),
)

return decorator


def authenticated_federated_field( # type: ignore
description: str,
resolver: StrawberryResolver | Callable | staticmethod | classmethod | None = None,
deprecation_reason: str | None = None,
requires: list[str] | None = None,
permission_classes: list[type[BasePermission]] | None = None,
**kwargs,
) -> Any:
permissions = permission_classes if permission_classes else []
return strawberry.federation.field(
description=description,
resolver=resolver, # type: ignore
deprecation_reason=deprecation_reason,
permission_classes=[IsAuthenticatedForQuery, IsAuthorizedForQuery] + permissions,
requires=requires,
**kwargs,
)
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

permissions = permission_classes if permission_classes else []
return cast(
F,
strawberry.federation.field(
description=description,
resolver=wrapper,
deprecation_reason=deprecation_reason,
permission_classes=[IsAuthenticatedForQuery, IsAuthorizedForQuery] + permissions,
requires=requires,
**kwargs,
),
)

return decorator
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def make_mock_async_client():
Pass a MockResponse for single or list for multiple sequential HTTP responses.
"""

def _make_mock_async_client(mock_response: MockResponse | list[MockResponse] | None = None):
def _make_mock_async_client(mock_response: MockResponse | list[MockResponse] | None = None) -> AsyncClientMock:
mock_async_client = AsyncMock(spec=AsyncClient)

mock_responses = ([mock_response] if isinstance(mock_response, MockResponse) else mock_response) or []
Expand Down
9 changes: 5 additions & 4 deletions tests/strawberry/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import strawberry
from fastapi import Depends, FastAPI
from starlette.requests import Request
from httpx._client import AsyncClient
from starlette.testclient import TestClient
from strawberry.fastapi import GraphQLRouter

Expand All @@ -17,7 +17,7 @@

async def get_oidc_authentication():
class OIDCAuthMock(OIDCAuth):
async def userinfo(self, request: Request, token: str | None = None) -> OIDCUserModel | None:
async def userinfo(self, async_request: AsyncClient, token: str) -> OIDCUserModel:
return user_info_matching

return OIDCAuthMock("openid_url", "openid_url/.well-known/openid-configuration", "id", "secret", OIDCUserModel)
Expand Down Expand Up @@ -57,7 +57,8 @@ class Query:
def book(self) -> BookType:
return BookType(title="test title", author="test author")

@strawberry.field(description="query book nested auth test")
# Known issue with strawberry: https://github.com/strawberry-graphql/strawberry/issues/1929
@strawberry.field(description="query book nested auth test") # type: ignore
def book_nested_auth(self) -> bookNestedAuthType:
return bookNestedAuthType(title="test title")

Expand All @@ -76,7 +77,7 @@ async def get_context(auth_manager=Depends(get_auth_manger)) -> OauthContext: #

app = FastAPI()
schema = strawberry.Schema(query=Query, mutation=Mutation)
graphql_app = GraphQLRouter(schema, context_getter=get_context)
graphql_app: GraphQLRouter = GraphQLRouter(schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

return TestClient(app)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_async_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,18 @@ class FakeApiClient:
def __init__(self, *args, **kwargs):
pass

def request(self, method, url, query_params, headers, *args):
def request(
self,
method,
url,
query_params=None,
headers=None,
post_params=None,
body=None,
_preload_content=True,
_request_timeout=None,
):
headers = {} if headers is None else headers
http = urllib3.PoolManager()
response = http.request(method, url, headers=headers)
if not 200 <= response.status <= 299:
Expand Down
15 changes: 6 additions & 9 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,25 +134,19 @@ async def test_userinfo_success_with_mock(oidc_auth):
assert user["sub"] == "hoi", "User info not retrieved correctly"


def test_oidc_auth_initialization_default_extractor(oidc_auth):
assert isinstance(
oidc_auth.id_token_extractor, HttpBearerExtractor
), "Default ID token extractor should be HttpBearerExtractor"


@pytest.mark.asyncio
async def test_extract_token_success():
request = mock.MagicMock()
request.headers = {"Authorization": "Bearer example_token"}
extractor = HttpBearerExtractor()
extractor = HttpBearerExtractor(auto_error=False)
assert await extractor.extract(request) == "example_token", "Token extraction failed"


@pytest.mark.asyncio
async def test_extract_token_returns_none():
request = mock.MagicMock()
request.headers = {}
extractor = HttpBearerExtractor()
extractor = HttpBearerExtractor(auto_error=False)
assert await extractor.extract(request) is None


Expand All @@ -165,7 +159,10 @@ async def test_authenticate_success(make_mock_async_client, discovery, oidc_auth
request = mock.MagicMock(spec=Request)
request.headers = {"Authorization": "Bearer valid_token"}

user = await oidc_auth.authenticate(request)
http_bearer_extractor = HttpBearerExtractor(auto_error=False)
token = await http_bearer_extractor(request)
# token = await oidc_auth.id_token_extractor.extract(request)
user = await oidc_auth.authenticate(request, token)
assert user == user_info_matching, "Authentication failed for a valid token"


Expand Down
Loading