Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar committed Nov 19, 2024
1 parent f076dcc commit 76210b5
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 229 deletions.
9 changes: 5 additions & 4 deletions py/core/main/api/v2/kg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ async def enrich_graph(
# If the run type is estimate, return an estimate of the enrichment cost
if run_type is KGRunType.ESTIMATE:
return await self.service.get_enrichment_estimate(
collection_id, server_kg_enrichment_settings
collection_id=collection_id,
kg_enrichment_settings=server_kg_enrichment_settings,
)

# Otherwise, run the enrichment workflow
Expand Down Expand Up @@ -291,7 +292,7 @@ async def get_entities(
)

# for backwards compatibility with the old API
return {"entities": entities["entities"]}, { # type: ignore
return entities['entities'], { # type: ignore
"total_entries": entities["total_entries"]
}

Expand Down Expand Up @@ -336,7 +337,7 @@ async def get_relationships(
)

# for backwards compatibility with the old API
return {"triples": triples["relationships"]}, { # type: ignore
return triples["relationships"], { # type: ignore
"total_entries": triples["total_entries"]
}

Expand Down Expand Up @@ -380,7 +381,7 @@ async def get_communities(
)

# for backwards compatibility with the old API
return {"communities": results["communities"]}, { # type: ignore
return results["communities"], { # type: ignore
"total_entries": results["total_entries"]
}

Expand Down
38 changes: 27 additions & 11 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def _deduplicate_entities(
)

workflow_input = {
"id": str(id),
"graph_id": str(id),
"kg_entity_deduplication_settings": server_settings.model_dump_json(),
"user": auth_user.model_dump_json(),
}
Expand Down Expand Up @@ -183,9 +183,13 @@ async def _create_communities(

# If the run type is estimate, return an estimate of the enrichment cost
if run_type is KGRunType.ESTIMATE:
return await self.services["kg"].get_enrichment_estimate(
id, server_kg_enrichment_settings
)
return {
"message": "Ran community build estimate.",
"estimate": await self.services["kg"].get_enrichment_estimate(
graph_id=graph_id,
kg_enrichment_settings=server_kg_enrichment_settings,
),
}

else:
if run_with_orchestration:
Expand Down Expand Up @@ -369,6 +373,10 @@ async def list_entities(
None,
description="A list of attributes to return. By default, all attributes are returned.",
),
from_built_graph: Optional[bool] = Query(
False,
description="Whether to retrieve entities from the built graph.",
),
offset: int = Query(
0,
ge=0,
Expand Down Expand Up @@ -402,6 +410,7 @@ async def list_entities(
entity_names=entity_names,
entity_categories=entity_categories,
attributes=attributes,
from_built_graph=from_built_graph,
)

return entities, { # type: ignore
Expand Down Expand Up @@ -992,26 +1001,26 @@ async def delete_relationship(
################### COMMUNITIES ###################

@self.router.post(
"/graphs/{id}/build/communities",
"/graphs/{id}/build/communities/{run_type}",
summary="Build communities in the graph",
)
@self.base_endpoint
async def create_communities(
id: UUID = Path(...),
settings: Optional[dict] = Body(None),
run_type: Optional[KGRunType] = Body(
default=None,
run_type: Optional[KGRunType] = Path(
description="Run type for the graph creation process.",
),
run_with_orchestration: bool = Query(True),
auth_user=Depends(self.providers.auth.auth_wrapper),
) -> WrappedKGEnrichmentResponse:

return await self._create_communities(
id=id,
graph_id=id,
settings=settings,
run_type=run_type,
run_with_orchestration=run_with_orchestration,
auth_user=auth_user,
)

@self.router.post(
Expand Down Expand Up @@ -1547,25 +1556,32 @@ async def remove_data(
raise R2RException("Invalid data type", 400)

@self.router.post(
"/graphs/{id}/build",
"/graphs/{id}/build/{run_type}",
summary="Build entities, relationships, and communities in the graph",
)
@self.base_endpoint
async def build(
id: UUID = Path(...),
run_type: KGRunType = Path(...),
settings: GraphBuildSettings = Body(GraphBuildSettings()),
run_with_orchestration: bool = Query(True),
auth_user=Depends(self.providers.auth.auth_wrapper),
):

# build entities
logger.info(f"Building entities for graph {id}")
entities_result = await self._deduplicate_entities(
id, settings.entity_settings, run_type=KGRunType.RUN
id, settings.entity_settings.__dict__, run_type=run_type,
run_with_orchestration=run_with_orchestration,
auth_user=auth_user,
)

# build communities
logger.info(f"Building communities for graph {id}")
communities_result = await self._create_communities(
id, settings.community_settings, run_type=KGRunType.RUN
id, settings.community_settings.__dict__, run_type=run_type,
run_with_orchestration=run_with_orchestration,
auth_user=auth_user,
)

return {
Expand Down
9 changes: 7 additions & 2 deletions py/core/main/orchestration/hatchet/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def get_input_data_dict(input_data):
if key == "document_id":
input_data[key] = uuid.UUID(value)

if key == "collection_id":
input_data[key] = uuid.UUID(value)

if key == "graph_id":
input_data[key] = uuid.UUID(value)

Expand Down Expand Up @@ -411,7 +414,8 @@ async def kg_community_summary(self, context: Context) -> dict:
input_data = get_input_data_dict(
context.workflow_input()["request"]
)
graph_id = input_data["graph_id"]
graph_id = input_data.get("graph_id", None)
collection_id = input_data.get("collection_id", None)
num_communities = context.step_output("kg_clustering")[
"kg_clustering"
][0]["num_communities"]
Expand All @@ -437,6 +441,7 @@ async def kg_community_summary(self, context: Context) -> dict:
num_communities - offset,
),
"graph_id": str(graph_id),
"collection_id": str(collection_id),
**input_data["kg_enrichment_settings"],
}
},
Expand All @@ -454,7 +459,7 @@ async def kg_community_summary(self, context: Context) -> dict:
document_ids = await self.kg_service.providers.database.get_document_ids_by_status(
status_type="kg_extraction_status",
status=KGExtractionStatus.SUCCESS,
collection_id=graph_id,
collection_id=collection_id,
)

await self.kg_service.providers.database.set_workflow_status(
Expand Down
10 changes: 8 additions & 2 deletions py/core/main/orchestration/simple/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def get_input_data_dict(input_data):
if key == "collection_id":
input_data[key] = uuid.UUID(value)

if key == "graph_id":
input_data[key] = uuid.UUID(value)

if key == "kg_creation_settings":
input_data[key] = json.loads(value)
input_data[key]["generation_config"] = GenerationConfig(
Expand Down Expand Up @@ -76,7 +79,8 @@ async def enrich_graph(input_data):

try:
num_communities = await service.kg_clustering(
collection_id=input_data["collection_id"],
collection_id=input_data.get("collection_id", None),
graph_id=input_data.get("graph_id", None),
**input_data["kg_enrichment_settings"],
)
num_communities = num_communities[0]["num_communities"]
Expand Down Expand Up @@ -144,11 +148,13 @@ async def entity_deduplication_workflow(input_data):
input_data["kg_entity_deduplication_settings"]
)

collection_id = input_data["collection_id"]
collection_id = input_data.get("collection_id", None)
graph_id = input_data.get("graph_id", None)

number_of_distinct_entities = (
await service.kg_entity_deduplication(
collection_id=collection_id,
graph_id=graph_id,
**input_data["kg_entity_deduplication_settings"],
)
)[0]["num_entities"]
Expand Down
26 changes: 20 additions & 6 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ async def list_entities(
entity_names: Optional[list[str]] = None,
entity_categories: Optional[list[str]] = None,
attributes: Optional[list[str]] = None,
from_built_graph: Optional[bool] = False,
):
return await self.providers.database.graph_handler.entities.get(
level=level,
Expand All @@ -149,6 +150,7 @@ async def list_entities(
attributes=attributes,
offset=offset,
limit=limit,
from_built_graph=from_built_graph,
)

@telemetry_event("update_entity")
Expand Down Expand Up @@ -469,6 +471,7 @@ async def get_graph_status(
async def kg_clustering(
self,
collection_id: UUID,
graph_id: UUID,
generation_config: GenerationConfig,
leiden_params: dict,
**kwargs,
Expand All @@ -477,10 +480,12 @@ async def kg_clustering(
logger.info(
f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
)

clustering_result = await self.pipes.kg_clustering_pipe.run(
input=self.pipes.kg_clustering_pipe.Input(
message={
"collection_id": collection_id,
"graph_id": graph_id,
"generation_config": generation_config,
"leiden_params": leiden_params,
"logger": logger,
Expand Down Expand Up @@ -580,13 +585,19 @@ async def get_creation_estimate(
@telemetry_event("get_enrichment_estimate")
async def get_enrichment_estimate(
self,
collection_id: UUID,
kg_enrichment_settings: KGEnrichmentSettings,
collection_id: UUID | None = None,
graph_id: UUID | None = None,
kg_enrichment_settings: KGEnrichmentSettings | None = None,
**kwargs,
):

if graph_id is None and collection_id is None:
raise ValueError("Either graph_id or collection_id must be provided")

return await self.providers.database.graph_handler.get_enrichment_estimate(
collection_id, kg_enrichment_settings
collection_id=collection_id,
graph_id=graph_id,
kg_enrichment_settings=kg_enrichment_settings,
)

@telemetry_event("get_deduplication_estimate")
Expand All @@ -597,13 +608,15 @@ async def get_deduplication_estimate(
**kwargs,
):
return await self.providers.database.graph_handler.get_deduplication_estimate(
collection_id, kg_deduplication_settings
collection_id=collection_id,
kg_deduplication_settings=kg_deduplication_settings,
)

@telemetry_event("kg_entity_deduplication")
async def kg_entity_deduplication(
self,
id: UUID,
collection_id: UUID,
graph_id: UUID,
kg_entity_deduplication_type: KGEntityDeduplicationType,
kg_entity_deduplication_prompt: str,
generation_config: GenerationConfig,
Expand All @@ -612,7 +625,8 @@ async def kg_entity_deduplication(
deduplication_results = await self.pipes.kg_entity_deduplication_pipe.run(
input=self.pipes.kg_entity_deduplication_pipe.Input(
message={
"id": id,
"collection_id": collection_id,
"graph_id": graph_id,
"kg_entity_deduplication_type": kg_entity_deduplication_type,
"kg_entity_deduplication_prompt": kg_entity_deduplication_prompt,
"generation_config": generation_config,
Expand Down
15 changes: 11 additions & 4 deletions py/core/pipes/kg/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,17 @@ def __init__(
async def cluster_kg(
self,
collection_id: UUID,
graph_id: UUID,
leiden_params: dict,
):
"""
Clusters the knowledge graph relationships into communities using hierarchical Leiden algorithm. Uses graspologic library.
"""

num_communities = await self.database_provider.graph_handler.perform_graph_clustering(
collection_id,
leiden_params,
collection_id=collection_id,
graph_id=graph_id,
leiden_params=leiden_params,
) # type: ignore

logger.info(
Expand All @@ -74,7 +76,12 @@ async def _run_logic( # type: ignore
Executes the KG clustering pipe: clustering entities and relationships into communities.
"""

collection_id = input.message["collection_id"]
collection_id = input.message.get("collection_id", None)
graph_id = input.message.get("graph_id", None)
leiden_params = input.message["leiden_params"]

yield await self.cluster_kg(collection_id, leiden_params)
yield await self.cluster_kg(
collection_id=collection_id,
graph_id=graph_id,
leiden_params=leiden_params,
)
Loading

0 comments on commit 76210b5

Please sign in to comment.