diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 262c978ad..d02f9cd81 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -1281,7 +1281,14 @@ async def search_documents( async def delete( self, filters: dict[str, Any] ) -> dict[str, dict[str, str]]: - return await self.vector_handler.delete(filters) + result = await self.vector_handler.delete(filters) + await self.graph_handler.entities.delete( + parent_id=filters["document_id"]["$eq"] + ) + await self.graph_handler.relationships.delete( + parent_id=filters["document_id"]["$eq"] + ) + return result async def assign_document_to_collection_vector( self, diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 0f383cb9d..69a9d6308 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -5,7 +5,7 @@ from fastapi import Body, Depends, Path, Query -from core.base import R2RException, RunType +from core.base import KGEnrichmentStatus, R2RException, RunType from core.base.abstractions import KGRunType from core.base.api.models import ( GenericBooleanResponse, @@ -1646,6 +1646,10 @@ async def pull( collection_id: UUID = Path( ..., description="The ID of the graph to initialize." ), + force: Optional[bool] = Body( + False, + description="If true, forces a re-pull of all entities and relationships.", + ), # document_ids: list[UUID] = Body( # ..., description="List of document IDs to add to the graph." # ), @@ -1736,17 +1740,20 @@ async def pull( ) continue if len(entities[0]) == 0: - logger.warning( - f"Document {document.id} has no entities, extraction may not have been called, skipping." - ) - continue + if not force: + logger.warning( + f"Document {document.id} has no entities, extraction may not have been called, skipping." + ) + continue + else: + logger.warning( + f"Document {document.id} has no entities, but force=True, continuing." + ) success = ( await self.providers.database.graph_handler.add_documents( id=collection_id, - document_ids=[ - document.id - ], # [doc.id for doc in documents] + document_ids=[document.id], ) ) if not success: @@ -1754,6 +1761,13 @@ async def pull( f"No documents were added to graph {collection_id}, marking as failed." ) + if success: + await self.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.SUCCESS, + ) + return GenericBooleanResponse(success=success) # type: ignore @self.router.delete( diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 04fae1bb4..136b4a7d1 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -11,6 +11,7 @@ from core.base import ( DocumentChunk, IngestionStatus, + KGEnrichmentStatus, OrchestrationProvider, generate_extraction_id, increment_version, @@ -179,6 +180,16 @@ async def parse(self, context: Context) -> dict: document_id=document_info.id, collection_id=collection_id, ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=KGEnrichmentStatus.OUTDATED, + ) else: for collection_id in collection_ids: try: @@ -218,7 +229,16 @@ async def parse(self, context: Context) -> dict: document_id=document_info.id, collection_id=collection_id, ) - + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=KGEnrichmentStatus.OUTDATED, + ) # get server chunk enrichment settings and override parts of it if provided in the ingestion config server_chunk_enrichment_settings = getattr( service.providers.ingestion.config, @@ -525,6 +545,16 @@ async def finalize(self, context: Context) -> dict: document_id=document_info.id, collection_id=collection_id, ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=KGEnrichmentStatus.OUTDATED, + ) else: for collection_id in collection_ids: try: @@ -556,6 +586,16 @@ async def finalize(self, context: Context) -> dict: document_id=document_info.id, collection_id=collection_id, ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=KGEnrichmentStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) except Exception as e: logger.error( f"Error during assigning document to collection: {str(e)}" diff --git a/py/core/main/orchestration/simple/ingestion_workflow.py b/py/core/main/orchestration/simple/ingestion_workflow.py index 07ce86dab..79fd94475 100644 --- a/py/core/main/orchestration/simple/ingestion_workflow.py +++ b/py/core/main/orchestration/simple/ingestion_workflow.py @@ -5,7 +5,12 @@ from fastapi import HTTPException from litellm import AuthenticationError -from core.base import DocumentChunk, R2RException, increment_version +from core.base import ( + DocumentChunk, + KGEnrichmentStatus, + R2RException, + increment_version, +) from core.utils import ( generate_default_user_collection_id, generate_extraction_id, @@ -89,6 +94,16 @@ async def ingest_files(input_data): document_id=document_info.id, collection_id=collection_id, ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=KGEnrichmentStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) else: print("collection_ids = ", collection_ids) @@ -134,6 +149,17 @@ async def ingest_files(input_data): document_id=document_info.id, collection_id=collection_id, ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=KGEnrichmentStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + except Exception as e: logger.error( f"Error during assigning document to collection: {str(e)}" @@ -307,6 +333,17 @@ async def ingest_chunks(input_data): document_id=document_info.id, collection_id=collection_id, ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=KGEnrichmentStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + else: print("collection_ids = ", collection_ids) for collection_id in collection_ids: @@ -344,6 +381,17 @@ async def ingest_chunks(input_data): document_id=document_info.id, collection_id=collection_id, ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=KGEnrichmentStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + except Exception as e: logger.error( f"Error during assigning document to collection: {str(e)}" diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index f576748c9..a3b1c96b8 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -120,7 +120,7 @@ async def enrich_graph(input_data): print("workflow_status = ", workflow_status) if workflow_status == KGEnrichmentStatus.SUCCESS: raise R2RException( - "Communities have already been built for this collection. To build communities again, first submit a POST request to `graphs/{collection_id}/reset`.", + "Communities have already been built for this collection. To build communities again, first submit a POST request to `graphs/{collection_id}/reset` to erase the previously built communities.", 400, ) diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 03b61f8d8..f61989632 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -11,6 +11,7 @@ AnalysisTypes, CollectionResponse, DocumentResponse, + KGEnrichmentStatus, LogFilterCriteria, LogProcessor, Message, @@ -355,29 +356,46 @@ def process_filter(filter_dict: dict[str, Any]): ) for document_id in document_ids_to_purge: - remaining_chunks = await self.providers.database.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. - document_id=document_id, - offset=0, - limit=1000, - ) - if remaining_chunks["total_entries"] == 0: - try: - await self.providers.database.delete_from_documents_overview( - document_id - ) - logger.info( - f"Deleted document ID {document_id} from documents_overview." - ) - except Exception as e: - logger.error( - f"Error deleting document ID {document_id} from documents_overview: {e}" - ) - await self.providers.database.graph_handler.entities.delete( - parent_id=document_id, store_type="document" - ) - await self.providers.database.graph_handler.relationships.delete( - parent_id=document_id, store_type="document" + # remaining_chunks = await self.providers.database.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + # document_id=document_id, + # offset=0, + # limit=1000, + # ) + # if remaining_chunks["total_entries"] == 0: + # try: + # await self.providers.database.delete_from_documents_overview( + # document_id + # ) + # logger.info( + # f"Deleted document ID {document_id} from documents_overview." + # ) + # except Exception as e: + # logger.error( + # f"Error deleting document ID {document_id} from documents_overview: {e}" + # ) + # await self.providers.database.graph_handler.entities.delete( + # parent_id=document_id, store_type="document" + # ) + # await self.providers.database.graph_handler.relationships.delete( + # parent_id=document_id, store_type="document" + # ) + collections = ( + await self.providers.database.get_collections_overview( + offset=0, limit=1000, filter_document_ids=[document_id] + ) ) + # TODO - Loop over all collections + for collection in collections["results"]: + await self.providers.database.set_workflow_status( + id=collection.id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await self.providers.database.set_workflow_status( + id=collection.id, + status_type="graph_cluster_status", + status=KGEnrichmentStatus.OUTDATED, + ) return None @@ -435,6 +453,17 @@ async def assign_document_to_collection( await self.providers.database.assign_document_to_collection_relational( document_id, collection_id ) + await self.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await self.providers.database.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=KGEnrichmentStatus.OUTDATED, + ) + return {"message": "Document assigned to collection successfully"} @telemetry_event("RemoveDocumentFromCollection") diff --git a/py/core/providers/database/collection.py b/py/core/providers/database/collection.py index 43e77a394..91cf052d7 100644 --- a/py/core/providers/database/collection.py +++ b/py/core/providers/database/collection.py @@ -46,6 +46,7 @@ async def create_tables(self) -> None: user_id UUID, name TEXT NOT NULL, description TEXT, + graph_sync_status TEXT DEFAULT 'pending', graph_cluster_status TEXT DEFAULT 'pending', created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW() @@ -79,7 +80,7 @@ async def create_collection( INSERT INTO {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} (collection_id, user_id, name, description) VALUES ($1, $2, $3, $4) - RETURNING collection_id, user_id, name, description, graph_cluster_status, created_at, updated_at + RETURNING collection_id, user_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at """ params = [ collection_id or uuid4(), @@ -104,6 +105,7 @@ async def create_collection( name=result["name"], description=result["description"], graph_cluster_status=result["graph_cluster_status"], + graph_sync_status=result["graph_sync_status"], created_at=result["created_at"], updated_at=result["updated_at"], user_count=0, @@ -150,7 +152,7 @@ async def update_collection( UPDATE {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} SET {', '.join(update_fields)} WHERE collection_id = ${param_index} - RETURNING collection_id, user_id, name, description, graph_cluster_status, created_at, updated_at + RETURNING collection_id, user_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at ) SELECT uc.*, @@ -159,7 +161,7 @@ async def update_collection( FROM updated_collection uc LEFT JOIN {self._get_table_name('users')} u ON uc.collection_id = ANY(u.collection_ids) LEFT JOIN {self._get_table_name('document_info')} d ON uc.collection_id = ANY(d.collection_ids) - GROUP BY uc.collection_id, uc.user_id, uc.name, uc.description, uc.graph_cluster_status, uc.created_at, uc.updated_at + GROUP BY uc.collection_id, uc.user_id, uc.name, uc.description, uc.graph_sync_status, uc.graph_cluster_status, uc.created_at, uc.updated_at """ try: result = await self.connection_manager.fetchrow_query( @@ -175,6 +177,7 @@ async def update_collection( user_id=result["user_id"], name=result["name"], description=result["description"], + graph_sync_status=result["graph_sync_status"], graph_cluster_status=result["graph_cluster_status"], created_at=result["created_at"], updated_at=result["updated_at"], @@ -322,6 +325,7 @@ async def get_collections_overview( c.description, c.created_at, c.updated_at, + c.graph_sync_status, c.graph_cluster_status, COUNT(DISTINCT u.user_id) FILTER (WHERE u.user_id IS NOT NULL) as user_count, COUNT(DISTINCT d.document_id) FILTER (WHERE d.document_id IS NOT NULL) as document_count @@ -358,6 +362,7 @@ async def get_collections_overview( user_id=row["user_id"], name=row["name"], description=row["description"], + graph_sync_status=row["graph_sync_status"], graph_cluster_status=row["graph_cluster_status"], created_at=row["created_at"], updated_at=row["updated_at"], diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index ba6aba655..70283ee2b 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -307,6 +307,8 @@ def _get_status_model(self, status_type: str): return KGExtractionStatus elif status_type == "graph_cluster_status": return KGEnrichmentStatus + elif status_type == "graph_sync_status": + return KGEnrichmentStatus else: raise R2RException( status_code=400, message=f"Invalid status type: {status_type}" diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index de398223e..10ceb3a61 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -87,6 +87,7 @@ class CollectionResponse(BaseModel): name: str description: Optional[str] graph_cluster_status: str + graph_sync_status: str created_at: datetime updated_at: datetime user_count: int