From 8767153c415ff692b8386948719fd1dc811c254f Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Wed, 27 Nov 2024 16:08:00 -0800 Subject: [PATCH] add collection extraction --- py/core/main/api/v2/kg_router.py | 22 ++-- py/core/main/api/v3/collections_router.py | 111 +++++++++++++++++- py/core/main/api/v3/documents_router.py | 6 +- py/core/main/api/v3/graph_router.py | 6 +- .../main/orchestration/hatchet/kg_workflow.py | 8 +- .../main/orchestration/simple/kg_workflow.py | 39 ++++-- py/core/providers/database/graph.py | 10 +- 7 files changed, 167 insertions(+), 35 deletions(-) diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index 526cd3835..dd652b133 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -57,20 +57,20 @@ def _register_workflows(self): workflow_messages = {} if self.orchestration_provider.config.provider == "hatchet": - workflow_messages["create-graph"] = ( + workflow_messages["extract-triples"] = ( "Graph creation task queued successfully." ) - workflow_messages["enrich-graph"] = ( + workflow_messages["build-communities"] = ( "Graph enrichment task queued successfully." ) workflow_messages["entity-deduplication"] = ( "KG Entity Deduplication task queued successfully." ) else: - workflow_messages["create-graph"] = ( + workflow_messages["extract-triples"] = ( "Document entities and relationships extracted successfully. To generate GraphRAG communities, run cluster on the collection this document belongs to." ) - workflow_messages["enrich-graph"] = ( + workflow_messages["build-communities"] = ( "Graph communities created successfully. You can view the communities at http://localhost:7272/v2/communities" ) workflow_messages["entity-deduplication"] = ( @@ -119,7 +119,7 @@ async def create_graph( auth_user.id ) - logger.info(f"Running create-graph on collection {collection_id}") + logger.info(f"Running extract-triples on collection {collection_id}") # If no run type is provided, default to estimate if not run_type: @@ -152,14 +152,14 @@ async def create_graph( } return await self.orchestration_provider.run_workflow( # type: ignore - "create-graph", {"request": workflow_input}, {} + "extract-triples", {"request": workflow_input}, {} ) else: from core.main.orchestration import simple_kg_factory - logger.info("Running create-graph without orchestration.") + logger.info("Running extract-triples without orchestration.") simple_kg = simple_kg_factory(self.service) - await simple_kg["create-graph"](workflow_input) + await simple_kg["extract-triples"](workflow_input) return { "message": "Graph created successfully.", "task_id": None, @@ -229,14 +229,14 @@ async def enrich_graph( } return await self.orchestration_provider.run_workflow( # type: ignore - "enrich-graph", {"request": workflow_input}, {} + "build-communities", {"request": workflow_input}, {} ) else: from core.main.orchestration import simple_kg_factory - logger.info("Running enrich-graph without orchestration.") + logger.info("Running build-communities without orchestration.") simple_kg = simple_kg_factory(self.service) - await simple_kg["enrich-graph"](workflow_input) + await simple_kg["build-communities"](workflow_input) return { "message": "Graph communities created successfully.", "task_id": None, diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index 4684282bd..305e89764 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -5,7 +5,7 @@ from fastapi import Body, Depends, Path, Query -from core.base import R2RException, RunType, KGRunType +from core.base import R2RException, RunType, KGRunType, KGCreationSettings from core.base.api.models import ( GenericBooleanResponse, GenericMessageResponse, @@ -15,6 +15,7 @@ WrappedDocumentsResponse, WrappedGenericMessageResponse, WrappedUsersResponse, + WrappedKGCreationResponse ) from core.providers import ( HatchetOrchestrationProvider, @@ -1021,4 +1022,110 @@ async def remove_user_from_collection( await self.services["management"].remove_user_from_collection( user_id, id ) - return GenericBooleanResponse(success=True) # type: ignore \ No newline at end of file + return GenericBooleanResponse(success=True) # type: ignore + + + @self.router.post( + "/collections/{id}/extract", + summary="Extract entities and relationships", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.documents.extract( + id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1" + ) + """ + ), + }, + ], + }, + ) + @self.base_endpoint + async def extract( + id: UUID = Path( + ..., + description="The ID of the document to extract entities and relationships from.", + ), + run_type: KGRunType = Query( + default=KGRunType.RUN, + description="Whether to return an estimate of the creation cost or to actually extract the document.", + ), + settings: Optional[KGCreationSettings] = Body( + default=None, + description="Settings for the entities and relationships extraction process.", + ), + run_with_orchestration: Optional[bool] = Query( + default=True, + description="Whether to run the entities and relationships extraction process with orchestration.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> WrappedKGCreationResponse: # type: ignore + """ + Extracts entities and relationships from a document. + The entities and relationships extraction process involves: + 1. Parsing documents into semantic chunks + 2. Extracting entities and relationships using LLMs + """ + + settings = settings.dict() if settings else None # type: ignore + if not auth_user.is_superuser: + logger.warning("Implement permission checks here.") + + # If no run type is provided, default to estimate + if not run_type: + run_type = KGRunType.ESTIMATE + + # Apply runtime settings overrides + server_kg_creation_settings = ( + self.providers.database.config.kg_creation_settings + ) + + if settings: + server_kg_creation_settings = update_settings_from_dict( + server_settings=server_kg_creation_settings, + settings_dict=settings, # type: ignore + ) + + # If the run type is estimate, return an estimate of the creation cost + # if run_type is KGRunType.ESTIMATE: + # return { # type: ignore + # "message": "Estimate retrieved successfully", + # "task_id": None, + # "id": id, + # "estimate": await self.services[ + # "kg" + # ].get_creation_estimate( + # document_id=id, + # kg_creation_settings=server_kg_creation_settings, + # ), + # } + # else: + # Otherwise, create the graph + if run_with_orchestration: + workflow_input = { + "collection_id": str(id), + "kg_creation_settings": server_kg_creation_settings.model_dump_json(), + "user": auth_user.json(), + } + + return await self.orchestration_provider.run_workflow( # type: ignore + "extract-triples", {"request": workflow_input}, {} + ) + else: + from core.main.orchestration import simple_kg_factory + + logger.info("Running extract-triples without orchestration.") + simple_kg = simple_kg_factory(self.services["kg"]) + await simple_kg["extract-triples"](workflow_input) # type: ignore + return { # type: ignore + "message": "Graph created successfully.", + "task_id": None, + } diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index 356177fdc..509ad8879 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -1316,14 +1316,14 @@ async def extract( } return await self.orchestration_provider.run_workflow( # type: ignore - "create-graph", {"request": workflow_input}, {} + "extract-triples", {"request": workflow_input}, {} ) else: from core.main.orchestration import simple_kg_factory - logger.info("Running create-graph without orchestration.") + logger.info("Running extract-triples without orchestration.") simple_kg = simple_kg_factory(self.services["kg"]) - await simple_kg["create-graph"](workflow_input) # type: ignore + await simple_kg["extract-triples"](workflow_input) # type: ignore return { # type: ignore "message": "Graph created successfully.", "task_id": None, diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index f1105b666..f8e36697b 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -499,14 +499,14 @@ async def build_communities( } # return await self.orchestration_provider.run_workflow( # type: ignore - # "enrich-graph", {"request": workflow_input}, {} + # "build-communities", {"request": workflow_input}, {} # ) # else: from core.main.orchestration import simple_kg_factory - logger.info("Running enrich-graph without orchestration.") + logger.info("Running build-communities without orchestration.") simple_kg = simple_kg_factory(self.services["kg"]) - await simple_kg["enrich-graph"](workflow_input) + await simple_kg["build-communities"](workflow_input) return { "message": "Graph communities created successfully.", "task_id": None, diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 588cdaaf6..3e7a53463 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -158,7 +158,7 @@ async def on_failure(self, context: Context) -> None: f"Failed to update document status for {document_id}: {e}" ) - @orchestration_provider.workflow(name="create-graph", timeout="600m") + @orchestration_provider.workflow(name="extract-triples", timeout="600m") class CreateGraphWorkflow: def __init__(self, kg_service: KgService): self.kg_service = kg_service @@ -373,7 +373,7 @@ async def kg_entity_deduplication_summary( "result": f"successfully queued kg entity deduplication summary for collection {graph_id}" } - @orchestration_provider.workflow(name="enrich-graph", timeout="360m") + @orchestration_provider.workflow(name="build-communities", timeout="360m") class EnrichGraphWorkflow: def __init__(self, kg_service: KgService): self.kg_service = kg_service @@ -539,8 +539,8 @@ async def kg_community_summary(self, context: Context) -> dict: return { "kg-extract": KGExtractDescribeEmbedWorkflow(service), - "create-graph": CreateGraphWorkflow(service), - "enrich-graph": EnrichGraphWorkflow(service), + "extract-triples": CreateGraphWorkflow(service), + "build-communities": EnrichGraphWorkflow(service), "kg-community-summary": KGCommunitySummaryWorkflow(service), "kg-entity-deduplication": EntityDeduplicationWorkflow(service), "kg-entity-deduplication-summary": EntityDeduplicationSummaryWorkflow( diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index b0a125f50..dce1ae352 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -40,17 +40,42 @@ def get_input_data_dict(input_data): ) return input_data - async def create_graph(input_data): + async def extract_triples(input_data): input_data = get_input_data_dict(input_data) if input_data.get("document_id"): document_ids = [input_data.get("document_id")] else: - document_ids = await service.get_document_ids_for_create_graph( - collection_id=input_data.get("collection_id"), - **input_data["kg_creation_settings"], - ) + documents = [] + collection_id = input_data.get("collection_id") + batch_size = 100 + offset = 0 + while True: + # Fetch current batch + batch = (await service.providers.database.collections_handler.documents_in_collection( + collection_id=collection_id, + offset=offset, + limit=batch_size + ))["results"] + + # If no documents returned, we've reached the end + if not batch: + break + + # Add current batch to results + documents.extend(batch) + + # Update offset for next batch + offset += batch_size + + # Optional: If batch is smaller than batch_size, we've reached the end + if len(batch) < batch_size: + break + + # documents = service.providers.database.collections_handler.documents_in_collection(input_data.get("collection_id"), offset=0, limit=1000) + print('extracting for documents = ', documents) + document_ids = [document.id for document in documents] logger.info( f"Creating graph for {len(document_ids)} documents with IDs: {document_ids}" @@ -176,8 +201,8 @@ async def entity_deduplication_workflow(input_data): ) return { - "create-graph": create_graph, - "enrich-graph": enrich_graph, + "extract-triples": extract_triples, + "build-communities": enrich_graph, "kg-community-summary": kg_community_summary, "entity-deduplication": entity_deduplication_workflow, } diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index bcd7e7565..84f206f01 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -1743,7 +1743,7 @@ async def get_creation_estimate( ) return { - "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `create-graph` with `--run` in the cli, or `run_type="run"` in the client.', + "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `extract-triples` with `--run` in the cli, or `run_type="run"` in the client.', "document_count": len(document_ids), "number_of_jobs_created": len(document_ids) + 1, "total_chunks": total_chunks, @@ -1792,7 +1792,7 @@ async def get_enrichment_estimate( if not entity_count: raise ValueError( - "No entities found in the graph. Please run `create-graph` first." + "No entities found in the graph. Please run `extract-triples` first." ) relationship_count = ( @@ -1812,7 +1812,7 @@ async def get_enrichment_estimate( if not entity_count: raise ValueError( - "No entities found in the graph. Please run `create-graph` first." + "No entities found in the graph. Please run `extract-triples` first." ) relationship_count = ( @@ -1838,7 +1838,7 @@ async def get_enrichment_estimate( ) return { - "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_type="run"` in the client.', + "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `build-communities` with `--run` in the cli, or `run_type="run"` in the client.', "total_entities": entity_count, "total_relationships": relationship_count, "estimated_llm_calls": _get_str_estimation_output( @@ -1911,7 +1911,7 @@ async def get_deduplication_estimate( } except UndefinedTableError: raise R2RException( - "Entity embedding table not found. Please run `create-graph` first.", + "Entity embedding table not found. Please run `extract-triples` first.", 404, ) except Exception as e: