Skip to content

Commit

Permalink
Change the way API keys are generated and used (#1221)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchop authored Feb 12, 2025
1 parent 56f330c commit 85c1294
Show file tree
Hide file tree
Showing 19 changed files with 342 additions and 116 deletions.
89 changes: 76 additions & 13 deletions core/schemas/user.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import re
import datetime
import json
import secrets
from typing import ClassVar, Literal

from jose import jwt
from passlib.context import CryptContext
from pydantic import ConfigDict, Field, computed_field
from pydantic import BaseModel, ConfigDict, Field, computed_field

from core import database_arango
from core.config.config import yeti_config
from core.helpers import now
from core.schemas import graph, rbac, roles
from core.schemas.model import YetiModel

Expand All @@ -16,10 +19,37 @@
"reader": roles.Role.READER,
"writer": roles.Role.WRITER,
}
SECRET_KEY = yeti_config.get("auth", "secret_key")
ALGORITHM = yeti_config.get("auth", "algorithm")


def create_access_token(
data: dict, expires_delta: datetime.timedelta | None = None
) -> str:
to_encode = data.copy()
expire = None
if expires_delta:
expire = datetime.datetime.now(datetime.timezone.utc) + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


class RegisteredApiKey(BaseModel):
name: str
sub: str
scopes: list[str]
created: datetime.datetime = Field(default_factory=now)
exp: datetime.datetime | None = None
last_used: datetime.datetime | None = None
enabled: bool = True


def generate_api_key():
return secrets.token_hex(32)
@computed_field
@property
def expired(self) -> bool:
if self.exp is None:
return False
return self.exp > datetime.datetime.now(tz=datetime.timezone.utc)


class User(YetiModel, database_arango.ArangoYetiConnector):
Expand All @@ -31,7 +61,7 @@ class User(YetiModel, database_arango.ArangoYetiConnector):
username: str
enabled: bool = True
admin: bool = False
api_key: str = Field(default_factory=generate_api_key)
api_keys: dict[str, RegisteredApiKey] | None = {}

global_role: int = RBAC_DEFAULT_ROLES[
str(yeti_config.get("rbac", "default_global_role", default="writer"))
Expand All @@ -46,13 +76,46 @@ def root_type(self):
def load(cls, object: dict) -> "User":
return cls(**object)

def reset_api_key(self, api_key=None) -> None:
if api_key:
if not re.match(r"^[a-f0-9]{64}$", api_key):
raise ValueError("Invalid API key: must match ^[a-f0-9]{64}$")
self.api_key = api_key
else:
self.api_key = secrets.token_hex(32)
def create_api_key(
self,
key_name: str,
scopes: list[str] | None = None,
expiration_delta: datetime.timedelta | None = None,
) -> str:
exp = None
if expiration_delta:
exp = datetime.datetime.now(datetime.timezone.utc) + expiration_delta
api_key = RegisteredApiKey(
name=key_name,
sub=self.username,
scopes=scopes or ["all"],
exp=exp,
)
self.api_keys[key_name] = api_key
self.save()
return create_access_token(json.loads(api_key.model_dump_json()))

def validate_api_key_payload(self, payload) -> RegisteredApiKey:
sub = payload.get("sub")
key_name = payload.get("name")
if key_name not in self.api_keys or sub != self.username:
raise ValueError("Could not validate credentials")

key = self.api_keys[key_name]
if not key.enabled:
raise ValueError("API key disabled.")
if key.expired:
raise ValueError("API key expired.")

return key

def delete_api_key(self, api_key_name) -> None:
api_keys = self.api_keys
del api_keys[api_key_name]
self.api_keys = None
self.save()
self.api_keys = api_keys
self.save()

def has_permissions(self, target: str, permissions: roles.Permission) -> bool:
return graph.RoleRelationship.has_permissions(self, target, permissions)
Expand Down
47 changes: 30 additions & 17 deletions core/web/apiv2/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from starlette.requests import Request

from core.config.config import yeti_config
from core.schemas.user import User, UserSensitive
from core.schemas.user import User, UserSensitive, create_access_token

ACCESS_TOKEN_EXPIRE_DELTA = datetime.timedelta(
minutes=yeti_config.get("auth", "access_token_expire_minutes", default=30)
Expand Down Expand Up @@ -65,17 +65,6 @@ def get_oauth_client() -> OAuth:
return client


def create_access_token(data: dict, expires_delta: datetime.timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.datetime.utcnow() + expires_delta
else:
expire = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


def get_current_user(
request: Request,
token: str = Depends(oauth2_scheme),
Expand Down Expand Up @@ -277,26 +266,50 @@ def login(


@router.post("/api-token")
def login_api(x_yeti_api_key: str = Security(api_key_header)) -> dict[str, str]:
user = UserSensitive.find(api_key=x_yeti_api_key)
def login_api(x_yeti_api_key_token: str = Security(api_key_header)) -> dict[str, str]:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "x-yeti-apikey"},
)

try:
payload = jwt.decode(
x_yeti_api_key_token,
SECRET_KEY,
algorithms=[ALGORITHM],
options={"verify_exp": False, "requires_exp": False},
)
except JWTError:
raise credentials_exception

user = UserSensitive.find(username=payload.get("sub"))
if not user:
raise credentials_exception

try:
user.validate_api_key_payload(payload)
except ValueError as error:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
detail=str(error),
headers={"WWW-Authenticate": "x-yeti-apikey"},
)

if not user.enabled:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account disabled. Please contact your server admin.",
headers={"WWW-Authenticate": "Bearer"},
headers={"WWW-Authenticate": "x-yeti-apikey"},
)

access_token = create_access_token(
data={"sub": user.username, "enabled": user.enabled},
expires_delta=ACCESS_TOKEN_EXPIRE_DELTA,
)
user.api_keys[payload["name"]].last_used = datetime.datetime.now(
tz=datetime.timezone.utc
)
return {"access_token": access_token, "token_type": "bearer"}


Expand Down
103 changes: 95 additions & 8 deletions core/web/apiv2/users.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import datetime
from enum import Enum

from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, ConfigDict

from core.schemas import rbac, roles
from core.schemas.user import User, UserSensitive
from core.schemas.user import RegisteredApiKey, User, UserSensitive
from core.web.apiv2.auth import GetCurrentUserWithPermissions, get_current_user


Expand Down Expand Up @@ -47,10 +48,41 @@ class PatchRoleRequest(BaseModel):
role: roles.Permission


class ResetApiKeyRequest(BaseModel):
class NewApiKeyRequest(BaseModel):
model_config = ConfigDict(extra="forbid")

user_id: str
name: str
scopes: list[str] = []
expiration: datetime.timedelta | None = None


class NewAPIKeyResponse(BaseModel):
model_config = ConfigDict(extra="forbid")

name: str
token: str
api_keys: dict[str, RegisteredApiKey]


class DeleteApiKeyRequest(BaseModel):
model_config = ConfigDict(extra="forbid")

user_id: str
name: str


class DeleteApiKeyResponse(BaseModel):
model_config = ConfigDict(extra="forbid")

api_keys: dict[str, RegisteredApiKey]


class ToggleApiKeyRequest(BaseModel):
model_config = ConfigDict(extra="forbid")

user_id: str
name: str


class ResetPasswordRequest(BaseModel):
Expand Down Expand Up @@ -137,22 +169,73 @@ def update_user_role(
return user.save()


@router.post("/reset-api-key")
@router.post("/new-api-key")
def new_api_key(
request: NewApiKeyRequest, current_user: UserSensitive = Depends(get_current_user)
) -> NewAPIKeyResponse:
if not current_user.admin and current_user.id != request.user_id:
raise HTTPException(
status_code=401, detail="cannot create API keys for other users"
)

user = UserSensitive.get(request.user_id)
if not user:
raise HTTPException(status_code=404, detail="user {user_id} not found")

token = user.create_api_key(
request.name, scopes=request.scopes, expiration_delta=request.expiration
)
user.save()
return NewAPIKeyResponse(name=request.name, token=token, api_keys=user.api_keys)


@router.post("/toggle-api-key")
def toggle_api_key(
request: ToggleApiKeyRequest,
current_user: UserSensitive = Depends(get_current_user),
) -> RegisteredApiKey:
if not current_user.admin and current_user.id != request.user_id:
raise HTTPException(
status_code=401, detail="cannot create API keys for other users"
)

user = UserSensitive.get(request.user_id)
if not user:
raise HTTPException(status_code=404, detail="user {user_id} not found")

assert isinstance(user.api_keys, dict)
if request.name not in user.api_keys:
raise HTTPException(
status_code=401, detail=f"{request.name}: invalid API key name"
)

user.api_keys[request.name].enabled = not user.api_keys[request.name].enabled
user.save()
return user.api_keys[request.name]


@router.post("/delete-api-key")
def reset_api_key(
request: ResetApiKeyRequest, current_user: UserSensitive = Depends(get_current_user)
) -> User:
request: DeleteApiKeyRequest,
current_user: UserSensitive = Depends(get_current_user),
) -> DeleteApiKeyResponse:
"""Resets a user's API key."""
if not current_user.admin and current_user.id != request.user_id:
raise HTTPException(
status_code=401, detail="cannot reset API keys for other users"
status_code=401, detail="cannot delete API keys for other users"
)

user = UserSensitive.get(request.user_id)
if not user:
raise HTTPException(status_code=404, detail="user {user_id} not found")

user.reset_api_key()
return user.save()
if request.name not in user.api_keys:
raise HTTPException(
status_code=401, detail=f"{request.name}: invalid API key name"
)
user.delete_api_key(request.name)

return DeleteApiKeyResponse(api_keys=user.api_keys)


@router.post("/reset-password")
Expand Down Expand Up @@ -198,10 +281,14 @@ def create(
user = user.save()

all_users = rbac.Group.find(name="All users")
if not all_users:
all_users = rbac.Group(name="All users").save()
if not request.admin:
user.link_to_acl(all_users, roles.Role.READER)
else:
admins = rbac.Group.find(name="Admins")
if not admins:
admins = rbac.Group(name="Admins").save()
user.link_to_acl(admins, roles.Role.OWNER)
user.link_to_acl(all_users, roles.Role.OWNER)

Expand Down
Loading

0 comments on commit 85c1294

Please sign in to comment.