Skip to content

Commit

Permalink
add Websoket as a Union type alongside of Request
Browse files Browse the repository at this point in the history
  • Loading branch information
torkashvandmt committed Apr 9, 2024
1 parent 6e95446 commit 303030a
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions oauth2_lib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from httpx import AsyncClient, NetworkError
from pydantic import BaseModel
from starlette.requests import ClientDisconnect
from starlette.websockets import WebSocket
from structlog import get_logger

from oauth2_lib.settings import oauth2lib_settings
Expand Down Expand Up @@ -78,7 +79,7 @@ def user_name(self) -> str:
RequestPath = str
AuthenticationFunc = Callable[[Request, Optional[str]], Awaitable[Optional[dict]]]
AuthorizationFunc = Callable[[Request, OIDCUserModel, Any], Awaitable[bool]]
GraphqlAuthorizationFunc = Callable[[RequestPath, OIDCUserModel, Optional[AsyncClient], Any], Awaitable[bool]]
GraphqlAuthorizationFunc = Callable[[str, OIDCUserModel], Awaitable[Optional[bool]]]


class OIDCConfig(BaseModel):
Expand Down Expand Up @@ -124,7 +125,7 @@ class Authentication(ABC):
"""

@abstractmethod
async def authenticate(self, request: Request, token: Optional[str] = None) -> Optional[dict]:
async def authenticate(self, request: Union[Request, WebSocket], token: Optional[str] = None) -> Optional[dict]:
"""Authenticate the user."""
pass

Expand Down Expand Up @@ -177,7 +178,9 @@ def __init__(

self.openid_config: Optional[OIDCConfig] = None

async def authenticate(self, request: Request, token: Optional[str] = None) -> Optional[OIDCUserModel]:
async def authenticate(
self, request: Union[Request, WebSocket], token: Optional[str] = None
) -> Optional[OIDCUserModel]:
"""Return the OIDC user from OIDC introspect endpoint.
This is used as a security module in Fastapi projects
Expand All @@ -196,15 +199,21 @@ async def authenticate(self, request: Request, token: Optional[str] = None) -> O
async with AsyncClient(http1=True, verify=HTTPX_SSL_CONTEXT) as async_client:
await self.check_openid_config(async_client)

if token is None:
extracted_id_token = await self.id_token_extractor.extract(request)
if not extracted_id_token:
# Handle WebSocket requests separately only to check for token presence.
if isinstance(request, WebSocket):
if token is None:
return None
token_or_extracted_id_token = extracted_id_token
elif await self.is_bypassable_request(request):
return None
else:
token_or_extracted_id_token = token
else:
if token is None:
extracted_id_token = await self.id_token_extractor.extract(request)
if not extracted_id_token:
return None
token_or_extracted_id_token = extracted_id_token
elif await self.is_bypassable_request(request):
return None
else:
token_or_extracted_id_token = token

user_info = await self.userinfo(async_client, token_or_extracted_id_token)
logger.debug("OIDCUserModel object.", user_info=user_info)
Expand Down Expand Up @@ -255,7 +264,7 @@ class Authorization(ABC):

@abstractmethod
async def authorize(
self, request: Request, user: OIDCUserModel = Depends(oidc_instance.authenticate)
self, request: Union[Request, WebSocket], user: OIDCUserModel = Depends(oidc_instance.authenticate)
) -> Optional[bool]:
pass

Expand Down Expand Up @@ -320,13 +329,16 @@ class OPAAuthorization(Authorization, OPAMixin):
"""

async def authorize(
self, request: Request, user_info: OIDCUserModel = Depends(oidc_instance.authenticate)
self, request: Union[Request, WebSocket], user_info: OIDCUserModel = Depends(oidc_instance.authenticate)
) -> Optional[bool]:
if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
return None

try:
json = await request.json()
if isinstance(request, WebSocket):
json = {}
else:
json = await request.json()
# Silencing the Decode error or Type error when request.json() does not return anything sane.
# Some requests do not have a json response therefore as this code gets called on every request
# we need to suppress the `None` case (TypeError) or the `other than json` case (JSONDecodeError)
Expand Down

0 comments on commit 303030a

Please sign in to comment.