diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 8b4b1a5e5..3cc40c678 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -65,7 +65,6 @@ KGEntityResult, KGGlobalResult, KGRelationshipResult, - KGSearchMethod, KGSearchResultType, SearchSettings, ) @@ -131,7 +130,6 @@ # Search abstractions "AggregateSearchResult", "GraphSearchResult", - "KGSearchMethod", "KGSearchResultType", "KGEntityResult", "KGRelationshipResult", diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index a8f189d23..3ba13b1a5 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -601,13 +601,11 @@ async def get_entities( "The currently authenticated user does not have access to the specified graph.", 403, ) - # return await self.services["kg"].get_entities( - # id, offset, limit, auth_user - # ) - entities, count = ( - await self.providers.database.graph_handler.get_entities( - collection_id, offset, limit - ) + + entities, count = await self.services["kg"].get_entities( + parent_id=collection_id, + offset=offset, + limit=limit, ) return entities, { # type: ignore @@ -775,7 +773,11 @@ async def get_entity( ) result = await self.providers.database.graph_handler.entities.get( - collection_id, "graph", entity_ids=[entity_id] + parent_id=collection_id, + store_type="graph", + offset=0, + limit=1, + entity_ids=[entity_id], ) if len(result) == 0 or len(result[0]) == 0: raise R2RException("Entity not found", 404) @@ -888,9 +890,11 @@ async def delete_entity( 403, ) - await self.providers.database.graph_handler.entities.delete( - collection_id, [entity_id], "graph" + self.services["kg"].delete_entity( + parent_id=collection_id, + entity_id=entity_id, ) + return GenericBooleanResponse(success=True) # type: ignore @self.router.get( @@ -963,13 +967,10 @@ async def get_relationships( 403, ) - relationships, count = ( - await self.providers.database.graph_handler.relationships.get( - parent_id=collection_id, - store_type="graph", - offset=offset, - limit=limit, - ) + relationships, count = await self.services["kg"].get_relationships( + parent_id=collection_id, + offset=offset, + limit=limit, ) return relationships, { # type: ignore @@ -1042,7 +1043,11 @@ async def get_relationship( results = ( await self.providers.database.graph_handler.relationships.get( - collection_id, "graph", relationship_ids=[relationship_id] + parent_id=collection_id, + store_type="graph", + offset=0, + limit=1, + relationship_ids=[relationship_id], ) ) if len(results) == 0 or len(results[0]) == 0: @@ -1175,14 +1180,11 @@ async def delete_relationship( 403, ) - # return await self.services[ - # "kg" - # ].documents.graph_handler.relationships.remove_from_graph( - # id, relationship_id, auth_user - # ) - await self.providers.database.graph_handler.relationships.delete( - collection_id, [relationship_id], "graph" + await self.services["kg"].delete_relationship( + parent_id=collection_id, + relationship_id=relationship_id, ) + return GenericBooleanResponse(success=True) # type: ignore @self.router.post( @@ -1358,12 +1360,10 @@ async def get_communities( 403, ) - communities, count = ( - await self.providers.database.graph_handler.get_communities( - collection_id=collection_id, - offset=offset, - limit=limit, - ) + communities, count = await self.services["kg"].get_communities( + parent_id=collection_id, + offset=offset, + limit=limit, ) return communities, { # type: ignore @@ -1720,7 +1720,10 @@ async def pull( ) entities = ( await self.providers.database.graph_handler.entities.get( - document.id, store_type="document" + parent_id=document.id, + store_type="document", + offset=0, + limit=100, ) ) has_document = ( diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index b5c446f66..22c90a96f 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -157,29 +157,6 @@ async def create_entity( metadata=metadata, ) - @telemetry_event("list_entities") - async def list_entities( - self, - offset: int, - limit: int, - id: Optional[UUID] = None, - graph_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - include_embeddings: Optional[bool] = False, - user_id: Optional[UUID] = None, - ): - return await self.providers.database.graph_handler.entities.get( - id=id, - graph_id=graph_id, - document_id=document_id, - entity_names=entity_names, - include_embeddings=include_embeddings, - offset=offset, - limit=limit, - user_id=user_id, - ) - @telemetry_event("update_entity") async def update_entity( self, @@ -198,69 +175,43 @@ async def update_entity( return await self.providers.database.graph_handler.entities.update( entity_id=entity_id, + store_type="graph", # type: ignore name=name, description=description, - category=category, description_embedding=description_embedding, + category=category, metadata=metadata, - store_type="graph", # type: ignore ) @telemetry_event("delete_entity") async def delete_entity( self, - id: UUID, + parent_id: UUID, entity_id: UUID, - level: DataLevel, ): return await self.providers.database.graph_handler.entities.delete( - id=id, - entity_id=entity_id, - level=level, + parent_id=parent_id, + entity_ids=[entity_id], + store_type="graph", # type: ignore ) - # TODO: deprecate this @telemetry_event("get_entities") async def get_entities( self, - collection_id: Optional[UUID] = None, - entity_ids: Optional[list[str]] = None, - entity_table_name: str = "entity", - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, - ): - return await self.providers.database.graph_handler.get_entities( - collection_id=collection_id, - entity_ids=entity_ids, - entity_table_name=entity_table_name, - offset=offset or 0, - limit=limit or -1, - ) - - ################### RELATIONSHIPS ################### - - @telemetry_event("list_relationships_v3") - async def list_relationships_v3( - self, - id: UUID, - level: DataLevel, + parent_id: UUID, offset: int, limit: int, + entity_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, - relationship_types: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - relationship_id: Optional[UUID] = None, + include_embeddings: bool = False, ): - return await self.providers.database.graph_handler.relationships.get( - id=id, - level=level, - entity_names=entity_names, - relationship_types=relationship_types, - attributes=attributes, + return await self.providers.database.graph_handler.get_entities( + parent_id=parent_id, offset=offset, limit=limit, - relationship_id=relationship_id, + entity_ids=entity_ids, + entity_names=entity_names, + include_embeddings=include_embeddings, ) @telemetry_event("create_relationship") @@ -301,13 +252,14 @@ async def create_relationship( @telemetry_event("delete_relationship") async def delete_relationship( self, - id: UUID, + parent_id: UUID, relationship_id: UUID, ): return ( await self.providers.database.graph_handler.relationships.delete( - id=id, - relationship_id=relationship_id, + parent_id=parent_id, + relationship_ids=[relationship_id], + store_type="graph", # type: ignore ) ) @@ -347,27 +299,24 @@ async def update_relationship( ) ) - # TODO: deprecate this - @telemetry_event("get_triples") + @telemetry_event("get_relationships") async def get_relationships( self, + parent_id: UUID, offset: int, limit: int, - collection_id: UUID, - entity_names: Optional[list[str]] = None, relationship_ids: Optional[list[UUID]] = None, + entity_names: Optional[list[str]] = None, ): return await self.providers.database.graph_handler.relationships.get( - parent_id=collection_id, + parent_id=parent_id, store_type="graph", # type: ignore - entity_names=entity_names, - relationship_ids=relationship_ids, offset=offset, limit=limit, + relationship_ids=relationship_ids, + entity_names=entity_names, ) - ################### COMMUNITIES ################### - @telemetry_event("create_community") async def create_community( self, @@ -447,17 +396,19 @@ async def list_communities( @telemetry_event("get_communities") async def get_communities( self, - collection_id: UUID, - community_ids: Optional[list[int]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, + parent_id: UUID, + offset: int, + limit: int, + community_ids: Optional[list[UUID]] = None, + community_names: Optional[list[str]] = None, + include_embeddings: bool = False, ): return await self.providers.database.graph_handler.get_communities( - collection_id=collection_id, - community_ids=community_ids, + parent_id=parent_id, offset=offset, limit=limit, + community_ids=community_ids, + include_embeddings=include_embeddings, ) # @telemetry_event("create_new_graph") diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index 5838d1b7c..9ce3f62ba 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -1,10 +1,8 @@ import json import logging -from typing import Any, Union +from typing import Any from uuid import UUID -from fastapi import HTTPException - from core.base import AsyncState from core.base.abstractions import DataLevel, Entity, KGEntityDeduplicationType from core.base.pipes import AsyncPipe @@ -26,14 +24,12 @@ def __init__( self, config: AsyncPipe.PipeConfig, database_provider: PostgresDBProvider, - llm_provider: Union[ - OpenAICompletionProvider, LiteLLMCompletionProvider - ], - embedding_provider: Union[ - LiteLLMEmbeddingProvider, - OpenAIEmbeddingProvider, - OllamaEmbeddingProvider, - ], + llm_provider: OpenAICompletionProvider | LiteLLMCompletionProvider, + embedding_provider: ( + LiteLLMEmbeddingProvider + | OpenAIEmbeddingProvider + | OllamaEmbeddingProvider + ), logging_provider: SqlitePersistentLoggingProvider, **kwargs, ): diff --git a/py/core/pipes/retrieval/kg_search_pipe.py b/py/core/pipes/retrieval/kg_search_pipe.py index a8ba17177..ff6625351 100644 --- a/py/core/pipes/retrieval/kg_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_pipe.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, AsyncGenerator, Optional +from typing import Any, AsyncGenerator from uuid import UUID from core.base import ( @@ -15,7 +15,6 @@ KGCommunityResult, KGEntityResult, KGRelationshipResult, - KGSearchMethod, KGSearchResultType, SearchSettings, ) @@ -234,7 +233,6 @@ async def search( rating_explanation=search_result["rating_explanation"], findings=search_result["findings"], ), - # method=KGSearchMethod.LOCAL, result_type=KGSearchResultType.COMMUNITY, metadata=( { diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index be559217a..c1fa9bac4 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -14,7 +14,6 @@ from core.base.abstractions import ( Community, CommunityInfo, - DataLevel, Entity, Graph, KGCreationSettings, @@ -178,8 +177,8 @@ async def get( self, parent_id: UUID, store_type: StoreType, - offset: int = 0, - limit: int = 100, + offset: int, + limit: int, entity_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, include_embeddings: bool = False, @@ -268,7 +267,7 @@ async def update( """Update an entity in the specified store.""" table_name = self._get_entity_table_for_store(store_type) update_fields = [] - params: list = [] + params: list[Any] = [] param_index = 1 if isinstance(metadata, str): @@ -340,7 +339,7 @@ async def delete( parent_id: UUID, entity_ids: Optional[list[UUID]] = None, store_type: StoreType = StoreType.GRAPH, - ) -> list[UUID]: + ) -> None: """ Delete entities from the specified store. If entity_ids is not provided, deletes all entities for the given parent_id. @@ -387,8 +386,6 @@ async def delete( 404, ) - return [row["id"] for row in results] - class PostgresRelationshipHandler(RelationshipHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -538,8 +535,8 @@ async def get( self, parent_id: UUID, store_type: StoreType, - offset: int = 0, - limit: int = 100, + offset: int, + limit: int, relationship_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, @@ -747,7 +744,7 @@ async def delete( parent_id: UUID, relationship_ids: Optional[list[UUID]] = None, store_type: StoreType = StoreType.GRAPH, - ) -> list[UUID]: + ) -> None: """ Delete relationships from the specified store. If relationship_ids is not provided, deletes all relationships for the given parent_id. @@ -791,8 +788,6 @@ async def delete( 404, ) - return [row["id"] for row in results] - class PostgresCommunityHandler(CommunityHandler): @@ -2060,7 +2055,7 @@ async def get_deduplication_estimate( async def get_entities( self, - graph_id: UUID, + parent_id: UUID, offset: int, limit: int, entity_ids: Optional[list[UUID]] = None, @@ -2073,7 +2068,7 @@ async def get_entities( Args: offset: Number of records to skip limit: Maximum number of records to return (-1 for no limit) - graph_id: UUID of the graph + parent_id: UUID of the collection entity_ids: Optional list of entity IDs to filter by entity_names: Optional list of entity names to filter by include_embeddings: Whether to include embeddings in the response @@ -2082,7 +2077,7 @@ async def get_entities( Tuple of (list of entities, total count) """ conditions = ["parent_id = $1"] - params: [Any] = [graph_id] + params: list[Any] = [parent_id] param_index = 2 if entity_ids: @@ -2320,8 +2315,8 @@ async def get( self, parent_id: UUID, store_type: StoreType, - offset: int = 0, - limit: int = 100, + offset: int, + limit: int, entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, relationship_id: Optional[UUID] = None, @@ -2527,7 +2522,7 @@ async def add_community_info( async def get_communities( self, - collection_id: UUID, + parent_id: UUID, offset: int, limit: int, community_ids: Optional[list[UUID]] = None, @@ -2547,7 +2542,7 @@ async def get_communities( Tuple of (list of communities, total count) """ conditions = ["collection_id = $1"] - params: list[Any] = [collection_id] + params: list[Any] = [parent_id] param_index = 2 if community_ids: diff --git a/py/sdk/models.py b/py/sdk/models.py index 987a8479d..2891c4216 100644 --- a/py/sdk/models.py +++ b/py/sdk/models.py @@ -12,7 +12,6 @@ KGGlobalResult, KGRelationshipResult, KGRunType, - KGSearchMethod, KGSearchResultType, Message, MessageType, @@ -38,7 +37,6 @@ "KGGlobalResult", "KGRelationshipResult", "KGRunType", - "KGSearchMethod", "GraphSearchResult", "KGSearchResultType", "GraphSearchSettings", diff --git a/py/sdk/v3/conversations.py b/py/sdk/v3/conversations.py index bc6b0aad6..4fa3887a8 100644 --- a/py/sdk/v3/conversations.py +++ b/py/sdk/v3/conversations.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from uuid import UUID from shared.api.models.base import WrappedBooleanResponse @@ -125,7 +125,7 @@ async def add_message( Returns: dict: Result of the operation, including the new message ID """ - data = { + data: dict[str, Any] = { "content": content, "role": role, } diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index b7954180f..ef819502e 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from uuid import UUID from shared.api.models.base import ( @@ -387,7 +387,7 @@ async def update_community( Returns: dict: Updated community information """ - data = {} + data: dict[str, Any] = {} if name is not None: data["name"] = name if summary is not None: diff --git a/py/shared/abstractions/__init__.py b/py/shared/abstractions/__init__.py index 8c55df1b8..8e70a0008 100644 --- a/py/shared/abstractions/__init__.py +++ b/py/shared/abstractions/__init__.py @@ -50,7 +50,6 @@ KGEntityResult, KGGlobalResult, KGRelationshipResult, - KGSearchMethod, KGSearchResultType, SearchSettings, ) @@ -110,7 +109,6 @@ # Search abstractions "AggregateSearchResult", "GraphSearchResult", - "KGSearchMethod", "KGSearchResultType", "KGEntityResult", "KGRelationshipResult", diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index 991588282..fd3502691 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -62,10 +62,6 @@ class KGSearchResultType(str, Enum): COMMUNITY = "community" -class KGSearchMethod(str, Enum): - LOCAL = "local" - - class KGEntityResult(R2RSerializable): name: str description: str