Skip to content

Commit

Permalink
V3 graph testing (#1606)
Browse files Browse the repository at this point in the history
* up

* up

* up

* graph crud

* up

* community endpts

* up

* up

* up

* up

* up

* up
  • Loading branch information
shreyaspimpalgaonkar authored Nov 19, 2024
1 parent 1085576 commit 1f28eea
Show file tree
Hide file tree
Showing 24 changed files with 1,991 additions and 1,274 deletions.
12 changes: 10 additions & 2 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
CommunityInfo,
Community,
Entity,
EntityLevel,
DataLevel,
EntityType,
Graph,
KGExtraction,
Expand All @@ -39,6 +39,10 @@
KGCreationSettings,
KGEnrichmentSettings,
KGEntityDeduplicationSettings,
GraphBuildSettings,
GraphEntitySettings,
GraphRelationshipSettings,
GraphCommunitySettings,
KGEntityDeduplicationType,
KGRunType,
)
Expand Down Expand Up @@ -113,7 +117,7 @@
"CommunityInfo",
"KGExtraction",
"Relationship",
"EntityLevel",
"DataLevel",
# Index abstractions
"IndexConfig",
# LLM abstractions
Expand Down Expand Up @@ -141,6 +145,10 @@
"KGCreationSettings",
"KGEnrichmentSettings",
"KGEntityDeduplicationSettings",
"GraphBuildSettings",
"GraphEntitySettings",
"GraphRelationshipSettings",
"GraphCommunitySettings",
"KGEntityDeduplicationType",
"KGRunType",
# User abstractions
Expand Down
24 changes: 10 additions & 14 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
from abc import ABC, abstractmethod
from datetime import datetime
from io import BytesIO
from typing import Any, BinaryIO, Optional, Sequence, Tuple
from typing import (
Any,
BinaryIO,
Optional,
Sequence,
Tuple,
Union,
)
from uuid import UUID

from pydantic import BaseModel
Expand Down Expand Up @@ -692,26 +699,15 @@ async def get(self, *args: Any, **kwargs: Any) -> list[Graph]:
pass

@abstractmethod
async def update(self, *args: Any, **kwargs: Any) -> None:
async def update(self, graph: Graph) -> UUID:
"""Update graph in storage."""
pass

@abstractmethod
async def delete(self, *args: Any, **kwargs: Any) -> None:
async def delete(self, graph_id: UUID, cascade: bool = False) -> UUID:
"""Delete graph from storage."""
pass

# add documents to the graph
@abstractmethod
async def add_document(self, *args: Any, **kwargs: Any) -> None:
"""Add document to graph."""
pass

@abstractmethod
async def remove_document(self, *args: Any, **kwargs: Any) -> None:
"""Delete document from graph."""
pass


class PromptHandler(Handler):
"""Abstract base class for prompt handling operations."""
Expand Down
39 changes: 27 additions & 12 deletions py/core/main/api/v2/kg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi import Body, Depends, Query

from core.base import Workflow
from core.base.abstractions import EntityLevel, KGRunType
from core.base.abstractions import DataLevel, KGRunType
from core.base.api.models import (
WrappedCommunitiesResponse,
WrappedKGCreationResponse,
Expand Down Expand Up @@ -140,7 +140,8 @@ async def create_graph(
# If the run type is estimate, return an estimate of the creation cost
if run_type is KGRunType.ESTIMATE:
return await self.service.get_creation_estimate(
collection_id, server_kg_creation_settings
collection_id=collection_id,
kg_creation_settings=server_kg_creation_settings,
)
else:

Expand Down Expand Up @@ -216,7 +217,8 @@ async def enrich_graph(
# If the run type is estimate, return an estimate of the enrichment cost
if run_type is KGRunType.ESTIMATE:
return await self.service.get_enrichment_estimate(
collection_id, server_kg_enrichment_settings
collection_id=collection_id,
kg_enrichment_settings=server_kg_enrichment_settings,
)

# Otherwise, run the enrichment workflow
Expand Down Expand Up @@ -248,8 +250,8 @@ async def get_entities(
collection_id: Optional[UUID] = Query(
None, description="Collection ID to retrieve entities from."
),
entity_level: Optional[EntityLevel] = Query(
default=EntityLevel.DOCUMENT,
entity_level: Optional[DataLevel] = Query(
default=DataLevel.DOCUMENT,
description="Type of entities to retrieve. Options are: raw, dedup_document, dedup_collection.",
),
entity_ids: Optional[list[str]] = Query(
Expand All @@ -262,7 +264,7 @@ async def get_entities(
description="Number of items to return. Use -1 to return all items.",
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedEntitiesResponse:
) -> WrappedEntitiesResponse: # type: ignore
"""
Retrieve entities from the knowledge graph.
"""
Expand All @@ -274,12 +276,12 @@ async def get_entities(
auth_user.id
)

if entity_level == EntityLevel.CHUNK:
if entity_level == DataLevel.CHUNK:
entity_table_name = "chunk_entity"
elif entity_level == EntityLevel.DOCUMENT:
elif entity_level == DataLevel.DOCUMENT:
entity_table_name = "document_entity"
else:
entity_table_name = "collection_entity"
entity_table_name = "graph_entity"

entities = await self.service.get_entities(
collection_id=collection_id,
Expand All @@ -289,7 +291,10 @@ async def get_entities(
limit=limit,
)

return entities
# for backwards compatibility with the old API
return entities["entities"], { # type: ignore
"total_entries": entities["total_entries"]
}

@self.router.get("/triples")
@self.base_endpoint
Expand Down Expand Up @@ -323,14 +328,19 @@ async def get_relationships(
auth_user.id
)

return await self.service.get_relationships(
triples = await self.service.get_relationships(
offset=offset,
limit=limit,
collection_id=collection_id,
entity_names=entity_names,
relationship_ids=triple_ids,
)

# for backwards compatibility with the old API
return triples["relationships"], { # type: ignore
"total_entries": triples["total_entries"]
}

@self.router.get("/communities")
@self.base_endpoint
async def get_communities(
Expand Down Expand Up @@ -362,14 +372,19 @@ async def get_communities(
auth_user.id
)

return await self.service.get_communities(
results = await self.service.get_communities(
offset=offset,
limit=limit,
collection_id=collection_id,
levels=levels,
community_numbers=community_numbers,
)

# for backwards compatibility with the old API
return results["communities"], { # type: ignore
"total_entries": results["total_entries"]
}

@self.router.post("/deduplicate_entities")
@self.base_endpoint
async def deduplicate_entities(
Expand Down
Loading

0 comments on commit 1f28eea

Please sign in to comment.