-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "Refactor auth code to output auth scheme in OpenAPI spec"
This reverts commit 6442d2d.
- Loading branch information
Showing
2 changed files
with
51 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,39 @@ | ||
from fastapi import Request | ||
import threading | ||
|
||
from fastapi import Header | ||
from fastapi.exceptions import HTTPException | ||
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType | ||
from fastapi.security.base import SecurityBase | ||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN | ||
|
||
from app_users.models import AppUser | ||
from auth.auth_backend import authlocal | ||
from daras_ai_v2 import db | ||
from daras_ai_v2.crypto import PBKDF2PasswordHasher | ||
|
||
auth_keyword = "Bearer" | ||
|
||
class AuthenticationError(HTTPException): | ||
status_code = HTTP_401_UNAUTHORIZED | ||
|
||
def __init__(self, msg: str): | ||
super().__init__(status_code=self.status_code, detail={"error": msg}) | ||
|
||
def api_auth_header( | ||
authorization: str = Header( | ||
alias="Authorization", | ||
description=f"{auth_keyword} $GOOEY_API_KEY", | ||
), | ||
) -> AppUser: | ||
if authlocal: | ||
return authlocal[0] | ||
return authenticate(authorization) | ||
|
||
class AuthorizationError(HTTPException): | ||
status_code = HTTP_403_FORBIDDEN | ||
|
||
def __init__(self, msg: str): | ||
super().__init__(status_code=self.status_code, detail={"error": msg}) | ||
def authenticate(auth_token: str) -> AppUser: | ||
auth = auth_token.split() | ||
if not auth or auth[0].lower() != auth_keyword.lower(): | ||
msg = "Invalid Authorization header." | ||
raise HTTPException(status_code=401, detail={"error": msg}) | ||
if len(auth) == 1: | ||
msg = "Invalid Authorization header. No credentials provided." | ||
raise HTTPException(status_code=401, detail={"error": msg}) | ||
elif len(auth) > 2: | ||
msg = "Invalid Authorization header. Token string should not contain spaces." | ||
raise HTTPException(status_code=401, detail={"error": msg}) | ||
return authenticate_credentials(auth[1]) | ||
|
||
|
||
def authenticate_credentials(token: str) -> AppUser: | ||
|
@@ -36,7 +48,12 @@ def authenticate_credentials(token: str) -> AppUser: | |
.get()[0] | ||
) | ||
except IndexError: | ||
raise AuthorizationError("Invalid API Key.") | ||
raise HTTPException( | ||
status_code=403, | ||
detail={ | ||
"error": "Invalid API Key.", | ||
}, | ||
) | ||
|
||
uid = doc.get("uid") | ||
user = AppUser.objects.get_or_create_from_uid(uid)[0] | ||
|
@@ -45,50 +62,6 @@ def authenticate_credentials(token: str) -> AppUser: | |
"Your Gooey.AI account has been disabled for violating our Terms of Service. " | ||
"Contact us at [email protected] if you think this is a mistake." | ||
) | ||
raise AuthenticationError(msg) | ||
raise HTTPException(status_code=401, detail={"error": msg}) | ||
|
||
return user | ||
|
||
|
||
class APIAuth(SecurityBase): | ||
""" | ||
### Usage: | ||
```python | ||
api_auth = APIAuth(scheme_name="Bearer", description="Bearer $GOOEY_API_KEY") | ||
@app.get("/api/users") | ||
def get_users(authenticated_user: AppUser = Depends(api_auth)): | ||
... | ||
``` | ||
""" | ||
|
||
def __init__(self, scheme_name: str, description: str): | ||
self.model = HTTPBaseModel( | ||
type=SecuritySchemeType.http, scheme=scheme_name, description=description | ||
) | ||
self.scheme_name = scheme_name | ||
self.description = description | ||
|
||
def __call__(self, request: Request) -> AppUser: | ||
if authlocal: # testing only! | ||
return authlocal[0] | ||
|
||
auth = request.headers.get("Authorization", "").split() | ||
if not auth or auth[0].lower() != self.scheme_name.lower(): | ||
raise AuthenticationError("Invalid Authorization header.") | ||
if len(auth) == 1: | ||
raise AuthenticationError( | ||
"Invalid Authorization header. No credentials provided." | ||
) | ||
elif len(auth) > 2: | ||
raise AuthenticationError( | ||
"Invalid Authorization header. Token string should not contain spaces." | ||
) | ||
return authenticate_credentials(auth[1]) | ||
|
||
|
||
auth_scheme = "Bearer" | ||
api_auth_header = APIAuth( | ||
scheme_name=auth_scheme, description=f"{auth_scheme} $GOOEY_API_KEY" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters