From 04e6c8c814252090c43f7d2c6bf78c53ea2ec18c Mon Sep 17 00:00:00 2001 From: Harpo Harbert Date: Tue, 24 Oct 2023 14:36:41 -0700 Subject: [PATCH] Allows (str, list, set) as permissions --- chowda/auth/utils.py | 17 +++++++++-------- chowda/routers/events.py | 2 +- chowda/routers/sony_ci.py | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/chowda/auth/utils.py b/chowda/auth/utils.py index 61304c8f..a40c02a7 100644 --- a/chowda/auth/utils.py +++ b/chowda/auth/utils.py @@ -1,4 +1,4 @@ -from typing import Annotated, List +from typing import Annotated, List, Set from fastapi import Depends, HTTPException, Request, status from pydantic import BaseModel, Field @@ -122,22 +122,23 @@ def verified_access_token( ) from exc -def permissions(permissions: List[str]) -> None: +def permissions(permissions: str | List[str] | Set[str]) -> None: """Dependency function to check if token has required permissions. Args: - permissions (List[str]): List of required permissions + permissions (str, List, Set): Required permissions. Can be a str, list, or set. Examples: - @app.get('/users/', dependencies=[Depends(permissions(['read:user']))]) - - + @app.get('/users/', dependencies=[Depends(permissions('read:users'))]) """ + if isinstance(permissions, (str, list)): + permissions: set = {permissions} def _permissions( token: Annotated[OAuthAccessToken, Depends(verified_access_token)], ) -> None: - """Check if user has required permissions.""" - missing_permissions = set(permissions) - set(token.permissions) + """Verify token has all required permissions, or raise a 403 Forbidden exception + with the missing permissions in the detail message.""" + missing_permissions = permissions - set(token.permissions) if missing_permissions: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/chowda/routers/events.py b/chowda/routers/events.py index 358df5eb..779c26c2 100644 --- a/chowda/routers/events.py +++ b/chowda/routers/events.py @@ -11,7 +11,7 @@ events = APIRouter() -@events.post('/', dependencies=[Depends(permissions(['create:event']))]) +@events.post('/', dependencies=[Depends(permissions('create:event'))]) async def event(event: dict): """Receive an event from Argo Events.""" print('Chowda event received', event) diff --git a/chowda/routers/sony_ci.py b/chowda/routers/sony_ci.py index 85d51f5b..fce321a0 100644 --- a/chowda/routers/sony_ci.py +++ b/chowda/routers/sony_ci.py @@ -14,7 +14,7 @@ class SyncResponse(BaseModel): @sony_ci.post( - '/sync', tags=['sync'], dependencies=[Depends(permissions(['sync:sonyci']))] + '/sync', tags=['sync'], dependencies=[Depends(permissions('sync:sonyci'))] ) async def sony_ci_sync() -> SyncResponse: try: