From 3d6ebe95c8758fdfc5a2cd87ea9d54cc6592a673 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Tue, 12 Nov 2024 18:09:07 -0800 Subject: [PATCH] More cleanups --- py/compose.full.yaml | 40 +++++++++---------- py/compose.yaml | 42 ++++++++++---------- py/core/base/api/models/__init__.py | 33 ++++++--------- py/core/main/api/v2/management_router.py | 32 ++++++++------- py/core/main/api/v3/chunks_router.py | 18 ++++----- py/core/main/api/v3/collections_router.py | 5 +-- py/core/main/api/v3/documents_router.py | 4 +- py/core/main/api/v3/prompts_router.py | 13 +++--- py/core/main/api/v3/users_router.py | 9 ++--- py/core/main/services/management_service.py | 4 +- py/core/providers/database/postgres.py | 1 - py/core/providers/database/user.py | 2 +- py/shared/abstractions/user.py | 3 +- py/shared/api/models/__init__.py | 34 +++++++--------- py/shared/api/models/auth/responses.py | 25 +----------- py/shared/api/models/management/responses.py | 34 +++++----------- 16 files changed, 123 insertions(+), 176 deletions(-) diff --git a/py/compose.full.yaml b/py/compose.full.yaml index 8c7ea4a19..c68dfdb8a 100644 --- a/py/compose.full.yaml +++ b/py/compose.full.yaml @@ -26,11 +26,11 @@ services: image: pgvector/pgvector:pg16 profiles: [postgres] environment: - - POSTGRES_USER=${R2R_POSTGRES_USER:-${POSTGRES_USER:-postgres}} # Eventually get rid of POSTGRES_USER, but for now keep it for backwards compatibility - - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-${POSTGRES_PASSWORD:-postgres}} # Eventually get rid of POSTGRES_PASSWORD, but for now keep it for backwards compatibility - - POSTGRES_HOST=${R2R_POSTGRES_HOST:-${POSTGRES_HOST:-postgres}} # Eventually get rid of POSTGRES_HOST, but for now keep it for backwards compatibility - - POSTGRES_PORT=${R2R_POSTGRES_PORT:-${POSTGRES_PORT:-5432}} # Eventually get rid of POSTGRES_PORT, but for now keep it for backwards compatibility - - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-${POSTGRES_MAX_CONNECTIONS:-1024}} # Eventually get rid of POSTGRES_MAX_CONNECTIONS, but for now keep it for backwards compatibility + - POSTGRES_USER=${R2R_POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres} + - POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres} + - POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432} + - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} volumes: - postgres_data:/var/lib/postgresql/data networks: @@ -274,30 +274,30 @@ services: build: context: . args: - PORT: ${R2R_PORT:-${PORT:-7272}} # Eventually get rid of PORT, but for now keep it for backwards compatibility + PORT: ${R2R_PORT:-7272} R2R_PORT: ${R2R_PORT:-${PORT:-7272}} - HOST: ${R2R_HOST:-${HOST:-0.0.0.0}} # Eventually get rid of HOST, but for now keep it for backwards compatibility - R2R_HOST: ${R2R_HOST:-${HOST:-0.0.0.0}} + HOST: ${R2R_HOST:-0.0.0.0} + R2R_HOST: ${R2R_HOST:-0.0.0.0} ports: - - "${R2R_PORT:-${PORT:-7272}}:${R2R_PORT:-${PORT:-7272}}" + - "${R2R_PORT:-7272}:${R2R_PORT:-7272}" environment: - PYTHONUNBUFFERED=1 - - R2R_PORT=${R2R_PORT:-${PORT:-7272}} # Eventually get rid of PORT, but for now keep it for backwards compatibility - - R2R_HOST=${R2R_HOST:-${HOST:-0.0.0.0}} # Eventually get rid of HOST, but for now keep it for backwards compatibility + - R2R_PORT=${R2R_PORT:-7272} + - R2R_HOST=${R2R_HOST:-0.0.0.0} # R2R - - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-${CONFIG_NAME:-}} # Eventually get rid of CONFIG_NAME, but for now keep it for backwards compatibility - - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-${CONFIG_PATH:-}} # Eventually get rid of CONFIG_PATH, but for now keep it for backwards compatibility + - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-} + - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-} - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default} # Postgres - - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-${POSTGRES_USER:-postgres}} # Eventually get rid of POSTGRES_USER, but for now keep it for backwards compatibility - - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-${POSTGRES_PASSWORD:-postgres}} # Eventually get rid of POSTGRES_PASSWORD, but for now keep it for backwards compatibility - - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-${POSTGRES_HOST:-postgres}} # Eventually get rid of POSTGRES_HOST, but for now keep it for backwards compatibility - - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-${POSTGRES_PORT:-5432}} # Eventually get rid of POSTGRES_PORT, but for now keep it for backwards compatibility - - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-${POSTGRES_DBNAME:-postgres}} # Eventually get rid of POSTGRES_DBNAME, but for now keep it for backwards compatibility - - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-${POSTGRES_MAX_CONNECTIONS:-1024}} # Eventually get rid of POSTGRES_MAX_CONNECTIONS, but for now keep it for backwards compatibility - - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-${POSTGRES_PROJECT_NAME:-r2r_default}} # Eventually get rid of POSTGRES_PROJECT_NAME, but for now keep it for backwards compatibility + - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres} + - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres} + - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres} + - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432} + - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres} + - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} + - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default} # OpenAI - OPENAI_API_KEY=${OPENAI_API_KEY:-} diff --git a/py/compose.yaml b/py/compose.yaml index b764817ed..652e3940c 100644 --- a/py/compose.yaml +++ b/py/compose.yaml @@ -14,11 +14,11 @@ services: image: pgvector/pgvector:pg16 profiles: [postgres] environment: - - POSTGRES_USER=${R2R_POSTGRES_USER:-${POSTGRES_USER:-postgres}} # Eventually get rid of POSTGRES_USER, but for now keep it for backwards compatibility - - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-${POSTGRES_PASSWORD:-postgres}} # Eventually get rid of POSTGRES_PASSWORD, but for now keep it for backwards compatibility - - POSTGRES_HOST=${R2R_POSTGRES_HOST:-${POSTGRES_HOST:-postgres}} # Eventually get rid of POSTGRES_HOST, but for now keep it for backwards compatibility - - POSTGRES_PORT=${R2R_POSTGRES_PORT:-${POSTGRES_PORT:-5432}} # Eventually get rid of POSTGRES_PORT, but for now keep it for backwards compatibility - - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-${POSTGRES_MAX_CONNECTIONS:-1024}} # Eventually get rid of POSTGRES_MAX_CONNECTIONS, but for now keep it for backwards compatibility + - POSTGRES_USER=${R2R_POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres} + - POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres} + - POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432} + - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} volumes: - postgres_data:/var/lib/postgresql/data networks: @@ -40,30 +40,30 @@ services: build: context: . args: - PORT: ${R2R_PORT:-${PORT:-7272}} # Eventually get rid of PORT, but for now keep it for backwards compatibility - R2R_PORT: ${R2R_PORT:-${PORT:-7272}} - HOST: ${R2R_HOST:-${HOST:-0.0.0.0}} # Eventually get rid of HOST, but for now keep it for backwards compatibility - R2R_HOST: ${R2R_HOST:-${HOST:-0.0.0.0}} + PORT: ${R2R_PORT:-7272} + R2R_PORT: ${R2R_PORT:-7272} + HOST: ${R2R_HOST:-0.0.0.0} + R2R_HOST: ${R2R_HOST:-0.0.0.0} ports: - - "${R2R_PORT:-${PORT:-7272}}:${R2R_PORT:-${PORT:-7272}}" + - "${R2R_PORT:-7272}:${R2R_PORT:-7272}" environment: - PYTHONUNBUFFERED=1 - - R2R_PORT=${R2R_PORT:-${PORT:-7272}} # Eventually get rid of PORT, but for now keep it for backwards compatibility - - R2R_HOST=${R2R_HOST:-${HOST:-0.0.0.0}} # Eventually get rid of HOST, but for now keep it for backwards compatibility + - R2R_PORT=${R2R_PORT:-7272} + - R2R_HOST=${R2R_HOST:-0.0.0.0} # R2R - - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-${CONFIG_NAME:-}} # Eventually get rid of CONFIG_NAME, but for now keep it for backwards compatibility - - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-${CONFIG_PATH:-}} # Eventually get rid of CONFIG_PATH, but for now keep it for backwards compatibility + - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:--} + - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:--} - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default} # Postgres - - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-${POSTGRES_USER:-postgres}} # Eventually get rid of POSTGRES_USER, but for now keep it for backwards compatibility - - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-${POSTGRES_PASSWORD:-postgres}} # Eventually get rid of POSTGRES_PASSWORD, but for now keep it for backwards compatibility - - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-${POSTGRES_HOST:-postgres}} # Eventually get rid of POSTGRES_HOST, but for now keep it for backwards compatibility - - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-${POSTGRES_PORT:-5432}} # Eventually get rid of POSTGRES_PORT, but for now keep it for backwards compatibility - - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-${POSTGRES_DBNAME:-postgres}} # Eventually get rid of POSTGRES_DBNAME, but for now keep it for backwards compatibility - - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-${POSTGRES_MAX_CONNECTIONS:-1024}} # Eventually get rid of POSTGRES_MAX_CONNECTIONS, but for now keep it for backwards compatibility - - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-${POSTGRES_PROJECT_NAME:-r2r_default}} # Eventually get rid of POSTGRES_PROJECT_NAME, but for now keep it for backwards compatibility + - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres} + - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres} + - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres} + - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432} + - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres} + - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} + - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default} # OpenAI - OPENAI_API_KEY=${OPENAI_API_KEY:-} diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index aecd3cf17..76c3af398 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -8,9 +8,7 @@ ) from shared.api.models.auth.responses import ( TokenResponse, - UserResponse, WrappedTokenResponse, - WrappedUserResponse, ) from shared.api.models.ingestion.responses import ( IngestionResponse, @@ -40,13 +38,12 @@ AppSettingsResponse, CollectionResponse, ConversationResponse, - DocumentChunkResponse, + ChunkResponse, + UserResponse, LogResponse, PromptResponse, ScoreCompletionResponse, ServerStats, - UserOverviewResponse, - WrappedAddUserResponse, WrappedAnalyticsResponse, WrappedAppSettingsResponse, WrappedCollectionResponse, @@ -61,18 +58,18 @@ WrappedPromptResponse, WrappedPromptsResponse, WrappedLogResponse, - WrappedPromptMessageResponse, # Chunk Responses - WrappedDocumentChunkResponse, - WrappedDocumentChunksResponse, + WrappedChunkResponse, + WrappedChunksResponse, # Conversation Responses WrappedMessageResponse, WrappedMessagesResponse, WrappedBranchResponse, WrappedBranchesResponse, # User Responses - WrappedUserOverviewResponse, - WrappedUsersOverviewResponse, + WrappedUserResponse, + WrappedUsersResponse, + # TODO: anything below this hasn't been reviewed WrappedServerStatsResponse, WrappedUserCollectionResponse, WrappedUsersInCollectionResponse, @@ -93,9 +90,7 @@ __all__ = [ # Auth Responses "TokenResponse", - "UserResponse", "WrappedTokenResponse", - "WrappedUserResponse", "WrappedVerificationResult", "WrappedGenericMessageResponse", # Ingestion Responses @@ -123,10 +118,8 @@ "AnalyticsResponse", "AppSettingsResponse", "ScoreCompletionResponse", - "UserOverviewResponse", - "DocumentChunkResponse", + "ChunkResponse", "CollectionResponse", - "WrappedPromptMessageResponse", "WrappedServerStatsResponse", "WrappedLogResponse", "WrappedAnalyticsResponse", @@ -135,7 +128,6 @@ "WrappedDocumentsResponse", "WrappedCollectionResponse", "WrappedCollectionsResponse", - "WrappedAddUserResponse", "WrappedUsersInCollectionResponse", # Conversation Responses "ConversationResponse", @@ -150,11 +142,12 @@ "WrappedBranchResponse", "WrappedBranchesResponse", # Chunk Responses - "WrappedDocumentChunkResponse", - "WrappedDocumentChunksResponse", + "WrappedChunkResponse", + "WrappedChunksResponse", # User Responses - "WrappedUserOverviewResponse", - "WrappedUsersOverviewResponse", + "UserResponse", + "WrappedUserResponse", + "WrappedUsersResponse", # Base Responses "PaginatedResultsWrapper", "ResultsWrapper", diff --git a/py/core/main/api/v2/management_router.py b/py/core/main/api/v2/management_router.py index b24dfad04..9ddb7722e 100644 --- a/py/core/main/api/v2/management_router.py +++ b/py/core/main/api/v2/management_router.py @@ -12,7 +12,8 @@ from core.base import Message, R2RException from core.base.api.models import ( - WrappedAddUserResponse, + WrappedBooleanResponse, + GenericMessageResponse, WrappedAnalyticsResponse, WrappedAppSettingsResponse, WrappedCollectionResponse, @@ -20,14 +21,15 @@ WrappedConversationResponse, WrappedConversationsResponse, WrappedDeleteResponse, - WrappedDocumentChunksResponse, + WrappedChunksResponse, + WrappedGenericMessageResponse, + WrappedBooleanResponse, WrappedDocumentsResponse, WrappedPromptsResponse, WrappedLogResponse, - WrappedPromptMessageResponse, WrappedServerStatsResponse, WrappedUserCollectionResponse, - WrappedUsersOverviewResponse, + WrappedUsersResponse, WrappedUsersInCollectionResponse, ) from core.base.logger import AnalysisTypes, LogFilterCriteria @@ -91,7 +93,7 @@ async def update_prompt_app( {}, description="Input types" ), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedPromptMessageResponse: + ) -> WrappedGenericMessageResponse: if not auth_user.is_superuser: raise R2RException( "Only a superuser can call the `update_prompt` endpoint.", @@ -101,7 +103,7 @@ async def update_prompt_app( result = await self.service.update_prompt( name, template, input_types ) - return result # type: ignore + return GenericMessageResponse(message=result) @self.router.post("/add_prompt") @self.base_endpoint @@ -110,14 +112,14 @@ async def add_prompt_app( template: str = Body(..., description="Prompt template"), input_types: dict[str, str] = Body({}, description="Input types"), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedPromptMessageResponse: + ) -> WrappedGenericMessageResponse: if not auth_user.is_superuser: raise R2RException( "Only a superuser can call the `add_prompt` endpoint.", 403, ) result = await self.service.add_prompt(name, template, input_types) - return result # type: ignore + return GenericMessageResponse(message=result) @self.router.get("/get_prompt/{prompt_name}") @self.base_endpoint @@ -130,7 +132,7 @@ async def get_prompt_app( None, description="Prompt override" ), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedPromptMessageResponse: + ) -> WrappedGenericMessageResponse: if not auth_user.is_superuser: raise R2RException( "Only a superuser can call the `get_prompt` endpoint.", @@ -139,7 +141,7 @@ async def get_prompt_app( result = await self.service.get_cached_prompt( prompt_name, inputs, prompt_override ) - return result # type: ignore + return GenericMessageResponse(message=result) @self.router.get("/get_all_prompts") @self.base_endpoint @@ -236,7 +238,7 @@ async def users_overview_app( offset: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedUsersOverviewResponse: + ) -> WrappedUsersResponse: if not auth_user.is_superuser: raise R2RException( "Only a superuser can call the `users_overview` endpoint.", @@ -373,7 +375,7 @@ async def document_chunks_app( limit: Optional[int] = Query(100, ge=0), include_vectors: Optional[bool] = Query(False), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedDocumentChunksResponse: + ) -> WrappedChunksResponse: document_uuid = UUID(document_id) document_chunks = await self.service.list_document_chunks( @@ -431,7 +433,7 @@ async def document_chunks_app( limit: Optional[int] = Query(100, ge=0), include_vectors: Optional[bool] = Query(False), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedDocumentChunksResponse: + ) -> WrappedChunksResponse: document_uuid = UUID(document_id) list_document_chunks = await self.service.list_document_chunks( @@ -627,7 +629,7 @@ async def add_user_to_collection_app( user_id: str = Body(..., description="User ID"), collection_id: str = Body(..., description="Collection ID"), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedAddUserResponse: + ) -> WrappedBooleanResponse: collection_uuid = UUID(collection_id) user_uuid = UUID(user_id) if ( @@ -642,7 +644,7 @@ async def add_user_to_collection_app( result = await self.service.add_user_to_collection( user_uuid, collection_uuid ) - return result # type: ignore + return WrappedBooleanResponse(result=result) @self.router.post("/remove_user_from_collection") @self.base_endpoint diff --git a/py/core/main/api/v3/chunks_router.py b/py/core/main/api/v3/chunks_router.py index 19b614702..6d98a79a2 100644 --- a/py/core/main/api/v3/chunks_router.py +++ b/py/core/main/api/v3/chunks_router.py @@ -15,13 +15,13 @@ UnprocessedChunk, UpdateChunk, VectorSearchSettings, - DocumentChunkResponse, + ChunkResponse, ) from core.base.api.models import ( GenericBooleanResponse, WrappedBooleanResponse, - WrappedDocumentChunkResponse, - WrappedDocumentChunksResponse, + WrappedChunkResponse, + WrappedChunksResponse, WrappedVectorSearchResponse, ) from core.providers import ( @@ -373,7 +373,7 @@ async def search_chunks( async def retrieve_chunk( id: UUID = Path(...), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedDocumentChunkResponse: + ) -> WrappedChunkResponse: """ Get a specific chunk by its ID. @@ -392,7 +392,7 @@ async def retrieve_chunk( ): raise R2RException("Not authorized to access this chunk", 403) - return DocumentChunkResponse( # type: ignore + return ChunkResponse( # type: ignore id=chunk["chunk_id"], document_id=chunk["document_id"], user_id=chunk["user_id"], @@ -453,7 +453,7 @@ async def update_chunk( chunk_update: UpdateChunk = Body(...), # TODO: Run with orchestration? auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedDocumentChunkResponse: + ) -> WrappedChunkResponse: """ Update an existing chunk's content and/or metadata. @@ -486,7 +486,7 @@ async def update_chunk( ) await simple_ingestor["update-chunk"](workflow_input) - return DocumentChunkResponse( + return ChunkResponse( id=chunk_update.id, document_id=existing_chunk["document_id"], user_id=existing_chunk["user_id"], @@ -665,7 +665,7 @@ async def list_chunks( description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedDocumentChunksResponse: + ) -> WrappedChunksResponse: """ List chunks with pagination support. @@ -696,7 +696,7 @@ async def list_chunks( # Convert to response format chunks = [ - DocumentChunkResponse( + ChunkResponse( id=chunk["chunk_id"], document_id=chunk["document_id"], user_id=chunk["user_id"], diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index 799a48d75..0bddb752d 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -8,7 +8,6 @@ from core.base import R2RException, RunType from core.base.api.models import ( GenericBooleanResponse, - WrappedAddUserResponse, WrappedBooleanResponse, WrappedCollectionResponse, WrappedCollectionsResponse, @@ -866,7 +865,7 @@ async def add_user_to_collection( ..., description="The unique identifier of the user to add" ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedAddUserResponse: + ) -> WrappedBooleanResponse: """ Add a user to a collection. @@ -885,7 +884,7 @@ async def add_user_to_collection( result = await self.services["management"].add_user_to_collection( user_id, id ) - return result # type: ignore + return GenericBooleanResponse(success=result) @self.router.delete( "/collections/{id}/users/{user_id}", diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index b1d0aa7b9..49f473382 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -17,7 +17,7 @@ WrappedDocumentsResponse, WrappedIngestionResponse, WrappedBooleanResponse, - WrappedDocumentChunksResponse, + WrappedChunksResponse, ) from pydantic import Json @@ -706,7 +706,7 @@ async def list_chunks( description="Whether to include vector embeddings in the response.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedDocumentChunksResponse: + ) -> WrappedChunksResponse: """ Retrieves the text chunks that were generated from a document during ingestion. Chunks represent semantic sections of the document and are used for retrieval diff --git a/py/core/main/api/v3/prompts_router.py b/py/core/main/api/v3/prompts_router.py index 46b2a8915..70d631eca 100644 --- a/py/core/main/api/v3/prompts_router.py +++ b/py/core/main/api/v3/prompts_router.py @@ -6,10 +6,11 @@ from core.base import R2RException, RunType from core.base.api.models import ( GenericBooleanResponse, + GenericMessageResponse, + WrappedBooleanResponse, + WrappedGenericMessageResponse, WrappedPromptResponse, WrappedPromptsResponse, - WrappedPromptMessageResponse, - WrappedBooleanResponse, ) from core.providers import ( HatchetOrchestrationProvider, @@ -99,7 +100,7 @@ async def create_prompt( description="A dictionary mapping input names to their types", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedPromptMessageResponse: + ) -> WrappedGenericMessageResponse: """ Create a new prompt with the given configuration. @@ -113,7 +114,7 @@ async def create_prompt( result = await self.services["management"].add_prompt( name, template, input_types ) - return result + return GenericMessageResponse(message=result) @self.router.get( "/prompts", @@ -334,7 +335,7 @@ async def update_prompt( description="A dictionary mapping input names to their types", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedPromptMessageResponse: + ) -> WrappedGenericMessageResponse: """ Update an existing prompt's template and/or input types. @@ -348,7 +349,7 @@ async def update_prompt( result = await self.services["management"].update_prompt( name, template, input_types ) - return result # type: ignore + return GenericMessageResponse(message=result) @self.router.delete( "/prompts/{name}", diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index 0d4a86d42..81e89aba7 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -13,8 +13,7 @@ GenericBooleanResponse, WrappedGenericMessageResponse, WrappedTokenResponse, - WrappedUserOverviewResponse, - WrappedUsersOverviewResponse, + WrappedUsersResponse, WrappedUserResponse, WrappedBooleanResponse, WrappedCollectionsResponse, @@ -575,7 +574,6 @@ async def list_users( # is_active: Optional[bool] = Query(None, example=True), # is_superuser: Optional[bool] = Query(None, example=False), # auth_user=Depends(self.providers.auth.auth_wrapper), - # ) -> PaginatedResultsWrapper[List[UserOverviewResponse]]: user_ids: Optional[list[UUID]] = Query( None, description="List of user IDs to filter by" ), @@ -591,7 +589,7 @@ async def list_users( description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedUsersOverviewResponse: + ) -> WrappedUsersResponse: """ List all users with pagination and filtering options. Only accessible by superusers. @@ -613,7 +611,6 @@ async def list_users( @self.router.get( "/users/{id}", summary="Get User Details", - # response_model=ResultsWrapper[UserOverviewResponse], openapi_extra={ "x-codeSamples": [ { @@ -668,7 +665,7 @@ async def get_user( ..., example="550e8400-e29b-41d4-a716-446655440000" ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedUserOverviewResponse: + ) -> WrappedUserResponse: """ Get detailed information about a specific user. Users can only access their own information unless they are superusers. diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index e78277274..95428d064 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -622,7 +622,7 @@ async def add_prompt( await self.providers.database.add_prompt( name, template, input_types ) - return {"message": f"Prompt '{name}' added successfully."} + return f"Prompt '{name}' added successfully." except ValueError as e: raise R2RException(status_code=400, message=str(e)) @@ -673,7 +673,7 @@ async def update_prompt( await self.providers.database.update_prompt( name, template, input_types ) - return {"message": f"Prompt '{name}' updated successfully."} + return f"Prompt '{name}' updated successfully." except ValueError as e: raise R2RException(status_code=404, message=str(e)) diff --git a/py/core/providers/database/postgres.py b/py/core/providers/database/postgres.py index 655473192..dfbf3541f 100644 --- a/py/core/providers/database/postgres.py +++ b/py/core/providers/database/postgres.py @@ -6,7 +6,6 @@ from core.base import ( DatabaseConfig, - DatabaseConnectionManager, DatabaseProvider, PostgresConfigurationSettings, VectorQuantizationType, diff --git a/py/core/providers/database/user.py b/py/core/providers/database/user.py index e79343006..f55d461d7 100644 --- a/py/core/providers/database/user.py +++ b/py/core/providers/database/user.py @@ -549,7 +549,7 @@ async def get_users_overview( users = [ UserStats( - user_id=row[0], + id=row[0], email=row[1], is_superuser=row[2], is_active=row[3], diff --git a/py/shared/abstractions/user.py b/py/shared/abstractions/user.py index 7cdaa625e..ac0961167 100644 --- a/py/shared/abstractions/user.py +++ b/py/shared/abstractions/user.py @@ -34,8 +34,9 @@ class TokenData(BaseModel): exp: Optional[datetime] = None +# TODO: Seems like an unnecessary abstraction class UserStats(BaseModel): - user_id: UUID + id: UUID email: str is_superuser: bool is_active: bool diff --git a/py/shared/api/models/__init__.py b/py/shared/api/models/__init__.py index 5098e7ad2..d7172405a 100644 --- a/py/shared/api/models/__init__.py +++ b/py/shared/api/models/__init__.py @@ -8,9 +8,7 @@ ) from shared.api.models.auth.responses import ( TokenResponse, - UserResponse, WrappedTokenResponse, - WrappedUserResponse, ) from shared.api.models.ingestion.responses import ( IngestionResponse, @@ -31,21 +29,20 @@ AppSettingsResponse, CollectionResponse, ConversationResponse, - DocumentChunkResponse, + ChunkResponse, + UserResponse, LogResponse, PromptResponse, ScoreCompletionResponse, ServerStats, - UserOverviewResponse, - WrappedAddUserResponse, WrappedAnalyticsResponse, WrappedAppSettingsResponse, WrappedCollectionResponse, WrappedCollectionsResponse, WrappedConversationResponse, WrappedConversationsResponse, - WrappedDocumentChunkResponse, - WrappedDocumentChunksResponse, + WrappedChunkResponse, + WrappedChunksResponse, # Document Responses WrappedDocumentResponse, WrappedDocumentsResponse, @@ -53,10 +50,10 @@ WrappedPromptResponse, WrappedPromptsResponse, # Collection Responses - WrappedUserOverviewResponse, - WrappedUsersOverviewResponse, + # User Responses + WrappedUserResponse, + WrappedUsersResponse, WrappedLogResponse, - WrappedPromptMessageResponse, WrappedServerStatsResponse, WrappedUserCollectionResponse, WrappedUsersInCollectionResponse, @@ -76,9 +73,7 @@ # Auth Responses "GenericMessageResponse", "TokenResponse", - "UserResponse", "WrappedTokenResponse", - "WrappedUserResponse", "WrappedGenericMessageResponse", # Ingestion Responses "IngestionResponse", @@ -97,11 +92,9 @@ "AnalyticsResponse", "AppSettingsResponse", "ScoreCompletionResponse", - "UserOverviewResponse", - "DocumentChunkResponse", + "ChunkResponse", "CollectionResponse", "ConversationResponse", - "WrappedPromptMessageResponse", "WrappedServerStatsResponse", "WrappedLogResponse", "WrappedAnalyticsResponse", @@ -113,17 +106,18 @@ # Collection Responses "WrappedCollectionResponse", "WrappedCollectionsResponse", - "WrappedAddUserResponse", "WrappedUsersInCollectionResponse", # Prompt Responses "WrappedPromptResponse", "WrappedPromptsResponse", # Chunk Responses - "WrappedDocumentChunkResponse", - "WrappedDocumentChunksResponse", + "WrappedChunkResponse", + "WrappedChunksResponse", # Conversation Responses - "WrappedUserOverviewResponse", - "WrappedUsersOverviewResponse", + # User Responses + "UserResponse", + "WrappedUserResponse", + "WrappedUsersResponse", # Base Responses "PaginatedResultsWrapper", "ResultsWrapper", diff --git a/py/shared/api/models/auth/responses.py b/py/shared/api/models/auth/responses.py index ec8c89723..ccb44687f 100644 --- a/py/shared/api/models/auth/responses.py +++ b/py/shared/api/models/auth/responses.py @@ -1,10 +1,6 @@ -from datetime import datetime -from typing import Optional -from uuid import UUID - from pydantic import BaseModel -from shared.abstractions import R2RSerializable, Token +from shared.abstractions import Token from shared.api.models.base import ResultsWrapper @@ -13,24 +9,5 @@ class TokenResponse(BaseModel): refresh_token: Token -class UserResponse(R2RSerializable): - id: UUID - email: str - is_active: bool = True - is_superuser: bool = False - created_at: datetime = datetime.now() - updated_at: datetime = datetime.now() - is_verified: bool = False - collection_ids: list[UUID] = [] - - # Optional fields (to update or set at creation) - hashed_password: Optional[str] = None - verification_code_expiry: Optional[datetime] = None - name: Optional[str] = None - bio: Optional[str] = None - profile_picture: Optional[str] = None - - # Create wrapped versions of each response WrappedTokenResponse = ResultsWrapper[TokenResponse] -WrappedUserResponse = ResultsWrapper[UserResponse] diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index a2e455010..10ec9656d 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -9,9 +9,7 @@ from shared.abstractions.llm import Message - -class UpdatePromptResponse(BaseModel): - message: str +from shared.abstractions import R2RSerializable class PromptResponse(BaseModel): @@ -60,16 +58,7 @@ class ScoreCompletionResponse(BaseModel): message: str -# TODO: This should just be a UserResponse... -class UserOverviewResponse(BaseModel): - user_id: UUID - num_files: int - total_size_in_bytes: int - document_ids: list[UUID] - - -# FIXME: Why are we redefining this and not using the model in py/shared/api/models/auth/responses.py? -class UserResponse(BaseModel): +class UserResponse(R2RSerializable): id: UUID email: str is_active: bool = True @@ -87,7 +76,7 @@ class UserResponse(BaseModel): profile_picture: Optional[str] = None -class DocumentChunkResponse(BaseModel): +class ChunkResponse(BaseModel): id: UUID document_id: UUID user_id: UUID @@ -140,6 +129,10 @@ class AddUserResponse(BaseModel): result: bool +# Chunk Responses +WrappedChunkResponse = ResultsWrapper[ChunkResponse] +WrappedChunksResponse = PaginatedResultsWrapper[list[ChunkResponse]] + # Collection Responses WrappedCollectionResponse = ResultsWrapper[CollectionResponse] WrappedCollectionsResponse = PaginatedResultsWrapper[list[CollectionResponse]] @@ -164,10 +157,8 @@ class AddUserResponse(BaseModel): WrappedPromptsResponse = PaginatedResultsWrapper[list[PromptResponse]] # User Responses -WrappedUserOverviewResponse = ResultsWrapper[UserOverviewResponse] -WrappedUsersOverviewResponse = PaginatedResultsWrapper[ - list[UserOverviewResponse] -] +WrappedUserResponse = ResultsWrapper[UserResponse] +WrappedUsersResponse = PaginatedResultsWrapper[list[UserResponse]] # TODO: anything below this hasn't been reviewed WrappedServerStatsResponse = ResultsWrapper[ServerStats] @@ -175,17 +166,10 @@ class AddUserResponse(BaseModel): WrappedAnalyticsResponse = ResultsWrapper[AnalyticsResponse] WrappedAppSettingsResponse = ResultsWrapper[AppSettingsResponse] -WrappedPromptMessageResponse = ResultsWrapper[UpdatePromptResponse] - -WrappedAddUserResponse = ResultsWrapper[None] WrappedUsersInCollectionResponse = PaginatedResultsWrapper[list[UserResponse]] WrappedUserCollectionResponse = PaginatedResultsWrapper[ list[CollectionResponse] ] -WrappedDocumentChunkResponse = ResultsWrapper[DocumentChunkResponse] -WrappedDocumentChunksResponse = PaginatedResultsWrapper[ - list[DocumentChunkResponse] -] WrappedDeleteResponse = ResultsWrapper[None] WrappedVerificationResult = ResultsWrapper[VerificationResult]