Skip to content

Commit

Permalink
Graph refactor (#1611)
Browse files Browse the repository at this point in the history
* up

* up

* add back routers

* up

* pre-commit

* update tests

* revert test change

* up
  • Loading branch information
shreyaspimpalgaonkar authored Nov 20, 2024
1 parent 7a9b81c commit 934a66a
Show file tree
Hide file tree
Showing 14 changed files with 1,980 additions and 1,404 deletions.
426 changes: 426 additions & 0 deletions py/core/main/api/v3/communities_router.py

Large diffs are not rendered by default.

119 changes: 118 additions & 1 deletion py/core/main/api/v3/documents_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional
from uuid import UUID

from fastapi import Depends, File, Form, Path, Query, UploadFile
from fastapi import Depends, File, Form, Path, Query, UploadFile, Body
from fastapi.responses import StreamingResponse
from pydantic import Json

Expand All @@ -20,12 +20,23 @@
WrappedDocumentResponse,
WrappedDocumentsResponse,
WrappedIngestionResponse,
WrappedKGCreationResponse,
)
from core.providers import (
HatchetOrchestrationProvider,
SimpleOrchestrationProvider,
)

from core.base.abstractions import (
Entity,
KGCreationSettings,
KGRunType,
Relationship,
GraphBuildSettings,
)

from core.utils import update_settings_from_dict

from .base_router import BaseRouterV3

logger = logging.getLogger()
Expand Down Expand Up @@ -1205,6 +1216,112 @@ async def get_document_collections(
"total_entries": collections_response["total_entries"]
}

@self.router.post(
"/documents/{id}/entities_and_relationships",
summary="Extract entities and relationships from a document",
openapi_extra={
"x-codeSamples": [
{
"lang": "Python",
"source": textwrap.dedent(
"""
from r2r import R2RClient
client = R2RClient("http://localhost:7272")
# when using auth, do client.login(...)
result = client.documents.extract_entities_and_relationships(
id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1"
)
"""
),
},
],
"operationId": "documents_extract_entities_and_relationships_v3_documents__id__entities_and_relationships_post_documents",
},
)
@self.base_endpoint
async def extract_entities_and_relationships(
id: UUID = Path(
...,
description="The ID of the document to extract entities and relationships from.",
),
run_type: KGRunType = Query(
default=KGRunType.ESTIMATE,
description="Whether to return an estimate of the creation cost or to actually create the graph.",
),
settings: Optional[KGCreationSettings] = Body(
default=None,
description="Settings for the entities and relationships extraction process.",
),
run_with_orchestration: Optional[bool] = Query(
default=True,
description="Whether to run the entities and relationships extraction process with orchestration.",
),
auth_user=Depends(self.providers.auth.auth_wrapper),
) -> WrappedKGCreationResponse: # type: ignore
"""
Extracts entities and relationships from a document.
The entities and relationships extraction process involves:
1. Parsing documents into semantic chunks
2. Extracting entities and relationships using LLMs
"""

settings = settings.dict() if settings else None # type: ignore
if not auth_user.is_superuser:
logger.warning("Implement permission checks here.")

# If no run type is provided, default to estimate
if not run_type:
run_type = KGRunType.ESTIMATE

# Apply runtime settings overrides
server_kg_creation_settings = (
self.providers.database.config.kg_creation_settings
)

if settings:
server_kg_creation_settings = update_settings_from_dict(
server_settings=server_kg_creation_settings,
settings_dict=settings, # type: ignore
)

# If the run type is estimate, return an estimate of the creation cost
if run_type is KGRunType.ESTIMATE:
return { # type: ignore
"message": "Estimate retrieved successfully",
"task_id": None,
"id": id,
"estimate": await self.services[
"kg"
].get_creation_estimate(
document_id=id,
kg_creation_settings=server_kg_creation_settings,
),
}
else:
# Otherwise, create the graph
if run_with_orchestration:
workflow_input = {
"document_id": str(id),
"kg_creation_settings": server_kg_creation_settings.model_dump_json(),
"user": auth_user.json(),
}

return await self.orchestration_provider.run_workflow( # type: ignore
"create-graph", {"request": workflow_input}, {}
)
else:
from core.main.orchestration import simple_kg_factory

logger.info("Running create-graph without orchestration.")
simple_kg = simple_kg_factory(self.services["kg"])
await simple_kg["create-graph"](workflow_input) # type: ignore
return { # type: ignore
"message": "Graph created successfully.",
"task_id": None,
}

@staticmethod
async def _process_file(file):
import base64
Expand Down
Loading

0 comments on commit 934a66a

Please sign in to comment.