Skip to content

Commit

Permalink
add collection extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Nov 28, 2024
1 parent e40884f commit 8767153
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 35 deletions.
22 changes: 11 additions & 11 deletions py/core/main/api/v2/kg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,20 @@ def _register_workflows(self):

workflow_messages = {}
if self.orchestration_provider.config.provider == "hatchet":
workflow_messages["create-graph"] = (
workflow_messages["extract-triples"] = (
"Graph creation task queued successfully."
)
workflow_messages["enrich-graph"] = (
workflow_messages["build-communities"] = (
"Graph enrichment task queued successfully."
)
workflow_messages["entity-deduplication"] = (
"KG Entity Deduplication task queued successfully."
)
else:
workflow_messages["create-graph"] = (
workflow_messages["extract-triples"] = (
"Document entities and relationships extracted successfully. To generate GraphRAG communities, run cluster on the collection this document belongs to."
)
workflow_messages["enrich-graph"] = (
workflow_messages["build-communities"] = (
"Graph communities created successfully. You can view the communities at http://localhost:7272/v2/communities"
)
workflow_messages["entity-deduplication"] = (
Expand Down Expand Up @@ -119,7 +119,7 @@ async def create_graph(
auth_user.id
)

logger.info(f"Running create-graph on collection {collection_id}")
logger.info(f"Running extract-triples on collection {collection_id}")

# If no run type is provided, default to estimate
if not run_type:
Expand Down Expand Up @@ -152,14 +152,14 @@ async def create_graph(
}

return await self.orchestration_provider.run_workflow( # type: ignore
"create-graph", {"request": workflow_input}, {}
"extract-triples", {"request": workflow_input}, {}
)
else:
from core.main.orchestration import simple_kg_factory

logger.info("Running create-graph without orchestration.")
logger.info("Running extract-triples without orchestration.")
simple_kg = simple_kg_factory(self.service)
await simple_kg["create-graph"](workflow_input)
await simple_kg["extract-triples"](workflow_input)
return {
"message": "Graph created successfully.",
"task_id": None,
Expand Down Expand Up @@ -229,14 +229,14 @@ async def enrich_graph(
}

return await self.orchestration_provider.run_workflow( # type: ignore
"enrich-graph", {"request": workflow_input}, {}
"build-communities", {"request": workflow_input}, {}
)
else:
from core.main.orchestration import simple_kg_factory

logger.info("Running enrich-graph without orchestration.")
logger.info("Running build-communities without orchestration.")
simple_kg = simple_kg_factory(self.service)
await simple_kg["enrich-graph"](workflow_input)
await simple_kg["build-communities"](workflow_input)
return {
"message": "Graph communities created successfully.",
"task_id": None,
Expand Down
111 changes: 109 additions & 2 deletions py/core/main/api/v3/collections_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fastapi import Body, Depends, Path, Query

from core.base import R2RException, RunType, KGRunType
from core.base import R2RException, RunType, KGRunType, KGCreationSettings
from core.base.api.models import (
GenericBooleanResponse,
GenericMessageResponse,
Expand All @@ -15,6 +15,7 @@
WrappedDocumentsResponse,
WrappedGenericMessageResponse,
WrappedUsersResponse,
WrappedKGCreationResponse
)
from core.providers import (
HatchetOrchestrationProvider,
Expand Down Expand Up @@ -1021,4 +1022,110 @@ async def remove_user_from_collection(
await self.services["management"].remove_user_from_collection(
user_id, id
)
return GenericBooleanResponse(success=True) # type: ignore
return GenericBooleanResponse(success=True) # type: ignore


@self.router.post(
"/collections/{id}/extract",
summary="Extract entities and relationships",
openapi_extra={
"x-codeSamples": [
{
"lang": "Python",
"source": textwrap.dedent(
"""
from r2r import R2RClient
client = R2RClient("http://localhost:7272")
# when using auth, do client.login(...)
result = client.documents.extract(
id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1"
)
"""
),
},
],
},
)
@self.base_endpoint
async def extract(
id: UUID = Path(
...,
description="The ID of the document to extract entities and relationships from.",
),
run_type: KGRunType = Query(
default=KGRunType.RUN,
description="Whether to return an estimate of the creation cost or to actually extract the document.",
),
settings: Optional[KGCreationSettings] = Body(
default=None,
description="Settings for the entities and relationships extraction process.",
),
run_with_orchestration: Optional[bool] = Query(
default=True,
description="Whether to run the entities and relationships extraction process with orchestration.",
),
auth_user=Depends(self.providers.auth.auth_wrapper),
) -> WrappedKGCreationResponse: # type: ignore
"""
Extracts entities and relationships from a document.
The entities and relationships extraction process involves:
1. Parsing documents into semantic chunks
2. Extracting entities and relationships using LLMs
"""

settings = settings.dict() if settings else None # type: ignore
if not auth_user.is_superuser:
logger.warning("Implement permission checks here.")

# If no run type is provided, default to estimate
if not run_type:
run_type = KGRunType.ESTIMATE

# Apply runtime settings overrides
server_kg_creation_settings = (
self.providers.database.config.kg_creation_settings
)

if settings:
server_kg_creation_settings = update_settings_from_dict(
server_settings=server_kg_creation_settings,
settings_dict=settings, # type: ignore
)

# If the run type is estimate, return an estimate of the creation cost
# if run_type is KGRunType.ESTIMATE:
# return { # type: ignore
# "message": "Estimate retrieved successfully",
# "task_id": None,
# "id": id,
# "estimate": await self.services[
# "kg"
# ].get_creation_estimate(
# document_id=id,
# kg_creation_settings=server_kg_creation_settings,
# ),
# }
# else:
# Otherwise, create the graph
if run_with_orchestration:
workflow_input = {
"collection_id": str(id),
"kg_creation_settings": server_kg_creation_settings.model_dump_json(),
"user": auth_user.json(),
}

return await self.orchestration_provider.run_workflow( # type: ignore
"extract-triples", {"request": workflow_input}, {}
)
else:
from core.main.orchestration import simple_kg_factory

logger.info("Running extract-triples without orchestration.")
simple_kg = simple_kg_factory(self.services["kg"])
await simple_kg["extract-triples"](workflow_input) # type: ignore
return { # type: ignore
"message": "Graph created successfully.",
"task_id": None,
}
6 changes: 3 additions & 3 deletions py/core/main/api/v3/documents_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,14 +1316,14 @@ async def extract(
}

return await self.orchestration_provider.run_workflow( # type: ignore
"create-graph", {"request": workflow_input}, {}
"extract-triples", {"request": workflow_input}, {}
)
else:
from core.main.orchestration import simple_kg_factory

logger.info("Running create-graph without orchestration.")
logger.info("Running extract-triples without orchestration.")
simple_kg = simple_kg_factory(self.services["kg"])
await simple_kg["create-graph"](workflow_input) # type: ignore
await simple_kg["extract-triples"](workflow_input) # type: ignore
return { # type: ignore
"message": "Graph created successfully.",
"task_id": None,
Expand Down
6 changes: 3 additions & 3 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,14 +499,14 @@ async def build_communities(
}

# return await self.orchestration_provider.run_workflow( # type: ignore
# "enrich-graph", {"request": workflow_input}, {}
# "build-communities", {"request": workflow_input}, {}
# )
# else:
from core.main.orchestration import simple_kg_factory

logger.info("Running enrich-graph without orchestration.")
logger.info("Running build-communities without orchestration.")
simple_kg = simple_kg_factory(self.services["kg"])
await simple_kg["enrich-graph"](workflow_input)
await simple_kg["build-communities"](workflow_input)
return {
"message": "Graph communities created successfully.",
"task_id": None,
Expand Down
8 changes: 4 additions & 4 deletions py/core/main/orchestration/hatchet/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async def on_failure(self, context: Context) -> None:
f"Failed to update document status for {document_id}: {e}"
)

@orchestration_provider.workflow(name="create-graph", timeout="600m")
@orchestration_provider.workflow(name="extract-triples", timeout="600m")
class CreateGraphWorkflow:
def __init__(self, kg_service: KgService):
self.kg_service = kg_service
Expand Down Expand Up @@ -373,7 +373,7 @@ async def kg_entity_deduplication_summary(
"result": f"successfully queued kg entity deduplication summary for collection {graph_id}"
}

@orchestration_provider.workflow(name="enrich-graph", timeout="360m")
@orchestration_provider.workflow(name="build-communities", timeout="360m")
class EnrichGraphWorkflow:
def __init__(self, kg_service: KgService):
self.kg_service = kg_service
Expand Down Expand Up @@ -539,8 +539,8 @@ async def kg_community_summary(self, context: Context) -> dict:

return {
"kg-extract": KGExtractDescribeEmbedWorkflow(service),
"create-graph": CreateGraphWorkflow(service),
"enrich-graph": EnrichGraphWorkflow(service),
"extract-triples": CreateGraphWorkflow(service),
"build-communities": EnrichGraphWorkflow(service),
"kg-community-summary": KGCommunitySummaryWorkflow(service),
"kg-entity-deduplication": EntityDeduplicationWorkflow(service),
"kg-entity-deduplication-summary": EntityDeduplicationSummaryWorkflow(
Expand Down
39 changes: 32 additions & 7 deletions py/core/main/orchestration/simple/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,42 @@ def get_input_data_dict(input_data):
)
return input_data

async def create_graph(input_data):
async def extract_triples(input_data):

input_data = get_input_data_dict(input_data)

if input_data.get("document_id"):
document_ids = [input_data.get("document_id")]
else:
document_ids = await service.get_document_ids_for_create_graph(
collection_id=input_data.get("collection_id"),
**input_data["kg_creation_settings"],
)
documents = []
collection_id = input_data.get("collection_id")
batch_size = 100
offset = 0
while True:
# Fetch current batch
batch = (await service.providers.database.collections_handler.documents_in_collection(
collection_id=collection_id,
offset=offset,
limit=batch_size
))["results"]

# If no documents returned, we've reached the end
if not batch:
break

# Add current batch to results
documents.extend(batch)

# Update offset for next batch
offset += batch_size

# Optional: If batch is smaller than batch_size, we've reached the end
if len(batch) < batch_size:
break

# documents = service.providers.database.collections_handler.documents_in_collection(input_data.get("collection_id"), offset=0, limit=1000)
print('extracting for documents = ', documents)
document_ids = [document.id for document in documents]

logger.info(
f"Creating graph for {len(document_ids)} documents with IDs: {document_ids}"
Expand Down Expand Up @@ -176,8 +201,8 @@ async def entity_deduplication_workflow(input_data):
)

return {
"create-graph": create_graph,
"enrich-graph": enrich_graph,
"extract-triples": extract_triples,
"build-communities": enrich_graph,
"kg-community-summary": kg_community_summary,
"entity-deduplication": entity_deduplication_workflow,
}
10 changes: 5 additions & 5 deletions py/core/providers/database/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,7 +1743,7 @@ async def get_creation_estimate(
)

return {
"message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `create-graph` with `--run` in the cli, or `run_type="run"` in the client.',
"message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `extract-triples` with `--run` in the cli, or `run_type="run"` in the client.',
"document_count": len(document_ids),
"number_of_jobs_created": len(document_ids) + 1,
"total_chunks": total_chunks,
Expand Down Expand Up @@ -1792,7 +1792,7 @@ async def get_enrichment_estimate(

if not entity_count:
raise ValueError(
"No entities found in the graph. Please run `create-graph` first."
"No entities found in the graph. Please run `extract-triples` first."
)

relationship_count = (
Expand All @@ -1812,7 +1812,7 @@ async def get_enrichment_estimate(

if not entity_count:
raise ValueError(
"No entities found in the graph. Please run `create-graph` first."
"No entities found in the graph. Please run `extract-triples` first."
)

relationship_count = (
Expand All @@ -1838,7 +1838,7 @@ async def get_enrichment_estimate(
)

return {
"message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_type="run"` in the client.',
"message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `build-communities` with `--run` in the cli, or `run_type="run"` in the client.',
"total_entities": entity_count,
"total_relationships": relationship_count,
"estimated_llm_calls": _get_str_estimation_output(
Expand Down Expand Up @@ -1911,7 +1911,7 @@ async def get_deduplication_estimate(
}
except UndefinedTableError:
raise R2RException(
"Entity embedding table not found. Please run `create-graph` first.",
"Entity embedding table not found. Please run `extract-triples` first.",
404,
)
except Exception as e:
Expand Down

0 comments on commit 8767153

Please sign in to comment.