Skip to content

Commit

Permalink
Fix broken user token rotation API (#1487)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored Aug 1, 2024
1 parent 5aebcee commit fcbf4df
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/dstack/_internal/server/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ async def update_user(
async def refresh_token(
body: RefreshTokenRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(GlobalAdmin()),
user: UserModel = Depends(Authenticated()),
) -> UserWithCreds:
res = await users.refresh_user_token(session=session, username=body.username)
res = await users.refresh_user_token(session=session, user=user, username=body.username)
if res is None:
raise ResourceNotExistsError()
return users.user_model_to_user_with_creds(res)
Expand Down
11 changes: 9 additions & 2 deletions src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dstack._internal.core.errors import ResourceExistsError
from dstack._internal.core.models.users import GlobalRole, User, UserTokenCreds, UserWithCreds
from dstack._internal.server.models import UserModel
from dstack._internal.server.utils.routers import error_forbidden

_ADMIN_USERNAME = "admin"

Expand Down Expand Up @@ -90,9 +91,15 @@ async def update_user(
return await get_user_model_by_name_or_error(session=session, username=username)


async def refresh_user_token(session: AsyncSession, username: str) -> Optional[UserModel]:
async def refresh_user_token(
session: AsyncSession,
user: UserModel,
username: str,
) -> Optional[UserModel]:
if user.global_role != GlobalRole.ADMIN and user.name != username:
raise error_forbidden()
await session.execute(
update(UserModel).where(UserModel.name == username).values(token=uuid.uuid4())
update(UserModel).where(UserModel.name == username).values(token=str(uuid.uuid4()))
)
await session.commit()
return await get_user_model_by_name(session=session, username=username)
Expand Down
48 changes: 48 additions & 0 deletions src/tests/_internal/server/routers/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,51 @@ async def test_deletes_users(self, test_db, session: AsyncSession):
assert response.status_code == 200
res = await session.execute(select(UserModel).where(UserModel.name == user.name))
assert len(res.scalars().all()) == 0


class TestRefreshToken:
def test_returns_40x_if_not_authenticated(self):
response = client.post("/api/users/refresh_token")
assert response.status_code in [401, 403]

@pytest.mark.asyncio
async def test_refreshes_token(self, test_db, session: AsyncSession):
user1 = await create_user(name="user1", session=session)
old_token = user1.token
response = client.post(
"/api/users/refresh_token",
headers=get_auth_headers(user1.token),
json={"username": user1.name},
)
assert response.status_code == 200
assert response.json()["creds"]["token"] != old_token
await session.refresh(user1)
assert user1.token != old_token

@pytest.mark.asyncio
async def test_returns_403_if_non_admin_refreshes_for_other_user(
self, test_db, session: AsyncSession
):
user1 = await create_user(name="user1", session=session, global_role=GlobalRole.USER)
user2 = await create_user(name="user2", session=session)
response = client.post(
"/api/users/refresh_token",
headers=get_auth_headers(user1.token),
json={"username": user2.name},
)
assert response.status_code == 403

@pytest.mark.asyncio
async def test_global_admin_refreshes_token(self, test_db, session: AsyncSession):
user1 = await create_user(name="user1", session=session, global_role=GlobalRole.ADMIN)
user2 = await create_user(name="user2", session=session)
old_token = user2.token
response = client.post(
"/api/users/refresh_token",
headers=get_auth_headers(user1.token),
json={"username": user2.name},
)
assert response.status_code == 200
assert response.json()["creds"]["token"] != old_token
await session.refresh(user2)
assert user2.token != old_token

0 comments on commit fcbf4df

Please sign in to comment.