Skip to content

Commit

Permalink
Revert "Refactor auth code to output auth scheme in OpenAPI spec"
Browse files Browse the repository at this point in the history
This reverts commit 6442d2d.
  • Loading branch information
nikochiko committed Sep 16, 2024
1 parent 0713227 commit 9a5f74b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 78 deletions.
91 changes: 32 additions & 59 deletions auth/token_authentication.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@
from fastapi import Request
import threading

from fastapi import Header
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType
from fastapi.security.base import SecurityBase
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN

from app_users.models import AppUser
from auth.auth_backend import authlocal
from daras_ai_v2 import db
from daras_ai_v2.crypto import PBKDF2PasswordHasher

auth_keyword = "Bearer"

class AuthenticationError(HTTPException):
status_code = HTTP_401_UNAUTHORIZED

def __init__(self, msg: str):
super().__init__(status_code=self.status_code, detail={"error": msg})

def api_auth_header(
authorization: str = Header(
alias="Authorization",
description=f"{auth_keyword} $GOOEY_API_KEY",
),
) -> AppUser:
if authlocal:
return authlocal[0]
return authenticate(authorization)

class AuthorizationError(HTTPException):
status_code = HTTP_403_FORBIDDEN

def __init__(self, msg: str):
super().__init__(status_code=self.status_code, detail={"error": msg})
def authenticate(auth_token: str) -> AppUser:
auth = auth_token.split()
if not auth or auth[0].lower() != auth_keyword.lower():
msg = "Invalid Authorization header."
raise HTTPException(status_code=401, detail={"error": msg})
if len(auth) == 1:
msg = "Invalid Authorization header. No credentials provided."
raise HTTPException(status_code=401, detail={"error": msg})
elif len(auth) > 2:
msg = "Invalid Authorization header. Token string should not contain spaces."
raise HTTPException(status_code=401, detail={"error": msg})
return authenticate_credentials(auth[1])


def authenticate_credentials(token: str) -> AppUser:
Expand All @@ -36,7 +48,12 @@ def authenticate_credentials(token: str) -> AppUser:
.get()[0]
)
except IndexError:
raise AuthorizationError("Invalid API Key.")
raise HTTPException(
status_code=403,
detail={
"error": "Invalid API Key.",
},
)

uid = doc.get("uid")
user = AppUser.objects.get_or_create_from_uid(uid)[0]
Expand All @@ -45,50 +62,6 @@ def authenticate_credentials(token: str) -> AppUser:
"Your Gooey.AI account has been disabled for violating our Terms of Service. "
"Contact us at [email protected] if you think this is a mistake."
)
raise AuthenticationError(msg)
raise HTTPException(status_code=401, detail={"error": msg})

return user


class APIAuth(SecurityBase):
"""
### Usage:
```python
api_auth = APIAuth(scheme_name="Bearer", description="Bearer $GOOEY_API_KEY")
@app.get("/api/users")
def get_users(authenticated_user: AppUser = Depends(api_auth)):
...
```
"""

def __init__(self, scheme_name: str, description: str):
self.model = HTTPBaseModel(
type=SecuritySchemeType.http, scheme=scheme_name, description=description
)
self.scheme_name = scheme_name
self.description = description

def __call__(self, request: Request) -> AppUser:
if authlocal: # testing only!
return authlocal[0]

auth = request.headers.get("Authorization", "").split()
if not auth or auth[0].lower() != self.scheme_name.lower():
raise AuthenticationError("Invalid Authorization header.")
if len(auth) == 1:
raise AuthenticationError(
"Invalid Authorization header. No credentials provided."
)
elif len(auth) > 2:
raise AuthenticationError(
"Invalid Authorization header. Token string should not contain spaces."
)
return authenticate_credentials(auth[1])


auth_scheme = "Bearer"
api_auth_header = APIAuth(
scheme_name=auth_scheme, description=f"{auth_scheme} $GOOEY_API_KEY"
)
38 changes: 19 additions & 19 deletions daras_ai_v2/api_examples_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from furl import furl

import gooey_gui as gui
from auth.token_authentication import auth_scheme
from auth.token_authentication import auth_keyword
from daras_ai_v2 import settings
from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url

Expand Down Expand Up @@ -48,12 +48,12 @@ def api_example_generator(
if as_form_data:
curl_code = r"""
curl %(api_url)s \
-H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \
-H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \
%(files)s \
-F json=%(json)s
""" % dict(
api_url=shlex.quote(api_url),
auth_scheme=auth_scheme,
auth_keyword=auth_keyword,
files=" \\\n ".join(
f"-F {key}=@{shlex.quote(filename)}" for key, filename in filenames
),
Expand All @@ -62,12 +62,12 @@ def api_example_generator(
else:
curl_code = r"""
curl %(api_url)s \
-H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \
-H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \
-H 'Content-Type: application/json' \
-d %(json)s
""" % dict(
api_url=shlex.quote(api_url),
auth_scheme=auth_scheme,
auth_keyword=auth_keyword,
json=shlex.quote(json.dumps(request_body, indent=2)),
)
if as_async:
Expand All @@ -77,7 +77,7 @@ def api_example_generator(
)
while true; do
result=$(curl $status_url -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY")
result=$(curl $status_url -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY")
status=$(echo $result | jq -r '.status')
if [ "$status" = "completed" ]; then
echo $result
Expand All @@ -91,7 +91,7 @@ def api_example_generator(
""" % dict(
curl_code=indent(curl_code.strip(), " " * 2),
api_url=shlex.quote(api_url),
auth_scheme=auth_scheme,
auth_keyword=auth_keyword,
json=shlex.quote(json.dumps(request_body, indent=2)),
)

Expand Down Expand Up @@ -128,7 +128,7 @@ def api_example_generator(
response = requests.post(
"%(api_url)s",
headers={
"Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"],
"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"],
},
files=files,
data={"json": json.dumps(payload)},
Expand All @@ -140,7 +140,7 @@ def api_example_generator(
),
json=repr(request_body),
api_url=api_url,
auth_scheme=auth_scheme,
auth_keyword=auth_keyword,
)
else:
py_code = r"""
Expand All @@ -152,14 +152,14 @@ def api_example_generator(
response = requests.post(
"%(api_url)s",
headers={
"Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"],
"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"],
},
json=payload,
)
assert response.ok, response.content
""" % dict(
api_url=api_url,
auth_scheme=auth_scheme,
auth_keyword=auth_keyword,
json=repr(request_body),
)
if as_async:
Expand All @@ -168,7 +168,7 @@ def api_example_generator(
status_url = response.headers["Location"]
while True:
response = requests.get(status_url, headers={"Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"]})
response = requests.get(status_url, headers={"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"]})
assert response.ok, response.content
result = response.json()
if result["status"] == "completed":
Expand All @@ -181,7 +181,7 @@ def api_example_generator(
sleep(3)
""" % dict(
api_url=api_url,
auth_scheme=auth_scheme,
auth_keyword=auth_keyword,
)
else:
py_code += r"""
Expand Down Expand Up @@ -229,7 +229,7 @@ def api_example_generator(
const response = await fetch("%(api_url)s", {
method: "POST",
headers: {
"Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"],
"Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"],
},
body: formData,
});
Expand All @@ -243,7 +243,7 @@ def api_example_generator(
" " * 2,
),
api_url=api_url,
auth_scheme=auth_scheme,
auth_keyword=auth_keyword,
)

else:
Expand All @@ -256,14 +256,14 @@ def api_example_generator(
const response = await fetch("%(api_url)s", {
method: "POST",
headers: {
"Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"],
"Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"],
"Content-Type": "application/json",
},
body: JSON.stringify(payload),
});
""" % dict(
api_url=api_url,
auth_scheme=auth_scheme,
auth_keyword=auth_keyword,
json=json.dumps(request_body, indent=2),
)

Expand All @@ -280,7 +280,7 @@ def api_example_generator(
const response = await fetch(status_url, {
method: "GET",
headers: {
"Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"],
"Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"],
},
});
if (!response.ok) {
Expand All @@ -299,7 +299,7 @@ def api_example_generator(
}
}""" % dict(
api_url=api_url,
auth_scheme=auth_scheme,
auth_keyword=auth_keyword,
)
else:
js_code += """
Expand Down

0 comments on commit 9a5f74b

Please sign in to comment.