diff --git a/auth/token_authentication.py b/auth/token_authentication.py index 483e291b6..b33bbbbd0 100644 --- a/auth/token_authentication.py +++ b/auth/token_authentication.py @@ -1,27 +1,39 @@ -from fastapi import Request +import threading + +from fastapi import Header from fastapi.exceptions import HTTPException -from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType -from fastapi.security.base import SecurityBase -from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from app_users.models import AppUser from auth.auth_backend import authlocal from daras_ai_v2 import db from daras_ai_v2.crypto import PBKDF2PasswordHasher +auth_keyword = "Bearer" -class AuthenticationError(HTTPException): - status_code = HTTP_401_UNAUTHORIZED - - def __init__(self, msg: str): - super().__init__(status_code=self.status_code, detail={"error": msg}) +def api_auth_header( + authorization: str = Header( + alias="Authorization", + description=f"{auth_keyword} $GOOEY_API_KEY", + ), +) -> AppUser: + if authlocal: + return authlocal[0] + return authenticate(authorization) -class AuthorizationError(HTTPException): - status_code = HTTP_403_FORBIDDEN - def __init__(self, msg: str): - super().__init__(status_code=self.status_code, detail={"error": msg}) +def authenticate(auth_token: str) -> AppUser: + auth = auth_token.split() + if not auth or auth[0].lower() != auth_keyword.lower(): + msg = "Invalid Authorization header." + raise HTTPException(status_code=401, detail={"error": msg}) + if len(auth) == 1: + msg = "Invalid Authorization header. No credentials provided." + raise HTTPException(status_code=401, detail={"error": msg}) + elif len(auth) > 2: + msg = "Invalid Authorization header. Token string should not contain spaces." + raise HTTPException(status_code=401, detail={"error": msg}) + return authenticate_credentials(auth[1]) def authenticate_credentials(token: str) -> AppUser: @@ -36,7 +48,12 @@ def authenticate_credentials(token: str) -> AppUser: .get()[0] ) except IndexError: - raise AuthorizationError("Invalid API Key.") + raise HTTPException( + status_code=403, + detail={ + "error": "Invalid API Key.", + }, + ) uid = doc.get("uid") user = AppUser.objects.get_or_create_from_uid(uid)[0] @@ -45,50 +62,6 @@ def authenticate_credentials(token: str) -> AppUser: "Your Gooey.AI account has been disabled for violating our Terms of Service. " "Contact us at support@gooey.ai if you think this is a mistake." ) - raise AuthenticationError(msg) + raise HTTPException(status_code=401, detail={"error": msg}) return user - - -class APIAuth(SecurityBase): - """ - ### Usage: - - ```python - api_auth = APIAuth(scheme_name="Bearer", description="Bearer $GOOEY_API_KEY") - - @app.get("/api/users") - def get_users(authenticated_user: AppUser = Depends(api_auth)): - ... - ``` - """ - - def __init__(self, scheme_name: str, description: str): - self.model = HTTPBaseModel( - type=SecuritySchemeType.http, scheme=scheme_name, description=description - ) - self.scheme_name = scheme_name - self.description = description - - def __call__(self, request: Request) -> AppUser: - if authlocal: # testing only! - return authlocal[0] - - auth = request.headers.get("Authorization", "").split() - if not auth or auth[0].lower() != self.scheme_name.lower(): - raise AuthenticationError("Invalid Authorization header.") - if len(auth) == 1: - raise AuthenticationError( - "Invalid Authorization header. No credentials provided." - ) - elif len(auth) > 2: - raise AuthenticationError( - "Invalid Authorization header. Token string should not contain spaces." - ) - return authenticate_credentials(auth[1]) - - -auth_scheme = "Bearer" -api_auth_header = APIAuth( - scheme_name=auth_scheme, description=f"{auth_scheme} $GOOEY_API_KEY" -) diff --git a/daras_ai_v2/api_examples_widget.py b/daras_ai_v2/api_examples_widget.py index ef54058c2..9cd9beeb1 100644 --- a/daras_ai_v2/api_examples_widget.py +++ b/daras_ai_v2/api_examples_widget.py @@ -6,7 +6,7 @@ from furl import furl import gooey_gui as gui -from auth.token_authentication import auth_scheme +from auth.token_authentication import auth_keyword from daras_ai_v2 import settings from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url @@ -48,12 +48,12 @@ def api_example_generator( if as_form_data: curl_code = r""" curl %(api_url)s \ - -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \ + -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \ %(files)s \ -F json=%(json)s """ % dict( api_url=shlex.quote(api_url), - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, files=" \\\n ".join( f"-F {key}=@{shlex.quote(filename)}" for key, filename in filenames ), @@ -62,12 +62,12 @@ def api_example_generator( else: curl_code = r""" curl %(api_url)s \ - -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \ + -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \ -H 'Content-Type: application/json' \ -d %(json)s """ % dict( api_url=shlex.quote(api_url), - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, json=shlex.quote(json.dumps(request_body, indent=2)), ) if as_async: @@ -77,7 +77,7 @@ def api_example_generator( ) while true; do - result=$(curl $status_url -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY") + result=$(curl $status_url -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY") status=$(echo $result | jq -r '.status') if [ "$status" = "completed" ]; then echo $result @@ -91,7 +91,7 @@ def api_example_generator( """ % dict( curl_code=indent(curl_code.strip(), " " * 2), api_url=shlex.quote(api_url), - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, json=shlex.quote(json.dumps(request_body, indent=2)), ) @@ -128,7 +128,7 @@ def api_example_generator( response = requests.post( "%(api_url)s", headers={ - "Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"], }, files=files, data={"json": json.dumps(payload)}, @@ -140,7 +140,7 @@ def api_example_generator( ), json=repr(request_body), api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, ) else: py_code = r""" @@ -152,14 +152,14 @@ def api_example_generator( response = requests.post( "%(api_url)s", headers={ - "Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"], }, json=payload, ) assert response.ok, response.content """ % dict( api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, json=repr(request_body), ) if as_async: @@ -168,7 +168,7 @@ def api_example_generator( status_url = response.headers["Location"] while True: - response = requests.get(status_url, headers={"Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"]}) + response = requests.get(status_url, headers={"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"]}) assert response.ok, response.content result = response.json() if result["status"] == "completed": @@ -181,7 +181,7 @@ def api_example_generator( sleep(3) """ % dict( api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, ) else: py_code += r""" @@ -229,7 +229,7 @@ def api_example_generator( const response = await fetch("%(api_url)s", { method: "POST", headers: { - "Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"], }, body: formData, }); @@ -243,7 +243,7 @@ def api_example_generator( " " * 2, ), api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, ) else: @@ -256,14 +256,14 @@ def api_example_generator( const response = await fetch("%(api_url)s", { method: "POST", headers: { - "Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"], "Content-Type": "application/json", }, body: JSON.stringify(payload), }); """ % dict( api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, json=json.dumps(request_body, indent=2), ) @@ -280,7 +280,7 @@ def api_example_generator( const response = await fetch(status_url, { method: "GET", headers: { - "Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"], }, }); if (!response.ok) { @@ -299,7 +299,7 @@ def api_example_generator( } }""" % dict( api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, ) else: js_code += """