From 2a3f06a541e9643205b9b66134f6076ee27a522a Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Sun, 17 Nov 2024 20:45:35 -0800 Subject: [PATCH 1/2] expose reset data to admin (#1602) --- py/core/base/api/models/__init__.py | 2 ++ py/core/base/providers/database.py | 6 ++-- py/core/main/api/auth_router.py | 32 +++++++++++++++++++- py/core/main/services/auth_service.py | 24 +++++++++++++-- py/core/pipes/kg/triples_extraction.py | 3 +- py/core/providers/database/user.py | 2 +- py/sdk/mixins/auth.py | 14 +++++++++ py/shared/api/models/management/responses.py | 13 ++++++++ 8 files changed, 88 insertions(+), 8 deletions(-) diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index 46d3007db..9b7590f42 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -60,6 +60,7 @@ WrappedGetPromptsResponse, WrappedLogResponse, WrappedPromptMessageResponse, + WrappedResetDataResult, WrappedServerStatsResponse, WrappedUserCollectionResponse, WrappedUserOverviewResponse, @@ -86,6 +87,7 @@ "WrappedUserResponse", "WrappedVerificationResult", "WrappedGenericMessageResponse", + "WrappedResetDataResult", # Ingestion Responses "IngestionResponse", "WrappedIngestionResponse", diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 4411e0a1c..83432c8f0 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -495,7 +495,7 @@ async def get_users_overview( pass @abstractmethod - async def get_user_verification_data( + async def get_user_validation_data( self, user_id: UUID, *args, **kwargs ) -> dict: """ @@ -1393,10 +1393,10 @@ async def get_users_overview( user_ids, offset, limit ) - async def get_user_verification_data( + async def get_user_validation_data( self, user_id: UUID, *args, **kwargs ) -> dict: - return await self.user_handler.get_user_verification_data(user_id) + return await self.user_handler.get_user_validation_data(user_id) # Vector handler methods async def upsert(self, entry: VectorEntry) -> None: diff --git a/py/core/main/api/auth_router.py b/py/core/main/api/auth_router.py index b439017e1..8bcbe3c48 100644 --- a/py/core/main/api/auth_router.py +++ b/py/core/main/api/auth_router.py @@ -9,6 +9,7 @@ from core.base.api.models import ( GenericMessageResponse, WrappedGenericMessageResponse, + WrappedResetDataResult, WrappedTokenResponse, WrappedUserResponse, WrappedVerificationResult, @@ -280,7 +281,36 @@ async def get_user_verification_code( raise R2RException( status_code=400, message="Invalid user ID format" ) - result = await self.service.get_user_verification_data(user_uuid) + result = await self.service.get_user_verification_code(user_uuid) + return result + + @self.router.get("/user/{user_id}/reset_token") + @self.base_endpoint + async def get_user_reset_token( + user_id: str = Path(..., description="User ID"), + auth_user=Depends(self.service.providers.auth.auth_wrapper), + ) -> WrappedResetDataResult: + """ + Get only the verification code for a specific user. + Only accessible by superusers. + """ + if not auth_user.is_superuser: + raise R2RException( + status_code=403, + message="Only superusers can access verification codes", + ) + + try: + user_uuid = UUID(user_id) + except ValueError: + raise R2RException( + status_code=400, message="Invalid user ID format" + ) + result = await self.service.get_user_reset_token(user_uuid) + if not result["reset_token"]: + raise R2RException( + status_code=404, message="No reset token found" + ) return result # Add to AuthRouter class (auth_router.py) diff --git a/py/core/main/services/auth_service.py b/py/core/main/services/auth_service.py index f72e3a1c5..8eb350605 100644 --- a/py/core/main/services/auth_service.py +++ b/py/core/main/services/auth_service.py @@ -185,7 +185,7 @@ async def clean_expired_blacklisted_tokens( ) @telemetry_event("GetUserVerificationCode") - async def get_user_verification_data( + async def get_user_verification_code( self, user_id: UUID, *args, **kwargs ) -> dict: """ @@ -193,7 +193,7 @@ async def get_user_verification_data( This method should be called after superuser authorization has been verified. """ verification_data = ( - await self.providers.database.get_user_verification_data(user_id) + await self.providers.database.get_user_validation_data(user_id) ) return { "verification_code": verification_data["verification_data"][ @@ -204,6 +204,26 @@ async def get_user_verification_data( ], } + @telemetry_event("GetUserVerificationCode") + async def get_user_reset_token( + self, user_id: UUID, *args, **kwargs + ) -> dict: + """ + Get only the verification code data for a specific user. + This method should be called after superuser authorization has been verified. + """ + verification_data = ( + await self.providers.database.get_user_validation_data(user_id) + ) + return { + "reset_token": verification_data["verification_data"][ + "reset_token" + ], + "expiry": verification_data["verification_data"][ + "reset_token_expiry" + ], + } + @telemetry_event("SendResetEmail") async def send_reset_email(self, email: str) -> dict: """ diff --git a/py/core/pipes/kg/triples_extraction.py b/py/core/pipes/kg/triples_extraction.py index 9037e22c3..e0643e169 100644 --- a/py/core/pipes/kg/triples_extraction.py +++ b/py/core/pipes/kg/triples_extraction.py @@ -289,7 +289,8 @@ async def _run_logic( # type: ignore # sort the extractions accroding to chunk_order field in metadata in ascending order extractions = sorted( - extractions, key=lambda x: x.metadata.get("chunk_order", float('inf')) + extractions, + key=lambda x: x.metadata.get("chunk_order", float("inf")), ) # group these extractions into groups of extraction_merge_count diff --git a/py/core/providers/database/user.py b/py/core/providers/database/user.py index 0c45d761a..a6974a1d4 100644 --- a/py/core/providers/database/user.py +++ b/py/core/providers/database/user.py @@ -580,7 +580,7 @@ async def _collection_exists(self, collection_id: UUID) -> bool: ) return result is not None - async def get_user_verification_data( + async def get_user_validation_data( self, user_id: UUID, *args, **kwargs ) -> dict: """ diff --git a/py/sdk/mixins/auth.py b/py/sdk/mixins/auth.py index c6eab6baa..d3ca201ac 100644 --- a/py/sdk/mixins/auth.py +++ b/py/sdk/mixins/auth.py @@ -213,6 +213,20 @@ async def get_user_verification_code( "GET", f"user/{user_id}/verification_data" ) + async def get_user_reset_token(self, user_id: Union[str, UUID]) -> dict: + """ + Retrieves only the verification code for a specific user. Requires superuser access. + + Args: + user_id (Union[str, UUID]): The ID of the user to get verification code for. + + Returns: + dict: Contains verification code and its expiry date + """ + return await self._make_request( # type: ignore + "GET", f"user/{user_id}/reset_token" + ) + async def send_reset_email(self, email: str) -> dict: """ Generates a new verification code and sends a reset email to the user. diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index 924c5276e..14cd54bb3 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -145,6 +145,18 @@ class VerificationResult(BaseModel): message: Optional[str] = None +class VerificationResult(BaseModel): + verification_code: str + expiry: datetime + message: Optional[str] = None + + +class ResetDataResult(BaseModel): + reset_token: str + expiry: datetime + message: Optional[str] = None + + class AddUserResponse(BaseModel): result: bool @@ -178,6 +190,7 @@ class AddUserResponse(BaseModel): ] WrappedDeleteResponse = ResultsWrapper[None] WrappedVerificationResult = ResultsWrapper[VerificationResult] +WrappedResetDataResult = ResultsWrapper[ResetDataResult] WrappedConversationsOverviewResponse = PaginatedResultsWrapper[ list[ConversationOverviewResponse] ] From b96312161e8a9821b56099949f857cc9eab8a5d7 Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Sun, 17 Nov 2024 21:09:52 -0800 Subject: [PATCH 2/2] up (#1603) * up * up --- py/core/main/services/management_service.py | 2 +- py/core/providers/logger/r2r_logger.py | 19 +++++++++++++------ py/sdk/mixins/management.py | 15 ++++++++++----- py/shared/api/models/management/responses.py | 2 +- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index efdd08893..86c3268f8 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -674,7 +674,7 @@ async def get_conversation( conversation_id: str, branch_id: Optional[str] = None, auth_user=None, - ) -> Tuple[str, list[Message]]: + ) -> Tuple[str, list[Message], list[dict]]: return await self.logging_connection.get_conversation( conversation_id, branch_id ) diff --git a/py/core/providers/logger/r2r_logger.py b/py/core/providers/logger/r2r_logger.py index 1cf25c031..14aff8bdc 100644 --- a/py/core/providers/logger/r2r_logger.py +++ b/py/core/providers/logger/r2r_logger.py @@ -655,25 +655,32 @@ async def get_conversation( # Get all messages for this branch async with self.conn.execute( """ - WITH RECURSIVE branch_messages(id, content, parent_id, depth, created_at) AS ( - SELECT m.id, m.content, m.parent_id, 0, m.created_at + WITH RECURSIVE branch_messages(id, content, parent_id, depth, created_at, metadata) AS ( + SELECT m.id, m.content, m.parent_id, 0, m.created_at, m.metadata FROM messages m JOIN message_branches mb ON m.id = mb.message_id WHERE mb.branch_id = ? AND m.parent_id IS NULL UNION - SELECT m.id, m.content, m.parent_id, bm.depth + 1, m.created_at + SELECT m.id, m.content, m.parent_id, bm.depth + 1, m.created_at, m.metadata FROM messages m JOIN message_branches mb ON m.id = mb.message_id JOIN branch_messages bm ON m.parent_id = bm.id WHERE mb.branch_id = ? ) - SELECT id, content, parent_id FROM branch_messages + SELECT id, content, parent_id, metadata FROM branch_messages ORDER BY created_at ASC - """, + """, (branch_id, branch_id), ) as cursor: rows = await cursor.fetchall() - return [(row[0], Message.parse_raw(row[1])) for row in rows] + return [ + ( + row[0], # id + Message.parse_raw(row[1]), # message content + json.loads(row[3]) if row[3] else {}, # metadata + ) + for row in rows + ] async def get_branches_overview(self, conversation_id: str) -> list[dict]: if not self.conn: diff --git a/py/sdk/mixins/management.py b/py/sdk/mixins/management.py index bb0d7fda2..21e304fbe 100644 --- a/py/sdk/mixins/management.py +++ b/py/sdk/mixins/management.py @@ -695,7 +695,7 @@ async def create_conversation(self) -> dict: async def add_message( self, conversation_id: Union[str, UUID], - message: Message, + message: dict, parent_id: Optional[str] = None, metadata: Optional[dict[str, Any]] = None, ) -> dict: @@ -716,9 +716,14 @@ async def add_message( data["parent_id"] = parent_id if metadata is not None: data["metadata"] = metadata - return await self._make_request( # type: ignore - "POST", f"add_message/{str(conversation_id)}", data=data - ) + if len(data) == 1: + return await self._make_request( # type: ignore + "POST", f"add_message/{str(conversation_id)}", json=data + ) + else: + return await self._make_request( # type: ignore + "POST", f"add_message/{str(conversation_id)}", data=data + ) async def update_message( self, @@ -755,7 +760,7 @@ async def update_message_metadata( dict: The response from the server. """ return await self._make_request( # type: ignore - "PATCH", f"messages/{message_id}/metadata", data=metadata + "PATCH", f"messages/{message_id}/metadata", json=metadata ) async def branches_overview( diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index 14cd54bb3..5702cd8bf 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -171,7 +171,7 @@ class AddUserResponse(BaseModel): WrappedUserOverviewResponse = PaginatedResultsWrapper[ list[UserOverviewResponse] ] -WrappedConversationResponse = ResultsWrapper[list[Tuple[str, Message]]] +WrappedConversationResponse = ResultsWrapper[list[Tuple[str, Message, dict]]] WrappedDocumentOverviewResponse = PaginatedResultsWrapper[ list[DocumentOverviewResponse] ]