diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index 5191e4d34..db5d1303f 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -217,7 +217,8 @@ async def enrich_graph( # If the run type is estimate, return an estimate of the enrichment cost if run_type is KGRunType.ESTIMATE: return await self.service.get_enrichment_estimate( - collection_id, server_kg_enrichment_settings + collection_id=collection_id, + kg_enrichment_settings=server_kg_enrichment_settings, ) # Otherwise, run the enrichment workflow @@ -291,7 +292,7 @@ async def get_entities( ) # for backwards compatibility with the old API - return {"entities": entities["entities"]}, { # type: ignore + return entities['entities'], { # type: ignore "total_entries": entities["total_entries"] } @@ -336,7 +337,7 @@ async def get_relationships( ) # for backwards compatibility with the old API - return {"triples": triples["relationships"]}, { # type: ignore + return triples["relationships"], { # type: ignore "total_entries": triples["total_entries"] } @@ -380,7 +381,7 @@ async def get_communities( ) # for backwards compatibility with the old API - return {"communities": results["communities"]}, { # type: ignore + return results["communities"], { # type: ignore "total_entries": results["total_entries"] } diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 894f99a81..e6f8661d3 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -125,7 +125,7 @@ async def _deduplicate_entities( ) workflow_input = { - "id": str(id), + "graph_id": str(id), "kg_entity_deduplication_settings": server_settings.model_dump_json(), "user": auth_user.model_dump_json(), } @@ -183,9 +183,13 @@ async def _create_communities( # If the run type is estimate, return an estimate of the enrichment cost if run_type is KGRunType.ESTIMATE: - return await self.services["kg"].get_enrichment_estimate( - id, server_kg_enrichment_settings - ) + return { + "message": "Ran community build estimate.", + "estimate": await self.services["kg"].get_enrichment_estimate( + graph_id=graph_id, + kg_enrichment_settings=server_kg_enrichment_settings, + ), + } else: if run_with_orchestration: @@ -369,6 +373,10 @@ async def list_entities( None, description="A list of attributes to return. By default, all attributes are returned.", ), + from_built_graph: Optional[bool] = Query( + False, + description="Whether to retrieve entities from the built graph.", + ), offset: int = Query( 0, ge=0, @@ -402,6 +410,7 @@ async def list_entities( entity_names=entity_names, entity_categories=entity_categories, attributes=attributes, + from_built_graph=from_built_graph, ) return entities, { # type: ignore @@ -992,15 +1001,14 @@ async def delete_relationship( ################### COMMUNITIES ################### @self.router.post( - "/graphs/{id}/build/communities", + "/graphs/{id}/build/communities/{run_type}", summary="Build communities in the graph", ) @self.base_endpoint async def create_communities( id: UUID = Path(...), settings: Optional[dict] = Body(None), - run_type: Optional[KGRunType] = Body( - default=None, + run_type: Optional[KGRunType] = Path( description="Run type for the graph creation process.", ), run_with_orchestration: bool = Query(True), @@ -1008,10 +1016,11 @@ async def create_communities( ) -> WrappedKGEnrichmentResponse: return await self._create_communities( - id=id, + graph_id=id, settings=settings, run_type=run_type, run_with_orchestration=run_with_orchestration, + auth_user=auth_user, ) @self.router.post( @@ -1547,25 +1556,32 @@ async def remove_data( raise R2RException("Invalid data type", 400) @self.router.post( - "/graphs/{id}/build", + "/graphs/{id}/build/{run_type}", summary="Build entities, relationships, and communities in the graph", ) @self.base_endpoint async def build( id: UUID = Path(...), + run_type: KGRunType = Path(...), settings: GraphBuildSettings = Body(GraphBuildSettings()), + run_with_orchestration: bool = Query(True), + auth_user=Depends(self.providers.auth.auth_wrapper), ): # build entities logger.info(f"Building entities for graph {id}") entities_result = await self._deduplicate_entities( - id, settings.entity_settings, run_type=KGRunType.RUN + id, settings.entity_settings.__dict__, run_type=run_type, + run_with_orchestration=run_with_orchestration, + auth_user=auth_user, ) # build communities logger.info(f"Building communities for graph {id}") communities_result = await self._create_communities( - id, settings.community_settings, run_type=KGRunType.RUN + id, settings.community_settings.__dict__, run_type=run_type, + run_with_orchestration=run_with_orchestration, + auth_user=auth_user, ) return { diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 0cdbf565f..ab21e132e 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -30,6 +30,9 @@ def get_input_data_dict(input_data): if key == "document_id": input_data[key] = uuid.UUID(value) + if key == "collection_id": + input_data[key] = uuid.UUID(value) + if key == "graph_id": input_data[key] = uuid.UUID(value) @@ -411,7 +414,8 @@ async def kg_community_summary(self, context: Context) -> dict: input_data = get_input_data_dict( context.workflow_input()["request"] ) - graph_id = input_data["graph_id"] + graph_id = input_data.get("graph_id", None) + collection_id = input_data.get("collection_id", None) num_communities = context.step_output("kg_clustering")[ "kg_clustering" ][0]["num_communities"] @@ -437,6 +441,7 @@ async def kg_community_summary(self, context: Context) -> dict: num_communities - offset, ), "graph_id": str(graph_id), + "collection_id": str(collection_id), **input_data["kg_enrichment_settings"], } }, @@ -454,7 +459,7 @@ async def kg_community_summary(self, context: Context) -> dict: document_ids = await self.kg_service.providers.database.get_document_ids_by_status( status_type="kg_extraction_status", status=KGExtractionStatus.SUCCESS, - collection_id=graph_id, + collection_id=collection_id, ) await self.kg_service.providers.database.set_workflow_status( diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index 51e9c8746..113b45fc6 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -22,6 +22,9 @@ def get_input_data_dict(input_data): if key == "collection_id": input_data[key] = uuid.UUID(value) + if key == "graph_id": + input_data[key] = uuid.UUID(value) + if key == "kg_creation_settings": input_data[key] = json.loads(value) input_data[key]["generation_config"] = GenerationConfig( @@ -76,7 +79,8 @@ async def enrich_graph(input_data): try: num_communities = await service.kg_clustering( - collection_id=input_data["collection_id"], + collection_id=input_data.get("collection_id", None), + graph_id=input_data.get("graph_id", None), **input_data["kg_enrichment_settings"], ) num_communities = num_communities[0]["num_communities"] @@ -144,11 +148,13 @@ async def entity_deduplication_workflow(input_data): input_data["kg_entity_deduplication_settings"] ) - collection_id = input_data["collection_id"] + collection_id = input_data.get("collection_id", None) + graph_id = input_data.get("graph_id", None) number_of_distinct_entities = ( await service.kg_entity_deduplication( collection_id=collection_id, + graph_id=graph_id, **input_data["kg_entity_deduplication_settings"], ) )[0]["num_entities"] diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 0956311c7..b07351763 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -140,6 +140,7 @@ async def list_entities( entity_names: Optional[list[str]] = None, entity_categories: Optional[list[str]] = None, attributes: Optional[list[str]] = None, + from_built_graph: Optional[bool] = False, ): return await self.providers.database.graph_handler.entities.get( level=level, @@ -149,6 +150,7 @@ async def list_entities( attributes=attributes, offset=offset, limit=limit, + from_built_graph=from_built_graph, ) @telemetry_event("update_entity") @@ -469,6 +471,7 @@ async def get_graph_status( async def kg_clustering( self, collection_id: UUID, + graph_id: UUID, generation_config: GenerationConfig, leiden_params: dict, **kwargs, @@ -477,10 +480,12 @@ async def kg_clustering( logger.info( f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}" ) + clustering_result = await self.pipes.kg_clustering_pipe.run( input=self.pipes.kg_clustering_pipe.Input( message={ "collection_id": collection_id, + "graph_id": graph_id, "generation_config": generation_config, "leiden_params": leiden_params, "logger": logger, @@ -580,13 +585,19 @@ async def get_creation_estimate( @telemetry_event("get_enrichment_estimate") async def get_enrichment_estimate( self, - collection_id: UUID, - kg_enrichment_settings: KGEnrichmentSettings, + collection_id: UUID | None = None, + graph_id: UUID | None = None, + kg_enrichment_settings: KGEnrichmentSettings | None = None, **kwargs, ): + if graph_id is None and collection_id is None: + raise ValueError("Either graph_id or collection_id must be provided") + return await self.providers.database.graph_handler.get_enrichment_estimate( - collection_id, kg_enrichment_settings + collection_id=collection_id, + graph_id=graph_id, + kg_enrichment_settings=kg_enrichment_settings, ) @telemetry_event("get_deduplication_estimate") @@ -597,13 +608,15 @@ async def get_deduplication_estimate( **kwargs, ): return await self.providers.database.graph_handler.get_deduplication_estimate( - collection_id, kg_deduplication_settings + collection_id=collection_id, + kg_deduplication_settings=kg_deduplication_settings, ) @telemetry_event("kg_entity_deduplication") async def kg_entity_deduplication( self, - id: UUID, + collection_id: UUID, + graph_id: UUID, kg_entity_deduplication_type: KGEntityDeduplicationType, kg_entity_deduplication_prompt: str, generation_config: GenerationConfig, @@ -612,7 +625,8 @@ async def kg_entity_deduplication( deduplication_results = await self.pipes.kg_entity_deduplication_pipe.run( input=self.pipes.kg_entity_deduplication_pipe.Input( message={ - "id": id, + "collection_id": collection_id, + "graph_id": graph_id, "kg_entity_deduplication_type": kg_entity_deduplication_type, "kg_entity_deduplication_prompt": kg_entity_deduplication_prompt, "generation_config": generation_config, diff --git a/py/core/pipes/kg/clustering.py b/py/core/pipes/kg/clustering.py index 6765096b6..5b4c70eeb 100644 --- a/py/core/pipes/kg/clustering.py +++ b/py/core/pipes/kg/clustering.py @@ -43,6 +43,7 @@ def __init__( async def cluster_kg( self, collection_id: UUID, + graph_id: UUID, leiden_params: dict, ): """ @@ -50,8 +51,9 @@ async def cluster_kg( """ num_communities = await self.database_provider.graph_handler.perform_graph_clustering( - collection_id, - leiden_params, + collection_id=collection_id, + graph_id=graph_id, + leiden_params=leiden_params, ) # type: ignore logger.info( @@ -74,7 +76,12 @@ async def _run_logic( # type: ignore Executes the KG clustering pipe: clustering entities and relationships into communities. """ - collection_id = input.message["collection_id"] + collection_id = input.message.get("collection_id", None) + graph_id = input.message.get("graph_id", None) leiden_params = input.message["leiden_params"] - yield await self.cluster_kg(collection_id, leiden_params) + yield await self.cluster_kg( + collection_id=collection_id, + graph_id=graph_id, + leiden_params=leiden_params, + ) diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index 702293f71..249b2a486 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -17,7 +17,7 @@ PostgresDBProvider, ) from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider - +from core.base.abstractions import DataLevel logger = logging.getLogger() @@ -46,127 +46,109 @@ def __init__( self.llm_provider = llm_provider self.embedding_provider = embedding_provider - async def kg_named_entity_deduplication(self, graph_id: UUID, **kwargs): - try: - entity_count = ( - await self.database_provider.graph_handler.get_entity_count( - graph_id=graph_id, collection_id=graph_id, distinct=True - ) - ) - logger.info( - f"KGEntityDeduplicationPipe: Getting entities for collection {graph_id}" - ) - logger.info( - f"KGEntityDeduplicationPipe: Entity count: {entity_count}" - ) + async def _get_entities(self, graph_id: UUID | None, collection_id: UUID | None): + if collection_id is not None: + return await self.database_provider.graph_handler.get_entities(collection_id, offset=0, limit=-1) + elif graph_id is not None: + # TODO: remove the tuple return type + return (await self.database_provider.graph_handler.entities.get(level=DataLevel.GRAPH, id=graph_id, offset=0, limit=-1, from_built_graph=False))[0] + else: + raise ValueError("Either graph_id or collection_id must be provided") - # TODO: FIX this method - entities = ( - await self.database_provider.graph_handler.get_entities( - graph_id=graph_id, - collection_id=graph_id, - offset=0, - limit=-1, - ) - )["entities"] + async def kg_named_entity_deduplication(self, graph_id: UUID | None, collection_id: UUID | None, **kwargs): - logger.info( - f"KGEntityDeduplicationPipe: Got {len(entities)} entities for collection {graph_id}" - ) + import numpy as np - # deduplicate entities by name - deduplicated_entities: dict[str, dict[str, list[str]]] = {} - deduplication_source_keys = [ - "chunk_ids", - "document_id", - "attributes", - ] - deduplication_target_keys = [ - "chunk_ids", - "document_ids", - "attributes", - ] - deduplication_keys = list( - zip(deduplication_source_keys, deduplication_target_keys) - ) - for entity in entities: - if entity.name not in deduplicated_entities: - deduplicated_entities[entity.name] = { - target_key: [] for _, target_key in deduplication_keys - } - for source_key, target_key in deduplication_keys: - value = getattr(entity, source_key) - if isinstance(value, list): - deduplicated_entities[entity.name][target_key].extend( - value - ) - else: - deduplicated_entities[entity.name][target_key].append( - value - ) - - logger.info( - f"KGEntityDeduplicationPipe: Deduplicated {len(deduplicated_entities)} entities" - ) + entities = await self._get_entities(graph_id, collection_id) - # upsert deduplcated entities in the graph_entity table - deduplicated_entities_list = [ - Entity( - name=name, - graph_id=graph_id, - chunk_ids=entity["chunk_ids"], - document_ids=entity["document_ids"], - attributes={}, - ) - for name, entity in deduplicated_entities.items() - ] + logger.info( + f"KGEntityDeduplicationPipe: Got {len(entities)} entities for {graph_id or collection_id}" + ) - logger.info( - f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}" - ) - await self.database_provider.graph_handler.add_entities( - deduplicated_entities_list, - table_name="graph_entity", - conflict_columns=["name", "graph_id", "attributes"], - ) + # deduplicate entities by name + deduplicated_entities: dict[str, dict[str, list[str]]] = {} + deduplication_source_keys = [ + "description", + "chunk_ids", + "document_id", + "attributes", + # "description_embedding", + ] + deduplication_target_keys = [ + "description", + "chunk_ids", + "document_ids", + "attributes", + # "description_embedding", + ] + deduplication_keys = list( + zip(deduplication_source_keys, deduplication_target_keys) + ) + for entity in entities: + if entity.name not in deduplicated_entities: + deduplicated_entities[entity.name] = { + target_key: [] for _, target_key in deduplication_keys + } + # deduplicated_entities[entity.name]['total_entries'] = 0 + # deduplicated_entities[entity.name]['description_embedding'] = np.zeros(len(json.loads(entity.description_embedding))) + + for source_key, target_key in deduplication_keys: + value = getattr(entity, source_key) + + # if source_key == "description_embedding": + # deduplicated_entities[entity.name]['total_entries'] += 1 + # deduplicated_entities[entity.name][target_key] += np.array(json.loads(value)) + + if isinstance(value, list): + deduplicated_entities[entity.name][target_key].extend( + value + ) + else: + deduplicated_entities[entity.name][target_key].append( + value + ) - yield { - "result": f"successfully deduplicated {len(entities)} entities to {len(deduplicated_entities)} entities for collection {graph_id}", - "num_entities": len(deduplicated_entities), - } - except Exception as e: - logger.error( - f"KGEntityDeduplicationPipe: Error in entity deduplication: {str(e)}" - ) - raise HTTPException( - status_code=500, - detail=f"KGEntityDeduplicationPipe: Error deduplicating entities: {str(e)}", + # upsert deduplcated entities in the graph_entity table + deduplicated_entities_list = [ + Entity( + name=name, + # description="\n".join(entity["description"]), + # description_embedding=json.dumps((entity["description_embedding"] / entity['total_entries']).tolist()), + collection_id=collection_id, + graph_id=graph_id, + chunk_ids=list(set(entity["chunk_ids"])), + document_ids=list(set(entity["document_ids"])), + attributes={}, ) + for name, entity in deduplicated_entities.items() + ] + + logger.info( + f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}" + ) + + await self.database_provider.graph_handler.add_entities( + deduplicated_entities_list, + table_name="graph_entity", + ) + + yield { + "result": f"successfully deduplicated {len(entities)} entities to {len(deduplicated_entities)} entities for collection {graph_id}", + "num_entities": len(deduplicated_entities), + } async def kg_description_entity_deduplication( - self, graph_id: UUID, **kwargs + self, graph_id: UUID | None, collection_id: UUID | None, **kwargs ): from sklearn.cluster import DBSCAN - entities = ( - await self.database_provider.graph_handler.get_entities( - graph_id=graph_id, - collection_id=graph_id, # deprecated - offset=0, - limit=-1, - extra_columns=["description_embedding"], - ) - )["entities"] - + entities = await self._get_entities(graph_id, collection_id) for entity in entities: entity.description_embedding = json.loads( entity.description_embedding ) - logger.info( - f"KGEntityDeduplicationPipe: Got {len(entities)} entities for collection {graph_id}" - ) deduplication_source_keys = [ "chunk_ids", @@ -241,6 +223,7 @@ async def kg_description_entity_deduplication( name=longest_name, description=description, graph_id=graph_id, + collection_id=collection_id, chunk_ids=chunk_ids_list, document_ids=document_ids_list, attributes={ @@ -263,7 +246,7 @@ async def kg_description_entity_deduplication( "num_entities": len(deduplicated_entities), } - async def kg_llm_entity_deduplication(self, graph_id: UUID, **kwargs): + async def kg_llm_entity_deduplication(self, graph_id: UUID, collection_id: UUID, **kwargs): # TODO: implement LLM based entity deduplication raise NotImplementedError( "LLM entity deduplication is not implemented yet" @@ -279,18 +262,19 @@ async def _run_logic( ): # TODO: figure out why the return type AsyncGenerator[dict, None] is not working - graph_id = input.message["graph_id"] + graph_id = input.message.get("graph_id", None) + collection_id = input.message.get("collection_id", None) + if graph_id and collection_id: + raise ValueError("graph_id and collection_id cannot both be provided") + kg_entity_deduplication_type = input.message[ "kg_entity_deduplication_type" ] if kg_entity_deduplication_type == KGEntityDeduplicationType.BY_NAME: - logger.info( - f"KGEntityDeduplicationPipe: Running named entity deduplication for collection {graph_id}" - ) async for result in self.kg_named_entity_deduplication( - graph_id, **kwargs + graph_id=graph_id, collection_id=collection_id, **kwargs ): yield result @@ -298,20 +282,14 @@ async def _run_logic( kg_entity_deduplication_type == KGEntityDeduplicationType.BY_DESCRIPTION ): - logger.info( - f"KGEntityDeduplicationPipe: Running description entity deduplication for collection {graph_id}" - ) - async for result in self.kg_description_entity_deduplication( # type: ignore - graph_id, **kwargs + async for result in self.kg_description_entity_deduplication( + graph_id=graph_id, collection_id=collection_id, **kwargs ): yield result elif kg_entity_deduplication_type == KGEntityDeduplicationType.BY_LLM: - logger.info( - f"KGEntityDeduplicationPipe: Running LLM entity deduplication for collection {graph_id}" - ) - async for result in self.kg_llm_entity_deduplication( # type: ignore - graph_id, **kwargs + async for result in self.kg_llm_entity_deduplication( + graph_id=graph_id, collection_id=collection_id, **kwargs ): yield result diff --git a/py/core/pipes/kg/deduplication_summary.py b/py/core/pipes/kg/deduplication_summary.py index 3d0379b3c..35937fa8c 100644 --- a/py/core/pipes/kg/deduplication_summary.py +++ b/py/core/pipes/kg/deduplication_summary.py @@ -131,6 +131,28 @@ async def _prepare_and_upsert_entities( for entity in entities_batch: yield entity + + async def _get_entities(self, graph_id: UUID | None, collection_id: UUID | None, offset: int, limit: int, level): + + if graph_id is not None: + return await self.database_provider.graph_handler.entities.get( + graph_id=graph_id, + offset=offset, + limit=limit, + level = level, + ) + + elif collection_id is not None: + return await self.database_provider.graph_handler.get_entities( + collection_id=collection_id, + entity_table_name=f"{level}_entity", + offset=offset, + limit=limit, + ) + + else: + raise ValueError("Either graph_id or collection_id must be provided") + async def _run_logic( self, input: AsyncPipe.Input, @@ -141,7 +163,8 @@ async def _run_logic( ): # TODO: figure out why the return type AsyncGenerator[dict, None] is not working - graph_id = input.message["graph_id"] + graph_id = input.message.get("graph_id", None) + collection_id = input.message.get("collection_id", None) offset = input.message["offset"] limit = input.message["limit"] kg_entity_deduplication_type = input.message[ @@ -156,20 +179,14 @@ async def _run_logic( f"Running kg_entity_deduplication_summary for graph {graph_id} with settings kg_entity_deduplication_type: {kg_entity_deduplication_type}, kg_entity_deduplication_prompt: {kg_entity_deduplication_prompt}, generation_config: {generation_config}" ) - entities = ( - await self.database_provider.graph_handler.get_entities( - graph_id=graph_id, - entity_table_name="graph_entity", - offset=offset, - limit=limit, - ) - )["entities"] + entities = await self._get_entities(graph_id, collection_id, offset, limit, level) entity_names = [entity.name for entity in entities] entity_descriptions = ( await self.database_provider.graph_handler.get_entities( graph_id=graph_id, + collection_id=collection_id, entity_names=entity_names, entity_table_name="document_entity", offset=offset, diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 63cb3c856..0259a706a 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -164,12 +164,13 @@ async def create(self, entities: list[Entity]) -> list[UUID]: # type: ignore async def get( self, level: DataLevel, + offset: int, + limit: int, id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, entity_categories: Optional[list[str]] = None, attributes: Optional[list[str]] = None, - offset: int = 0, - limit: int = -1, + from_built_graph: Optional[bool] = False, ): """Retrieve entities from the database based on various filters. @@ -198,9 +199,12 @@ async def get( filter = { DataLevel.CHUNK: "chunk_ids = ANY($1)", DataLevel.DOCUMENT: "document_id = $1", - DataLevel.GRAPH: "graph_id = $1", + DataLevel.GRAPH: "graph_id = $1" if from_built_graph else "$1 = ANY(graph_ids)", }[level] + if not from_built_graph and level == DataLevel.GRAPH: + level = DataLevel.DOCUMENT + if entity_names: filter += " AND name = ANY($2)" params.append(entity_names) @@ -209,12 +213,19 @@ async def get( filter += " AND category = ANY($3)" params.append(entity_categories) - QUERY = f""" + # Build query with conditional LIMIT + base_query = f""" SELECT * from {self._get_table_name(level + "_entity")} WHERE {filter} - OFFSET ${len(params)+1} LIMIT ${len(params) + 2} + OFFSET ${len(params)+1} """ - - params.extend([offset, limit]) + + params.append(offset) + + if limit != -1: + base_query += f" LIMIT ${len(params)+1}" + params.append(limit) + + QUERY = base_query output = await self.connection_manager.fetch_query(QUERY, params) @@ -229,7 +240,7 @@ async def get( SELECT COUNT(*) from {self._get_table_name(level + "_entity")} WHERE {filter} """ count = ( - await self.connection_manager.fetch_query(QUERY, params[:-2]) + await self.connection_manager.fetch_query(QUERY, params[:-2 + (limit == -1)]) )[0]["count"] if count == 0 and level == DataLevel.GRAPH: @@ -412,13 +423,21 @@ async def get( filter += " AND predicate = ANY($3)" params.append(relationship_types) # type: ignore - QUERY = f""" + # Build query with conditional LIMIT + base_query = f""" SELECT * FROM {self._get_table_name(level + "_relationship")} WHERE {filter} - OFFSET ${len(params)+1} LIMIT ${len(params) + 2} + OFFSET ${len(params)+1} """ + + params.append(offset) + + if limit != -1: + base_query += f" LIMIT ${len(params)+1}" + params.append(limit) + + QUERY = base_query - params.extend([offset, limit]) # type: ignore rows = await self.connection_manager.fetch_query(QUERY, params) QUERY_COUNT = f""" @@ -962,36 +981,57 @@ async def get_creation_estimate( } async def get_enrichment_estimate( - self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings + self, collection_id: UUID | None = None, graph_id: UUID | None = None, kg_enrichment_settings: KGEnrichmentSettings | None = None ): """Get the estimated cost and time for enriching a KG.""" + if collection_id is not None: - document_ids = [ - doc.id - for doc in ( - await self.collection_handler.documents_in_collection(collection_id) # type: ignore - )["results"] - ] + document_ids = [ + doc.id + for doc in ( + await self.collection_handler.documents_in_collection(collection_id, offset=0, limit=-1) # type: ignore + )["results"] + ] - # Get entity and relationship counts - entity_count = ( - await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1);", - [document_ids], - ) - )[0]["count"] + # Get entity and relationship counts + entity_count = ( + await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1);", + [document_ids], + ) + )[0]["count"] - if not entity_count: - raise ValueError( - "No entities found in the graph. Please run `create-graph` first." - ) + if not entity_count: + raise ValueError( + "No entities found in the graph. Please run `create-graph` first." + ) + + relationship_count = ( + await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1);", + [document_ids], + ) + )[0]["count"] - relationship_count = ( - await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1);", - [document_ids], - ) - )[0]["count"] + else: + entity_count = ( + await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('document_entity')} WHERE $1 = ANY(graph_ids);", + [graph_id], + ) + )[0]["count"] + + if not entity_count: + raise ValueError( + "No entities found in the graph. Please run `create-graph` first." + ) + + relationship_count = ( + await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('chunk_relationship')} WHERE $1 = ANY(graph_ids);", + [graph_id], + ) + )[0]["count"] # Calculate estimates estimated_llm_calls = (entity_count // 10, entity_count // 5) @@ -1130,7 +1170,7 @@ async def get_entities( if entity_table_name == "graph_entity": query = f""" - SELECT sid as id, name, description, chunk_ids, document_ids {", " + ", ".join(extra_columns) if extra_columns else ""} + SELECT sid as id, name, description, chunk_ids, document_ids, graph_id {", " + ", ".join(extra_columns) if extra_columns else ""} FROM {self._get_table_name(entity_table_name)} WHERE collection_id = $1 {" AND " + " AND ".join(conditions) if conditions else ""} @@ -1139,7 +1179,7 @@ async def get_entities( """ else: query = f""" - SELECT sid as id, name, description, chunk_ids, document_id {", " + ", ".join(extra_columns) if extra_columns else ""} + SELECT sid as id, name, description, chunk_ids, document_id, graph_ids {", " + ", ".join(extra_columns) if extra_columns else ""} FROM {self._get_table_name(entity_table_name)} WHERE document_id = ANY( SELECT document_id FROM {self._get_table_name("document_info")} @@ -1190,6 +1230,8 @@ async def add_entities( ) cleaned_entities.append(entity_dict) + import pdb; pdb.set_trace() + return await _add_objects( objects=cleaned_entities, full_table_name=self._get_table_name(table_name), @@ -1275,27 +1317,39 @@ async def add_relationships( ) async def get_all_relationships( - self, collection_id: UUID, document_ids: Optional[list[UUID]] = None + self, collection_id: UUID | None, graph_id: UUID | None, document_ids: Optional[list[UUID]] = None ) -> list[Relationship]: - # getting all documents for a collection - if document_ids is None: + + if collection_id is not None: + + # getting all documents for a collection + if document_ids is None: + QUERY = f""" + select distinct document_id from {self._get_table_name("document_info")} where $1 = ANY(collection_ids) + """ + document_ids_list = await self.connection_manager.fetch_query( + QUERY, [collection_id] + ) + document_ids = [ + doc_id["document_id"] for doc_id in document_ids_list + ] + QUERY = f""" - select distinct document_id from {self._get_table_name("document_info")} where $1 = ANY(collection_ids) + SELECT sid as id, subject, predicate, weight, object, document_id FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY($1) """ - document_ids_list = await self.connection_manager.fetch_query( - QUERY, [collection_id] + relationships = await self.connection_manager.fetch_query( + QUERY, [document_ids] + ) + + else: + QUERY = f""" + SELECT sid as id, subject, predicate, weight, object, document_id FROM {self._get_table_name("chunk_relationship")} WHERE $1 = ANY(graph_ids) + """ + relationships = await self.connection_manager.fetch_query( + QUERY, [graph_id] ) - document_ids = [ - doc_id["document_id"] for doc_id in document_ids_list - ] - QUERY = f""" - SELECT sid as id, subject, predicate, weight, object, document_id FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY($1) - """ - relationships = await self.connection_manager.fetch_query( - QUERY, [document_ids] - ) return [Relationship(**relationship) for relationship in relationships] # DEPRECATED @@ -1607,7 +1661,8 @@ async def delete_graph_for_collection( async def perform_graph_clustering( self, - collection_id: UUID, + collection_id: UUID | None, + graph_id: UUID | None, leiden_params: dict[str, Any], use_community_cache: bool = False, ) -> int: @@ -1630,7 +1685,7 @@ async def perform_graph_clustering( start_time = time.time() - relationships = await self.get_all_relationships(collection_id) + relationships = await self.get_all_relationships(collection_id, graph_id) logger.info(f"Clustering with settings: {leiden_params}") @@ -1638,7 +1693,8 @@ async def perform_graph_clustering( relationships ) - if use_community_cache and await self._use_community_cache( + # incremental clustering isn't enabled for v3 yet. + if not graph_id and await self._use_community_cache( collection_id, relationship_ids_cache ): num_communities = await self._incremental_clustering( @@ -1646,10 +1702,11 @@ async def perform_graph_clustering( ) else: num_communities = await self._cluster_and_add_community_info( - relationships, - relationship_ids_cache, - leiden_params, - collection_id, + relationships = relationships, + relationship_ids_cache = relationship_ids_cache, + leiden_params = leiden_params, + collection_id = collection_id, + graph_id = graph_id, ) return num_communities @@ -1896,23 +1953,25 @@ async def _cluster_and_add_community_info( relationship_ids_cache: dict[str, list[int]], leiden_params: dict[str, Any], collection_id: UUID, + graph_id: UUID, ) -> int: # clear if there is any old information QUERY = f""" - DELETE FROM {self._get_table_name("graph_community_info")} WHERE collection_id = $1 + DELETE FROM {self._get_table_name("graph_community_info")} WHERE collection_id = $1 OR graph_id = $2 """ - await self.connection_manager.execute_query(QUERY, [collection_id]) + await self.connection_manager.execute_query(QUERY, [collection_id, graph_id]) QUERY = f""" - DELETE FROM {self._get_table_name("graph_community")} WHERE collection_id = $1 + DELETE FROM {self._get_table_name("graph_community")} WHERE collection_id = $1 OR graph_id = $2 """ - await self.connection_manager.execute_query(QUERY, [collection_id]) + await self.connection_manager.execute_query(QUERY, [collection_id, graph_id]) start_time = time.time() hierarchical_communities = await self._create_graph_and_cluster( - relationships, leiden_params + relationships = relationships, + leiden_params = leiden_params, ) logger.info( @@ -1936,6 +1995,7 @@ def relationship_ids(node: str) -> list[int]: is_final_cluster=item.is_final_cluster, relationship_ids=relationship_ids(item.node), collection_id=collection_id, + graph_id=graph_id, ) for item in hierarchical_communities ] @@ -2319,7 +2379,7 @@ async def _add_objects( sample_value = cleaned_objects[0][col] if "embedding" in col: pg_type = "vector" - elif "chunk_ids" in col: + elif "chunk_ids" in col or "document_ids" in col or "graph_ids" in col: pg_type = "uuid[]" elif col == "id" or "_id" in col: pg_type = "uuid" diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 52dd5644e..231493559 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -34,13 +34,13 @@ class Entity(R2RSerializable): """An entity extracted from a document.""" name: str - id: Optional[UUID] = None - sid: Optional[str] = None + id: Optional[UUID | int] = None + sid: Optional[int] = None level: Optional[DataLevel] = None category: Optional[str] = None description: Optional[str] = None description_embedding: Optional[list[float] | str] = None - community_numbers: Optional[list[str]] = None + community_numbers: Optional[list[int]] = None chunk_ids: Optional[list[UUID]] = None graph_id: Optional[UUID] = None graph_ids: Optional[list[UUID]] = None @@ -73,7 +73,7 @@ class Relationship(R2RSerializable): """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" id: Optional[UUID | int] = None - sid: Optional[str] = None + sid: Optional[int] = None level: Optional[DataLevel] = None subject: Optional[str] = None predicate: Optional[str] = None diff --git a/py/shared/api/models/kg/responses.py b/py/shared/api/models/kg/responses.py index f04f555d5..c063fcf72 100644 --- a/py/shared/api/models/kg/responses.py +++ b/py/shared/api/models/kg/responses.py @@ -163,8 +163,12 @@ class KGEnrichmentResponse(BaseModel): ..., description="A message describing the result of the KG enrichment request.", ) - task_id: UUID = Field( - ..., + id: Optional[UUID] = Field( + None, + description="The ID of the created object.", + ) + task_id: Optional[UUID] = Field( + None, description="The task ID of the KG enrichment request.", ) estimate: Optional[KGEnrichmentEstimate] = Field( @@ -176,6 +180,7 @@ class Config: json_schema_extra = { "example": { "message": "Graph enrichment queued successfuly.", + "id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", "estimate": { "total_entities": 1000,