Skip to content

Commit

Permalink
expose reset data to admin (#1602)
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty authored Nov 18, 2024
1 parent 8a2723a commit 2a3f06a
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 8 deletions.
2 changes: 2 additions & 0 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
WrappedGetPromptsResponse,
WrappedLogResponse,
WrappedPromptMessageResponse,
WrappedResetDataResult,
WrappedServerStatsResponse,
WrappedUserCollectionResponse,
WrappedUserOverviewResponse,
Expand All @@ -86,6 +87,7 @@
"WrappedUserResponse",
"WrappedVerificationResult",
"WrappedGenericMessageResponse",
"WrappedResetDataResult",
# Ingestion Responses
"IngestionResponse",
"WrappedIngestionResponse",
Expand Down
6 changes: 3 additions & 3 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ async def get_users_overview(
pass

@abstractmethod
async def get_user_verification_data(
async def get_user_validation_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
Expand Down Expand Up @@ -1393,10 +1393,10 @@ async def get_users_overview(
user_ids, offset, limit
)

async def get_user_verification_data(
async def get_user_validation_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
return await self.user_handler.get_user_verification_data(user_id)
return await self.user_handler.get_user_validation_data(user_id)

# Vector handler methods
async def upsert(self, entry: VectorEntry) -> None:
Expand Down
32 changes: 31 additions & 1 deletion py/core/main/api/auth_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from core.base.api.models import (
GenericMessageResponse,
WrappedGenericMessageResponse,
WrappedResetDataResult,
WrappedTokenResponse,
WrappedUserResponse,
WrappedVerificationResult,
Expand Down Expand Up @@ -280,7 +281,36 @@ async def get_user_verification_code(
raise R2RException(
status_code=400, message="Invalid user ID format"
)
result = await self.service.get_user_verification_data(user_uuid)
result = await self.service.get_user_verification_code(user_uuid)
return result

@self.router.get("/user/{user_id}/reset_token")
@self.base_endpoint
async def get_user_reset_token(
user_id: str = Path(..., description="User ID"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedResetDataResult:
"""
Get only the verification code for a specific user.
Only accessible by superusers.
"""
if not auth_user.is_superuser:
raise R2RException(
status_code=403,
message="Only superusers can access verification codes",
)

try:
user_uuid = UUID(user_id)
except ValueError:
raise R2RException(
status_code=400, message="Invalid user ID format"
)
result = await self.service.get_user_reset_token(user_uuid)
if not result["reset_token"]:
raise R2RException(
status_code=404, message="No reset token found"
)
return result

# Add to AuthRouter class (auth_router.py)
Expand Down
24 changes: 22 additions & 2 deletions py/core/main/services/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,15 @@ async def clean_expired_blacklisted_tokens(
)

@telemetry_event("GetUserVerificationCode")
async def get_user_verification_data(
async def get_user_verification_code(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
Get only the verification code data for a specific user.
This method should be called after superuser authorization has been verified.
"""
verification_data = (
await self.providers.database.get_user_verification_data(user_id)
await self.providers.database.get_user_validation_data(user_id)
)
return {
"verification_code": verification_data["verification_data"][
Expand All @@ -204,6 +204,26 @@ async def get_user_verification_data(
],
}

@telemetry_event("GetUserVerificationCode")
async def get_user_reset_token(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
Get only the verification code data for a specific user.
This method should be called after superuser authorization has been verified.
"""
verification_data = (
await self.providers.database.get_user_validation_data(user_id)
)
return {
"reset_token": verification_data["verification_data"][
"reset_token"
],
"expiry": verification_data["verification_data"][
"reset_token_expiry"
],
}

@telemetry_event("SendResetEmail")
async def send_reset_email(self, email: str) -> dict:
"""
Expand Down
3 changes: 2 additions & 1 deletion py/core/pipes/kg/triples_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ async def _run_logic( # type: ignore

# sort the extractions accroding to chunk_order field in metadata in ascending order
extractions = sorted(
extractions, key=lambda x: x.metadata.get("chunk_order", float('inf'))
extractions,
key=lambda x: x.metadata.get("chunk_order", float("inf")),
)

# group these extractions into groups of extraction_merge_count
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ async def _collection_exists(self, collection_id: UUID) -> bool:
)
return result is not None

async def get_user_verification_data(
async def get_user_validation_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
Expand Down
14 changes: 14 additions & 0 deletions py/sdk/mixins/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,20 @@ async def get_user_verification_code(
"GET", f"user/{user_id}/verification_data"
)

async def get_user_reset_token(self, user_id: Union[str, UUID]) -> dict:
"""
Retrieves only the verification code for a specific user. Requires superuser access.
Args:
user_id (Union[str, UUID]): The ID of the user to get verification code for.
Returns:
dict: Contains verification code and its expiry date
"""
return await self._make_request( # type: ignore
"GET", f"user/{user_id}/reset_token"
)

async def send_reset_email(self, email: str) -> dict:
"""
Generates a new verification code and sends a reset email to the user.
Expand Down
13 changes: 13 additions & 0 deletions py/shared/api/models/management/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ class VerificationResult(BaseModel):
message: Optional[str] = None


class VerificationResult(BaseModel):
verification_code: str
expiry: datetime
message: Optional[str] = None


class ResetDataResult(BaseModel):
reset_token: str
expiry: datetime
message: Optional[str] = None


class AddUserResponse(BaseModel):
result: bool

Expand Down Expand Up @@ -178,6 +190,7 @@ class AddUserResponse(BaseModel):
]
WrappedDeleteResponse = ResultsWrapper[None]
WrappedVerificationResult = ResultsWrapper[VerificationResult]
WrappedResetDataResult = ResultsWrapper[ResetDataResult]
WrappedConversationsOverviewResponse = PaginatedResultsWrapper[
list[ConversationOverviewResponse]
]

0 comments on commit 2a3f06a

Please sign in to comment.