Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Nov 30, 2024
1 parent ed7f734 commit a21003e
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 161 deletions.
2 changes: 0 additions & 2 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
KGEntityResult,
KGGlobalResult,
KGRelationshipResult,
KGSearchMethod,
KGSearchResultType,
SearchSettings,
)
Expand Down Expand Up @@ -131,7 +130,6 @@
# Search abstractions
"AggregateSearchResult",
"GraphSearchResult",
"KGSearchMethod",
"KGSearchResultType",
"KGEntityResult",
"KGRelationshipResult",
Expand Down
67 changes: 35 additions & 32 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,13 +601,11 @@ async def get_entities(
"The currently authenticated user does not have access to the specified graph.",
403,
)
# return await self.services["kg"].get_entities(
# id, offset, limit, auth_user
# )
entities, count = (
await self.providers.database.graph_handler.get_entities(
collection_id, offset, limit
)

entities, count = await self.services["kg"].get_entities(
parent_id=collection_id,
offset=offset,
limit=limit,
)

return entities, { # type: ignore
Expand Down Expand Up @@ -775,7 +773,11 @@ async def get_entity(
)

result = await self.providers.database.graph_handler.entities.get(
collection_id, "graph", entity_ids=[entity_id]
parent_id=collection_id,
store_type="graph",
offset=0,
limit=1,
entity_ids=[entity_id],
)
if len(result) == 0 or len(result[0]) == 0:
raise R2RException("Entity not found", 404)
Expand Down Expand Up @@ -888,9 +890,11 @@ async def delete_entity(
403,
)

await self.providers.database.graph_handler.entities.delete(
collection_id, [entity_id], "graph"
self.services["kg"].delete_entity(
parent_id=collection_id,
entity_id=entity_id,
)

return GenericBooleanResponse(success=True) # type: ignore

@self.router.get(
Expand Down Expand Up @@ -963,13 +967,10 @@ async def get_relationships(
403,
)

relationships, count = (
await self.providers.database.graph_handler.relationships.get(
parent_id=collection_id,
store_type="graph",
offset=offset,
limit=limit,
)
relationships, count = await self.services["kg"].get_relationships(
parent_id=collection_id,
offset=offset,
limit=limit,
)

return relationships, { # type: ignore
Expand Down Expand Up @@ -1042,7 +1043,11 @@ async def get_relationship(

results = (
await self.providers.database.graph_handler.relationships.get(
collection_id, "graph", relationship_ids=[relationship_id]
parent_id=collection_id,
store_type="graph",
offset=0,
limit=1,
relationship_ids=[relationship_id],
)
)
if len(results) == 0 or len(results[0]) == 0:
Expand Down Expand Up @@ -1175,14 +1180,11 @@ async def delete_relationship(
403,
)

# return await self.services[
# "kg"
# ].documents.graph_handler.relationships.remove_from_graph(
# id, relationship_id, auth_user
# )
await self.providers.database.graph_handler.relationships.delete(
collection_id, [relationship_id], "graph"
await self.services["kg"].delete_relationship(
parent_id=collection_id,
relationship_id=relationship_id,
)

return GenericBooleanResponse(success=True) # type: ignore

@self.router.post(
Expand Down Expand Up @@ -1358,12 +1360,10 @@ async def get_communities(
403,
)

communities, count = (
await self.providers.database.graph_handler.get_communities(
collection_id=collection_id,
offset=offset,
limit=limit,
)
communities, count = await self.services["kg"].get_communities(
parent_id=collection_id,
offset=offset,
limit=limit,
)

return communities, { # type: ignore
Expand Down Expand Up @@ -1720,7 +1720,10 @@ async def pull(
)
entities = (
await self.providers.database.graph_handler.entities.get(
document.id, store_type="document"
parent_id=document.id,
store_type="document",
offset=0,
limit=100,
)
)
has_document = (
Expand Down
115 changes: 33 additions & 82 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,29 +157,6 @@ async def create_entity(
metadata=metadata,
)

@telemetry_event("list_entities")
async def list_entities(
self,
offset: int,
limit: int,
id: Optional[UUID] = None,
graph_id: Optional[UUID] = None,
document_id: Optional[UUID] = None,
entity_names: Optional[list[str]] = None,
include_embeddings: Optional[bool] = False,
user_id: Optional[UUID] = None,
):
return await self.providers.database.graph_handler.entities.get(
id=id,
graph_id=graph_id,
document_id=document_id,
entity_names=entity_names,
include_embeddings=include_embeddings,
offset=offset,
limit=limit,
user_id=user_id,
)

@telemetry_event("update_entity")
async def update_entity(
self,
Expand All @@ -198,69 +175,43 @@ async def update_entity(

return await self.providers.database.graph_handler.entities.update(
entity_id=entity_id,
store_type="graph", # type: ignore
name=name,
description=description,
category=category,
description_embedding=description_embedding,
category=category,
metadata=metadata,
store_type="graph", # type: ignore
)

@telemetry_event("delete_entity")
async def delete_entity(
self,
id: UUID,
parent_id: UUID,
entity_id: UUID,
level: DataLevel,
):
return await self.providers.database.graph_handler.entities.delete(
id=id,
entity_id=entity_id,
level=level,
parent_id=parent_id,
entity_ids=[entity_id],
store_type="graph", # type: ignore
)

# TODO: deprecate this
@telemetry_event("get_entities")
async def get_entities(
self,
collection_id: Optional[UUID] = None,
entity_ids: Optional[list[str]] = None,
entity_table_name: str = "entity",
offset: Optional[int] = None,
limit: Optional[int] = None,
**kwargs,
):
return await self.providers.database.graph_handler.get_entities(
collection_id=collection_id,
entity_ids=entity_ids,
entity_table_name=entity_table_name,
offset=offset or 0,
limit=limit or -1,
)

################### RELATIONSHIPS ###################

@telemetry_event("list_relationships_v3")
async def list_relationships_v3(
self,
id: UUID,
level: DataLevel,
parent_id: UUID,
offset: int,
limit: int,
entity_ids: Optional[list[UUID]] = None,
entity_names: Optional[list[str]] = None,
relationship_types: Optional[list[str]] = None,
attributes: Optional[list[str]] = None,
relationship_id: Optional[UUID] = None,
include_embeddings: bool = False,
):
return await self.providers.database.graph_handler.relationships.get(
id=id,
level=level,
entity_names=entity_names,
relationship_types=relationship_types,
attributes=attributes,
return await self.providers.database.graph_handler.get_entities(
parent_id=parent_id,
offset=offset,
limit=limit,
relationship_id=relationship_id,
entity_ids=entity_ids,
entity_names=entity_names,
include_embeddings=include_embeddings,
)

@telemetry_event("create_relationship")
Expand Down Expand Up @@ -301,13 +252,14 @@ async def create_relationship(
@telemetry_event("delete_relationship")
async def delete_relationship(
self,
id: UUID,
parent_id: UUID,
relationship_id: UUID,
):
return (
await self.providers.database.graph_handler.relationships.delete(
id=id,
relationship_id=relationship_id,
parent_id=parent_id,
relationship_ids=[relationship_id],
store_type="graph", # type: ignore
)
)

Expand Down Expand Up @@ -347,27 +299,24 @@ async def update_relationship(
)
)

# TODO: deprecate this
@telemetry_event("get_triples")
@telemetry_event("get_relationships")
async def get_relationships(
self,
parent_id: UUID,
offset: int,
limit: int,
collection_id: UUID,
entity_names: Optional[list[str]] = None,
relationship_ids: Optional[list[UUID]] = None,
entity_names: Optional[list[str]] = None,
):
return await self.providers.database.graph_handler.relationships.get(
parent_id=collection_id,
parent_id=parent_id,
store_type="graph", # type: ignore
entity_names=entity_names,
relationship_ids=relationship_ids,
offset=offset,
limit=limit,
relationship_ids=relationship_ids,
entity_names=entity_names,
)

################### COMMUNITIES ###################

@telemetry_event("create_community")
async def create_community(
self,
Expand Down Expand Up @@ -447,17 +396,19 @@ async def list_communities(
@telemetry_event("get_communities")
async def get_communities(
self,
collection_id: UUID,
community_ids: Optional[list[int]] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
**kwargs,
parent_id: UUID,
offset: int,
limit: int,
community_ids: Optional[list[UUID]] = None,
community_names: Optional[list[str]] = None,
include_embeddings: bool = False,
):
return await self.providers.database.graph_handler.get_communities(
collection_id=collection_id,
community_ids=community_ids,
parent_id=parent_id,
offset=offset,
limit=limit,
community_ids=community_ids,
include_embeddings=include_embeddings,
)

# @telemetry_event("create_new_graph")
Expand Down
18 changes: 7 additions & 11 deletions py/core/pipes/kg/deduplication.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
import logging
from typing import Any, Union
from typing import Any
from uuid import UUID

from fastapi import HTTPException

from core.base import AsyncState
from core.base.abstractions import DataLevel, Entity, KGEntityDeduplicationType
from core.base.pipes import AsyncPipe
Expand All @@ -26,14 +24,12 @@ def __init__(
self,
config: AsyncPipe.PipeConfig,
database_provider: PostgresDBProvider,
llm_provider: Union[
OpenAICompletionProvider, LiteLLMCompletionProvider
],
embedding_provider: Union[
LiteLLMEmbeddingProvider,
OpenAIEmbeddingProvider,
OllamaEmbeddingProvider,
],
llm_provider: OpenAICompletionProvider | LiteLLMCompletionProvider,
embedding_provider: (
LiteLLMEmbeddingProvider
| OpenAIEmbeddingProvider
| OllamaEmbeddingProvider
),
logging_provider: SqlitePersistentLoggingProvider,
**kwargs,
):
Expand Down
4 changes: 1 addition & 3 deletions py/core/pipes/retrieval/kg_search_pipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, AsyncGenerator, Optional
from typing import Any, AsyncGenerator
from uuid import UUID

from core.base import (
Expand All @@ -15,7 +15,6 @@
KGCommunityResult,
KGEntityResult,
KGRelationshipResult,
KGSearchMethod,
KGSearchResultType,
SearchSettings,
)
Expand Down Expand Up @@ -234,7 +233,6 @@ async def search(
rating_explanation=search_result["rating_explanation"],
findings=search_result["findings"],
),
# method=KGSearchMethod.LOCAL,
result_type=KGSearchResultType.COMMUNITY,
metadata=(
{
Expand Down
Loading

0 comments on commit a21003e

Please sign in to comment.