From 38c642d543b305abd843357f7ca423176d5ca04c Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:21:38 -0800 Subject: [PATCH 01/28] SDK First pass --- js/sdk/src/types.ts | 43 +++ js/sdk/src/v3/clients/documents.ts | 19 ++ js/sdk/src/v3/clients/graphs.ts | 231 ++++++++++++- py/core/providers/database/graph.py | 1 + py/sdk/v3/collections.py | 4 +- py/sdk/v3/documents.py | 20 ++ py/sdk/v3/graphs.py | 511 ++++++++++------------------ 7 files changed, 494 insertions(+), 335 deletions(-) diff --git a/js/sdk/src/types.ts b/js/sdk/src/types.ts index 75c2707b3..b9ae60f32 100644 --- a/js/sdk/src/types.ts +++ b/js/sdk/src/types.ts @@ -47,6 +47,10 @@ export interface CollectionResponse { document_count: number; } +//TODO: Sync this with the finished API response model +// Community types +export interface CommunityResponse {} + // Conversation types export interface ConversationResponse { id: string; @@ -89,6 +93,25 @@ export interface DocumentResponse { summary_embedding?: string; } +// Entity types +export interface EntityResponse { + id: string; + sid?: string; + name: string; + category?: string; + description?: string; + chunk_ids: string[]; + description_embedding?: string; + document_id: string; + document_ids: string[]; + graph_ids: string[]; + user_id: string; + last_modified_by: string; + created_at: string; + updated_at: string; + attributes?: Record; +} + // Graph types export interface GraphResponse { id: string; @@ -134,6 +157,10 @@ export interface PromptResponse { input_types: string[]; } +//TODO: Sync this with the finished API response model +// Relationship types +export interface RelationshipResponse {} + // Retrieval types export interface VectorSearchResult { chunk_id: string; @@ -218,6 +245,12 @@ export type WrappedCollectionsResponse = PaginatedResultsWrapper< CollectionResponse[] >; +// Community Responses +export type WrappedCommunityResponse = ResultsWrapper; +export type WrappedCommunitiesResponse = PaginatedResultsWrapper< + CommunityResponse[] +>; + // Conversation Responses export type WrappedConversationMessagesResponse = ResultsWrapper< MessageResponse[] @@ -240,6 +273,10 @@ export type WrappedDocumentsResponse = PaginatedResultsWrapper< DocumentResponse[] >; +// Entity Responses +export type WrappedEntityResponse = ResultsWrapper; +export type WrappedEntitiesResponse = PaginatedResultsWrapper; + // Graph Responses export type WrappedGraphResponse = ResultsWrapper; export type WrappedGraphsResponse = PaginatedResultsWrapper; @@ -254,6 +291,12 @@ export type WrappedListVectorIndicesResponse = ResultsWrapper; export type WrappedPromptResponse = ResultsWrapper; export type WrappedPromptsResponse = PaginatedResultsWrapper; +// Relationship Responses +export type WrappedRelationshipResponse = ResultsWrapper; +export type WrappedRelationshipsResponse = PaginatedResultsWrapper< + RelationshipResponse[] +>; + // Retrieval Responses export type WrappedVectorSearchResponse = ResultsWrapper; export type WrappedSearchResponse = ResultsWrapper; diff --git a/js/sdk/src/v3/clients/documents.ts b/js/sdk/src/v3/clients/documents.ts index 8b328b8ca..0a1922092 100644 --- a/js/sdk/src/v3/clients/documents.ts +++ b/js/sdk/src/v3/clients/documents.ts @@ -344,4 +344,23 @@ export class DocumentsClient { data: options.filters, }); } + + async extract(options: { + id: string; + runType?: string; + runWithOrchestration?: boolean; + }): Promise { + const data: Record = {}; + + if (options.runType) { + data.runType = options.runType; + } + if (options.runWithOrchestration !== undefined) { + data.runWithOrchestration = options.runWithOrchestration; + } + + return this.client.makeRequest("POST", `documents/${options.id}/extract`, { + data, + }); + } } diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index 2daeba3bb..aa7128715 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -3,6 +3,13 @@ import { WrappedGraphResponse, WrappedBooleanResponse, WrappedGraphsResponse, + WrappedGenericMessageResponse, + WrappedEntityResponse, + WrappedEntitiesResponse, + WrappedRelationshipsResponse, + WrappedRelationshipResponse, + WrappedCommunitiesResponse, + WrappedCommunityResponse, } from "../../types"; export class GraphsClient { @@ -49,6 +56,11 @@ export class GraphsClient { }); } + /** + * Get detailed information about a specific graph. + * @param id Graph ID to retrieve + * @returns + */ async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `graphs/${options.id}`); } @@ -76,11 +88,228 @@ export class GraphsClient { } /** - * + * Delete a graph. * @param options * @returns */ async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `graphs/${options.id}`); } + + /** + * FIXME: Should this be `addEntity` or `createEntity`? + * Add an entity to a graph. + * @param id Graph ID + * @param entityId Entity ID to add + * @returns + */ + async addEntity(options: { + id: string; + entityId: string; + }): Promise { + return this.client.makeRequest( + "POST", + `graphs/${options.id}/entities/${options.entityId}`, + ); + } + + /** + * Remove an entity from a graph. + * @param id Graph ID + * @param entityId Entity ID to remove + * @returns + */ + async removeEntity(options: { + id: string; + entityId: string; + }): Promise { + return this.client.makeRequest( + "DELETE", + `graphs/${options.id}/entities/${options.entityId}`, + ); + } + + /** + * List all entities in a graph. + * @param id Graph ID + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @returns + */ + async listEntities(options: { + id: string; + offset?: number; + limit?: number; + }): Promise { + const params: Record = { + offset: options?.offset ?? 0, + limit: options?.limit ?? 100, + }; + + return this.client.makeRequest("GET", `graphs/${options.id}/entities`, { + params, + }); + } + + /** + * Retrieve an entity from a graph. + * @param id Graph ID + * @param entityId Entity ID to retrieve + * @returns + */ + async getEntity(options: { + id: string; + entityId: string; + }): Promise { + return this.client.makeRequest( + "GET", + `graphs/${options.id}/entities/${options.entityId}`, + ); + } + + /** + * FIXME: Should this be `addRelationship` or `createRelationship`? + * Add a relationship to a graph. + * @param id Graph ID + * @param relationshipId Relationship ID to add + * @returns + */ + async addRelationship(options: { + id: string; + relationshipId: string; + }): Promise { + return this.client.makeRequest( + "POST", + `graphs/${options.id}/relationships/${options.relationshipId}`, + ); + } + + /** + * Remove a relationship from a graph. + * @param id Graph ID + * @param relationshipId Relationship ID to remove + * @returns + */ + async removeRelationship(options: { + id: string; + relationshipId: string; + }): Promise { + return this.client.makeRequest( + "DELETE", + `graphs/${options.id}/relationships/${options.relationshipId}`, + ); + } + + /** + * List all relationships in a graph. + * @param id Graph ID + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @returns + */ + async listRelationships(options: { + id: string; + offset?: number; + limit?: number; + }): Promise { + const params: Record = { + offset: options?.offset ?? 0, + limit: options?.limit ?? 100, + }; + + return this.client.makeRequest( + "GET", + `graphs/${options.id}/relationships`, + { + params, + }, + ); + } + + /** + * Retrieve a relationship from a graph. + * @param id Graph ID + * @param relationshipId Relationship ID to retrieve + * @returns + */ + async getRelationship(options: { + id: string; + relationshipId: string; + }): Promise { + return this.client.makeRequest( + "GET", + `graphs/${options.id}/relationships/${options.relationshipId}`, + ); + } + + /** + * FIXME: Should this be `addCommunity` or `createCommunity`? + * Add a community to a graph. + * @param id Graph ID + * @param communityId Community ID to add + * @returns + */ + async addCommunity(options: { + id: string; + communityId: string; + }): Promise { + return this.client.makeRequest( + "POST", + `graphs/${options.id}/communities/${options.communityId}`, + ); + } + + /** + * Remove a community from a graph. + * @param id Graph ID + * @param communityId Community ID to remove + * @returns + */ + async removeCommunity(options: { + id: string; + communityId: string; + }): Promise { + return this.client.makeRequest( + "DELETE", + `graphs/${options.id}/communities/${options.communityId}`, + ); + } + + /** + * List all communities in a graph. + * @param id Graph ID + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @returns + */ + async listCommunities(options: { + id: string; + offset?: number; + limit?: number; + }): Promise { + const params: Record = { + offset: options?.offset ?? 0, + limit: options?.limit ?? 100, + }; + + return this.client.makeRequest("GET", `graphs/${options.id}/communities`, { + params, + }); + } + + /** + * Retrieve a community from a graph. + * @param id Graph ID + * @param communityId Community ID to retrieve + * @returns + */ + async getCommunity(options: { + id: string; + communityId: string; + }): Promise { + return this.client.makeRequest( + "GET", + `graphs/${options.id}/communities/${options.communityId}`, + ); + } } diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 13cdcc179..958afdb80 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -1120,6 +1120,7 @@ async def delete( # return [row["id"] for row in results] + class PostgresCommunityHandler(CommunityHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/py/sdk/v3/collections.py b/py/sdk/v3/collections.py index 2b4b05e29..ef2ac9fd1 100644 --- a/py/sdk/v3/collections.py +++ b/py/sdk/v3/collections.py @@ -36,7 +36,7 @@ async def create( return await self.client._make_request( "POST", "collections", - json=data, # {"config": data} + json=data, version="v3", ) @@ -111,7 +111,7 @@ async def update( return await self.client._make_request( "POST", f"collections/{str(id)}", - json=data, # {"config": data} + json=data, version="v3", ) diff --git a/py/sdk/v3/documents.py b/py/sdk/v3/documents.py index 6b634a2ab..7d1bcbed2 100644 --- a/py/sdk/v3/documents.py +++ b/py/sdk/v3/documents.py @@ -326,3 +326,23 @@ async def delete_by_filter( params={"filters": filters_json}, version="v3", ) + + async def extract( + self, + id: str | UUID, + run_type: Optional[str] = "estimate", + run_with_orchestration: Optional[bool] = True, + ): + data = {} + + if run_type: + data["run_type"] = run_type + if run_with_orchestration is not None: + data["run_with_orchestration"] = str(run_with_orchestration) + + return await self.client._make_request( + "POST", + f"documents/{str(id)}/extract", + data=data, + version="v3", + ) diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index 99905029c..44847bfce 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -1,9 +1,21 @@ from typing import Optional from uuid import UUID -from core.base.abstractions import DataLevel, KGRunType - -from ..models import KGCreationSettings, KGRunType +from shared.api.models.base import ( + WrappedBooleanResponse, + WrappedGenericMessageResponse, +) + +from shared.api.models.kg.responses import ( + WrappedGraphResponse, + WrappedGraphsResponse, + WrappedEntitiesResponse, + WrappedEntityResponse, + WrappedRelationshipResponse, + WrappedRelationshipsResponse, + # WrappedCommunitiesResponse, + # WrappedCommunityResponse, +) class GraphsSDK: @@ -16,557 +28,392 @@ def __init__(self, client): async def create( self, - collection_id: str | UUID, - run_type: Optional[str | KGRunType] = None, - settings: Optional[dict | KGCreationSettings] = None, - run_with_orchestration: Optional[bool] = True, - ): + name: str, + description: Optional[str] = None, + ) -> WrappedGraphResponse: """ - Create a new knowledge graph for a collection. + Create a new graph. Args: - collection_id (str | UUID): Collection ID to create graph for - settings (Optional[dict]): Graph creation settings - run_with_orchestration (Optional[bool]): Whether to run with task orchestration + name (str): Name of the graph + description (Optional[str]): Description of the graph Returns: - WrappedKGCreationResponse: Creation results + dict: Created graph information """ - if isinstance(settings, KGCreationSettings): - settings = settings.model_dump() - - data = { - # "collection_id": str(collection_id) if collection_id else None, - "run_type": str(run_type) if run_type else None, - "settings": settings or {}, - "run_with_orchestration": run_with_orchestration or True, - } - - return await self.client._make_request("POST", f"graphs/{collection_id}", json=data) # type: ignore + data = {"name": name, "description": description} + return await self.client._make_request( + "POST", + "graphs", + json=data, + version="v3", + ) - async def get_status(self, collection_id: str | UUID) -> dict: + async def list( + self, + ids: Optional[list[str | UUID]] = None, + offset: Optional[int] = 0, + limit: Optional[int] = 100, + ) -> WrappedGraphsResponse: """ - Get the status of a graph. + List graphs with pagination and filtering options. Args: - collection_id (str | UUID): Collection ID to get graph status for + ids (Optional[list[str | UUID]]): Filter graphs by ids + offset (int, optional): Specifies the number of objects to skip. Defaults to 0. + limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: - dict: Graph status information + dict: List of graphs and pagination information """ + params: dict = { + "offset": offset, + "limit": limit, + } + if ids: + params["ids"] = ids + return await self.client._make_request( - "GET", f"graphs/{str(collection_id)}" + "GET", "graphs", params=params, version="v3" ) - async def delete( + async def retrieve( self, - collection_id: str | UUID, - cascade: bool = False, - ) -> dict: + id: str | UUID, + ) -> WrappedGraphResponse: """ - Delete a graph. + Get detailed information about a specific graph. Args: - collection_id (str | UUID): Collection ID of graph to delete - cascade (bool): Whether to delete associated entities and relationships + id (str | UUID): Graph ID to retrieve Returns: - dict: Deletion confirmation + dict: Detailed graph information """ - params = {"cascade": cascade} return await self.client._make_request( - "DELETE", f"graphs/{str(collection_id)}", params=params + "GET", f"graphs/{str(id)}", version="v3" ) - # Entity operations - async def create_entity( + async def update( self, - collection_id: str | UUID, - entity: dict, - ) -> dict: + id: str | UUID, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> WrappedGraphResponse: """ - Create a new entity in the graph. + Update graph information. Args: - collection_id (str | UUID): Collection ID to create entity in - entity (dict): Entity data including name, type, and metadata + id (str | UUID): Graph ID to update + name (Optional[str]): Optional new name for the graph + description (Optional[str]): Optional new description for the graph Returns: - dict: Created entity information + dict: Updated graph information """ + data = {} + if name is not None: + data["name"] = name + if description is not None: + data["description"] = description + return await self.client._make_request( "POST", - f"graphs/{str(collection_id)}/entities", - json=entity, + f"graphs/{str(id)}", + json=data, version="v3", ) - async def get_entity( + async def delete( self, - collection_id: str | UUID, - entity_id: str | int, - include_embeddings: bool = False, - ) -> dict: + id: str | UUID, + ) -> WrappedBooleanResponse: """ - Get details of a specific entity. + Delete a graph. Args: - collection_id (str | UUID): Collection ID containing the entity - entity_id (str | UUID): Entity ID to retrieve - include_embeddings (bool): Whether to include vector embeddings + id (str | UUID): Graph ID to delete Returns: - dict: Entity details + bool: True if deletion was successful """ - params = {"include_embeddings": include_embeddings} - return await self.client._make_request( - "GET", - f"graphs/{str(collection_id)}/entities/{str(entity_id)}", - params=params, - version="v3", + result = await self.client._make_request( + "DELETE", f"graphs/{str(id)}", version="v3" ) + return result.get("results", True) - async def update_entity( + async def add_entity( self, - collection_id: str | UUID, + id: str | UUID, entity_id: str | UUID, - entity_update: dict, - ) -> dict: + ) -> WrappedGenericMessageResponse: """ - Update an existing entity. + Add an entity to a graph. Args: - collection_id (str | UUID): Collection ID containing the entity - entity_id (str | UUID): Entity ID to update - entity_update (dict): Updated entity data + id (str | UUID): Graph ID to add entity to + entity_id (str | UUID): Entity ID to add to the graph Returns: - dict: Updated entity information + dict: Success message """ return await self.client._make_request( "POST", - f"graphs/{str(collection_id)}/entities/{str(entity_id)}", - json=entity_update, + f"graphs/{str(id)}/entities/{str(entity_id)}", version="v3", ) - async def delete_entity( + async def remove_entity( self, - collection_id: str | UUID, + id: str | UUID, entity_id: str | UUID, - cascade: bool = False, - ) -> dict: + ) -> WrappedBooleanResponse: """ - Delete an entity. + Remove an entity from a graph. Args: - collection_id (str | UUID): Collection ID containing the entity - entity_id (str | UUID): Entity ID to delete - cascade (bool): Whether to delete related relationships + id (str | UUID): Graph ID to remove entity from + entity_id (str | UUID): Entity ID to remove from the graph Returns: - dict: Deletion confirmation + dict: Success message """ - params = {"cascade": cascade} return await self.client._make_request( "DELETE", - f"graphs/{str(collection_id)}/entities/{str(entity_id)}", - params=params, + f"graphs/{str(id)}/entities/{str(entity_id)}", version="v3", ) async def list_entities( self, - collection_id: str | UUID, - level=DataLevel.DOCUMENT, - include_embeddings: bool = False, + id: str | UUID, offset: Optional[int] = 0, limit: Optional[int] = 100, - ) -> dict: + ) -> WrappedEntitiesResponse: """ - List entities in the graph. + List entities in a graph. Args: - collection_id (str | UUID): Collection ID to list entities from - level (DataLevel): Entity level filter - include_embeddings (bool): Whether to include vector embeddings + id (str | UUID): Graph ID to list entities from offset (int, optional): Specifies the number of objects to skip. Defaults to 0. limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: dict: List of entities and pagination information """ - params = { - "level": level, + params: dict = { "offset": offset, "limit": limit, - "include_embeddings": include_embeddings, } - return await self.client._make_request( - "GET", - f"graphs/{str(collection_id)}/entities", - params=params, - version="v3", - ) - - async def deduplicate_entities( - self, - collection_id: str | UUID, - settings: Optional[dict] = None, - run_type: str = "ESTIMATE", - run_with_orchestration: bool = True, - ): - """ - Deduplicate entities in the graph. - - Args: - collection_id (str | UUID): Collection ID to deduplicate entities in - settings (Optional[dict]): Deduplication settings - run_type (str): Whether to estimate cost or run deduplication - run_with_orchestration (bool): Whether to run with task orchestration - - Returns: - WrappedKGEntityDeduplicationResponse: Deduplication results or cost estimate - """ - params = { - "run_type": run_type, - "run_with_orchestration": run_with_orchestration, - } - data = {} - if settings: - data["settings"] = settings return await self.client._make_request( - "POST", - f"graphs/{str(collection_id)}/entities/deduplicate", - json=data, + "GET", + f"graphs/{str(id)}/entities", params=params, version="v3", ) - # Relationship operations - async def create_relationship( - self, collection_id: str | UUID, relationship: dict - ) -> dict: - """ - Create a new relationship between entities. - - Args: - collection_id (str | UUID): Collection ID to create relationship in - relationship (dict): Relationship data including source, target, and type - - Returns: - dict: Created relationship information - """ - return await self.client._make_request( - "POST", - f"graphs/{str(collection_id)}/relationships", - json=relationship, - version="v3", - ) - - async def get_relationship( + async def get_entity( self, - collection_id: str | UUID, - relationship_id: str | UUID, - ) -> dict: + id: str | UUID, + entity_id: str | UUID, + ) -> WrappedEntityResponse: """ - Get details of a specific relationship. + Get entity information in a graph. Args: - collection_id (str | UUID): Collection ID containing the relationship - relationship_id (str | UUID): Relationship ID to retrieve + id (str | UUID): Graph ID to get entity from + entity_id (str | UUID): Entity ID to get from the graph Returns: - dict: Relationship details + dict: Entity information """ return await self.client._make_request( "GET", - f"graphs/{str(collection_id)}/relationships/{str(relationship_id)}", + f"graphs/{str(id)}/entities/{str(entity_id)}", version="v3", ) - async def update_relationship( + async def add_relationship( self, - collection_id: str | UUID, + id: str | UUID, relationship_id: str | UUID, - relationship_update: dict, - ) -> dict: + ) -> WrappedGenericMessageResponse: """ - Update an existing relationship. + Add a relationship to a graph. Args: - collection_id (str | UUID): Collection ID containing the relationship - relationship_id (str | UUID): Relationship ID to update - relationship_update (dict): Updated relationship data + id (str | UUID): Graph ID to add relationship to + relationship_id (str | UUID): Relationship ID to add to the graph Returns: - dict: Updated relationship information + dict: Success message """ return await self.client._make_request( "POST", - f"graphs/{str(collection_id)}/relationships/{str(relationship_id)}", - json=relationship_update, + f"graphs/{str(id)}/relationships/{str(relationship_id)}", version="v3", ) - async def delete_relationship( + async def remove_relationship( self, - collection_id: str | UUID, + id: str | UUID, relationship_id: str | UUID, - ) -> dict: + ) -> WrappedBooleanResponse: """ - Delete a relationship. + Remove a relationship from a graph. Args: - collection_id (str | UUID): Collection ID containing the relationship - relationship_id (str | UUID): Relationship ID to delete + id (str | UUID): Graph ID to remove relationship from + relationship_id (str | UUID): Relationship ID to remove from the graph Returns: - dict: Deletion confirmation + dict: Success message """ return await self.client._make_request( "DELETE", - f"graphs/{str(collection_id)}/relationships/{str(relationship_id)}", + f"graphs/{str(id)}/relationships/{str(relationship_id)}", version="v3", ) async def list_relationships( self, - collection_id: str | UUID, - source_id: Optional[str | UUID] = None, - target_id: Optional[str | UUID] = None, - relationship_type: Optional[str] = None, + id: str | UUID, offset: Optional[int] = 0, limit: Optional[int] = 100, - ) -> dict: + ) -> WrappedRelationshipsResponse: """ - List relationships in the graph. + List relationships in a graph. Args: - collection_id (str | UUID): Collection ID to list relationships from - source_id (Optional[str | UUID]): Filter by source entity - target_id (Optional[str | UUID]): Filter by target entity - relationship_type (Optional[str]): Filter by relationship type + id (str | UUID): Graph ID to list relationships from offset (int, optional): Specifies the number of objects to skip. Defaults to 0. limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: dict: List of relationships and pagination information """ - params = { + params: dict = { "offset": offset, "limit": limit, } - if source_id: - params["source_id"] = str(source_id) - if target_id: - params["target_id"] = str(target_id) - if relationship_type: - params["relationship_type"] = relationship_type return await self.client._make_request( "GET", - f"graphs/{str(collection_id)}/relationships", + f"graphs/{str(id)}/relationships", params=params, version="v3", ) - # Community operations - async def create_communities( + async def get_relationship( self, - collection_id: str | UUID, - run_type: Optional[str | KGRunType] = None, - settings: Optional[dict] = None, - run_with_orchestration: bool = True, - ): # -> WrappedKGCommunitiesResponse: + id: str | UUID, + relationship_id: str | UUID, + ) -> WrappedRelationshipResponse: """ - Create communities in the graph. + Get relationship information in a graph. Args: - collection_id (str | UUID): Collection ID to create communities in - settings (Optional[dict]): Community detection settings - run_with_orchestration (bool): Whether to run with task orchestration + id (str | UUID): Graph ID to get relationship from + relationship_id (str | UUID): Relationship ID to get from the graph Returns: - WrappedKGCommunitiesResponse: Community creation results + dict: Relationship information """ - params = {"run_with_orchestration": run_with_orchestration} - data = {} - if settings: - data["settings"] = settings - - if run_type: - data["run_type"] = str(run_type) - return await self.client._make_request( - "POST", - f"graphs/{str(collection_id)}/communities", - json=data, - params=params, + "GET", + f"graphs/{str(id)}/relationships/{str(relationship_id)}", version="v3", ) - async def get_community( + async def add_community( self, - collection_id: str | UUID, + id: str | UUID, community_id: str | UUID, - ) -> dict: + ) -> WrappedGenericMessageResponse: """ - Get details of a specific community. + Add a community to a graph. Args: - collection_id (str | UUID): Collection ID containing the community - community_id (str | UUID): Community ID to retrieve + id (str | UUID): Graph ID to add community to + community_id (str | UUID): Community ID to add to the graph Returns: - dict: Community details + dict: Success message """ return await self.client._make_request( - "GET", - f"graphs/{str(collection_id)}/communities/{str(community_id)}", + "POST", + f"graphs/{str(id)}/communities/{str(community_id)}", version="v3", ) - async def update_community( + async def remove_community( self, - collection_id: str | UUID, + id: str | UUID, community_id: str | UUID, - community_update: dict, - ) -> dict: + ) -> WrappedBooleanResponse: """ - Update a community. + Remove a community from a graph. Args: - collection_id (str | UUID): Collection ID containing the community - community_id (str | UUID): Community ID to update - community_update (dict): Updated community data + id (str | UUID): Graph ID to remove community from + community_id (str | UUID): Community ID to remove from the graph Returns: - dict: Updated community information + dict: Success message """ return await self.client._make_request( - "POST", - f"graphs/{str(collection_id)}/communities/{str(community_id)}", - json=community_update, + "DELETE", + f"graphs/{str(id)}/communities/{str(community_id)}", version="v3", ) async def list_communities( self, - collection_id: str | UUID, - level: Optional[int] = None, + id: str | UUID, offset: Optional[int] = 0, limit: Optional[int] = 100, - ) -> dict: + ): # -> WrappedCommunitiesResponse """ - List communities in the graph. + List communities in a graph. Args: - collection_id (str | UUID): Collection ID to list communities from - level (Optional[int]): Filter by community level + id (str | UUID): Graph ID to list communities from offset (int, optional): Specifies the number of objects to skip. Defaults to 0. limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: dict: List of communities and pagination information """ - params = { + params: dict = { "offset": offset, "limit": limit, } - if level is not None: - params["level"] = level return await self.client._make_request( "GET", - f"graphs/{str(collection_id)}/communities", + f"graphs/{str(id)}/communities", params=params, version="v3", ) - async def delete_community( + async def get_community( self, - collection_id: str | UUID, + id: str | UUID, community_id: str | UUID, - ) -> dict: - """ - Delete a specific community. - - Args: - collection_id (str | UUID): Collection ID containing the community - community_id (str | UUID): Community ID to delete - - Returns: - dict: Deletion confirmation - """ - return await self.client._make_request( - "DELETE", - f"graphs/{str(collection_id)}/communities/{str(community_id)}", - version="v3", - ) - - async def delete_communities( - self, - collection_id: str | UUID, - level: Optional[int] = None, - ) -> dict: + ): # -> WrappedCommunityResponse """ - Delete communities from the graph. + Get community information in a graph. Args: - collection_id (str | UUID): Collection ID to delete communities from - level (Optional[int]): Specific level to delete, or None for all levels + id (str | UUID): Graph ID to get community from + community_id (str | UUID): Community ID to get from the graph Returns: - dict: Deletion confirmation + dict: Community information """ - params = {} - if level is not None: - params["level"] = level - return await self.client._make_request( - "DELETE", - f"graphs/{str(collection_id)}/communities", - params=params, - version="v3", - ) - - async def tune_prompt( - self, - collection_id: str | UUID, - prompt_name: str, - documents_offset: Optional[int] = 0, - documents_limit: Optional[int] = 100, - chunks_offset: Optional[int] = 0, - chunks_limit: Optional[int] = 100, - ): # -> WrappedKGTunePromptResponse: - """ - Tune a graph-related prompt using collection data. - - Args: - collection_id (Union[str, UUID]): Collection ID to tune prompt for - prompt_name (str): Name of prompt to tune (graphrag_relationships_extraction_few_shot, - graphrag_entity_description, or graphrag_communities) - documents_offset (int): Document pagination offset - documents_limit (int): Maximum number of documents to use - chunks_offset (int): Chunk pagination offset - chunks_limit (int): Maximum number of chunks to use - - Returns: - WrappedKGTunePromptResponse: Tuned prompt results - """ - data = { - "prompt_name": prompt_name, - "documents_offset": documents_offset, - "documents_limit": documents_limit, - "chunks_offset": chunks_offset, - "chunks_limit": chunks_limit, - } - - return await self.client._make_request( - "POST", - f"graphs/{str(collection_id)}/tune-prompt", - json=data, + "GET", + f"graphs/{str(id)}/communities/{str(community_id)}", version="v3", ) From 10d4bdd96a4213570e508de3d2e6b1c718d4f2b9 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 22 Nov 2024 08:59:07 -0800 Subject: [PATCH 02/28] Add feature tracking --- js/sdk/src/v3/clients/chunks.ts | 6 ++ js/sdk/src/v3/clients/collections.ts | 14 +++- js/sdk/src/v3/clients/conversations.ts | 8 ++ js/sdk/src/v3/clients/documents.ts | 11 +++ js/sdk/src/v3/clients/graphs.ts | 18 +++++ js/sdk/src/v3/clients/indices.ts | 5 ++ js/sdk/src/v3/clients/prompts.ts | 45 +++++++++++ js/sdk/src/v3/clients/retrieval.ts | 102 ++++++++++++++++++++++++- js/sdk/src/v3/clients/system.ts | 8 +- js/sdk/src/v3/clients/users.ts | 25 +++++- 10 files changed, 237 insertions(+), 5 deletions(-) diff --git a/js/sdk/src/v3/clients/chunks.ts b/js/sdk/src/v3/clients/chunks.ts index 86593be7e..a2ce30ada 100644 --- a/js/sdk/src/v3/clients/chunks.ts +++ b/js/sdk/src/v3/clients/chunks.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { UnprocessedChunk, @@ -20,6 +21,7 @@ export class ChunksClient { * @param runWithOrchestration Optional flag to run with orchestration * @returns */ + @feature("chunks.create") async create(options: { chunks: UnprocessedChunk[]; runWithOrchestration?: boolean; @@ -39,6 +41,7 @@ export class ChunksClient { * @param metadata Optional new metadata for the chunk * @returns */ + @feature("chunks.update") async update(options: { id: string; text?: string; @@ -54,6 +57,7 @@ export class ChunksClient { * @param id ID of the chunk to retrieve * @returns */ + @feature("chunks.retrieve") async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `chunks/${options.id}`); } @@ -63,6 +67,7 @@ export class ChunksClient { * @param id ID of the chunk to delete * @returns */ + @feature("chunks.delete") async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `chunks/${options.id}`); } @@ -75,6 +80,7 @@ export class ChunksClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("chunks.list") async list(options?: { includeVectors?: boolean; metadataFilters?: Record; diff --git a/js/sdk/src/v3/clients/collections.ts b/js/sdk/src/v3/clients/collections.ts index a6fe32ddb..956aceea7 100644 --- a/js/sdk/src/v3/clients/collections.ts +++ b/js/sdk/src/v3/clients/collections.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedBooleanResponse, @@ -17,6 +18,7 @@ export class CollectionsClient { * @param description Optional description of the collection * @returns A promise that resolves with the created collection */ + @feature("collections.create") async create(options: { name: string; description?: string; @@ -33,6 +35,7 @@ export class CollectionsClient { * @param limit Optional limit for pagination * @returns */ + @feature("collections.list") async list(options?: { ids?: string[]; offset?: number; @@ -57,6 +60,7 @@ export class CollectionsClient { * @param id Collection ID to retrieve * @returns */ + @feature("collections.retrieve") async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `collections/${options.id}`); } @@ -68,6 +72,7 @@ export class CollectionsClient { * @param description Optional new description for the collection * @returns */ + @feature("collections.update") async update(options: { id: string; name?: string; @@ -88,6 +93,7 @@ export class CollectionsClient { * @param id Collection ID to delete * @returns */ + @feature("collections.delete") async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `collections/${options.id}`); } @@ -99,6 +105,7 @@ export class CollectionsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("collections.listDocuments") async listDocuments(options: { id: string; offset?: number; @@ -124,6 +131,7 @@ export class CollectionsClient { * @param documentId Document ID to add * @returns */ + @feature("collections.addDocument") async addDocument(options: { id: string; documentId: string; @@ -140,6 +148,7 @@ export class CollectionsClient { * @param documentId Document ID to remove * @returns */ + @feature("collections.removeDocument") async removeDocument(options: { id: string; documentId: string; @@ -157,6 +166,7 @@ export class CollectionsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("collections.listUsers") async listUsers(options: { id: string; offset?: number; @@ -178,6 +188,7 @@ export class CollectionsClient { * @param userId User ID to add * @returns */ + @feature("collections.addUser") async addUser(options: { id: string; userId: string; @@ -194,13 +205,14 @@ export class CollectionsClient { * @param userId User ID to remove * @returns */ + @feature("collections.removeUser") async removeUser(options: { id: string; userId: string; }): Promise { return this.client.makeRequest( "DELETE", - `collections/${options.id}/users/${options.userId}`, + `collecstions/${options.id}/users/${options.userId}`, ); } } diff --git a/js/sdk/src/v3/clients/conversations.ts b/js/sdk/src/v3/clients/conversations.ts index 425f6dab6..796c1edc7 100644 --- a/js/sdk/src/v3/clients/conversations.ts +++ b/js/sdk/src/v3/clients/conversations.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedBooleanResponse, @@ -15,6 +16,7 @@ export class ConversationsClient { * Create a new conversation. * @returns */ + @feature("conversations.create") async create(): Promise { return this.client.makeRequest("POST", "conversations"); } @@ -26,6 +28,7 @@ export class ConversationsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("conversations.list") async list(options?: { ids?: string[]; offset?: number; @@ -51,6 +54,7 @@ export class ConversationsClient { * @param branchID The ID of the branch to retrieve * @returns */ + @feature("conversations.retrieve") async retrieve(options: { id: string; branchID?: string; @@ -69,6 +73,7 @@ export class ConversationsClient { * @param id The ID of the conversation to delete * @returns */ + @feature("conversations.delete") async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `conversations/${options.id}`); } @@ -82,6 +87,7 @@ export class ConversationsClient { * @param metadata Additional metadata to attach to the message * @returns */ + @feature("conversations.addMessage") async addMessage(options: { id: string; content: string; @@ -112,6 +118,7 @@ export class ConversationsClient { * @param content The new content of the message * @returns */ + @feature("conversations.updateMessage") async updateMessage(options: { id: string; messageID: string; @@ -135,6 +142,7 @@ export class ConversationsClient { * @param id The ID of the conversation to list branches for * @returns */ + @feature("conversations.listBranches") async listBranches(options: { id: string; offset?: number; diff --git a/js/sdk/src/v3/clients/documents.ts b/js/sdk/src/v3/clients/documents.ts index 0a1922092..3b3a1c8e9 100644 --- a/js/sdk/src/v3/clients/documents.ts +++ b/js/sdk/src/v3/clients/documents.ts @@ -8,6 +8,7 @@ import { WrappedDocumentsResponse, WrappedIngestionResponse, } from "../../types"; +import { feature } from "../../feature"; let fs: any; if (typeof window === "undefined") { @@ -32,6 +33,7 @@ export class DocumentsClient { * @param runWithOrchestration Optional flag to run with orchestration * @returns */ + @feature("documents.create") async create(options: { file?: FileInput; content?: string; @@ -144,6 +146,7 @@ export class DocumentsClient { * @param runWithOrchestration Whether to run with orchestration * @returns */ + @feature("documents.update") async update(options: { id: string; file?: FileInput; @@ -236,6 +239,7 @@ export class DocumentsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("documents.retrieve") async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `documents/${options.id}`); } @@ -247,6 +251,7 @@ export class DocumentsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("documents.list") async list(options?: { ids?: string[]; offset?: number; @@ -271,6 +276,7 @@ export class DocumentsClient { * @param id ID of document to download * @returns */ + @feature("documents.download") async download(options: { id: string }): Promise { return this.client.makeRequest("GET", `documents/${options.id}/download`, { responseType: "blob", @@ -282,6 +288,7 @@ export class DocumentsClient { * @param id ID of document to delete * @returns */ + @feature("documents.delete") async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `documents/${options.id}`); } @@ -294,6 +301,7 @@ export class DocumentsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("documents.listChunks") async listChunks(options: { id: string; includeVectors?: boolean; @@ -318,6 +326,7 @@ export class DocumentsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("documents.listCollections") async listCollections(options: { id: string; offset?: number; @@ -337,6 +346,7 @@ export class DocumentsClient { ); } + @feature("documents.deleteByFilter") async deleteByFilter(options: { filters: Record; }): Promise { @@ -345,6 +355,7 @@ export class DocumentsClient { }); } + @feature("documents.extract") async extract(options: { id: string; runType?: string; diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index aa7128715..d7906f539 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedGraphResponse, @@ -21,6 +22,7 @@ export class GraphsClient { * @param description Optional description of the graph * @returns The created graph */ + @feature("graphs.create") async create(options: { name: string; description?: string; @@ -37,6 +39,7 @@ export class GraphsClient { * @param limit Optional limit for pagination * @returns */ + @feature("graphs.list") async list(options?: { ids?: string[]; offset?: number; @@ -61,6 +64,7 @@ export class GraphsClient { * @param id Graph ID to retrieve * @returns */ + @feature("graphs.retrieve") async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `graphs/${options.id}`); } @@ -72,6 +76,7 @@ export class GraphsClient { * @param description Optional new description for the graph * @returns */ + @feature("graphs.update") async update(options: { id: string; name?: string; @@ -92,6 +97,7 @@ export class GraphsClient { * @param options * @returns */ + @feature("graphs.delete") async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `graphs/${options.id}`); } @@ -103,6 +109,7 @@ export class GraphsClient { * @param entityId Entity ID to add * @returns */ + @feature("graphs.addEntity") async addEntity(options: { id: string; entityId: string; @@ -119,6 +126,7 @@ export class GraphsClient { * @param entityId Entity ID to remove * @returns */ + @feature("graphs.removeEntity") async removeEntity(options: { id: string; entityId: string; @@ -136,6 +144,7 @@ export class GraphsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("graphs.listEntities") async listEntities(options: { id: string; offset?: number; @@ -157,6 +166,7 @@ export class GraphsClient { * @param entityId Entity ID to retrieve * @returns */ + @feature("graphs.getEntity") async getEntity(options: { id: string; entityId: string; @@ -174,6 +184,7 @@ export class GraphsClient { * @param relationshipId Relationship ID to add * @returns */ + @feature("graphs.addRelationship") async addRelationship(options: { id: string; relationshipId: string; @@ -190,6 +201,7 @@ export class GraphsClient { * @param relationshipId Relationship ID to remove * @returns */ + @feature("graphs.removeRelationship") async removeRelationship(options: { id: string; relationshipId: string; @@ -207,6 +219,7 @@ export class GraphsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("graphs.listRelationships") async listRelationships(options: { id: string; offset?: number; @@ -232,6 +245,7 @@ export class GraphsClient { * @param relationshipId Relationship ID to retrieve * @returns */ + @feature("graphs.getRelationship") async getRelationship(options: { id: string; relationshipId: string; @@ -249,6 +263,7 @@ export class GraphsClient { * @param communityId Community ID to add * @returns */ + @feature("graphs.addCommunity") async addCommunity(options: { id: string; communityId: string; @@ -265,6 +280,7 @@ export class GraphsClient { * @param communityId Community ID to remove * @returns */ + @feature("graphs.removeCommunity") async removeCommunity(options: { id: string; communityId: string; @@ -282,6 +298,7 @@ export class GraphsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("graphs.listCommunities") async listCommunities(options: { id: string; offset?: number; @@ -303,6 +320,7 @@ export class GraphsClient { * @param communityId Community ID to retrieve * @returns */ + @feature("graphs.getCommunity") async getCommunity(options: { id: string; communityId: string; diff --git a/js/sdk/src/v3/clients/indices.ts b/js/sdk/src/v3/clients/indices.ts index 384093a52..cae88989b 100644 --- a/js/sdk/src/v3/clients/indices.ts +++ b/js/sdk/src/v3/clients/indices.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { IndexConfig, @@ -14,6 +15,7 @@ export class IndiciesClient { * @param runWithOrchestration Whether to run index creation as an orchestrated task. * @returns */ + @feature("indices.create") async create(options: { config: IndexConfig; runWithOrchestration?: boolean; @@ -37,6 +39,7 @@ export class IndiciesClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("indices.list") async list(options?: { filters?: Record; offset?: number; @@ -62,6 +65,7 @@ export class IndiciesClient { * @param tableName The name of the table where the index is stored. * @returns */ + @feature("indices.retrieve") async retrieve(options: { tableName: string; indexName: string; @@ -78,6 +82,7 @@ export class IndiciesClient { * @param tableName The name of the table where the index is stored. * @returns */ + @feature("indices.delete") async delete(options: { tableName: string; indexName: string; diff --git a/js/sdk/src/v3/clients/prompts.ts b/js/sdk/src/v3/clients/prompts.ts index 1cdf804bf..d247251f9 100644 --- a/js/sdk/src/v3/clients/prompts.ts +++ b/js/sdk/src/v3/clients/prompts.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedBooleanResponse, @@ -9,6 +10,17 @@ import { export class PromptsClient { constructor(private client: r2rClient) {} + /** + * Create a new prompt with the given configuration. + * + * This endpoint allows superusers to create a new prompt with a + * specified name, template, and input types. + * @param name The name of the prompt + * @param template The template string for the prompt + * @param inputTypes A dictionary mapping input names to their types + * @returns + */ + @feature("prompts.create") async create(options: { name: string; template: string; @@ -19,10 +31,28 @@ export class PromptsClient { }); } + /** + * List all available prompts. + * + * This endpoint retrieves a list of all prompts in the system. + * Only superusers can access this endpoint. + * @returns + */ + @feature("prompts.list") async list(): Promise { return this.client.makeRequest("GET", "prompts"); } + /** + * Get a specific prompt by name, optionally with inputs and override. + * + * This endpoint retrieves a specific prompt and allows for optional + * inputs and template override. + * Only superusers can access this endpoint. + * @param options + * @returns + */ + @feature("prompts.retrieve") async retrieve(options: { name: string; inputs?: string[]; @@ -41,6 +71,14 @@ export class PromptsClient { }); } + /** + * Update an existing prompt's template and/or input types. + * + * This endpoint allows superusers to update the template and input types of an existing prompt. + * @param options + * @returns + */ + @feature("prompts.update") async update(options: { name: string; template?: string; @@ -61,6 +99,13 @@ export class PromptsClient { }); } + /** + * Delete a prompt by name. + * + * This endpoint allows superusers to delete an existing prompt. + * @param name The name of the prompt to delete + * @returns + */ async delete(options: { name: string }): Promise { return this.client.makeRequest("DELETE", `prompts/${options.name}`); } diff --git a/js/sdk/src/v3/clients/retrieval.ts b/js/sdk/src/v3/clients/retrieval.ts index 914ee2d7e..08b9f5028 100644 --- a/js/sdk/src/v3/clients/retrieval.ts +++ b/js/sdk/src/v3/clients/retrieval.ts @@ -6,10 +6,27 @@ import { KGSearchSettings, GenerationConfig, } from "../../models"; +import { feature } from "../../feature"; export class RetrievalClient { constructor(private client: r2rClient) {} + /** + * Perform a search query on the vector database and knowledge graph and + * any other configured search engines. + * + * This endpoint allows for complex filtering of search results using + * PostgreSQL-based queries. Filters can be applied to various fields + * such as document_id, and internal metadata values. + * + * Allowed operators include: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, + * `like`, `ilike`, `in`, and `nin`. + * @param query Search query to find relevant documents + * @param VectorSearchSettings Settings for vector-based search + * @param KGSearchSettings Settings for knowledge graph search + * @returns + */ + @feature("retrieval.search") async search(options: { query: string; vectorSearchSettings?: ChunkSearchSettings | Record; @@ -30,6 +47,23 @@ export class RetrievalClient { }); } + /** + * Execute a RAG (Retrieval-Augmented Generation) query. + * + * This endpoint combines search results with language model generation. + * It supports the same filtering capabilities as the search endpoint, + * allowing for precise control over the retrieved context. + * + * The generation process can be customized using the `rag_generation_config` parameter. + * @param query + * @param ragGenerationConfig Configuration for RAG generation + * @param vectorSearchSettings Settings for vector-based search + * @param kgSearchSettings Settings for knowledge graph search + * @param taskPromptOverride Optional custom prompt to override default + * @param includeTitleIfAvailable Include document titles in responses when available + * @returns + */ + @feature("retrieval.rag") async rag(options: { query: string; ragGenerationConfig?: GenerationConfig | Record; @@ -66,6 +100,7 @@ export class RetrievalClient { } } + @feature("retrieval.streamRag") private async streamRag( ragData: Record, ): Promise> { @@ -82,8 +117,55 @@ export class RetrievalClient { ); } + /** + * Engage with an intelligent RAG-powered conversational agent for complex + * information retrieval and analysis. + * + * This advanced endpoint combines retrieval-augmented generation (RAG) + * with a conversational AI agent to provide detailed, context-aware + * responses based on your document collection. + * + * The agent can: + * - Maintain conversation context across multiple interactions + * - Dynamically search and retrieve relevant information from both + * vector and knowledge graph sources + * - Break down complex queries into sub-questions for comprehensive + * answers + * - Cite sources and provide evidence-based responses + * - Handle follow-up questions and clarifications + * - Navigate complex topics with multi-step reasoning + * + * Key Features: + * - Hybrid search combining vector and knowledge graph approaches + * - Contextual conversation management with conversation_id tracking + * - Customizable generation parameters for response style and length + * - Source document citation with optional title inclusion + * - Streaming support for real-time responses + * - Branch management for exploring different conversation paths + * + * Common Use Cases: + * - Research assistance and literature review + * - Document analysis and summarization + * - Technical support and troubleshooting + * - Educational Q&A and tutoring + * - Knowledge base exploration + * + * The agent uses both vector search and knowledge graph capabilities to + * find and synthesize information, providing detailed, factual responses + * with proper attribution to source documents. + * @param message Current message to process + * @param ragGenerationConfig Configuration for RAG generation + * @param vectorSearchSettings Settings for vector-based search + * @param kgSearchSettings Settings for knowledge graph search + * @param taskPromptOverride Optional custom prompt to override default + * @param includeTitleIfAvailable Include document titles in responses when available + * @param conversationId ID of the conversation + * @param branchId ID of the conversation branch + * @returns + */ + @feature("retrieval.agent") async agent(options: { - messages: Message[]; + message: Message; ragGenerationConfig?: GenerationConfig | Record; vectorSearchSettings?: ChunkSearchSettings | Record; kgSearchSettings?: KGSearchSettings | Record; @@ -93,7 +175,7 @@ export class RetrievalClient { branchId?: string; }): Promise> { const data: Record = { - messages: options.messages, + message: options.message, ...(options.vectorSearchSettings && { vectorSearchSettings: options.vectorSearchSettings, }), @@ -126,6 +208,7 @@ export class RetrievalClient { } } + @feature("retrieval.streamAgent") private async streamAgent( agentData: Record, ): Promise> { @@ -142,6 +225,20 @@ export class RetrievalClient { ); } + /** + * Generate completions for a list of messages. + * + * This endpoint uses the language model to generate completions for + * the provided messages. The generation process can be customized using + * the generation_config parameter. + * + * The messages list should contain alternating user and assistant + * messages, with an optional system message at the start. Each message + * should have a 'role' and 'content'. + * @param messages List of messages to generate completion for + * @returns + */ + @feature("retrieval.completion") async completion(options: { messages: Message[]; generationConfig?: GenerationConfig | Record; @@ -162,6 +259,7 @@ export class RetrievalClient { } } + @feature("retrieval.streamCompletion") private async streamCompletion( ragData: Record, ): Promise> { diff --git a/js/sdk/src/v3/clients/system.ts b/js/sdk/src/v3/clients/system.ts index 6d32ec828..7fdf30773 100644 --- a/js/sdk/src/v3/clients/system.ts +++ b/js/sdk/src/v3/clients/system.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedGenericMessageResponse, @@ -12,6 +13,7 @@ export class SystemClient { /** * Check the health of the R2R server. */ + @feature("system.health") async health(): Promise { return await this.client.makeRequest("GET", "health"); } @@ -21,6 +23,7 @@ export class SystemClient { * @param options * @returns */ + @feature("system.logs") async logs(options: { runTypeFilter?: string; offset?: number; @@ -42,14 +45,17 @@ export class SystemClient { * Get the configuration settings for the R2R server. * @returns */ + @feature("system.settings") async settings(): Promise { return await this.client.makeRequest("GET", "system/settings"); } /** - * Get statistics about the server, including the start time, uptime, CPU usage, and memory usage. + * Get statistics about the server, including the start time, uptime, + * CPU usage, and memory usage. * @returns */ + @feature("system.status") async status(): Promise { return await this.client.makeRequest("GET", "system/status"); } diff --git a/js/sdk/src/v3/clients/users.ts b/js/sdk/src/v3/clients/users.ts index b36e86aac..eadd37a7a 100644 --- a/js/sdk/src/v3/clients/users.ts +++ b/js/sdk/src/v3/clients/users.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedBooleanResponse, @@ -17,6 +18,7 @@ export class UsersClient { * @param password User's password * @returns */ + @feature("users.register") async register(options: { email: string; password: string; @@ -33,6 +35,7 @@ export class UsersClient { * @param password User's password * @returns */ + @feature("users.delete") async delete(options: { id: string; password: string; @@ -49,6 +52,7 @@ export class UsersClient { * @param email User's email address * @param verificationCode Verification code sent to the user's email */ + @feature("users.verifyEmail") async verifyEmail(options: { email: string; verificationCode: string; @@ -64,6 +68,7 @@ export class UsersClient { * @param password User's password * @returns */ + @feature("users.login") async login(options: { email: string; password: string }): Promise { const response = await this.client.makeRequest("POST", "users/login", { data: { @@ -90,7 +95,7 @@ export class UsersClient { * @param accessToken Existing access token * @returns */ - // FIXME: What is going on here... + @feature("users.loginWithToken") async loginWithToken(options: { accessToken: string }): Promise { this.client.setTokens(options.accessToken, null); @@ -115,6 +120,7 @@ export class UsersClient { * Log out the current user. * @returns */ + @feature("users.logout") async logout(): Promise { const response = await this.client.makeRequest("POST", "users/logout"); this.client.setTokens(null, null); @@ -125,6 +131,7 @@ export class UsersClient { * Refresh the access token using the refresh token. * @returns */ + @feature("users.refreshAccessToken") async refreshAccessToken(): Promise { const refreshToken = this.client.getRefreshToken(); if (!refreshToken) { @@ -160,6 +167,7 @@ export class UsersClient { * @param new_password User's new password * @returns */ + @feature("users.changePassword") async changePassword(options: { current_password: string; new_password: string; @@ -174,6 +182,7 @@ export class UsersClient { * @param email User's email address * @returns */ + @feature("users.requestPasswordReset") async requestPasswordReset(options: { email: string; }): Promise { @@ -182,6 +191,13 @@ export class UsersClient { }); } + /** + * Reset a user's password using a reset token. + * @param reset_token Reset token sent to the user's email + * @param new_password New password for the user + * @returns + */ + @feature("users.resetPassword") async resetPassword(options: { reset_token: string; new_password: string; @@ -200,6 +216,7 @@ export class UsersClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("users.list") async list(options?: { email?: string; is_active?: boolean; @@ -232,6 +249,7 @@ export class UsersClient { * @param id User ID to retrieve * @returns */ + @feature("users.retrieve") async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `users/${options.id}`); } @@ -240,6 +258,7 @@ export class UsersClient { * Get detailed information about the currently authenticated user. * @returns */ + @feature("users.me") async me(): Promise { return this.client.makeRequest("GET", `users/me`); } @@ -254,6 +273,7 @@ export class UsersClient { * @param profilePicture Optional new profile picture for the user * @returns */ + @feature("users.update") async update(options: { id: string; email?: string; @@ -284,6 +304,7 @@ export class UsersClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("users.listCollections") async listCollections(options: { id: string; offset?: number; @@ -305,6 +326,7 @@ export class UsersClient { * @param collectionId Collection ID to add the user to * @returns */ + @feature("users.addToCollection") async addToCollection(options: { id: string; collectionId: string; @@ -321,6 +343,7 @@ export class UsersClient { * @param collectionId Collection ID to remove the user from * @returns */ + @feature("users.removeFromCollection") async removeFromCollection(options: { id: string; collectionId: string; From 651b7864e62c9f6a453e11a6f2d8a0625ba2eeec Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 22 Nov 2024 16:04:18 -0800 Subject: [PATCH 03/28] Typo --- js/sdk/src/v3/clients/collections.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/sdk/src/v3/clients/collections.ts b/js/sdk/src/v3/clients/collections.ts index 956aceea7..a975f90c8 100644 --- a/js/sdk/src/v3/clients/collections.ts +++ b/js/sdk/src/v3/clients/collections.ts @@ -212,7 +212,7 @@ export class CollectionsClient { }): Promise { return this.client.makeRequest( "DELETE", - `collecstions/${options.id}/users/${options.userId}`, + `collections/${options.id}/users/${options.userId}`, ); } } From 3a7f419b650dcc96c922d991749c69a370282302 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Tue, 26 Nov 2024 13:14:42 -0600 Subject: [PATCH 04/28] Check in --- .../RetrievalIntegrationSuperUser.test.ts | 18 ++-- js/sdk/src/types.ts | 19 +++- js/sdk/src/v3/clients/documents.ts | 86 ++++++++++++++++++ py/compose.yaml | 2 +- py/core/main/api/v3/prompts_router.py | 89 +++++++++++++++++++ py/core/main/app.py | 10 +-- py/shared/abstractions/graph.py | 1 - 7 files changed, 202 insertions(+), 23 deletions(-) diff --git a/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts b/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts index 10e3f6465..59d346a53 100644 --- a/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts @@ -3,16 +3,10 @@ import { describe, test, beforeAll, expect } from "@jest/globals"; const baseUrl = "http://localhost:7272"; -const messages = [ - { - role: "system" as const, - content: "You are a helpful assistant.", - }, - { - role: "user" as const, - content: "Tell me about Sonia.", - }, -]; +const message = { + role: "user" as const, + content: "Tell me about Sonia.", +}; /** * sonia.txt will have an id of 28ce9a4c-4d15-5287-b0c6-67834b9c4546 @@ -85,7 +79,7 @@ describe("r2rClient V3 Documents Integration Tests", () => { test("Agent with no parameters", async () => { const response = await client.retrieval.agent({ - messages: messages, + message: message, }); expect(response.results).toBeDefined(); @@ -93,7 +87,7 @@ describe("r2rClient V3 Documents Integration Tests", () => { test("Streaming RAG", async () => { const stream = await client.retrieval.agent({ - messages: messages, + message: message, ragGenerationConfig: { stream: true, }, diff --git a/js/sdk/src/types.ts b/js/sdk/src/types.ts index b9ae60f32..efa06c0df 100644 --- a/js/sdk/src/types.ts +++ b/js/sdk/src/types.ts @@ -86,8 +86,8 @@ export interface DocumentResponse { size_in_bytes?: number; ingestion_status: string; kg_extraction_status: string; - created_date: string; - updated_date: string; + created_at: string; + updated_at: string; ingestion_attempt_number?: number; summary?: string; summary_embedding?: string; @@ -157,9 +157,20 @@ export interface PromptResponse { input_types: string[]; } -//TODO: Sync this with the finished API response model // Relationship types -export interface RelationshipResponse {} +export interface RelationshipResponse { + id: string; + subject: string; + predicate: string; + object: string; + description?: string; + subject_id: string; + object_id: string; + weight: number; + chunk_ids: string[]; + parent_id: string; + metadata: Record; +} // Retrieval types export interface VectorSearchResult { diff --git a/js/sdk/src/v3/clients/documents.ts b/js/sdk/src/v3/clients/documents.ts index 3b3a1c8e9..f43f1f1f9 100644 --- a/js/sdk/src/v3/clients/documents.ts +++ b/js/sdk/src/v3/clients/documents.ts @@ -6,6 +6,7 @@ import { WrappedCollectionsResponse, WrappedDocumentResponse, WrappedDocumentsResponse, + WrappedEntitiesResponse, WrappedIngestionResponse, } from "../../types"; import { feature } from "../../feature"; @@ -355,6 +356,15 @@ export class DocumentsClient { }); } + /** + * Extracts entities and relationships from a document. + * + * The entities and relationships extraction process involves: + * 1. Parsing documents into semantic chunks + * 2. Extracting entities and relationships using LLMs + * @param options + * @returns + */ @feature("documents.extract") async extract(options: { id: string; @@ -374,4 +384,80 @@ export class DocumentsClient { data, }); } + + /** + * Retrieves the entities that were extracted from a document. These + * represent important semantic elements like people, places, + * organizations, concepts, etc. + * + * Users can only access entities from documents they own or have access + * to through collections. Entity embeddings are only included if + * specifically requested. + * + * Results are returned in the order they were extracted from the document. + * @param id Document ID to retrieve entities for + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @param includeEmbeddings Whether to include vector embeddings in the response. + * @returns + */ + @feature("documents.listEntities") + async listEntities(options: { + id: string; + offset?: number; + limit?: number; + includeVectors?: boolean; + }): Promise { + const params: Record = { + offset: options.offset ?? 0, + limit: options.limit ?? 100, + includeVectors: options.includeVectors ?? false, + }; + + return this.client.makeRequest("GET", `documents/${options.id}/entities`, { + params, + }); + } + + /** + * Retrieves the relationships between entities that were extracted from + * a document. These represent connections and interactions between + * entities found in the text. + * + * Users can only access relationships from documents they own or have + * access to through collections. Results can be filtered by entity names + * and relationship types. + * + * Results are returned in the order they were extracted from the document. + * @param id Document ID to retrieve relationships for + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @param includeEmbeddings Whether to include vector embeddings in the response. + * @param entityNames Filter relationships by specific entity names. + * @param relationshipTypes Filter relationships by specific relationship types. + * @returns + */ + @feature("documents.listRelationships") + async listRelationships(options: { + id: string; + offset?: number; + limit?: number; + includeVectors?: boolean; + entityNames?: string[]; + relationshipTypes?: string[]; + }): Promise { + const params: Record = { + offset: options.offset ?? 0, + limit: options.limit ?? 100, + includeVectors: options.includeVectors ?? false, + }; + + return this.client.makeRequest( + "GET", + `documents/${options.id}/relationships`, + { + params, + }, + ); + } } diff --git a/py/compose.yaml b/py/compose.yaml index 2621c06d7..7024c83ba 100644 --- a/py/compose.yaml +++ b/py/compose.yaml @@ -36,7 +36,7 @@ services: -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} r2r: - image: r2r/test + image: ${R2R_IMAGE:-ragtoriches/prod:latest} build: context: . args: diff --git a/py/core/main/api/v3/prompts_router.py b/py/core/main/api/v3/prompts_router.py index b84efb3cf..5de440f7a 100644 --- a/py/core/main/api/v3/prompts_router.py +++ b/py/core/main/api/v3/prompts_router.py @@ -195,6 +195,95 @@ async def get_prompts( }, ) + @self.router.post( + "/prompts/{name}", + summary="Get a specific prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.prompts.get( + "greeting_prompt", + inputs={"name": "John"}, + prompt_override="Hi, {name}!" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.prompts.retrieve({ + name: "greeting_prompt", + inputs: { name: "John" }, + promptOverride: "Hi, {name}!", + }); + } + + main(); + """ + ), + }, + { + "lang": "CLI", + "source": textwrap.dedent( + """ + r2r prompts retrieve greeting_prompt --inputs '{"name": "John"}' --prompt-override "Hi, {name}!" + """ + ), + }, + { + "lang": "cURL", + "source": textwrap.dedent( + """ + curl -X POST "https://api.example.com/v3/prompts/greeting_prompt?inputs=%7B%22name%22%3A%22John%22%7D&prompt_override=Hi%2C%20%7Bname%7D!" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def get_prompt( + name: str = Path(..., description="Prompt name"), + inputs: Optional[dict[str, str]] = Body( + None, description="Prompt inputs" + ), + prompt_override: Optional[str] = Query( + None, description="Prompt override" + ), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> WrappedPromptResponse: + """ + Get a specific prompt by name, optionally with inputs and override. + + This endpoint retrieves a specific prompt and allows for optional inputs and template override. + Only superusers can access this endpoint. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can retrieve prompts.", + 403, + ) + result = await self.services["management"].get_prompt( + name, inputs, prompt_override + ) + return result # type: ignore + @self.router.put( "/prompts/{name}", summary="Update an existing prompt", diff --git a/py/core/main/app.py b/py/core/main/app.py index 5fe984d45..de18c0972 100644 --- a/py/core/main/app.py +++ b/py/core/main/app.py @@ -85,11 +85,11 @@ async def r2r_exception_handler(request: Request, exc: R2RException): def _setup_routes(self): # Include routers in the app - # self.app.include_router(self.ingestion_router, prefix="/v2") - # self.app.include_router(self.management_router, prefix="/v2") - # self.app.include_router(self.retrieval_router, prefix="/v2") - # self.app.include_router(self.auth_router, prefix="/v2") - # self.app.include_router(self.kg_router, prefix="/v2") + self.app.include_router(self.ingestion_router, prefix="/v2") + self.app.include_router(self.management_router, prefix="/v2") + self.app.include_router(self.retrieval_router, prefix="/v2") + self.app.include_router(self.auth_router, prefix="/v2") + self.app.include_router(self.kg_router, prefix="/v2") self.app.include_router(self.documents_router, prefix="/v3") self.app.include_router(self.chunks_router, prefix="/v3") diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 2036aefc3..de2a794dc 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -69,7 +69,6 @@ def __init__(self, **kwargs): class Relationship(R2RSerializable): """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" - # id is Union of UUID and int for backwards compatibility id: Optional[UUID] = None subject: str predicate: str From d96957e9984fab4fe3e0d8d0a91b704017520003 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Tue, 26 Nov 2024 13:19:23 -0600 Subject: [PATCH 05/28] Rebase --- py/core/base/api/models/__init__.py | 8 ---- py/core/main/api/v2/kg_router.py | 9 ++-- py/core/main/api/v3/collections_router.py | 1 - py/core/main/api/v3/documents_router.py | 4 +- py/core/main/api/v3/graph_router.py | 29 +++++------- py/core/main/api/v3/retrieval_router.py | 12 ----- py/core/providers/database/graph.py | 2 +- py/sdk/models.py | 4 -- py/sdk/v2/kg.py | 3 +- py/sdk/v2/sync_kg.py | 3 +- py/shared/api/models/__init__.py | 8 ---- py/shared/api/models/kg/responses.py | 57 ----------------------- 12 files changed, 19 insertions(+), 121 deletions(-) diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index 2cedafbd5..a976e423f 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -21,18 +21,14 @@ from shared.api.models.kg.responses import ( Community, Entity, - KGCreationResponse, KGEnrichmentResponse, - KGEntityDeduplicationResponse, KGTunePromptResponse, Relationship, WrappedCommunitiesResponse, WrappedCommunityResponse, WrappedEntitiesResponse, WrappedEntityResponse, - WrappedKGCreationResponse, WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, WrappedKGTunePromptResponse, WrappedRelationshipResponse, WrappedRelationshipsResponse, @@ -105,20 +101,16 @@ "Entity", "Relationship", "Community", - "KGCreationResponse", "KGEnrichmentResponse", "KGTunePromptResponse", - "KGEntityDeduplicationResponse", "WrappedEntityResponse", "WrappedEntitiesResponse", "WrappedRelationshipResponse", "WrappedRelationshipsResponse", "WrappedCommunityResponse", "WrappedCommunitiesResponse", - "WrappedKGCreationResponse", "WrappedKGEnrichmentResponse", "WrappedKGTunePromptResponse", - "WrappedKGEntityDeduplicationResponse", # TODO: Need to review anything above this "GraphResponse", "WrappedGraphResponse", diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index 3418e7750..f13c16a23 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -10,10 +10,7 @@ from core.base.abstractions import DataLevel, KGRunType from core.base.api.models import ( WrappedCommunitiesResponse, - WrappedKGCreationResponse, - WrappedKGEnrichmentResponse, WrappedEntitiesResponse, - WrappedKGEntityDeduplicationResponse, WrappedRelationshipsResponse, WrappedKGTunePromptResponse, ) @@ -105,7 +102,7 @@ async def create_graph( ), run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ): # -> WrappedKGCreationResponse: # type: ignore + ): """ Creating a graph on your documents. This endpoint takes input a list of document ids and KGCreationSettings. If document IDs are not provided, the graph will be created on all documents in the system. @@ -186,7 +183,7 @@ async def enrich_graph( ), run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ): # -> WrappedKGEnrichmentResponse: + ): """ This endpoint enriches the graph with additional information. It creates communities of nodes based on their similarity and adds embeddings to the graph. @@ -395,7 +392,7 @@ async def deduplicate_entities( None, description="Settings for the deduplication process." ), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedKGEntityDeduplicationResponse: + ): """ Deduplicate entities in the knowledge graph. """ diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index 8ebb35f6d..828e0ad77 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -8,7 +8,6 @@ from core.base import R2RException, RunType from core.base.api.models import ( GenericBooleanResponse, - GenericMessageResponse, WrappedBooleanResponse, WrappedCollectionResponse, WrappedCollectionsResponse, diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index 992265d40..edbaa5d2e 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -14,7 +14,6 @@ from core.base import R2RException, RunType, generate_document_id from core.base.abstractions import ( Entity, - GraphBuildSettings, KGCreationSettings, KGRunType, Relationship, @@ -28,7 +27,6 @@ WrappedDocumentResponse, WrappedDocumentsResponse, WrappedIngestionResponse, - WrappedKGCreationResponse, ) from core.providers import ( HatchetOrchestrationProvider, @@ -1257,7 +1255,7 @@ async def extract( description="Whether to run the entities and relationships extraction process with orchestration.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGCreationResponse: # type: ignore + ): """ Extracts entities and relationships from a document. The entities and relationships extraction process involves: diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index b1308b87d..9c34fe1be 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -16,19 +16,14 @@ ) from core.base.api.models import ( GenericBooleanResponse, - GenericMessageResponse, - PaginatedResultsWrapper, WrappedBooleanResponse, WrappedCommunitiesResponse, WrappedCommunityResponse, WrappedEntitiesResponse, WrappedEntityResponse, - WrappedGenericMessageResponse, WrappedGraphResponse, WrappedGraphsResponse, - WrappedKGCreationResponse, WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, WrappedKGTunePromptResponse, WrappedRelationshipResponse, WrappedRelationshipsResponse, @@ -87,7 +82,7 @@ async def _deduplicate_entities( run_type: Optional[KGRunType] = KGRunType.ESTIMATE, run_with_orchestration: bool = True, auth_user=None, - ) -> WrappedKGEntityDeduplicationResponse: + ): """Deduplicates entities in the knowledge graph using LLM-based analysis. The deduplication process: @@ -600,7 +595,7 @@ async def get_entities( description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Entity]]: + ) -> WrappedEntitiesResponse: """Lists all entities in the graph with pagination support.""" # return await self.services["kg"].get_entities( # id, offset, limit, auth_user @@ -610,8 +605,8 @@ async def get_entities( collection_id, offset, limit ) ) - print("entities = ", entities) - return entities, { + + return entities, { # type: ignore "total_entries": count, } @@ -624,7 +619,7 @@ async def create_entity( ), entity: Entity = Body(..., description="The entity to create"), auth_user=Depends(self.providers.auth.auth_wrapper), - ): # -> WrappedEntityResponse: + ) -> WrappedEntityResponse: """Creates a new entity in the graph.""" if ( not auth_user.is_superuser @@ -757,7 +752,7 @@ async def delete_entity( await self.providers.database.graph_handler.entities.delete( collection_id, [entity_id], "graph" ) - return {"success": True} + return GenericBooleanResponse(success=True) # type: ignore @self.router.get("/graphs/{collection_id}/relationships") @self.base_endpoint @@ -778,7 +773,7 @@ async def get_relationships( description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Relationship]]: + ) -> WrappedRelationshipsResponse: """ Lists all relationships in the graph with pagination support. """ @@ -801,7 +796,7 @@ async def get_relationships( ) ) - return relationships, { + return relationships, { # type: ignore "total_entries": count, } @@ -863,7 +858,7 @@ async def update_relationship( ..., description="The updated relationship object." ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): # -> WrappedRelationshipResponse: + ) -> WrappedRelationshipResponse: """Updates an existing relationship in the graph.""" relationship.id = relationship_id relationship.parent_id = relationship.parent_id or collection_id @@ -895,7 +890,7 @@ async def delete_relationship( await self.providers.database.graph_handler.relationships.delete( collection_id, [relationship_id], "graph" ) - return {"success": True} + return GenericBooleanResponse(success=True) # type: ignore @self.router.post( "/graphs/{collection_id}/communities/build", @@ -1388,7 +1383,7 @@ async def pull( f"No documents were added to graph {collection_id}, marking as failed." ) - return GenericBooleanResponse(success=success) + return GenericBooleanResponse(success=success) # type: ignore @self.router.delete( "/graphs/{collection_id}/documents/{document_id}", @@ -1460,4 +1455,4 @@ async def remove_document( ) ) - return GenericBooleanResponse(success=success) + return GenericBooleanResponse(success=success) # type: ignore diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 04c793260..9bbddb0c1 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -220,7 +220,6 @@ async def search_app( description="Search query to find relevant documents", ), search_settings: SearchSettings = Body( - alias="searchSettings", default_factory=SearchSettings, description="Settings for vector-based search", ), @@ -379,22 +378,18 @@ async def search_app( async def rag_app( query: str = Body(...), search_settings: SearchSettings = Body( - alias="searchSettings", default_factory=SearchSettings, description="Settings for vector-based search", ), rag_generation_config: GenerationConfig = Body( - alias="ragGenerationConfig", default_factory=GenerationConfig, description="Configuration for RAG generation", ), task_prompt_override: Optional[str] = Body( - alias="taskPromptOverride", default=None, description="Optional custom prompt to override default", ), include_title_if_available: bool = Body( - alias="includeTitleIfAvailable", default=False, description="Include document titles in responses when available", ), @@ -564,32 +559,26 @@ async def agent_app( description="List of messages (deprecated, use message instead)", ), search_settings: SearchSettings = Body( - alias="searchSettings", default_factory=SearchSettings, description="Settings for vector-based search", ), rag_generation_config: GenerationConfig = Body( - alias="ragGenerationConfig", default_factory=GenerationConfig, description="Configuration for RAG generation", ), task_prompt_override: Optional[str] = Body( - alias="taskPromptOverride", default=None, description="Optional custom prompt to override default", ), include_title_if_available: bool = Body( - alias="includeTitleIfAvailable", default=True, description="Include document titles in responses when available", ), conversation_id: Optional[UUID] = Body( - alias="conversationId", default=None, description="ID of the conversation", ), branch_id: Optional[UUID] = Body( - alias="branchId", default=None, description="ID of the conversation branch", ), @@ -770,7 +759,6 @@ async def completion( ], ), generation_config: GenerationConfig = Body( - alias="generationConfig", default_factory=GenerationConfig, description="Configuration for text generation", example={ diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 958afdb80..51b1ffc39 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -2084,7 +2084,7 @@ async def update( UPDATE {self._get_table_name("graph")} SET {', '.join(update_fields)} WHERE id = ${param_index} - RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids, + RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids """ try: diff --git a/py/sdk/models.py b/py/sdk/models.py index 197a17397..987a8479d 100644 --- a/py/sdk/models.py +++ b/py/sdk/models.py @@ -23,9 +23,7 @@ ) from shared.api.models import ( CombinedSearchResponse, - KGCreationResponse, KGEnrichmentResponse, - KGEntityDeduplicationResponse, RAGResponse, UserResponse, ) @@ -52,8 +50,6 @@ "ChunkSearchResult", "SearchSettings", "KGEntityDeduplicationSettings", - "KGEntityDeduplicationResponse", - "KGCreationResponse", "KGEnrichmentResponse", "RAGResponse", "CombinedSearchResponse", diff --git a/py/sdk/v2/kg.py b/py/sdk/v2/kg.py index 993e96e89..8c6e618ac 100644 --- a/py/sdk/v2/kg.py +++ b/py/sdk/v2/kg.py @@ -4,7 +4,6 @@ from ..models import ( KGCreationSettings, KGEnrichmentSettings, - KGEntityDeduplicationResponse, KGEntityDeduplicationSettings, KGRunType, ) @@ -216,7 +215,7 @@ async def deduplicate_entities( deduplication_settings: Optional[ Union[dict, KGEntityDeduplicationSettings] ] = None, - ) -> KGEntityDeduplicationResponse: + ): """ Deduplicate entities in the knowledge graph. Args: diff --git a/py/sdk/v2/sync_kg.py b/py/sdk/v2/sync_kg.py index b224a483c..1caf4d9ec 100644 --- a/py/sdk/v2/sync_kg.py +++ b/py/sdk/v2/sync_kg.py @@ -4,7 +4,6 @@ from ..models import ( KGCreationSettings, KGEnrichmentSettings, - KGEntityDeduplicationResponse, KGEntityDeduplicationSettings, KGRunType, ) @@ -216,7 +215,7 @@ def deduplicate_entities( deduplication_settings: Optional[ Union[dict, KGEntityDeduplicationSettings] ] = None, - ) -> KGEntityDeduplicationResponse: + ): """ Deduplicate entities in the knowledge graph. Args: diff --git a/py/shared/api/models/__init__.py b/py/shared/api/models/__init__.py index 8cc600a12..0bf980c94 100644 --- a/py/shared/api/models/__init__.py +++ b/py/shared/api/models/__init__.py @@ -17,12 +17,8 @@ WrappedUpdateResponse, ) from shared.api.models.kg.responses import ( - KGCreationResponse, KGEnrichmentResponse, - KGEntityDeduplicationResponse, - WrappedKGCreationResponse, WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, # TODO: Need to review anything above this GraphResponse, WrappedGraphResponse, @@ -86,12 +82,8 @@ "WrappedUpdateResponse", "WrappedMetadataUpdateResponse", # Restructure Responses - "KGCreationResponse", "KGEnrichmentResponse", - "KGEntityDeduplicationResponse", - "WrappedKGCreationResponse", "WrappedKGEnrichmentResponse", - "WrappedKGEntityDeduplicationResponse", # TODO: Need to review anything above this "GraphResponse", "WrappedGraphResponse", diff --git a/py/shared/api/models/kg/responses.py b/py/shared/api/models/kg/responses.py index a72d606ed..e58f3d15f 100644 --- a/py/shared/api/models/kg/responses.py +++ b/py/shared/api/models/kg/responses.py @@ -119,25 +119,6 @@ class KGDeduplicationEstimate(R2RSerializable): ) -class KGCreationResponse(BaseModel): - message: str = Field( - ..., - description="A message describing the result of the KG creation request.", - ) - id: Optional[UUID] = Field( - None, - description="The ID of the created object.", - ) - task_id: Optional[UUID] = Field( - None, - description="The task ID of the KG creation request.", - ) - estimate: Optional[KGCreationEstimate] = Field( - None, - description="The estimation of the KG creation request.", - ) - - class Config: json_schema_extra = { "example": { @@ -195,40 +176,6 @@ class Config: } -class KGEntityDeduplicationResponse(BaseModel): - """Response for knowledge graph entity deduplication.""" - - message: str = Field( - ..., - description="The message to display to the user.", - ) - - task_id: Optional[UUID] = Field( - None, - description="The task ID of the KG entity deduplication request.", - ) - - estimate: Optional[KGDeduplicationEstimate] = Field( - None, - description="The estimation of the KG entity deduplication request.", - ) - - class Config: - json_schema_extra = { - "example": { - "message": "Entity deduplication queued successfully.", - "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", - "estimate": { - "num_entities": 1000, - "estimated_llm_calls": "1000", - "estimated_total_in_out_tokens_in_millions": "1000", - "estimated_cost_in_usd": "1000", - "estimated_total_time_in_minutes": "1000", - }, - } - } - - class KGTunePromptResponse(R2RSerializable): """Response containing just the tuned prompt string.""" @@ -250,12 +197,8 @@ class Config: # CREATE -WrappedKGCreationResponse = ResultsWrapper[KGCreationResponse] WrappedKGEnrichmentResponse = ResultsWrapper[KGEnrichmentResponse] WrappedKGTunePromptResponse = ResultsWrapper[KGTunePromptResponse] -WrappedKGEntityDeduplicationResponse = ResultsWrapper[ - KGEntityDeduplicationResponse -] class GraphResponse(BaseModel): From 749935d6cf8eca2a879c41f3449eccc28d750582 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:35:19 -0600 Subject: [PATCH 06/28] Add Graph tests --- js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index 5d5d0a3e6..31cd60d74 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -24,7 +24,7 @@ describe("r2rClient V3 Collections Integration Tests", () => { graph1Id = response.results.id; expect(graph1Id).toEqual(response.results.id); expect(response.results.name).toEqual("Graph 1"); - expect(response.results.description).toBe(null); + expect(response.results.description).toBe(""); }); test("Create a graph with name and description", async () => { @@ -48,7 +48,7 @@ describe("r2rClient V3 Collections Integration Tests", () => { const response = await client.graphs.retrieve({ id: graph1Id }); expect(response.results).toBeDefined(); expect(response.results.name).toEqual("Graph 1"); - expect(response.results.description).toBe(null); + expect(response.results.description).toBe(""); }); test("Retrieve graph 2", async () => { @@ -67,7 +67,7 @@ describe("r2rClient V3 Collections Integration Tests", () => { expect(response.results.name).toEqual("Graph 1 Updated"); }); - test("Update the desription graph 2", async () => { + test("Update the description graph 2", async () => { const response = await client.graphs.update({ id: graph2Id, description: "Graph 2 Updated", From ef2e97ad1734301755bf52cb8dcf45855ce8e7a0 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:09:13 -0600 Subject: [PATCH 07/28] Fix Agent empty message bug --- js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts | 2 +- js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts | 4 ++-- py/core/main/services/retrieval_service.py | 3 +++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts b/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts index 59d346a53..89451941e 100644 --- a/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts @@ -85,7 +85,7 @@ describe("r2rClient V3 Documents Integration Tests", () => { expect(response.results).toBeDefined(); }, 30000); - test("Streaming RAG", async () => { + test("Streaming agent", async () => { const stream = await client.retrieval.agent({ message: message, ragGenerationConfig: { diff --git a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts index 2e5cfdbde..594663799 100644 --- a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts @@ -153,7 +153,7 @@ describe("r2rClient Integration Tests", () => { await expect( client.ingestChunks([{ text: "test chunks" }]), ).resolves.not.toThrow(); - }); + }, 10000); test("Ingest chunks", async () => { await expect( @@ -163,7 +163,7 @@ describe("r2rClient Integration Tests", () => { { source: "example" }, ), ).resolves.not.toThrow(); - }); + }, 10000); test("Search documents", async () => { await expect(client.search("test")).resolves.not.toThrow(); diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index 60173768d..867e06d86 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -314,6 +314,9 @@ async def agent( ) messages = messages or [] + if message and not messages: + messages = [message] + current_message = messages[-1] # type: ignore # Save the new message to the conversation From d245985b89af7984d32bb5839aed0e6272a7d934 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Tue, 26 Nov 2024 18:32:54 -0600 Subject: [PATCH 08/28] Check in JS routes --- js/sdk/src/types.ts | 26 +++ js/sdk/src/v3/clients/graphs.ts | 323 ++++++++++++++++++---------- py/core/main/api/v3/graph_router.py | 34 +-- py/core/providers/database/graph.py | 1 - 4 files changed, 236 insertions(+), 148 deletions(-) diff --git a/js/sdk/src/types.ts b/js/sdk/src/types.ts index efa06c0df..c29cb60e6 100644 --- a/js/sdk/src/types.ts +++ b/js/sdk/src/types.ts @@ -1,3 +1,29 @@ +export interface Entity { + name: string; + id?: string; + category?: string; + description?: string; + parent_id?: string; + description_embedding?: string; + chunk_ids: string[]; + metadata: Record; +} + +export interface Relationship { + id?: string; + subject: string; + predicate: string; + object: string; + description?: string; + subject_id?: string; + object_id?: string; + weight?: number; + chunk_ids: string[]; + parent_id?: string; + description_embedding?: string; + metadata: Record; +} + export interface UnprocessedChunk { id: string; document_id?: string; diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index d7906f539..669ec162c 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -4,34 +4,19 @@ import { WrappedGraphResponse, WrappedBooleanResponse, WrappedGraphsResponse, - WrappedGenericMessageResponse, WrappedEntityResponse, WrappedEntitiesResponse, WrappedRelationshipsResponse, WrappedRelationshipResponse, WrappedCommunitiesResponse, WrappedCommunityResponse, + Entity, + Relationship, } from "../../types"; export class GraphsClient { constructor(private client: r2rClient) {} - /** - * Create a new graph. - * @param name Name of the graph - * @param description Optional description of the graph - * @returns The created graph - */ - @feature("graphs.create") - async create(options: { - name: string; - description?: string; - }): Promise { - return this.client.makeRequest("POST", "graphs", { - data: options, - }); - } - /** * List graphs with pagination and filtering options. * @param ids Optional list of graph IDs to filter by @@ -41,7 +26,7 @@ export class GraphsClient { */ @feature("graphs.list") async list(options?: { - ids?: string[]; + collectionIds?: string[]; offset?: number; limit?: number; }): Promise { @@ -50,8 +35,8 @@ export class GraphsClient { limit: options?.limit ?? 100, }; - if (options?.ids && options.ids.length > 0) { - params.ids = options.ids; + if (options?.collectionIds && options.collectionIds.length > 0) { + params.ids = options.collectionIds; } return this.client.makeRequest("GET", "graphs", { @@ -65,8 +50,29 @@ export class GraphsClient { * @returns */ @feature("graphs.retrieve") - async retrieve(options: { id: string }): Promise { - return this.client.makeRequest("GET", `graphs/${options.id}`); + async retrieve(options: { + collectionId: string; + }): Promise { + return this.client.makeRequest("GET", `graphs/${options.collectionId}`); + } + + /** + * Deletes a graph and all its associated data. + * + * This endpoint permanently removes the specified graph along with all + * entities and relationships that belong to only this graph. + * + * Entities and relationships extracted from documents are not deleted + * and must be deleted separately using the /entities and /relationships + * endpoints. + * @param collectionId The collection ID of the graph to delete + * @returns + */ + @feature("graphs.reset") + async reset(options: { + collectionId: string; + }): Promise { + return this.client.makeRequest("DELETE", `graphs/${options.collectionId}`); } /** @@ -93,60 +99,35 @@ export class GraphsClient { } /** - * Delete a graph. - * @param options - * @returns - */ - @feature("graphs.delete") - async delete(options: { id: string }): Promise { - return this.client.makeRequest("DELETE", `graphs/${options.id}`); - } - - /** - * FIXME: Should this be `addEntity` or `createEntity`? - * Add an entity to a graph. - * @param id Graph ID - * @param entityId Entity ID to add + * Creates a new entity in the graph. + * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param entity Entity to add * @returns */ - @feature("graphs.addEntity") - async addEntity(options: { - id: string; - entityId: string; - }): Promise { + @feature("graphs.createEntity") + async createEntity(options: { + collectionId: string; + entity: Entity; + }): Promise { return this.client.makeRequest( "POST", - `graphs/${options.id}/entities/${options.entityId}`, - ); - } - - /** - * Remove an entity from a graph. - * @param id Graph ID - * @param entityId Entity ID to remove - * @returns - */ - @feature("graphs.removeEntity") - async removeEntity(options: { - id: string; - entityId: string; - }): Promise { - return this.client.makeRequest( - "DELETE", - `graphs/${options.id}/entities/${options.entityId}`, + `graphs/${options.collectionId}/entities`, + { + data: options.entity, + }, ); } /** * List all entities in a graph. - * @param id Graph ID + * @param collectionId Collection ID * @param offset Specifies the number of objects to skip. Defaults to 0. * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ @feature("graphs.listEntities") async listEntities(options: { - id: string; + collectionId: string; offset?: number; limit?: number; }): Promise { @@ -155,73 +136,100 @@ export class GraphsClient { limit: options?.limit ?? 100, }; - return this.client.makeRequest("GET", `graphs/${options.id}/entities`, { - params, - }); + return this.client.makeRequest( + "GET", + `graphs/${options.collectionId}/entities`, + { + params, + }, + ); } /** * Retrieve an entity from a graph. - * @param id Graph ID + * @param collectionId The collection ID corresponding to the graph to add the entity to. * @param entityId Entity ID to retrieve * @returns */ @feature("graphs.getEntity") async getEntity(options: { - id: string; + collectionId: string; entityId: string; }): Promise { return this.client.makeRequest( "GET", - `graphs/${options.id}/entities/${options.entityId}`, + `graphs/${options.collectionId}/entities/${options.entityId}`, ); } /** - * FIXME: Should this be `addRelationship` or `createRelationship`? - * Add a relationship to a graph. - * @param id Graph ID - * @param relationshipId Relationship ID to add + * Updates an existing entity in the graph. + * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param entityId Entity ID to update + * @param entity Entity to update * @returns */ - @feature("graphs.addRelationship") - async addRelationship(options: { - id: string; - relationshipId: string; - }): Promise { + @feature("graphs.updateEntity") + async updateEntity(options: { + collectionId: string; + entityId: string; + entity: Entity; + }): Promise { return this.client.makeRequest( "POST", - `graphs/${options.id}/relationships/${options.relationshipId}`, + `graphs/${options.collectionId}/entities/${options.entityId}`, + { + data: options.entity, + }, ); } /** - * Remove a relationship from a graph. - * @param id Graph ID - * @param relationshipId Relationship ID to remove + * Remove an entity from a graph. + * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param entityId Entity ID to remove * @returns */ - @feature("graphs.removeRelationship") - async removeRelationship(options: { - id: string; - relationshipId: string; + @feature("graphs.removeEntity") + async removeEntity(options: { + collectionId: string; + entityId: string; }): Promise { return this.client.makeRequest( "DELETE", - `graphs/${options.id}/relationships/${options.relationshipId}`, + `graphs/${options.collectionId}/entities/${options.entityId}`, + ); + } + /** + * Creates a new relationship in the graph. + * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param relationship Relationship to add + * @returns + */ + @feature("graphs.createRelationship") + async createRelationship(options: { + collectionId: string; + relationship: Relationship; + }): Promise { + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/relationships`, + { + data: options.relationship, + }, ); } /** * List all relationships in a graph. - * @param id Graph ID + * @param collectionId Collection ID * @param offset Specifies the number of objects to skip. Defaults to 0. * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ @feature("graphs.listRelationships") async listRelationships(options: { - id: string; + collectionId: string; offset?: number; limit?: number; }): Promise { @@ -232,7 +240,7 @@ export class GraphsClient { return this.client.makeRequest( "GET", - `graphs/${options.id}/relationships`, + `graphs/${options.collectionId}/relationships`, { params, }, @@ -241,66 +249,70 @@ export class GraphsClient { /** * Retrieve a relationship from a graph. - * @param id Graph ID + * @param collectionId The collection ID corresponding to the graph to add the entity to. * @param relationshipId Relationship ID to retrieve * @returns */ @feature("graphs.getRelationship") async getRelationship(options: { - id: string; + collectionId: string; relationshipId: string; }): Promise { return this.client.makeRequest( "GET", - `graphs/${options.id}/relationships/${options.relationshipId}`, + `graphs/${options.collectionId}/entities/${options.relationshipId}`, ); } /** - * FIXME: Should this be `addCommunity` or `createCommunity`? - * Add a community to a graph. - * @param id Graph ID - * @param communityId Community ID to add + * Updates an existing relationship in the graph. + * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param relationshipId Relationship ID to update + * @param relationship Relationship to update * @returns */ - @feature("graphs.addCommunity") - async addCommunity(options: { - id: string; - communityId: string; - }): Promise { + @feature("graphs.updateRelationship") + async updateRelationship(options: { + collectionId: string; + relationshipId: string; + relationship: Relationship; + }): Promise { return this.client.makeRequest( "POST", - `graphs/${options.id}/communities/${options.communityId}`, + `graphs/${options.collectionId}/relationships/${options.relationshipId}`, + { + data: options.relationship, + }, ); } /** - * Remove a community from a graph. - * @param id Graph ID - * @param communityId Community ID to remove + * Remove a relationship from a graph. + * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param relationshipId Entity ID to remove * @returns */ - @feature("graphs.removeCommunity") - async removeCommunity(options: { - id: string; - communityId: string; + @feature("graphs.removeRelationship") + async removeRelationship(options: { + collectionId: string; + relationshipId: string; }): Promise { return this.client.makeRequest( "DELETE", - `graphs/${options.id}/communities/${options.communityId}`, + `graphs/${options.collectionId}/entities/${options.relationshipId}`, ); } /** * List all communities in a graph. - * @param id Graph ID + * @param collectionId Collection ID * @param offset Specifies the number of objects to skip. Defaults to 0. * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ @feature("graphs.listCommunities") async listCommunities(options: { - id: string; + collectionId: string; offset?: number; limit?: number; }): Promise { @@ -309,25 +321,100 @@ export class GraphsClient { limit: options?.limit ?? 100, }; - return this.client.makeRequest("GET", `graphs/${options.id}/communities`, { - params, - }); + return this.client.makeRequest( + "GET", + `graphs/${options.collectionId}/communities`, + { + params, + }, + ); } /** * Retrieve a community from a graph. - * @param id Graph ID - * @param communityId Community ID to retrieve + * @param collectionId The ID of the collection to get communities for. + * @param communityId Entity ID to retrieve * @returns */ @feature("graphs.getCommunity") async getCommunity(options: { - id: string; + collectionId: string; communityId: string; }): Promise { return this.client.makeRequest( "GET", - `graphs/${options.id}/communities/${options.communityId}`, + `graphs/${options.collectionId}/communities/${options.communityId}`, + ); + } + + /** + * Updates an existing community in the graph. + * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param communityId Community ID to update + * @param entity Entity to update + * @returns WrappedEntityResponse + */ + @feature("graphs.updateCommunity") + async updateCommunity(options: { + collectionId: string; + communityId: string; + name?: string; + summary?: string; + findings?: string[]; + rating?: number; + ratingExplanation?: string; + level?: number; + attributes?: Record; + }): Promise { + const data = { + ...(options.name && { name: options.name }), + ...(options.summary && { summary: options.summary }), + ...(options.findings && { findings: options.findings }), + ...(options.rating && { rating: options.rating }), + ...(options.ratingExplanation && { + rating_explanation: options.ratingExplanation, + }), + ...(options.level && { level: options.level }), + ...(options.attributes && { attributes: options.attributes }), + }; + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/entities/${options.communityId}`, + { + data, + }, + ); + } + + /** + * Adds documents to a graph by copying their entities and relationships. + * + * This endpoint: + * 1. Copies document entities to the graph_entity table + * 2. Copies document relationships to the graph_relationship table + * 3. Associates the documents with the graph + * + * When a document is added: + * - Its entities and relationships are copied to graph-specific tables + * - Existing entities/relationships are updated by merging their properties + * - The document ID is recorded in the graph's document_ids array + * + * Documents added to a graph will contribute their knowledge to: + * - Graph analysis and querying + * - Community detection + * - Knowledge graph enrichment + * + * The user must have access to both the graph and the documents being added. + * @param options + * @returns + */ + @feature("graphs.pull") + async pull(options: { + collectionId: string; + }): Promise { + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/pull`, ); } } diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 9c34fe1be..8d674955b 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -3,7 +3,7 @@ from typing import Optional from uuid import UUID -from fastapi import Body, Depends, Path, Query, Request +from fastapi import Body, Depends, Path, Query from core.base import R2RException, RunType from core.base.abstractions import ( @@ -41,18 +41,6 @@ logger = logging.getLogger() -from enum import Enum - - -class GraphObjectType(str, Enum): - ENTITIES = "entities" - RELATIONSHIPS = "relationships" - COLLECTIONS = "collections" - DOCUMENTS = "documents" - - def __str__(self): - return self.value - class GraphRouter(BaseRouterV3): def __init__( @@ -66,15 +54,6 @@ def __init__( ): super().__init__(providers, services, orchestration_provider, run_type) - def _get_path_level(self, request: Request) -> DataLevel: - path = request.url.path - if "/chunks/" in path: - return DataLevel.CHUNK - elif "/documents/" in path: - return DataLevel.DOCUMENT - else: - return DataLevel.GRAPH - async def _deduplicate_entities( self, collection_id: UUID, @@ -1049,7 +1028,6 @@ async def create_communities( ) @self.base_endpoint async def get_communities( - request: Request, collection_id: UUID = Path( ..., description="The collection ID corresponding to the graph to get communities for.", @@ -1066,7 +1044,7 @@ async def get_communities( description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): + ) -> WrappedCommunitiesResponse: """ Lists all communities in the graph with pagination support. @@ -1086,7 +1064,7 @@ async def get_communities( auth_user=auth_user, ) - return communities, { + return communities, { # type: ignore "total_entries": count, } @@ -1113,7 +1091,6 @@ async def get_communities( ) @self.base_endpoint async def get_community( - request: Request, collection_id: UUID = Path( ..., description="The ID of the collection to get communities for.", @@ -1123,7 +1100,7 @@ async def get_community( description="The ID of the community to get.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): + ) -> WrappedCommunityResponse: """ Retrieves a specific community by its ID. @@ -1150,7 +1127,6 @@ async def get_community( ) @self.base_endpoint async def delete_community( - request: Request, collection_id: UUID = Path( ..., description="The collection ID corresponding to the graph to delete the community from.", @@ -1212,7 +1188,7 @@ async def update_community( level: Optional[int] = Body(None), attributes: Optional[dict] = Body(None), auth_user=Depends(self.providers.auth.auth_wrapper), - ): + ) -> WrappedCommunityResponse: """ Updates an existing community's metadata and properties. """ diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 51b1ffc39..b7ddd3096 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -1120,7 +1120,6 @@ async def delete( # return [row["id"] for row in results] - class PostgresCommunityHandler(CommunityHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: From 81e0e301af59a116fcea98508a0564ab754101e7 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 27 Nov 2024 08:42:44 -0600 Subject: [PATCH 09/28] More tests, examples --- .../GraphsIntegrationSuperUser.test.ts | 179 ++++++-- js/sdk/examples/data/raskolnikov_2.txt | 7 + js/sdk/src/v3/clients/graphs.ts | 101 ++++- py/core/main/api/v3/graph_router.py | 416 +++++++++++++++++- 4 files changed, 628 insertions(+), 75 deletions(-) create mode 100644 js/sdk/examples/data/raskolnikov_2.txt diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index 31cd60d74..3377d82f4 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -5,8 +5,8 @@ const baseUrl = "http://localhost:7272"; describe("r2rClient V3 Collections Integration Tests", () => { let client: r2rClient; - let graph1Id: string; - let graph2Id: string; + let documentId: string; + let collectionId: string; beforeAll(async () => { client = new r2rClient(baseUrl); @@ -16,75 +16,168 @@ describe("r2rClient V3 Collections Integration Tests", () => { }); }); - test("Create a graph with only a name", async () => { - const response = await client.graphs.create({ - name: "Graph 1", + test("Create document with file path", async () => { + const response = await client.documents.create({ + file: { + path: "examples/data/raskolnikov_2.txt", + name: "raskolnikov_2.txt", + }, + metadata: { title: "raskolnikov_2.txt" }, }); + + expect(response.results.document_id).toBeDefined(); + documentId = response.results.document_id; + }, 10000); + + test("Create new collection", async () => { + const response = await client.collections.create({ + name: "Raskolnikov Collection", + }); + expect(response).toBeTruthy(); + collectionId = response.results.id; + }); + + test("Retrieve collection", async () => { + const response = await client.collections.retrieve({ + id: collectionId, + }); + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(collectionId); + expect(response.results.name).toBe("Raskolnikov Collection"); + }); + + test("Update graph", async () => { + const response = await client.graphs.update({ + collectionId: collectionId, + name: "Raskolnikov Graph", + }); + expect(response.results).toBeDefined(); - graph1Id = response.results.id; - expect(graph1Id).toEqual(response.results.id); - expect(response.results.name).toEqual("Graph 1"); - expect(response.results.description).toBe(""); }); - test("Create a graph with name and description", async () => { - const response = await client.graphs.create({ - name: "2", - description: "Graph 2", + test("Retrieve graph and ensure that update was successful", async () => { + const response = await client.graphs.retrieve({ + collectionId: collectionId, }); - graph2Id = response.results.id; + expect(response.results).toBeDefined(); - expect(response.results.name).toEqual("2"); - expect(response.results.description).toEqual("Graph 2"); + expect(response.results.name).toBe("Raskolnikov Graph"); + expect(response.results.updated_at).not.toBe(response.results.created_at); }); - test("Ensure that there are two graphs", async () => { - const response = await client.graphs.list(); + test("List graphs", async () => { + const response = await client.graphs.list({}); + + expect(response.results).toBeDefined(); + }); + + test("Check that there are no entities in the graph", async () => { + const response = await client.graphs.listEntities({ + collectionId: collectionId, + }); + expect(response.results).toBeDefined(); - expect(response.results.length).toEqual(2); + expect(response.results.entries).toHaveLength(0); }); - test("Retrieve graph 1", async () => { - const response = await client.graphs.retrieve({ id: graph1Id }); + test("Check that there are no relationships in the graph", async () => { + const response = await client.graphs.listRelationships({ + collectionId: collectionId, + }); + expect(response.results).toBeDefined(); - expect(response.results.name).toEqual("Graph 1"); - expect(response.results.description).toBe(""); + expect(response.results.entries).toHaveLength; }); - test("Retrieve graph 2", async () => { - const response = await client.graphs.retrieve({ id: graph2Id }); + test("Extract entities from the document", async () => { + const response = await client.documents.extract({ + id: documentId, + }); + + await new Promise((resolve) => setTimeout(resolve, 10000)); + + expect(response.results).toBeDefined(); + }, 30000); + + test("Assign document to collection", async () => { + const response = await client.collections.addDocument({ + id: collectionId, + documentId: documentId, + }); expect(response.results).toBeDefined(); - expect(response.results.name).toEqual("2"); - expect(response.results.description).toEqual("Graph 2"); }); - test("Update the name of graph 1", async () => { - const response = await client.graphs.update({ - id: graph1Id, - name: "Graph 1 Updated", + test("Pull entities into the graph", async () => { + const response = await client.graphs.pull({ + collectionId: collectionId, }); + console.log("Pull entities into the graph", response.results); expect(response.results).toBeDefined(); - expect(response.results.name).toEqual("Graph 1 Updated"); }); - test("Update the description graph 2", async () => { - const response = await client.graphs.update({ - id: graph2Id, - description: "Graph 2 Updated", + test("Check that there are entities in the graph", async () => { + const response = await client.graphs.listEntities({ + collectionId: collectionId, }); expect(response.results).toBeDefined(); - expect(response.results.description).toEqual("Graph 2 Updated"); + expect(response.total_entries).toBeGreaterThanOrEqual(1); }); - test("Delete graph 1", async () => { - const response = await client.graphs.delete({ id: graph1Id }); + test("Check that there are relationships in the graph", async () => { + const response = await client.graphs.listRelationships({ + collectionId: collectionId, + }); + expect(response.results).toBeDefined(); + }); + + test("Reset the graph", async () => { + const response = await client.graphs.reset({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + }); + + test("Check that there are no entities in the graph", async () => { + const response = await client.graphs.listEntities({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + + test("Check that there are no relationships in the graph", async () => { + const response = await client.graphs.listRelationships({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + + test("Delete raskolnikov_2.txt", async () => { + const response = await client.documents.delete({ + id: documentId, + }); + expect(response.results).toBeDefined(); - expect(response.results.success).toBe(true); }); - test("Delete graph 2", async () => { - const response = await client.graphs.delete({ id: graph2Id }); + test("Check that the document is not in the collection", async () => { + const response = await client.collections.listDocuments({ + id: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + + test("Delete Raskolnikov Collection", async () => { + const response = await client.collections.delete({ + id: collectionId, + }); + expect(response.results).toBeDefined(); - expect(response.results.success).toBe(true); }); }); diff --git a/js/sdk/examples/data/raskolnikov_2.txt b/js/sdk/examples/data/raskolnikov_2.txt new file mode 100644 index 000000000..e82fe6b08 --- /dev/null +++ b/js/sdk/examples/data/raskolnikov_2.txt @@ -0,0 +1,7 @@ +When Raskolnikov got home, his hair was soaked with sweat and he was +breathing heavily. He went rapidly up the stairs, walked into his +unlocked room and at once fastened the latch. Then in senseless terror +he rushed to the corner, to that hole under the paper where he had put +the things; put his hand in, and for some minutes felt carefully in the +hole, in every crack and fold of the paper. Finding nothing, he got up +and drew a deep breath. \ No newline at end of file diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index 669ec162c..26e81ebde 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -19,7 +19,7 @@ export class GraphsClient { /** * List graphs with pagination and filtering options. - * @param ids Optional list of graph IDs to filter by + * @param collectionIds Optional list of collection IDs to filter by * @param offset Optional offset for pagination * @param limit Optional limit for pagination * @returns @@ -36,7 +36,7 @@ export class GraphsClient { }; if (options?.collectionIds && options.collectionIds.length > 0) { - params.ids = options.collectionIds; + params.collectionIds = options.collectionIds; } return this.client.makeRequest("GET", "graphs", { @@ -62,9 +62,7 @@ export class GraphsClient { * This endpoint permanently removes the specified graph along with all * entities and relationships that belong to only this graph. * - * Entities and relationships extracted from documents are not deleted - * and must be deleted separately using the /entities and /relationships - * endpoints. + * Entities and relationships extracted from documents are not deleted. * @param collectionId The collection ID of the graph to delete * @returns */ @@ -72,19 +70,19 @@ export class GraphsClient { async reset(options: { collectionId: string; }): Promise { - return this.client.makeRequest("DELETE", `graphs/${options.collectionId}`); + return this.client.makeRequest("POST", `graphs/${options.collectionId}/reset`); } /** * Update an existing graph. - * @param id Graph ID to update + * @param collectionId The collection ID corresponding to the graph to update. * @param name Optional new name for the graph * @param description Optional new description for the graph * @returns */ @feature("graphs.update") async update(options: { - id: string; + collectionId: string; name?: string; description?: string; }): Promise { @@ -93,7 +91,7 @@ export class GraphsClient { ...(options.description && { description: options.description }), }; - return this.client.makeRequest("POST", `graphs/${options.id}`, { + return this.client.makeRequest("POST", `graphs/${options.collectionId}`, { data, }); } @@ -303,6 +301,51 @@ export class GraphsClient { ); } + /** + * Creates communities in the graph by analyzing entity relationships and similarities. + * + * Communities are created through the following process: + * 1. Analyzes entity relationships and metadata to build a similarity graph + * 2. Applies advanced community detection algorithms (e.g. Leiden) to identify densely connected groups + * 3. Creates hierarchical community structure with multiple granularity levels + * 4. Generates natural language summaries and statistical insights for each community + * + * The resulting communities can be used to: + * - Understand high-level graph structure and organization + * - Identify key entity groupings and their relationships + * - Navigate and explore the graph at different levels of detail + * - Generate insights about entity clusters and their characteristics + * + * The community detection process is configurable through settings like: + * - Community detection algorithm parameters + * - Summary generation prompt + * @param collectionId The collection ID of the graph to create communities for + * @returns + */ + @feature("communities.build") + async build(options: { + collection_id: string; + settings?: Record; + runType?: string; + runWithOrchestration?: boolean; + }): Promise { + const data = { + ...(options.settings && { settings: options.settings }), + ...(options.runType && { run_type: options.runType }), + ...(options.runWithOrchestration && { + run_with_orchestration: options.runWithOrchestration, + }), + }; + + return this.client.makeRequest( + "POST", + `graphs/${options.collection_id}/communities/build`, + { + data, + }, + ); + } + /** * List all communities in a graph. * @param collectionId Collection ID @@ -386,6 +429,23 @@ export class GraphsClient { ); } + /** + * Delete a community in a graph. + * @param collectionId The collection ID corresponding to the graph. + * @param communityId Community ID to delete + * @returns + */ + @feature("graphs.deleteCommunity") + async deleteCommunity(options: { + collectionId: string; + communityId: string; + }): Promise { + return this.client.makeRequest( + "DELETE", + `graphs/${options.collectionId}/communities/${options.communityId}`, + ); + } + /** * Adds documents to a graph by copying their entities and relationships. * @@ -417,4 +477,27 @@ export class GraphsClient { `graphs/${options.collectionId}/pull`, ); } + + /** + * Removes a document from a graph and removes any associated entities + * + * This endpoint: + * 1. Removes the document ID from the graph's document_ids array + * 2. Optionally deletes the document's copied entities and relationships + * + * The user must have access to both the graph and the document being removed. + * @param collectionId The collection ID of the graph to remove the document from + * @param documentId The document ID to remove + * @returns + */ + @feature("graphs.removeDocument") + async removeDocument(options: { + collectionId: string; + documentId: string; + }): Promise { + return this.client.makeRequest( + "DELETE", + `graphs/${options.collectionId}/documents/${options.documentId}`, + ); + } } diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 8d674955b..cbedaba9d 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -253,12 +253,7 @@ def _setup_routes(self): client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.create( - graph={ - "name": "New Graph", - "description": "New Description" - } - ) + result = client.graphs.list() """ ), }, @@ -271,7 +266,7 @@ def _setup_routes(self): const client = new r2rClient("http://localhost:7272"); function main() { - const response = await client.graphs.list(); + const response = await client.graphs.list({}); } main(); @@ -355,7 +350,7 @@ async def list_graphs( function main() { const response = await client.graphs.retrieve({ - collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" }); } @@ -428,7 +423,7 @@ async def get_graph( function main() { const response = await client.graphs.reset({ - collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" }); } @@ -555,7 +550,44 @@ async def update_graph( description=description, ) - @self.router.get("/graphs/{collection_id}/entities") + @self.router.get( + "/graphs/{collection_id}/entities", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.get_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7") + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.get_entities({ + collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + }); + } + + main(); + """ + ), + }, + ], + }, + ) @self.base_endpoint async def get_entities( collection_id: UUID = Path( @@ -665,7 +697,48 @@ async def create_relationship( return relationship - @self.router.get("/graphs/{collection_id}/entities/{entity_id}") + @self.router.get( + "/graphs/{collection_id}/entities/{entity_id}", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.get_entity( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.get_entity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ] + }, + ) @self.base_endpoint async def get_entity( collection_id: UUID = Path( @@ -714,7 +787,49 @@ async def update_entity( print("results = ", results) return entity - @self.router.delete("/graphs/{collection_id}/entities/{entity_id}") + @self.router.delete( + "/graphs/{collection_id}/entities/{entity_id}", + summary="Remove an entity", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.remove_entity( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.removeEntity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ] + }, + ) @self.base_endpoint async def delete_entity( collection_id: UUID = Path( @@ -733,7 +848,45 @@ async def delete_entity( ) return GenericBooleanResponse(success=True) # type: ignore - @self.router.get("/graphs/{collection_id}/relationships") + @self.router.get( + "/graphs/{collection_id}/relationships", + description="Lists all relationships in the graph with pagination support.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.list_relationships(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7") + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.listRelationships({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + }); + } + + main(); + """ + ), + }, + ], + }, + ) @self.base_endpoint async def get_relationships( collection_id: UUID = Path( @@ -798,7 +951,47 @@ async def create_relationship( ) @self.router.get( - "/graphs/{collection_id}/relationships/{relationship_id}" + "/graphs/{collection_id}/relationships/{relationship_id}", + description="Retrieves a specific relationship by its ID.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.get_relationship( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.getRelationship({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ], + }, ) @self.base_endpoint async def get_relationship( @@ -846,7 +1039,47 @@ async def update_relationship( ) @self.router.delete( - "/graphs/{collection_id}/relationships/{relationship_id}" + "/graphs/{collection_id}/relationships/{relationship_id}", + description="Removes a relationship from the graph.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.delete_relationship( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.deleteRelationship({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ], + }, ) @self.base_endpoint async def delete_relationship( @@ -889,8 +1122,25 @@ async def delete_relationship( """ ), }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.communities.build({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + }); + } + + main(); + """ + ), + }, ], - "operationId": "graphs_build_communities_v3_graphs__id__communities_build_graphs", }, ) @self.base_endpoint @@ -903,7 +1153,8 @@ async def create_communities( run_with_orchestration: bool = Query(True), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedKGEnrichmentResponse: - """Creates communities in the graph by analyzing entity relationships and similarities. + """ + Creates communities in the graph by analyzing entity relationships and similarities. Communities are created through the following process: 1. Analyzes entity relationships and metadata to build a similarity graph @@ -918,8 +1169,8 @@ async def create_communities( - Generate insights about entity clusters and their characteristics The community detection process is configurable through settings like: - - Community detection algorithm parameters - - Summary generation prompt + - Community detection algorithm parameters + - Summary generation prompt """ return await self._create_communities( @@ -1019,7 +1270,25 @@ async def create_communities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.communities.get(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.graphs.list_communities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.listCommunities({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + }); + } + + main(); """ ), }, @@ -1082,7 +1351,25 @@ async def get_communities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.communities.get(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.graphs.get_community(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.getCommunity({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + }); + } + + main(); """ ), }, @@ -1124,6 +1411,45 @@ async def get_community( @self.router.delete( "/graphs/{collection_id}/communities/{community_id}", summary="Delete a community", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.delete_community( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + community_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.deleteCommunity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ] + }, ) @self.base_endpoint async def delete_community( @@ -1173,6 +1499,31 @@ async def delete_community( )""" ), }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + async function main() { + const response = await client.graphs.updateCommunity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + communityUpdate: { + metadata: { + topic: "Technology", + description: "Tech companies and products" + } + } + }); + } + + main(); + """ + ), + }, ] }, ) @@ -1224,7 +1575,7 @@ async def update_community( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.initialize( + result = client.graphs.pull( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" )""" ), @@ -1238,7 +1589,7 @@ async def update_community( const client = new r2rClient("http://localhost:7272"); async function main() { - const response = await client.graphs.addDocuments({ + const response = await client.graphs.pull({ collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" }); } @@ -1381,6 +1732,25 @@ async def pull( )""" ), }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + async function main() { + const response = await client.graphs.removeDocument({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + documentId: "f98db41a-5555-4444-3333-222222222222" + }); + } + + main(); + """ + ), + }, ] }, ) From 06eeddb291cb44df9fef69cd4341410d8066460f Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 27 Nov 2024 10:15:20 -0600 Subject: [PATCH 10/28] Sync python --- .../GraphsIntegrationSuperUser.test.ts | 4 +- js/sdk/examples/data/raskolnikov_2.txt | 2 +- js/sdk/src/v3/clients/graphs.ts | 53 +-- py/core/main/api/v3/graph_router.py | 107 ------ py/sdk/v3/graphs.py | 355 ++++++++++-------- 5 files changed, 239 insertions(+), 282 deletions(-) diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index 3377d82f4..3e746ef3c 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -68,8 +68,8 @@ describe("r2rClient V3 Collections Integration Tests", () => { test("List graphs", async () => { const response = await client.graphs.list({}); - expect(response.results).toBeDefined(); - }); + expect(response.results).toBeDefined(); + }); test("Check that there are no entities in the graph", async () => { const response = await client.graphs.listEntities({ diff --git a/js/sdk/examples/data/raskolnikov_2.txt b/js/sdk/examples/data/raskolnikov_2.txt index e82fe6b08..895e99965 100644 --- a/js/sdk/examples/data/raskolnikov_2.txt +++ b/js/sdk/examples/data/raskolnikov_2.txt @@ -4,4 +4,4 @@ unlocked room and at once fastened the latch. Then in senseless terror he rushed to the corner, to that hole under the paper where he had put the things; put his hand in, and for some minutes felt carefully in the hole, in every crack and fold of the paper. Finding nothing, he got up -and drew a deep breath. \ No newline at end of file +and drew a deep breath. diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index 26e81ebde..20b153460 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -46,7 +46,7 @@ export class GraphsClient { /** * Get detailed information about a specific graph. - * @param id Graph ID to retrieve + * @param collectionId The collection ID corresponding to the graph * @returns */ @feature("graphs.retrieve") @@ -63,19 +63,22 @@ export class GraphsClient { * entities and relationships that belong to only this graph. * * Entities and relationships extracted from documents are not deleted. - * @param collectionId The collection ID of the graph to delete + * @param collectionId The collection ID corresponding to the graph * @returns */ @feature("graphs.reset") async reset(options: { collectionId: string; }): Promise { - return this.client.makeRequest("POST", `graphs/${options.collectionId}/reset`); + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/reset`, + ); } /** * Update an existing graph. - * @param collectionId The collection ID corresponding to the graph to update. + * @param collectionId The collection ID corresponding to the graph * @param name Optional new name for the graph * @param description Optional new description for the graph * @returns @@ -98,7 +101,7 @@ export class GraphsClient { /** * Creates a new entity in the graph. - * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param collectionId The collection ID corresponding to the graph * @param entity Entity to add * @returns */ @@ -118,7 +121,7 @@ export class GraphsClient { /** * List all entities in a graph. - * @param collectionId Collection ID + * @param collectionId The collection ID corresponding to the graph * @param offset Specifies the number of objects to skip. Defaults to 0. * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns @@ -145,7 +148,7 @@ export class GraphsClient { /** * Retrieve an entity from a graph. - * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param collectionId The collection ID corresponding to the graph * @param entityId Entity ID to retrieve * @returns */ @@ -162,7 +165,7 @@ export class GraphsClient { /** * Updates an existing entity in the graph. - * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param collectionId The collection ID corresponding to the graph * @param entityId Entity ID to update * @param entity Entity to update * @returns @@ -184,7 +187,7 @@ export class GraphsClient { /** * Remove an entity from a graph. - * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param collectionId The collection ID corresponding to the graph * @param entityId Entity ID to remove * @returns */ @@ -200,7 +203,7 @@ export class GraphsClient { } /** * Creates a new relationship in the graph. - * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param collectionId The collection ID corresponding to the graph * @param relationship Relationship to add * @returns */ @@ -220,7 +223,7 @@ export class GraphsClient { /** * List all relationships in a graph. - * @param collectionId Collection ID + * @param collectionId The collection ID corresponding to the graph * @param offset Specifies the number of objects to skip. Defaults to 0. * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns @@ -247,7 +250,7 @@ export class GraphsClient { /** * Retrieve a relationship from a graph. - * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param collectionId The collection ID corresponding to the graph * @param relationshipId Relationship ID to retrieve * @returns */ @@ -264,7 +267,7 @@ export class GraphsClient { /** * Updates an existing relationship in the graph. - * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param collectionId The collection ID corresponding to the graph * @param relationshipId Relationship ID to update * @param relationship Relationship to update * @returns @@ -286,7 +289,7 @@ export class GraphsClient { /** * Remove a relationship from a graph. - * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param collectionId The collection ID corresponding to the graph * @param relationshipId Entity ID to remove * @returns */ @@ -319,10 +322,10 @@ export class GraphsClient { * The community detection process is configurable through settings like: * - Community detection algorithm parameters * - Summary generation prompt - * @param collectionId The collection ID of the graph to create communities for + * @param collectionId The collection ID corresponding to the graph * @returns */ - @feature("communities.build") + @feature("graphs.build") async build(options: { collection_id: string; settings?: Record; @@ -346,9 +349,11 @@ export class GraphsClient { ); } + // TODO: Create community + /** * List all communities in a graph. - * @param collectionId Collection ID + * @param collectionId The collection ID corresponding to the graph * @param offset Specifies the number of objects to skip. Defaults to 0. * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns @@ -375,7 +380,7 @@ export class GraphsClient { /** * Retrieve a community from a graph. - * @param collectionId The ID of the collection to get communities for. + * @param collectionId The collection ID corresponding to the graph * @param communityId Entity ID to retrieve * @returns */ @@ -392,10 +397,10 @@ export class GraphsClient { /** * Updates an existing community in the graph. - * @param collectionId The collection ID corresponding to the graph to add the entity to. + * @param collectionId The collection ID corresponding to the graph * @param communityId Community ID to update * @param entity Entity to update - * @returns WrappedEntityResponse + * @returns WrappedCommunityResponse */ @feature("graphs.updateCommunity") async updateCommunity(options: { @@ -408,7 +413,7 @@ export class GraphsClient { ratingExplanation?: string; level?: number; attributes?: Record; - }): Promise { + }): Promise { const data = { ...(options.name && { name: options.name }), ...(options.summary && { summary: options.summary }), @@ -431,7 +436,7 @@ export class GraphsClient { /** * Delete a community in a graph. - * @param collectionId The collection ID corresponding to the graph. + * @param collectionId The collection ID corresponding to the graph * @param communityId Community ID to delete * @returns */ @@ -465,7 +470,7 @@ export class GraphsClient { * - Knowledge graph enrichment * * The user must have access to both the graph and the documents being added. - * @param options + * @param collectionId The collection ID corresponding to the graph * @returns */ @feature("graphs.pull") @@ -486,7 +491,7 @@ export class GraphsClient { * 2. Optionally deletes the document's copied entities and relationships * * The user must have access to both the graph and the document being removed. - * @param collectionId The collection ID of the graph to remove the document from + * @param collectionId The collection ID corresponding to the graph * @param documentId The document ID to remove * @returns */ diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 90b22af4d..cd76fa97f 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -7,10 +7,7 @@ from core.base import R2RException, RunType from core.base.abstractions import ( - DataLevel, Entity, - GraphBuildSettings, - KGCreationSettings, KGRunType, Relationship, ) @@ -135,110 +132,6 @@ async def _get_collection_id( return collection_id def _setup_routes(self): - # @self.router.post( - # "/graphs", - # summary="Create a new graph", - # openapi_extra={ - # "x-codeSamples": [ - # { # TODO: Verify - # "lang": "Python", - # "source": textwrap.dedent( - # """ - # from r2r import R2RClient - - # client = R2RClient("http://localhost:7272") - # # when using auth, do client.login(...) - - # result = client.graphs.create( - # graph={ - # "name": "New Graph", - # "description": "New Description" - # } - # ) - # """ - # ), - # }, - # { - # "lang": "JavaScript", - # "source": textwrap.dedent( - # """ - # const { r2rClient } = require("r2r-js"); - - # const client = new r2rClient("http://localhost:7272"); - - # function main() { - # const response = await client.documents.create({ - # name: "New Graph", - # description: "New Description", - # }); - # } - - # main(); - # """ - # ), - # }, - # ] - # }, - # ) - # @self.base_endpoint - # async def create_graph( - # collection_id: Optional[UUID] = Body( - # None, - # description="Collection ID to associate with the graph. If not provided, uses user's default collection.", - # ), - # name: Optional[str] = Body( - # None, description="The name of the graph" - # ), - # description: Optional[str] = Body( - # None, description="An optional description of the graph" - # ), - # auth_user=Depends(self.providers.auth.auth_wrapper), - # ) -> WrappedGraphResponse: - # """ - # Creates a new empty graph. - - # This is the first step in building a knowledge graph. After creating the graph, you can: - - # 1. Add data to the graph: - # - Manually add entities and relationships via the /entities and /relationships endpoints - # - Automatically extract entities and relationships from documents via the /graphs/{id}/documents endpoint - - # 2. Build communities: - # - Build communities of related entities via the /graphs/{collection_id}/communities/build endpoint - - # 3. Update graph metadata: - # - Modify the graph name, description and settings via the /graphs/{collection_id} endpoint - - # The graph ID returned by this endpoint is required for all subsequent operations on the graph. - - # Raises: - # R2RException: If a graph already exists for the given collection. - # """ - - # collection_id = await self._get_collection_id( - # collection_id, auth_user - # ) - - # # Check if a graph already exists for this collection - # existing_graphs = await self.services["kg"].list_graphs( - # collection_id=collection_id, - # offset=0, - # limit=1, - # ) - - # if existing_graphs["total_entries"] > 0: - # raise R2RException( - # f"A graph already exists for collection {collection_id}. Only one graph per collection is allowed.", - # 409, # HTTP 409 Conflict status code - # ) - - # return await self.services["kg"].create_new_graph( - # user_id=auth_user.id, - # collection_id=collection_id, - # name=name, - # description=description, - # ) - @self.router.get( "/graphs", summary="List graphs", diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index 44847bfce..5e1479aac 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -3,7 +3,6 @@ from shared.api.models.base import ( WrappedBooleanResponse, - WrappedGenericMessageResponse, ) from shared.api.models.kg.responses import ( @@ -13,8 +12,8 @@ WrappedEntityResponse, WrappedRelationshipResponse, WrappedRelationshipsResponse, - # WrappedCommunitiesResponse, - # WrappedCommunityResponse, + WrappedCommunitiesResponse, + WrappedCommunityResponse, ) @@ -26,32 +25,9 @@ class GraphsSDK: def __init__(self, client): self.client = client - async def create( - self, - name: str, - description: Optional[str] = None, - ) -> WrappedGraphResponse: - """ - Create a new graph. - - Args: - name (str): Name of the graph - description (Optional[str]): Description of the graph - - Returns: - dict: Created graph information - """ - data = {"name": name, "description": description} - return await self.client._make_request( - "POST", - "graphs", - json=data, - version="v3", - ) - async def list( self, - ids: Optional[list[str | UUID]] = None, + collection_ids: Optional[list[str | UUID]] = None, offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedGraphsResponse: @@ -70,8 +46,8 @@ async def list( "offset": offset, "limit": limit, } - if ids: - params["ids"] = ids + if collection_ids: + params["collection_ids"] = collection_ids return await self.client._make_request( "GET", "graphs", params=params, version="v3" @@ -79,24 +55,46 @@ async def list( async def retrieve( self, - id: str | UUID, + collection_id: str | UUID, ) -> WrappedGraphResponse: """ Get detailed information about a specific graph. Args: - id (str | UUID): Graph ID to retrieve + collection_id (str | UUID): Graph ID to retrieve Returns: dict: Detailed graph information """ return await self.client._make_request( - "GET", f"graphs/{str(id)}", version="v3" + "GET", f"graphs/{str(collection_id)}", version="v3" + ) + + async def reset( + self, + collection_id: str | UUID, + ) -> WrappedBooleanResponse: + """ + Deletes a graph and all its associated data. + + This endpoint permanently removes the specified graph along with all + entities and relationships that belong to only this graph. + + Entities and relationships extracted from documents are not deleted. + + Args: + collection_id (str | UUID): Graph ID to reset + + Returns: + dict: Success message + """ + return await self.client._make_request( + "POST", f"graphs/{str(collection_id)}/reset", version="v3" ) async def update( self, - id: str | UUID, + collection_id: str | UUID, name: Optional[str] = None, description: Optional[str] = None, ) -> WrappedGraphResponse: @@ -104,7 +102,7 @@ async def update( Update graph information. Args: - id (str | UUID): Graph ID to update + collection_id (str | UUID): The collection ID corresponding to the graph name (Optional[str]): Optional new name for the graph description (Optional[str]): Optional new description for the graph @@ -119,60 +117,75 @@ async def update( return await self.client._make_request( "POST", - f"graphs/{str(id)}", + f"graphs/{str(collection_id)}", json=data, version="v3", ) - async def delete( + # TODO: create entity + + async def list_entities( self, - id: str | UUID, - ) -> WrappedBooleanResponse: + collection_id: str | UUID, + offset: Optional[int] = 0, + limit: Optional[int] = 100, + ) -> WrappedEntitiesResponse: """ - Delete a graph. + List entities in a graph. Args: - id (str | UUID): Graph ID to delete + collection_id (str | UUID): Graph ID to list entities from + offset (int, optional): Specifies the number of objects to skip. Defaults to 0. + limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: - bool: True if deletion was successful + dict: List of entities and pagination information """ - result = await self.client._make_request( - "DELETE", f"graphs/{str(id)}", version="v3" + params: dict = { + "offset": offset, + "limit": limit, + } + + return await self.client._make_request( + "GET", + f"graphs/{str(collection_id)}/entities", + params=params, + version="v3", ) - return result.get("results", True) - async def add_entity( + async def get_entity( self, - id: str | UUID, + collection_id: str | UUID, entity_id: str | UUID, - ) -> WrappedGenericMessageResponse: + ) -> WrappedEntityResponse: """ - Add an entity to a graph. + Get entity information in a graph. Args: - id (str | UUID): Graph ID to add entity to - entity_id (str | UUID): Entity ID to add to the graph + collection_id (str | UUID): The collection ID corresponding to the graph + entity_id (str | UUID): Entity ID to get from the graph Returns: - dict: Success message + dict: Entity information """ return await self.client._make_request( - "POST", - f"graphs/{str(id)}/entities/{str(entity_id)}", + "GET", + f"graphs/{str(collection_id)}/entities/{str(entity_id)}", version="v3", ) + # TODO: update entity + async def remove_entity( self, - id: str | UUID, + collection_id: str | UUID, entity_id: str | UUID, ) -> WrappedBooleanResponse: """ Remove an entity from a graph. Args: - id (str | UUID): Graph ID to remove entity from + collection_id (str | UUID): The collection ID corresponding to the graph entity_id (str | UUID): Entity ID to remove from the graph Returns: @@ -180,26 +193,28 @@ async def remove_entity( """ return await self.client._make_request( "DELETE", - f"graphs/{str(id)}/entities/{str(entity_id)}", + f"graphs/{str(collection_id)}/entities/{str(entity_id)}", version="v3", ) - async def list_entities( + # TODO: create relationship + + async def list_relationships( self, - id: str | UUID, + collection_id: str | UUID, offset: Optional[int] = 0, limit: Optional[int] = 100, - ) -> WrappedEntitiesResponse: + ) -> WrappedRelationshipsResponse: """ - List entities in a graph. + List relationships in a graph. Args: - id (str | UUID): Graph ID to list entities from + collection_id (str | UUID): The collection ID corresponding to the graph offset (int, optional): Specifies the number of objects to skip. Defaults to 0. limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: - dict: List of entities and pagination information + dict: List of relationships and pagination information """ params: dict = { "offset": offset, @@ -208,90 +223,105 @@ async def list_entities( return await self.client._make_request( "GET", - f"graphs/{str(id)}/entities", + f"graphs/{str(collection_id)}/relationships", params=params, version="v3", ) - async def get_entity( + async def get_relationship( self, - id: str | UUID, - entity_id: str | UUID, - ) -> WrappedEntityResponse: + collection_id: str | UUID, + relationship_id: str | UUID, + ) -> WrappedRelationshipResponse: """ - Get entity information in a graph. + Get relationship information in a graph. Args: - id (str | UUID): Graph ID to get entity from - entity_id (str | UUID): Entity ID to get from the graph + collection_id (str | UUID): The collection ID corresponding to the graph + relationship_id (str | UUID): Relationship ID to get from the graph Returns: - dict: Entity information + dict: Relationship information """ return await self.client._make_request( "GET", - f"graphs/{str(id)}/entities/{str(entity_id)}", + f"graphs/{str(collection_id)}/relationships/{str(relationship_id)}", version="v3", ) - async def add_relationship( + # TODO: update relationship + + async def remove_relationship( self, - id: str | UUID, + collection_id: str | UUID, relationship_id: str | UUID, - ) -> WrappedGenericMessageResponse: + ) -> WrappedBooleanResponse: """ - Add a relationship to a graph. + Remove a relationship from a graph. Args: - id (str | UUID): Graph ID to add relationship to - relationship_id (str | UUID): Relationship ID to add to the graph + collection_id (str | UUID): The collection ID corresponding to the graph + relationship_id (str | UUID): Relationship ID to remove from the graph Returns: dict: Success message """ return await self.client._make_request( - "POST", - f"graphs/{str(id)}/relationships/{str(relationship_id)}", + "DELETE", + f"graphs/{str(collection_id)}/relationships/{str(relationship_id)}", version="v3", ) - async def remove_relationship( + async def build( self, - id: str | UUID, - relationship_id: str | UUID, + collection_id: str | UUID, + settings: dict, + run_type: str = "estimate", + run_with_orchestration: bool = True, ) -> WrappedBooleanResponse: """ - Remove a relationship from a graph. + Build a graph. Args: - id (str | UUID): Graph ID to remove relationship from - relationship_id (str | UUID): Relationship ID to remove from the graph + collection_id (str | UUID): The collection ID corresponding to the graph + settings (dict): Settings for the build + run_type (str, optional): Type of build to run. Defaults to "estimate". + run_with_orchestration (bool, optional): Whether to run with orchestration. Defaults to True. Returns: dict: Success message """ + data = { + "settings": settings, + "run_type": run_type, + "run_with_orchestration": run_with_orchestration, + } + return await self.client._make_request( - "DELETE", - f"graphs/{str(id)}/relationships/{str(relationship_id)}", + "POST", + f"graphs/{str(collection_id)}/build", + json=data, version="v3", ) - async def list_relationships( + # TODO: create community + + async def list_communities( self, - id: str | UUID, + collection_id: str | UUID, offset: Optional[int] = 0, limit: Optional[int] = 100, - ) -> WrappedRelationshipsResponse: + ) -> WrappedCommunitiesResponse: """ - List relationships in a graph. + List communities in a graph. Args: - id (str | UUID): Graph ID to list relationships from + collection_id (str | UUID): The collection ID corresponding to the graph offset (int, optional): Specifies the number of objects to skip. Defaults to 0. limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: - dict: List of relationships and pagination information + dict: List of communities and pagination information """ params: dict = { "offset": offset, @@ -300,63 +330,94 @@ async def list_relationships( return await self.client._make_request( "GET", - f"graphs/{str(id)}/relationships", + f"graphs/{str(collection_id)}/communities", params=params, version="v3", ) - async def get_relationship( + async def get_community( self, - id: str | UUID, - relationship_id: str | UUID, - ) -> WrappedRelationshipResponse: + collection_id: str | UUID, + community_id: str | UUID, + ) -> WrappedCommunityResponse: """ - Get relationship information in a graph. + Get community information in a graph. Args: - id (str | UUID): Graph ID to get relationship from - relationship_id (str | UUID): Relationship ID to get from the graph + collection_id (str | UUID): The collection ID corresponding to the graph + community_id (str | UUID): Community ID to get from the graph Returns: - dict: Relationship information + dict: Community information """ return await self.client._make_request( "GET", - f"graphs/{str(id)}/relationships/{str(relationship_id)}", + f"graphs/{str(collection_id)}/communities/{str(community_id)}", version="v3", ) - async def add_community( + async def update_community( self, - id: str | UUID, + collection_id: str | UUID, community_id: str | UUID, - ) -> WrappedGenericMessageResponse: + name: Optional[str] = None, + summary: Optional[str] = None, + findings: Optional[list[str]] = None, + rating: Optional[int] = None, + rating_explanation: Optional[str] = None, + level: Optional[int] = None, + attributes: Optional[dict] = None, + ) -> WrappedCommunityResponse: """ - Add a community to a graph. + Update community information. Args: - id (str | UUID): Graph ID to add community to - community_id (str | UUID): Community ID to add to the graph + collection_id (str | UUID): The collection ID corresponding to the graph + community_id (str | UUID): Community ID to update + name (Optional[str]): Optional new name for the community + summary (Optional[str]): Optional new summary for the community + findings (Optional[list[str]]): Optional new findings for the community + rating (Optional[int]): Optional new rating for the community + rating_explanation (Optional[str]): Optional new rating explanation for the community + level (Optional[int]): Optional new level for the community + attributes (Optional[dict]): Optional new attributes for the community Returns: - dict: Success message + dict: Updated community information """ + data = {} + if name is not None: + data["name"] = name + if summary is not None: + data["summary"] = summary + if findings is not None: + data["findings"] = findings + if rating is not None: + data["rating"] = rating + if rating_explanation is not None: + data["rating_explanation"] = rating_explanation + if level is not None: + data["level"] = level + if attributes is not None: + data["attributes"] = attributes + return await self.client._make_request( "POST", - f"graphs/{str(id)}/communities/{str(community_id)}", + f"graphs/{str(collection_id)}/communities/{str(community_id)}", + json=data, version="v3", ) - async def remove_community( + async def delete_community( self, - id: str | UUID, + collection_id: str | UUID, community_id: str | UUID, ) -> WrappedBooleanResponse: """ Remove a community from a graph. Args: - id (str | UUID): Graph ID to remove community from + collection_id (str | UUID): The collection ID corresponding to the graph community_id (str | UUID): Community ID to remove from the graph Returns: @@ -364,56 +425,54 @@ async def remove_community( """ return await self.client._make_request( "DELETE", - f"graphs/{str(id)}/communities/{str(community_id)}", + f"graphs/{str(collection_id)}/communities/{str(community_id)}", version="v3", ) - async def list_communities( + async def pull( self, - id: str | UUID, - offset: Optional[int] = 0, - limit: Optional[int] = 100, - ): # -> WrappedCommunitiesResponse + collection_id: str | UUID, + ) -> WrappedBooleanResponse: """ - List communities in a graph. + Adds documents to a graph by copying their entities and relationships. - Args: - id (str | UUID): Graph ID to list communities from - offset (int, optional): Specifies the number of objects to skip. Defaults to 0. - limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + This endpoint: + 1. Copies document entities to the graph_entity table + 2. Copies document relationships to the graph_relationship table + 3. Associates the documents with the graph - Returns: - dict: List of communities and pagination information - """ - params: dict = { - "offset": offset, - "limit": limit, - } + When a document is added: + - Its entities and relationships are copied to graph-specific tables + - Existing entities/relationships are updated by merging their properties + - The document ID is recorded in the graph's document_ids array + Documents added to a graph will contribute their knowledge to: + - Graph analysis and querying + - Community detection + - Knowledge graph enrichment + """ return await self.client._make_request( - "GET", - f"graphs/{str(id)}/communities", - params=params, + "POST", + f"graphs/{str(collection_id)}/pull", version="v3", ) - async def get_community( + async def remove_document( self, - id: str | UUID, - community_id: str | UUID, - ): # -> WrappedCommunityResponse + collection_id: str | UUID, + document_id: str | UUID, + ) -> WrappedBooleanResponse: """ - Get community information in a graph. + Removes a document from a graph and removes any associated entities - Args: - id (str | UUID): Graph ID to get community from - community_id (str | UUID): Community ID to get from the graph + This endpoint: + 1. Removes the document ID from the graph's document_ids array + 2. Optionally deletes the document's copied entities and relationships - Returns: - dict: Community information + The user must have access to both the graph and the document being removed. """ return await self.client._make_request( - "GET", - f"graphs/{str(id)}/communities/{str(community_id)}", + "DELETE", + f"graphs/{str(collection_id)}/documents/{str(document_id)}", version="v3", ) From 569b81685f4ee97199fadd60887425e115f7f9c6 Mon Sep 17 00:00:00 2001 From: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 27 Nov 2024 16:13:18 -0600 Subject: [PATCH 11/28] Expose Entity/Relationship Params in Routes (#1640) * Expose Entity/Relationship Params * Descriptions * Modify create entities * Create relationships * set parent_id * Update entitiy * Update Relationships --- .../GraphsIntegrationSuperUser.test.ts | 173 ++++++- js/sdk/src/v3/clients/graphs.ts | 76 +++- py/core/base/providers/database.py | 8 +- py/core/main/api/v3/graph_router.py | 189 +++++--- .../main/orchestration/simple/kg_workflow.py | 1 - py/core/main/services/kg_service.py | 149 ++++-- py/core/pipes/kg/extraction.py | 4 +- py/core/pipes/kg/storage.py | 2 +- py/core/providers/database/document.py | 4 +- py/core/providers/database/graph.py | 423 +++++++++++------- py/sdk/v3/graphs.py | 6 +- py/shared/abstractions/graph.py | 19 +- 12 files changed, 736 insertions(+), 318 deletions(-) diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index 3e746ef3c..d7cf63d35 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -7,6 +7,9 @@ describe("r2rClient V3 Collections Integration Tests", () => { let client: r2rClient; let documentId: string; let collectionId: string; + let entity1Id: string; + let entity2Id: string; + let relationshipId: string; beforeAll(async () => { client = new r2rClient(baseUrl); @@ -94,10 +97,10 @@ describe("r2rClient V3 Collections Integration Tests", () => { id: documentId, }); - await new Promise((resolve) => setTimeout(resolve, 10000)); + await new Promise((resolve) => setTimeout(resolve, 30000)); expect(response.results).toBeDefined(); - }, 30000); + }, 60000); test("Assign document to collection", async () => { const response = await client.collections.addDocument({ @@ -111,7 +114,6 @@ describe("r2rClient V3 Collections Integration Tests", () => { const response = await client.graphs.pull({ collectionId: collectionId, }); - console.log("Pull entities into the graph", response.results); expect(response.results).toBeDefined(); }); @@ -130,6 +132,171 @@ describe("r2rClient V3 Collections Integration Tests", () => { expect(response.results).toBeDefined(); }); + test("Create a new entity", async () => { + const response = await client.graphs.createEntity({ + collectionId: collectionId, + name: "Razumikhin", + description: "A good friend of Raskolnikov", + category: "Person", + }); + + expect(response.results).toBeDefined(); + entity1Id = response.results.id; + }); + + test("Create another new entity", async () => { + const response = await client.graphs.createEntity({ + collectionId: collectionId, + name: "Dunia", + description: "The sister of Raskolnikov", + category: "Person", + }); + + expect(response.results).toBeDefined(); + entity2Id = response.results.id; + }); + + test("Retrieve the entity", async () => { + const response = await client.graphs.getEntity({ + collectionId: collectionId, + entityId: entity1Id, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(entity1Id); + expect(response.results.name).toBe("Razumikhin"); + expect(response.results.description).toBe("A good friend of Raskolnikov"); + }); + + test("Retrieve the other entity", async () => { + const response = await client.graphs.getEntity({ + collectionId: collectionId, + entityId: entity2Id, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(entity2Id); + expect(response.results.name).toBe("Dunia"); + expect(response.results.description).toBe("The sister of Raskolnikov"); + }); + + test("Check that the entities are in the graph", async () => { + const response = await client.graphs.listEntities({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.map((entity) => entity.id)).toContain(entity1Id); + expect(response.results.map((entity) => entity.id)).toContain(entity2Id); + }); + + test("Create a relationship between the entities", async () => { + const response = await client.graphs.createRelationship({ + collectionId: collectionId, + subject: "Razumikhin", + subjectId: entity1Id, + predicate: "falls in love with", + object: "Dunia", + objectId: entity2Id, + }); + + relationshipId = response.results.id; + + expect(response.results).toBeDefined(); + expect(response.results.subject).toBe("Razumikhin"); + expect(response.results.object).toBe("Dunia"); + expect(response.results.predicate).toBe("falls in love with"); + }); + + test("Retrieve the relationship", async () => { + const response = await client.graphs.getRelationship({ + collectionId: collectionId, + relationshipId: relationshipId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(relationshipId); + expect(response.results.subject).toBe("Razumikhin"); + expect(response.results.object).toBe("Dunia"); + expect(response.results.predicate).toBe("falls in love with"); + }); + + test("Update the entity", async () => { + const response = await client.graphs.updateEntity({ + collectionId: collectionId, + entityId: entity1Id, + name: "Dmitri Prokofich Razumikhin", + description: "A good friend of Raskolnikov and Dunia", + category: "Person", + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(entity1Id); + expect(response.results.name).toBe("Dmitri Prokofich Razumikhin"); + expect(response.results.description).toBe( + "A good friend of Raskolnikov and Dunia", + ); + }); + + test("Retrieve the updated entity", async () => { + const response = await client.graphs.getEntity({ + collectionId: collectionId, + entityId: entity1Id, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(entity1Id); + expect(response.results.name).toBe("Dmitri Prokofich Razumikhin"); + expect(response.results.description).toBe( + "A good friend of Raskolnikov and Dunia", + ); + }); + + // This test is failing because we attach a separate name to the relationship, rather + // than use the names of the entities. This needs to be fixed in the backend. + // test("Ensure that the entity was updated in the relationship", async () => { + // const response = await client.graphs.getRelationship({ + // collectionId: collectionId, + // relationshipId: relationshipId, + // }); + + // expect(response.results).toBeDefined(); + // expect(response.results.subject).toBe("Dmitri Prokofich Razumikhin"); + // expect(response.results.object).toBe("Dunia"); + // expect(response.results.predicate).toBe("falls in love with"); + // }); + + test("Update the relationship", async () => { + const response = await client.graphs.updateRelationship({ + collectionId: collectionId, + relationshipId: relationshipId, + subject: "Razumikhin", + subjectId: entity1Id, + predicate: "marries", + object: "Dunia", + objectId: entity2Id, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(relationshipId); + expect(response.results.subject).toBe("Razumikhin"); + expect(response.results.object).toBe("Dunia"); + expect(response.results.predicate).toBe("marries"); + }); + + test("Retrieve the updated relationship", async () => { + const response = await client.graphs.getRelationship({ + collectionId: collectionId, + relationshipId: relationshipId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(relationshipId); + expect(response.results.subject).toBe("Razumikhin"); + expect(response.results.object).toBe("Dunia"); + expect(response.results.predicate).toBe("marries"); + }); + test("Reset the graph", async () => { const response = await client.graphs.reset({ collectionId: collectionId, diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index 20b153460..2189cb10c 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -108,13 +108,23 @@ export class GraphsClient { @feature("graphs.createEntity") async createEntity(options: { collectionId: string; - entity: Entity; + name: string; + description?: string; + category?: string; + metadata?: Record; }): Promise { + const data = { + name: options.name, + ...(options.description && { description: options.description }), + ...(options.category && { category: options.category }), + ...(options.metadata && { metadata: options.metadata }), + }; + return this.client.makeRequest( "POST", `graphs/${options.collectionId}/entities`, { - data: options.entity, + data, }, ); } @@ -174,13 +184,23 @@ export class GraphsClient { async updateEntity(options: { collectionId: string; entityId: string; - entity: Entity; + name?: string; + description?: string; + category?: string; + metadata?: Record; }): Promise { + const data = { + ...(options.name && { name: options.name }), + ...(options.description && { description: options.description }), + ...(options.category && { category: options.category }), + ...(options.metadata && { metadata: options.metadata }), + }; + return this.client.makeRequest( "POST", `graphs/${options.collectionId}/entities/${options.entityId}`, { - data: options.entity, + data, }, ); } @@ -210,13 +230,31 @@ export class GraphsClient { @feature("graphs.createRelationship") async createRelationship(options: { collectionId: string; - relationship: Relationship; + subject: string; + subjectId: string; + predicate: string; + object: string; + objectId: string; + description?: string; + weight?: number; + metadata?: Record; }): Promise { + const data = { + subject: options.subject, + subject_id: options.subjectId, + predicate: options.predicate, + object: options.object, + object_id: options.objectId, + ...(options.description && { description: options.description }), + ...(options.weight && { weight: options.weight }), + ...(options.metadata && { metadata: options.metadata }), + }; + return this.client.makeRequest( "POST", `graphs/${options.collectionId}/relationships`, { - data: options.relationship, + data, }, ); } @@ -261,7 +299,7 @@ export class GraphsClient { }): Promise { return this.client.makeRequest( "GET", - `graphs/${options.collectionId}/entities/${options.relationshipId}`, + `graphs/${options.collectionId}/relationships/${options.relationshipId}`, ); } @@ -270,19 +308,37 @@ export class GraphsClient { * @param collectionId The collection ID corresponding to the graph * @param relationshipId Relationship ID to update * @param relationship Relationship to update - * @returns + * @returns WrappedRelationshipResponse */ @feature("graphs.updateRelationship") async updateRelationship(options: { collectionId: string; relationshipId: string; - relationship: Relationship; + subject?: string; + subjectId?: string; + predicate?: string; + object?: string; + objectId?: string; + description?: string; + weight?: number; + metadata?: Record; }): Promise { + const data = { + ...(options.subject && { subject: options.subject }), + ...(options.subjectId && { subject_id: options.subjectId }), + ...(options.predicate && { predicate: options.predicate }), + ...(options.object && { object: options.object }), + ...(options.objectId && { object_id: options.objectId }), + ...(options.description && { description: options.description }), + ...(options.weight && { weight: options.weight }), + ...(options.metadata && { metadata: options.metadata }), + }; + return this.client.makeRequest( "POST", `graphs/${options.collectionId}/relationships/${options.relationshipId}`, { - data: options.relationship, + data, }, ); } diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index d11d496fd..0f5007cf9 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -604,7 +604,7 @@ async def list_chunks( class EntityHandler(Handler): @abstractmethod - async def create(self, *args: Any, **kwargs: Any) -> None: + async def create(self, *args: Any, **kwargs: Any) -> Entity: """Create entities in storage.""" pass @@ -614,7 +614,7 @@ async def get(self, *args: Any, **kwargs: Any) -> list[Entity]: pass @abstractmethod - async def update(self, *args: Any, **kwargs: Any) -> None: + async def update(self, *args: Any, **kwargs: Any) -> Entity: """Update entities in storage.""" pass @@ -626,7 +626,7 @@ async def delete(self, *args: Any, **kwargs: Any) -> None: class RelationshipHandler(Handler): @abstractmethod - async def create(self, *args: Any, **kwargs: Any) -> None: + async def create(self, *args: Any, **kwargs: Any) -> Relationship: """Add relationships to storage.""" pass @@ -636,7 +636,7 @@ async def get(self, *args: Any, **kwargs: Any) -> list[Relationship]: pass @abstractmethod - async def update(self, *args: Any, **kwargs: Any) -> None: + async def update(self, *args: Any, **kwargs: Any) -> Relationship: """Update relationships in storage.""" pass diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index cd76fa97f..a9281066b 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -7,9 +7,7 @@ from core.base import R2RException, RunType from core.base.abstractions import ( - Entity, KGRunType, - Relationship, ) from core.base.api.models import ( GenericBooleanResponse, @@ -521,7 +519,18 @@ async def create_entity( ..., description="The collection ID corresponding to the graph to add the entity to.", ), - entity: Entity = Body(..., description="The entity to create"), + name: str = Body( + ..., description="The name of the entity to create." + ), + description: Optional[str] = Body( + None, description="The description of the entity to create." + ), + category: Optional[str] = Body( + None, description="The category of the entity to create." + ), + metadata: Optional[dict] = Body( + None, description="The metadata of the entity to create." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedEntityResponse: """Creates a new entity in the graph.""" @@ -534,26 +543,13 @@ async def create_entity( 403, ) - # Set parent ID to graph ID - entity.parent_id = collection_id - - # Create entity - created_ids = ( - await self.providers.database.graph_handler.entities.create( - entities=[entity], store_type="graph" - ) - ) - if not created_ids: - raise R2RException("Failed to create entity", 500) - - result = await self.providers.database.graph_handler.entities.get( + return await self.services["kg"].create_entity( + name=name, + description=description, parent_id=collection_id, - store_type="graph", - entity_ids=[created_ids[0]], + category=category, + metadata=metadata, ) - if len(result) == 0: - raise R2RException("Failed to create entity", 500) - return result[0] @self.router.post("/graphs/{collection_id}/relationships") @self.base_endpoint @@ -562,8 +558,32 @@ async def create_relationship( ..., description="The collection ID corresponding to the graph to add the relationship to.", ), - relationship: Relationship = Body( - ..., description="The relationship to create" + subject: str = Body( + ..., description="The subject of the relationship to create." + ), + subject_id: UUID = Body( + ..., + description="The ID of the subject of the relationship to create.", + ), + predicate: str = Body( + ..., description="The predicate of the relationship to create." + ), + object: str = Body( + ..., description="The object of the relationship to create." + ), + object_id: UUID = Body( + ..., + description="The ID of the object of the relationship to create.", + ), + description: Optional[str] = Body( + None, + description="The description of the relationship to create.", + ), + weight: Optional[float] = Body( + None, description="The weight of the relationship to create." + ), + metadata: Optional[dict] = Body( + None, description="The metadata of the relationship to create." ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedRelationshipResponse: @@ -577,16 +597,18 @@ async def create_relationship( 403, ) - # Set parent ID to graph ID - relationship.parent_id = collection_id - - # Create relationship - await self.providers.database.graph_handler.relationships.create( - relationships=[relationship], store_type="graph" + return await self.services["kg"].create_relationship( + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + weight=weight, + metadata=metadata, + parent_id=collection_id, ) - return relationship - @self.router.get( "/graphs/{collection_id}/entities/{entity_id}", openapi_extra={ @@ -659,23 +681,37 @@ async def update_entity( entity_id: UUID = Path( ..., description="The ID of the entity to update." ), - entity: Entity = Body( - ..., description="The updated entity object." + name: Optional[str] = Body( + ..., description="The updated name of the entity." + ), + description: Optional[str] = Body( + None, description="The updated description of the entity." + ), + category: Optional[str] = Body( + None, description="The updated category of the entity." + ), + metadata: Optional[dict] = Body( + None, description="The updated metadata of the entity." ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedEntityResponse: """Updates an existing entity in the graph.""" - entity.id = entity_id - entity.parent_id = ( - entity.parent_id or collection_id - ) # Set parent ID to graph ID - results = await self.providers.database.graph_handler.entities.update( - [entity], - store_type="graph", - # id, entity_id, entity, auth_user + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + + return await self.services["kg"].update_entity( + entity_id=entity_id, + name=name, + category=category, + description=description, + metadata=metadata, ) - print("results = ", results) - return entity @self.router.delete( "/graphs/{collection_id}/entities/{entity_id}", @@ -822,24 +858,6 @@ async def get_relationships( "total_entries": count, } - @self.router.post("/graphs/{collection_id}/relationships") - @self.base_endpoint - async def create_relationship( - collection_id: UUID = Path( - ..., - description="The collection ID corresponding to the graph to add the relationship to.", - ), - relationship_ids: list[UUID] = Body( - ..., - description="The IDs of the relationships to add to the graph.", - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedRelationshipResponse: - """Creates a new relationship in the graph.""" - return await self.providers.database.graph_handler.relationships.add_to_graph( - collection_id, relationship_ids, "graph" - ) - @self.router.get( "/graphs/{collection_id}/relationships/{relationship_id}", description="Retrieves a specific relationship by its ID.", @@ -916,16 +934,53 @@ async def update_relationship( relationship_id: UUID = Path( ..., description="The ID of the relationship to update." ), - relationship: Relationship = Body( - ..., description="The updated relationship object." + subject: Optional[str] = Body( + ..., description="The updated subject of the relationship." + ), + subject_id: Optional[UUID] = Body( + ..., description="The updated subject ID of the relationship." + ), + predicate: Optional[str] = Body( + ..., description="The updated predicate of the relationship." + ), + object: Optional[str] = Body( + ..., description="The updated object of the relationship." + ), + object_id: Optional[UUID] = Body( + ..., description="The updated object ID of the relationship." + ), + description: Optional[str] = Body( + None, + description="The updated description of the relationship.", + ), + weight: Optional[float] = Body( + None, description="The updated weight of the relationship." + ), + metadata: Optional[dict] = Body( + None, description="The updated metadata of the relationship." ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedRelationshipResponse: """Updates an existing relationship in the graph.""" - relationship.id = relationship_id - relationship.parent_id = relationship.parent_id or collection_id - return await self.providers.database.graph_handler.relationships.update( - [relationship], "graph" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + + return await self.services["kg"].update_relationship( + relationship_id=relationship_id, + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + weight=weight, + metadata=metadata, ) @self.router.delete( diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index 6bd3cc666..cd1534a20 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -64,7 +64,6 @@ async def create_graph(input_data): document_id=document_id, **input_data["kg_creation_settings"], ): - print("extraction = ", extraction) extractions.append(extraction) await service.store_kg_extractions(extractions) diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 15354e90f..37167d6c9 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -133,15 +133,15 @@ async def kg_relationships_extraction( return await _collect_results(result_gen) - @telemetry_event("create_entities") - async def create_entities( + @telemetry_event("create_entity") + async def create_entity( self, name: str, description: str, - metadata: Optional[dict] = None, + parent_id: UUID, category: Optional[str] = None, - auth_user: Optional[Any] = None, - ): + metadata: Optional[dict] = None, + ) -> Entity: description_embedding = str( await self.providers.embedding.async_get_embedding(description) @@ -149,11 +149,12 @@ async def create_entities( return await self.providers.database.graph_handler.entities.create( name=name, + parent_id=parent_id, + store_type="graph", # type: ignore category=category, description=description, description_embedding=description_embedding, metadata=metadata, - auth_user=auth_user, ) @telemetry_event("list_entities") @@ -180,31 +181,29 @@ async def list_entities( ) @telemetry_event("update_entity") - async def update_entity_v3( + async def update_entity( self, - id: UUID, - name: Optional[str], - category: Optional[str], - description: Optional[str], - attributes: Optional[dict], - auth_user: Optional[Any] = None, - ): + entity_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + category: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> Entity: + description_embedding = None if description is not None: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) - else: - description_embedding = None return await self.providers.database.graph_handler.entities.update( - id=id, + entity_id=entity_id, name=name, - category=category, description=description, + category=category, description_embedding=description_embedding, - attributes=attributes, - auth_user=auth_user, + metadata=metadata, + store_type="graph", # type: ignore ) @telemetry_event("delete_entity") @@ -278,25 +277,38 @@ async def list_relationships_v3( relationship_id=relationship_id, ) - @telemetry_event("create_relationships_v3") - async def create_relationships_v3( + @telemetry_event("create_relationship") + async def create_relationship( self, - relationships: list[Relationship], - **kwargs, - ): - for relationship in relationships: - if relationship.description: - relationships.description_embedding = str( - await self.providers.embedding.async_get_embedding( - relationship.description - ) - ) - - print("relationships = ", relationships) + subject: str, + subject_id: UUID, + predicate: str, + object: str, + object_id: UUID, + parent_id: UUID, + description: str | None = None, + weight: float | None = 1.0, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + description_embedding = None + if description: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) return ( await self.providers.database.graph_handler.relationships.create( - relationships + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + parent_id=parent_id, + description=description, + description_embedding=description_embedding, + weight=weight, + metadata=metadata, + store_type="graph", # type: ignore ) ) @@ -310,21 +322,44 @@ async def delete_relationship_v3( ): return ( await self.providers.database.graph_handler.relationships.delete( - level=level, id=id, relationship_id=relationship_id, ) ) - @telemetry_event("update_relationship_v3") - async def update_relationship_v3( + @telemetry_event("update_relationship") + async def update_relationship( self, - relationship: Relationship, - **kwargs, - ): + relationship_id: UUID, + subject: Optional[str] = None, + subject_id: Optional[UUID] = None, + predicate: Optional[str] = None, + object: Optional[str] = None, + object_id: Optional[UUID] = None, + description: Optional[str] = None, + weight: Optional[float] = None, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + + description_embedding = None + if description is not None: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + return ( await self.providers.database.graph_handler.relationships.update( - relationship + relationship_id=relationship_id, + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + description_embedding=description_embedding, + weight=weight, + metadata=metadata, + store_type="graph", # type: ignore ) ) @@ -1206,21 +1241,37 @@ async def store_kg_extractions( total_entities, total_relationships = 0, 0 for extraction in kg_extractions: - print("extraction = ", extraction) - total_entities, total_relationships = ( total_entities + len(extraction.entities), total_relationships + len(extraction.relationships), ) - if extraction.entities: + for entity in extraction.entities: await self.providers.database.graph_handler.entities.create( - extraction.entities, store_type="document" + name=entity.name, + parent_id=entity.parent_id, + store_type="document", # type: ignore + category=entity.category, + description=entity.description, + description_embedding=entity.description_embedding, + chunk_ids=entity.chunk_ids, + metadata=entity.metadata, ) if extraction.relationships: - await self.providers.database.graph_handler.relationships.create( - extraction.relationships, store_type="document" - ) + for relationship in extraction.relationships: + await self.providers.database.graph_handler.relationships.create( + subject=relationship.subject, + subject_id=relationship.subject_id, + predicate=relationship.predicate, + object=relationship.object, + object_id=relationship.object_id, + parent_id=relationship.parent_id, + description=relationship.description, + description_embedding=relationship.description_embedding, + weight=relationship.weight, + metadata=relationship.metadata, + store_type="document", # type: ignore + ) return (total_entities, total_relationships) diff --git a/py/core/pipes/kg/extraction.py b/py/core/pipes/kg/extraction.py index 826356f47..3a5241833 100644 --- a/py/core/pipes/kg/extraction.py +++ b/py/core/pipes/kg/extraction.py @@ -143,7 +143,7 @@ def parse_fn(response_str: str) -> Any: category=entity_category, description=entity_description, name=entity_value, - document_id=extractions[0].document_id, + parent_id=extractions[0].document_id, chunk_ids=[ extraction.id for extraction in extractions ], @@ -167,7 +167,7 @@ def parse_fn(response_str: str) -> Any: object=object, description=description, weight=weight, - document_id=extractions[0].document_id, + parent_id=extractions[0].document_id, chunk_ids=[ extraction.id for extraction in extractions ], diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index 37eeef3f7..53a0f132e 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -64,7 +64,7 @@ async def store( if not extraction.entities[0].chunk_ids: for i in range(len(extraction.entities)): extraction.entities[i].chunk_ids = extraction.chunk_ids - extraction.entities[i].document_id = ( + extraction.entities[i].parent_id = ( extraction.document_id ) diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index e55ef9d39..1639fc484 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -138,7 +138,7 @@ async def upsert_documents_overview( summary = $12, summary_embedding = $13 WHERE document_id = $14 """ - print("db_entry = ", db_entry) + await conn.execute( update_query, db_entry["collection_ids"], @@ -465,7 +465,7 @@ async def get_documents_overview( logger.warning( f"Failed to parse embedding for document {row['document_id']}: {e}" ) - print("row = ", row) + documents.append( DocumentResponse( id=row["document_id"], diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 62d0d1652..84db93669 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -119,48 +119,59 @@ async def create_tables(self) -> None: await self.connection_manager.execute_query(QUERY) async def create( - self, entities: list[Entity], store_type: StoreType - ) -> list[UUID]: - """Create multiple entities in the specified store.""" + self, + name: str, + parent_id: UUID, + store_type: StoreType, + category: Optional[str] = None, + description: Optional[str] = None, + description_embedding: Optional[list[float] | str] = None, + chunk_ids: Optional[list[UUID]] = None, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Entity: + """Create a new entity in the specified store.""" table_name = self._get_entity_table_for_store(store_type) - values = [] - results = [] - for entity in entities: - metadata = entity.metadata - if isinstance(metadata, str): - try: - metadata = json.loads(metadata) - except json.JSONDecodeError: - pass + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + pass - description_embedding = entity.description_embedding - if isinstance(description_embedding, list): - description_embedding = str(description_embedding) + if isinstance(description_embedding, list): + description_embedding = str(description_embedding) - value = ( - entity.name, - entity.category, - entity.description, - entity.parent_id, - description_embedding, - entity.chunk_ids, - json.dumps(metadata) if metadata else None, - ) - values.append(value) - - QUERY = f""" + query = f""" INSERT INTO {self._get_table_name(table_name)} (name, category, description, parent_id, description_embedding, chunk_ids, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7) - RETURNING id + RETURNING id, name, category, description, parent_id, chunk_ids, metadata """ - for value in values: - result = await self.connection_manager.fetchrow_query(QUERY, value) - results.append(result["id"]) + params = [ + name, + category, + description, + parent_id, + description_embedding, + chunk_ids, + json.dumps(metadata) if metadata else None, + ] - return results + result = await self.connection_manager.fetchrow_query( + query=query, + params=params, + ) + + return Entity( + id=result["id"], + name=result["name"], + category=result["category"], + description=result["description"], + parent_id=result["parent_id"], + chunk_ids=result["chunk_ids"], + metadata=result["metadata"], + ) async def get( self, @@ -244,63 +255,84 @@ async def get( return entities, count async def update( - self, entities: list[Entity], store_type: StoreType - ) -> list[UUID]: - """Update multiple entities in the specified store.""" + self, + entity_id: UUID, + store_type: StoreType, + name: Optional[str] = None, + description: Optional[str] = None, + description_embedding: Optional[list[float] | str] = None, + category: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> Entity: + """Update an entity in the specified store.""" table_name = self._get_entity_table_for_store(store_type) - results = [] + update_fields = [] + params: list = [] + param_index = 1 - print("entities = ", entities) - QUERY = f""" - UPDATE {self._get_table_name(table_name)} - SET - name = $1, - category = $2, - description = $3, - description_embedding = $4, - chunk_ids = $5, - metadata = $6, - updated_at = CURRENT_TIMESTAMP - WHERE id = $7 AND parent_id = $8 - RETURNING id - """ + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + pass - for entity in entities: - metadata = entity.metadata - if isinstance(metadata, str): - try: - metadata = json.loads(metadata) - except json.JSONDecodeError: - pass + if name is not None: + update_fields.append(f"name = ${param_index}") + params.append(name) + param_index += 1 - description_embedding = entity.description_embedding - if isinstance(description_embedding, list): - description_embedding = str(description_embedding) + if description is not None: + update_fields.append(f"description = ${param_index}") + params.append(description) + param_index += 1 - params = [ - entity.name, - entity.category, - entity.description, - description_embedding, - entity.chunk_ids, - json.dumps(metadata) if metadata else None, - entity.id, - entity.parent_id, - ] - print("QUERY = ", QUERY) + if description_embedding is not None: + update_fields.append(f"description_embedding = ${param_index}") + params.append(description_embedding) + param_index += 1 + + if category is not None: + update_fields.append(f"category = ${param_index}") + params.append(category) + param_index += 1 + + if metadata is not None: + update_fields.append(f"metadata = ${param_index}") + params.append(json.dumps(metadata)) + param_index += 1 + + if not update_fields: + raise R2RException(status_code=400, message="No fields to update") + + update_fields.append("updated_at = NOW()") + params.append(entity_id) + query = f""" + UPDATE {self._get_table_name(table_name)} + SET {', '.join(update_fields)} + WHERE id = ${param_index} + RETURNING id, name, category, description, parent_id, chunk_ids, metadata + """ + try: result = await self.connection_manager.fetchrow_query( - QUERY, params + query=query, + params=params, ) - if not result: - raise R2RException( - f"Entity {entity.id} not found in {store_type} store or no permission to update", - 404, - ) - - results.append(result["id"]) - return results + return Entity( + id=result["id"], + name=result["name"], + category=result["category"], + description=result["description"], + parent_id=result["parent_id"], + chunk_ids=result["chunk_ids"], + metadata=result["metadata"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the entity: {e}", + ) async def delete( self, @@ -434,53 +466,72 @@ async def create_tables(self) -> None: await self.connection_manager.execute_query(QUERY) async def create( - self, relationships: list[Relationship], store_type: StoreType - ) -> list[UUID]: - """Create multiple relationships in the specified store.""" + self, + subject: str, + subject_id: UUID, + predicate: str, + object: str, + object_id: UUID, + parent_id: UUID, + store_type: StoreType, + description: str | None = None, + weight: float | None = 1.0, + chunk_ids: Optional[list[UUID]] = None, + description_embedding: Optional[list[float] | str] = None, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + """Create a new relationship in the specified store.""" table_name = self._get_relationship_table_for_store(store_type) - values = [] - results = [] - - for relationship in relationships: - metadata = relationship.metadata - if isinstance(metadata, str): - try: - metadata = json.loads(metadata) - except json.JSONDecodeError: - pass - description_embedding = relationship.description_embedding - if isinstance(description_embedding, list): - description_embedding = str(description_embedding) + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + pass - value = ( - relationship.subject, - relationship.predicate, - relationship.object, - relationship.description, - relationship.subject_id, - relationship.object_id, - relationship.weight, - relationship.chunk_ids, - relationship.parent_id, - description_embedding, - json.dumps(metadata) if metadata else None, - ) - values.append(value) + if isinstance(description_embedding, list): + description_embedding = str(description_embedding) - QUERY = f""" + query = f""" INSERT INTO {self._get_table_name(table_name)} (subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, description_embedding, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - RETURNING id + RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata """ - for value in values: - result = await self.connection_manager.fetchrow_query(QUERY, value) - results.append(result["id"]) + params = [ + subject, + predicate, + object, + description, + subject_id, + object_id, + weight, + chunk_ids, + parent_id, + description_embedding, + json.dumps(metadata) if metadata else None, + ] + + result = await self.connection_manager.fetchrow_query( + query=query, + params=params, + ) - return results + return Relationship( + id=result["id"], + subject=result["subject"], + predicate=result["predicate"], + object=result["object"], + description=result["description"], + subject_id=result["subject_id"], + object_id=result["object_id"], + weight=result["weight"], + chunk_ids=result["chunk_ids"], + parent_id=result["parent_id"], + metadata=result["metadata"], + ) async def get( self, @@ -587,62 +638,108 @@ async def get( return relationships, count async def update( - self, relationships: list[Relationship], store_type: StoreType - ) -> list[UUID]: + self, + relationship_id: UUID, + store_type: StoreType, + subject: Optional[str], + subject_id: Optional[UUID], + predicate: Optional[str], + object: Optional[str], + object_id: Optional[UUID], + description: Optional[str], + description_embedding: Optional[list[float] | str], + weight: Optional[float], + metadata: Optional[dict[str, Any] | str], + ) -> Relationship: """Update multiple relationships in the specified store.""" table_name = self._get_relationship_table_for_store(store_type) - results = [] + update_fields = [] + params: list = [] + param_index = 1 - QUERY = f""" - UPDATE {self._get_table_name(table_name)} - SET - subject = $1, - predicate = $2, - object = $3, - description = $4, - subject_id = $5, - object_id = $6, - weight = $7, - chunk_ids = $8, - metadata = $9, - updated_at = CURRENT_TIMESTAMP - WHERE id = $10 AND parent_id = $11 - RETURNING id - """ + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + pass - for relationship in relationships: - metadata = relationship.metadata - if isinstance(metadata, str): - try: - metadata = json.loads(metadata) - except json.JSONDecodeError: - pass + if subject is not None: + update_fields.append(f"subject = ${param_index}") + params.append(subject) + param_index += 1 - params = [ - relationship.subject, - relationship.predicate, - relationship.object, - relationship.description, - relationship.subject_id, - relationship.object_id, - relationship.weight, - relationship.chunk_ids, - json.dumps(metadata) if metadata else None, - relationship.id, - relationship.parent_id, - ] + if subject_id is not None: + update_fields.append(f"subject_id = ${param_index}") + params.append(subject_id) + param_index += 1 + if predicate is not None: + update_fields.append(f"predicate = ${param_index}") + params.append(predicate) + param_index += 1 + + if object is not None: + update_fields.append(f"object = ${param_index}") + params.append(object) + param_index += 1 + + if object_id is not None: + update_fields.append(f"object_id = ${param_index}") + params.append(object_id) + param_index += 1 + + if description is not None: + update_fields.append(f"description = ${param_index}") + params.append(description) + param_index += 1 + + if description_embedding is not None: + update_fields.append(f"description_embedding = ${param_index}") + params.append(description_embedding) + param_index += 1 + + if weight is not None: + update_fields.append(f"weight = ${param_index}") + params.append(weight) + param_index += 1 + + if not update_fields: + raise R2RException(status_code=400, message="No fields to update") + + update_fields.append("updated_at = NOW()") + params.append(relationship_id) + + query = f""" + UPDATE {self._get_table_name(table_name)} + SET {', '.join(update_fields)} + WHERE id = ${param_index} + RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata + """ + + try: result = await self.connection_manager.fetchrow_query( - QUERY, params + query=query, + params=params, ) - if not result: - raise R2RException( - f"Relationship {relationship.id} not found in {store_type} store or no permission to update", - 404, - ) - results.append(result["id"]) - return results + return Relationship( + id=result["id"], + subject=result["subject"], + predicate=result["predicate"], + object=result["object"], + description=result["description"], + subject_id=result["subject_id"], + object_id=result["object_id"], + weight=result["weight"], + chunk_ids=result["chunk_ids"], + parent_id=result["parent_id"], + metadata=result["metadata"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the relationship: {e}", + ) async def delete( self, diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index 5e1479aac..b7954180f 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -16,6 +16,8 @@ WrappedCommunityResponse, ) +_list = list # Required for type hinting since we have a list method + class GraphsSDK: """ @@ -362,7 +364,7 @@ async def update_community( community_id: str | UUID, name: Optional[str] = None, summary: Optional[str] = None, - findings: Optional[list[str]] = None, + findings: Optional[_list[str]] = None, rating: Optional[int] = None, rating_explanation: Optional[str] = None, level: Optional[int] = None, @@ -393,7 +395,7 @@ async def update_community( if findings is not None: data["findings"] = findings if rating is not None: - data["rating"] = rating + data["rating"] = str(rating) if rating_explanation is not None: data["rating_explanation"] = rating_explanation if level is not None: diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index de2a794dc..668eb1333 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -36,24 +36,15 @@ class Entity(R2RSerializable): """An entity extracted from a document.""" name: str - # id is Union of UUID and int for backwards compatibility - # we will migrate to UUID only in the future - # sid is also deprecated and needs to be removed in the future - id: Optional[UUID | int] = None - category: Optional[str] = None description: Optional[str] = None + category: Optional[str] = None + metadata: Optional[dict[str, Any] | str] = None + + id: Optional[UUID] = None parent_id: Optional[UUID] = None # graph_id | document_id - # document_ids: list[UUID] = [] description_embedding: Optional[list[float] | str] = None - chunk_ids: Optional[list[UUID]] = [] - # we don't use these yet - # name_embedding: Optional[list[float]] = None - # graph_embedding: Optional[list[float]] = None - # rank: Optional[int] = None - metadata: Optional[dict[str, Any] | str] = None - def __str__(self): return f"{self.name}:{self.category}" @@ -77,7 +68,7 @@ class Relationship(R2RSerializable): subject_id: Optional[UUID] = None object_id: Optional[UUID] = None weight: float | None = 1.0 - chunk_ids: list[UUID] = [] + chunk_ids: Optional[list[UUID]] = [] parent_id: Optional[UUID] = None description_embedding: Optional[list[float] | str] = None From 32c4054bbd2bc65ea4c9d1e90e4a91ab11d15938 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 27 Nov 2024 16:15:50 -0600 Subject: [PATCH 12/28] Check in --- .../GraphsIntegrationSuperUser.test.ts.txt | 90 +++++++++++++ .../GraphsIntegrationUser.test.ts.txt | 122 ++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts.txt create mode 100644 js/sdk/__tests__/GraphsIntegrationUser.test.ts.txt diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts.txt b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts.txt new file mode 100644 index 000000000..31cd60d74 --- /dev/null +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts.txt @@ -0,0 +1,90 @@ +import { r2rClient } from "../src/index"; +import { describe, test, beforeAll, expect } from "@jest/globals"; + +const baseUrl = "http://localhost:7272"; + +describe("r2rClient V3 Collections Integration Tests", () => { + let client: r2rClient; + let graph1Id: string; + let graph2Id: string; + + beforeAll(async () => { + client = new r2rClient(baseUrl); + await client.users.login({ + email: "admin@example.com", + password: "change_me_immediately", + }); + }); + + test("Create a graph with only a name", async () => { + const response = await client.graphs.create({ + name: "Graph 1", + }); + expect(response.results).toBeDefined(); + graph1Id = response.results.id; + expect(graph1Id).toEqual(response.results.id); + expect(response.results.name).toEqual("Graph 1"); + expect(response.results.description).toBe(""); + }); + + test("Create a graph with name and description", async () => { + const response = await client.graphs.create({ + name: "2", + description: "Graph 2", + }); + graph2Id = response.results.id; + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("2"); + expect(response.results.description).toEqual("Graph 2"); + }); + + test("Ensure that there are two graphs", async () => { + const response = await client.graphs.list(); + expect(response.results).toBeDefined(); + expect(response.results.length).toEqual(2); + }); + + test("Retrieve graph 1", async () => { + const response = await client.graphs.retrieve({ id: graph1Id }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("Graph 1"); + expect(response.results.description).toBe(""); + }); + + test("Retrieve graph 2", async () => { + const response = await client.graphs.retrieve({ id: graph2Id }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("2"); + expect(response.results.description).toEqual("Graph 2"); + }); + + test("Update the name of graph 1", async () => { + const response = await client.graphs.update({ + id: graph1Id, + name: "Graph 1 Updated", + }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("Graph 1 Updated"); + }); + + test("Update the description graph 2", async () => { + const response = await client.graphs.update({ + id: graph2Id, + description: "Graph 2 Updated", + }); + expect(response.results).toBeDefined(); + expect(response.results.description).toEqual("Graph 2 Updated"); + }); + + test("Delete graph 1", async () => { + const response = await client.graphs.delete({ id: graph1Id }); + expect(response.results).toBeDefined(); + expect(response.results.success).toBe(true); + }); + + test("Delete graph 2", async () => { + const response = await client.graphs.delete({ id: graph2Id }); + expect(response.results).toBeDefined(); + expect(response.results.success).toBe(true); + }); +}); diff --git a/js/sdk/__tests__/GraphsIntegrationUser.test.ts.txt b/js/sdk/__tests__/GraphsIntegrationUser.test.ts.txt new file mode 100644 index 000000000..4946015c2 --- /dev/null +++ b/js/sdk/__tests__/GraphsIntegrationUser.test.ts.txt @@ -0,0 +1,122 @@ +import { r2rClient } from "../src/index"; +import { describe, test, beforeAll, expect } from "@jest/globals"; + +const baseUrl = "http://localhost:7272"; + +describe("r2rClient V3 Collections Integration Tests", () => { + let client: r2rClient; + + let graph1Id: string; + let graph2Id: string; + + let entity1Id: string; + + beforeAll(async () => { + client = new r2rClient(baseUrl); + await client.users.login({ + email: "admin@example.com", + password: "change_me_immediately", + }); + }); + + test("Create a graph with only a name", async () => { + const response = await client.graphs.create({ + name: "Graph 1", + }); + expect(response.results).toBeDefined(); + graph1Id = response.results.id; + expect(graph1Id).toEqual(response.results.id); + expect(response.results.name).toEqual("Graph 1"); + expect(response.results.description).toBe(null); + }); + + test("Create a graph with name and description", async () => { + const response = await client.graphs.create({ + name: "2", + description: "Graph 2", + }); + graph2Id = response.results.id; + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("2"); + expect(response.results.description).toEqual("Graph 2"); + }); + + test("Ensure that there are two graphs", async () => { + const response = await client.graphs.list(); + expect(response.results).toBeDefined(); + expect(response.results.length).toEqual(2); + }); + + test("Retrieve graph 1", async () => { + const response = await client.graphs.retrieve({ id: graph1Id }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("Graph 1"); + expect(response.results.description).toBe(null); + }); + + test("Retrieve graph 2", async () => { + const response = await client.graphs.retrieve({ id: graph2Id }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("2"); + expect(response.results.description).toEqual("Graph 2"); + }); + + test("Update the name of graph 1", async () => { + const response = await client.graphs.update({ + id: graph1Id, + name: "Graph 1 Updated", + }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("Graph 1 Updated"); + }); + + test("Update the desription graph 2", async () => { + const response = await client.graphs.update({ + id: graph2Id, + description: "Graph 2 Updated", + }); + expect(response.results).toBeDefined(); + expect(response.results.description).toEqual("Graph 2 Updated"); + }); + + test("Create an entity and add it to graph 1", async () => { + const createResponse = await client.entities.create({ + name: "Entity 1", + description: "Entity 1 Description", + }); + entity1Id = createResponse.results.id; + expect(createResponse.results).toBeDefined(); + expect(createResponse.results.name).toEqual("Entity 1"); + + const addResponse = await client.graphs.addEntity({ + id: graph1Id, + entityId: createResponse.results.id, + }); + expect(addResponse.results).toBeDefined(); + }); + + test("Remove entity from graph 1", async () => { + const response = await client.graphs.removeEntity({ + id: graph1Id, + entityId: entity1Id, + }); + expect(response.results).toBeDefined(); + }); + + test("Delete entity from graph 1", async () => { + const response = await client.entities.delete({ id: entity1Id }); + expect(response.results).toBeDefined(); + }); + + test("Delete graph 1", async () => { + const response = await client.graphs.delete({ id: graph1Id }); + expect(response.results).toBeDefined(); + expect(response.results.success).toBe(true); + }); + + test("Delete graph 2", async () => { + const response = await client.graphs.delete({ id: graph2Id }); + expect(response.results).toBeDefined(); + expect(response.results.success).toBe(true); + }); +}); From 759150beb6bb0108b9a2ff95660964c80de7afca Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 27 Nov 2024 16:21:37 -0600 Subject: [PATCH 13/28] Ellipsis fixes --- js/sdk/src/v3/clients/documents.ts | 5 +++-- js/sdk/src/v3/clients/graphs.ts | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/js/sdk/src/v3/clients/documents.ts b/js/sdk/src/v3/clients/documents.ts index f43f1f1f9..bc720b4c3 100644 --- a/js/sdk/src/v3/clients/documents.ts +++ b/js/sdk/src/v3/clients/documents.ts @@ -8,6 +8,7 @@ import { WrappedDocumentsResponse, WrappedEntitiesResponse, WrappedIngestionResponse, + WrappedRelationshipsResponse, } from "../../types"; import { feature } from "../../feature"; @@ -435,7 +436,7 @@ export class DocumentsClient { * @param includeEmbeddings Whether to include vector embeddings in the response. * @param entityNames Filter relationships by specific entity names. * @param relationshipTypes Filter relationships by specific relationship types. - * @returns + * @returns WrappedRelationshipsResponse */ @feature("documents.listRelationships") async listRelationships(options: { @@ -445,7 +446,7 @@ export class DocumentsClient { includeVectors?: boolean; entityNames?: string[]; relationshipTypes?: string[]; - }): Promise { + }): Promise { const params: Record = { offset: options.offset ?? 0, limit: options.limit ?? 100, diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index 2189cb10c..d4c197944 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -356,7 +356,7 @@ export class GraphsClient { }): Promise { return this.client.makeRequest( "DELETE", - `graphs/${options.collectionId}/entities/${options.relationshipId}`, + `graphs/${options.collectionId}/relationships/${options.relationshipId}`, ); } @@ -383,7 +383,7 @@ export class GraphsClient { */ @feature("graphs.build") async build(options: { - collection_id: string; + collectionId: string; settings?: Record; runType?: string; runWithOrchestration?: boolean; @@ -398,7 +398,7 @@ export class GraphsClient { return this.client.makeRequest( "POST", - `graphs/${options.collection_id}/communities/build`, + `graphs/${options.collectionId}/communities/build`, { data, }, @@ -483,7 +483,7 @@ export class GraphsClient { }; return this.client.makeRequest( "POST", - `graphs/${options.collectionId}/entities/${options.communityId}`, + `graphs/${options.collectionId}/communities/${options.communityId}`, { data, }, From 7e1a947a87b3736617b6ebbfdbcd4e350edf7fb5 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 27 Nov 2024 16:47:06 -0600 Subject: [PATCH 14/28] More cleanup --- .../ChunksIntegrationSuperUser.test.ts | 2 +- .../GraphsIntegrationSuperUser.test.ts | 33 ++++++++++++++++++- js/sdk/src/types.ts | 26 --------------- js/sdk/src/v3/clients/graphs.ts | 2 -- py/core/main/api/v2/kg_router.py | 7 ++-- py/core/providers/database/graph.py | 4 +-- 6 files changed, 37 insertions(+), 37 deletions(-) diff --git a/js/sdk/__tests__/ChunksIntegrationSuperUser.test.ts b/js/sdk/__tests__/ChunksIntegrationSuperUser.test.ts index a15336239..6370b9b32 100644 --- a/js/sdk/__tests__/ChunksIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/ChunksIntegrationSuperUser.test.ts @@ -31,7 +31,7 @@ describe("r2rClient V3 Collections Integration Tests", () => { expect(response.results).toEqual([ { document_id: expect.any(String), - message: "Ingestion task completed successfully.", + message: "Document created and ingested successfully.", }, ]); }, 10000); diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index d7cf63d35..df9b93a57 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -123,13 +123,44 @@ describe("r2rClient V3 Collections Integration Tests", () => { }); expect(response.results).toBeDefined(); expect(response.total_entries).toBeGreaterThanOrEqual(1); - }); + }, 60000); test("Check that there are relationships in the graph", async () => { const response = await client.graphs.listRelationships({ collectionId: collectionId, }); expect(response.results).toBeDefined(); + expect(response.total_entries).toBeGreaterThanOrEqual(1); + }); + + test("Check that there are no communities in the graph prior to building", async () => { + const response = await client.graphs.listCommunities({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + + test("Build communities", async () => { + const response = await client.graphs.build({ + collectionId: collectionId, + }); + + await new Promise((resolve) => setTimeout(resolve, 30000)); + + expect(response.results).toBeDefined(); + }); + + test("Check that there are communities in the graph", async () => { + const response = await client.graphs.listCommunities({ + collectionId: collectionId, + }); + + console.log("Communities: ", response); + + expect(response.results).toBeDefined(); + expect(response.total_entries).toBeGreaterThanOrEqual(1); }); test("Create a new entity", async () => { diff --git a/js/sdk/src/types.ts b/js/sdk/src/types.ts index c29cb60e6..efa06c0df 100644 --- a/js/sdk/src/types.ts +++ b/js/sdk/src/types.ts @@ -1,29 +1,3 @@ -export interface Entity { - name: string; - id?: string; - category?: string; - description?: string; - parent_id?: string; - description_embedding?: string; - chunk_ids: string[]; - metadata: Record; -} - -export interface Relationship { - id?: string; - subject: string; - predicate: string; - object: string; - description?: string; - subject_id?: string; - object_id?: string; - weight?: number; - chunk_ids: string[]; - parent_id?: string; - description_embedding?: string; - metadata: Record; -} - export interface UnprocessedChunk { id: string; document_id?: string; diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index d4c197944..4a1c23140 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -10,8 +10,6 @@ import { WrappedRelationshipResponse, WrappedCommunitiesResponse, WrappedCommunityResponse, - Entity, - Relationship, } from "../../types"; export class GraphsClient { diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index b2c496a56..108021617 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Optional, Union +from typing import Optional from uuid import UUID import yaml @@ -12,9 +12,6 @@ WrappedCommunitiesResponse, WrappedEntitiesResponse, WrappedRelationshipsResponse, - WrappedKGCreationResponse, - WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, WrappedKGTunePromptResponse, WrappedRelationshipsResponse, ) @@ -39,7 +36,7 @@ def __init__( self, service: KgService, orchestration_provider: Optional[ - Union[HatchetOrchestrationProvider, SimpleOrchestrationProvider] + HatchetOrchestrationProvider | SimpleOrchestrationProvider ] = None, run_type: RunType = RunType.KG, ): diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 102e9c5b4..bc9740581 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -4,7 +4,7 @@ import logging import time from enum import Enum -from typing import Any, AsyncGenerator, List, Optional, Set, Tuple, Union +from typing import Any, AsyncGenerator, Optional, Set, Tuple, Union from uuid import UUID, uuid4 import asyncpg @@ -2429,7 +2429,7 @@ async def has_document(self, graph_id: UUID, document_id: UUID) -> bool: # return [item["community_id"] for item in community_ids] async def check_communities_exist( - self, collection_id: UUID, community_ids: List[UUID] + self, collection_id: UUID, community_ids: list[UUID] ) -> Set[UUID]: """ Check which communities already exist in the database. From d3d8936280d62aaa17bcdffc8b4d75372e93c995 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 27 Nov 2024 23:36:09 -0600 Subject: [PATCH 15/28] Start CRUD on communities --- .../GraphsIntegrationSuperUser.test.ts | 8 +- js/sdk/src/v3/clients/collections.ts | 45 ++++ js/sdk/src/v3/clients/graphs.ts | 55 +---- py/core/base/api/models/__init__.py | 3 + py/core/main/api/v3/collections_router.py | 3 +- py/core/main/api/v3/graph_router.py | 17 +- py/core/providers/database/graph.py | 216 +++++++++++------- py/shared/api/models/__init__.py | 4 - 8 files changed, 202 insertions(+), 149 deletions(-) diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index df9b93a57..e2c8ec52a 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -3,7 +3,7 @@ import { describe, test, beforeAll, expect } from "@jest/globals"; const baseUrl = "http://localhost:7272"; -describe("r2rClient V3 Collections Integration Tests", () => { +describe("r2rClient V3 Graphs Integration Tests", () => { let client: r2rClient; let documentId: string; let collectionId: string; @@ -143,14 +143,14 @@ describe("r2rClient V3 Collections Integration Tests", () => { }); test("Build communities", async () => { - const response = await client.graphs.build({ + const response = await client.graphs.buildCommunities({ collectionId: collectionId, }); - await new Promise((resolve) => setTimeout(resolve, 30000)); + await new Promise((resolve) => setTimeout(resolve, 15000)); expect(response.results).toBeDefined(); - }); + }, 45000); test("Check that there are communities in the graph", async () => { const response = await client.graphs.listCommunities({ diff --git a/js/sdk/src/v3/clients/collections.ts b/js/sdk/src/v3/clients/collections.ts index a975f90c8..c77b1eda5 100644 --- a/js/sdk/src/v3/clients/collections.ts +++ b/js/sdk/src/v3/clients/collections.ts @@ -215,4 +215,49 @@ export class CollectionsClient { `collections/${options.id}/users/${options.userId}`, ); } + + /** + * Creates communities in the graph by analyzing entity relationships and similarities. + * + * Communities are created through the following process: + * 1. Analyzes entity relationships and metadata to build a similarity graph + * 2. Applies advanced community detection algorithms (e.g. Leiden) to identify densely connected groups + * 3. Creates hierarchical community structure with multiple granularity levels + * 4. Generates natural language summaries and statistical insights for each community + * + * The resulting communities can be used to: + * - Understand high-level graph structure and organization + * - Identify key entity groupings and their relationships + * - Navigate and explore the graph at different levels of detail + * - Generate insights about entity clusters and their characteristics + * + * The community detection process is configurable through settings like: + * - Community detection algorithm parameters + * - Summary generation prompt + * @param collectionId The collection ID corresponding to the graph + * @returns + */ + @feature("collections.extract") + async extract(options: { + collectionId: string; + runType?: string; + settings?: Record; + runWithOrchestration?: boolean; + }): Promise { + const data = { + ...(options.settings && { settings: options.settings }), + ...(options.runType && { run_type: options.runType }), + ...(options.runWithOrchestration && { + run_with_orchestration: options.runWithOrchestration, + }), + }; + + return this.client.makeRequest( + "POST", + `collections/${options.collectionId}/extract`, + { + data, + }, + ); + } } diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index 4a1c23140..10a815fd8 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -358,51 +358,6 @@ export class GraphsClient { ); } - /** - * Creates communities in the graph by analyzing entity relationships and similarities. - * - * Communities are created through the following process: - * 1. Analyzes entity relationships and metadata to build a similarity graph - * 2. Applies advanced community detection algorithms (e.g. Leiden) to identify densely connected groups - * 3. Creates hierarchical community structure with multiple granularity levels - * 4. Generates natural language summaries and statistical insights for each community - * - * The resulting communities can be used to: - * - Understand high-level graph structure and organization - * - Identify key entity groupings and their relationships - * - Navigate and explore the graph at different levels of detail - * - Generate insights about entity clusters and their characteristics - * - * The community detection process is configurable through settings like: - * - Community detection algorithm parameters - * - Summary generation prompt - * @param collectionId The collection ID corresponding to the graph - * @returns - */ - @feature("graphs.build") - async build(options: { - collectionId: string; - settings?: Record; - runType?: string; - runWithOrchestration?: boolean; - }): Promise { - const data = { - ...(options.settings && { settings: options.settings }), - ...(options.runType && { run_type: options.runType }), - ...(options.runWithOrchestration && { - run_with_orchestration: options.runWithOrchestration, - }), - }; - - return this.client.makeRequest( - "POST", - `graphs/${options.collectionId}/communities/build`, - { - data, - }, - ); - } - // TODO: Create community /** @@ -559,4 +514,14 @@ export class GraphsClient { `graphs/${options.collectionId}/documents/${options.documentId}`, ); } + + @feature("graphs.buildCommunities") + async buildCommunities(options: { + collectionId: string; + }): Promise { + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/communities/build`, + ); + } } diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index 3051949a2..407eacb2c 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -24,6 +24,7 @@ KGEnrichmentResponse, KGTunePromptResponse, Relationship, + GraphResponse, WrappedCommunitiesResponse, WrappedCommunityResponse, WrappedEntitiesResponse, @@ -32,6 +33,8 @@ WrappedKGTunePromptResponse, WrappedRelationshipResponse, WrappedRelationshipsResponse, + WrappedGraphResponse, + WrappedGraphsResponse, ) from shared.api.models.management.responses import ( # Document Responses; Prompt Responses; Chunk Responses; Conversation Responses; User Responses; TODO: anything below this hasn't been reviewed AnalyticsResponse, diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index e0c81dce3..b5ceacc16 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -13,7 +13,6 @@ WrappedCollectionsResponse, WrappedDocumentsResponse, WrappedGenericMessageResponse, - WrappedKGCreationResponse, WrappedUsersResponse, ) from core.providers import ( @@ -1065,7 +1064,7 @@ async def extract( description="Whether to run the entities and relationships extraction process with orchestration.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGCreationResponse: # type: ignore + ): """ Extracts entities and relationships from a document. The entities and relationships extraction process involves: diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index bf8e681df..265f18bb9 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -1274,18 +1274,13 @@ async def get_communities( By default, all attributes are returned, but this can be limited using the `attributes` parameter. """ - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can access this endpoint.", 403 - ) - communities, count = await self.services[ - "kg" - ].providers.database.graph_handler.communities.get( - graph_id=collection_id, - offset=offset, - limit=limit, - auth_user=auth_user, + communities, count = ( + await self.providers.database.graph_handler.get_communities( + graph_id=collection_id, + offset=offset, + limit=limit, + ) ) return communities, { # type: ignore diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index cb24846ca..7bf478c7e 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -188,7 +188,7 @@ async def get( table_name = self._get_entity_table_for_store(store_type) conditions = ["parent_id = $1"] - params = [parent_id] + params: list[Any] = [parent_id] param_index = 2 if entity_ids: @@ -311,7 +311,7 @@ async def update( query = f""" UPDATE {self._get_table_name(table_name)} SET {', '.join(update_fields)} - WHERE id = ${param_index} + WHERE id = ${param_index}\ RETURNING id, name, category, description, parent_id, chunk_ids, metadata """ try: @@ -848,8 +848,6 @@ async def create_tables(self) -> None: metadata JSONB, UNIQUE (community_id, level, graph_id, collection_id) );""" - # created_by UUID REFERENCES {self._get_table_name("users")}(user_id), - # updated_by UUID REFERENCES {self._get_table_name("users")}(user_id), await self.connection_manager.execute_query(query) @@ -982,61 +980,83 @@ async def delete( async def get( self, - graph_id: UUID, + parent_id: UUID, + store_type: StoreType, offset: int, limit: int, - community_id: Optional[UUID] = None, - auth_user: Optional[Any] = None, + community_ids: Optional[list[UUID]] = None, + community_names: Optional[list[str]] = None, + include_embeddings: bool = False, ): + """Retrieve communities from the specified store.""" + # Do we ever want to get communities from document store? + table_name = "graph_community" - if not auth_user.is_superuser: - if not await self._check_permissions(graph_id, auth_user.id): - raise R2RException( - "You do not have permission to access this graph.", - 403, - ) + conditions = ["graph_id = $1"] + params: list[Any] = [parent_id] + param_index = 2 - if community_id is None: + if community_ids: + conditions.append(f"id = ANY(${param_index})") + params.append(community_ids) + param_index += 1 - QUERY = f""" - SELECT - id, graph_id, name, summary, findings, rating, rating_explanation, level, metadata, created_by, updated_by, created_at, updated_at - FROM {self._get_table_name("graph_community")} WHERE graph_id = $1 - OFFSET $2 LIMIT $3 - """ - params = [graph_id, offset, limit] - communities = [ - Community(**row) - for row in await self.connection_manager.fetch_query( - QUERY, params - ) - ] + if community_names: + conditions.append(f"name = ANY(${param_index})") + params.append(community_names) + param_index += 1 - QUERY_COUNT = f""" - SELECT COUNT(*) FROM {self._get_table_name("graph_community")} WHERE graph_id = $1 - """ - count = ( - await self.connection_manager.fetch_query( - QUERY_COUNT, [graph_id] - ) - )[0]["count"] + select_fields = """ + id, graph_id, name, summary, findings, rating, + rating_explanation, level, metadata, created_at, updated_at + """ + if include_embeddings: + select_fields += ", description_embedding" - return communities, count + COUNT_QUERY = f""" + SELECT COUNT(*) + FROM {self._get_table_name(table_name)} + WHERE {' AND '.join(conditions)} + """ - else: - QUERY = f""" - SELECT - id, graph_id, name, summary, findings, rating, rating_explanation, level, metadata, created_by, updated_by, created_at, updated_at - FROM {self._get_table_name("graph_community")} WHERE graph_id = $1 AND id = $2 - """ - params = [graph_id, community_id] - return [ - Community( - **await self.connection_manager.fetchrow_query( - QUERY, params + count = ( + await self.connection_manager.fetch_query( + COUNT_QUERY, params[: param_index - 1] + ) + )[0]["count"] + + QUERY = f""" + SELECT {select_fields} + FROM {self._get_table_name(table_name)} + WHERE {' AND '.join(conditions)} + ORDER BY created_at + OFFSET ${param_index} + """ + params.append(offset) + param_index += 1 + + if limit != -1: + QUERY += f" LIMIT ${param_index}" + params.append(limit) + + rows = await self.connection_manager.fetch_query(QUERY, params) + + communities = [] + for row in rows: + community_dict = dict(row) + + # Process metadata if it exists and is a string + if isinstance(community_dict["metadata"], str): + try: + community_dict["metadata"] = json.loads( + community_dict["metadata"] ) - ) - ] + except json.JSONDecodeError: + pass + + communities.append(Community(**community_dict)) + + return communities, count class PostgresGraphHandler(GraphHandler): @@ -2484,56 +2504,86 @@ async def add_community_info( async def get_communities( self, + graph_id: UUID, offset: int, limit: int, - collection_id: Optional[UUID] = None, - levels: Optional[list[int]] = None, community_ids: Optional[list[UUID]] = None, - ) -> dict: - conditions = [] - params: list = [collection_id] + levels: Optional[list[int]] = None, + include_embeddings: bool = False, + ) -> tuple[list[Community], int]: + """ + Get communities for a graph. + + Args: + graph_id: UUID of the graph + offset: Number of records to skip + limit: Maximum number of records to return (-1 for no limit) + community_ids: Optional list of community IDs to filter by + levels: Optional list of levels to filter by + include_embeddings: Whether to include embeddings in the response + + Returns: + Tuple of (list of communities, total count) + """ + conditions = ["graph_id = $1"] + params: list[Any] = [graph_id] param_index = 2 - if levels is not None: + if community_ids: + conditions.append(f"id = ANY(${param_index})") + params.append(community_ids) + param_index += 1 + + if levels: conditions.append(f"level = ANY(${param_index})") params.append(levels) param_index += 1 - if community_ids is not None: - conditions.append(f"community_id = ANY(${param_index})") - params.append(community_ids) - param_index += 1 + select_fields = """ + id, graph_id, name, summary, findings, rating, + rating_explanation, level, metadata, created_at, updated_at + """ + if include_embeddings: + select_fields += ", description_embedding" - pagination_params = [] - if offset: - pagination_params.append(f"OFFSET ${param_index}") - params.append(offset) - param_index += 1 + COUNT_QUERY = f""" + SELECT COUNT(*) + FROM {self._get_table_name("graph_community")} + WHERE {' AND '.join(conditions)} + """ + count = ( + await self.connection_manager.fetch_query(COUNT_QUERY, params) + )[0]["count"] + + QUERY = f""" + SELECT {select_fields} + FROM {self._get_table_name("graph_community")} + WHERE {' AND '.join(conditions)} + ORDER BY created_at + OFFSET ${param_index} + """ + params.append(offset) + param_index += 1 if limit != -1: - pagination_params.append(f"LIMIT ${param_index}") + QUERY += f" LIMIT ${param_index}" params.append(limit) - param_index += 1 - - pagination_clause = " ".join(pagination_params) - query = f""" - SELECT id, community_id, collection_id, level, name, summary, findings, rating, rating_explanation, COUNT(*) OVER() AS total_entries - FROM {self._get_table_name('graph_community')} - WHERE collection_id = $1 - {" AND " + " AND ".join(conditions) if conditions else ""} - ORDER BY community_id - {pagination_clause} - """ + rows = await self.connection_manager.fetch_query(QUERY, params) - results = await self.connection_manager.fetch_query(query, params) - total_entries = results[0]["total_entries"] if results else 0 - communities = [Community(**community) for community in results] + communities = [] + for row in rows: + community_dict = dict(row) + if isinstance(community_dict["metadata"], str): + try: + community_dict["metadata"] = json.loads( + community_dict["metadata"] + ) + except json.JSONDecodeError: + pass + communities.append(Community(**community_dict)) - return { - "communities": communities, - "total_entries": total_entries, - } + return communities, count async def get_community_details( self, diff --git a/py/shared/api/models/__init__.py b/py/shared/api/models/__init__.py index d9c123ec8..4cfc484fa 100644 --- a/py/shared/api/models/__init__.py +++ b/py/shared/api/models/__init__.py @@ -18,14 +18,10 @@ ) from shared.api.models.kg.responses import ( # TODO: Need to review anything above this GraphResponse, - KGCreationResponse, KGEnrichmentResponse, - KGEntityDeduplicationResponse, WrappedGraphResponse, WrappedGraphsResponse, - WrappedKGCreationResponse, WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, ) from shared.api.models.management.responses import ( # Chunk Responses; Conversation Responses; Document Responses; Collection Responses; Prompt Responses; System Responses; User Responses AnalyticsResponse, From 588b7b30e097a8ca75c105bd6bd9aae73a8cd6a9 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 29 Nov 2024 11:38:32 -0600 Subject: [PATCH 16/28] Communities DB --- .../r2r-js-sdk-integration-tests.yml | 96 ++++++++++++++++--- .../GraphsIntegrationSuperUser.test.ts | 2 - py/core/main/api/v3/graph_router.py | 2 +- py/core/providers/database/graph.py | 18 +--- py/shared/abstractions/graph.py | 4 +- 5 files changed, 89 insertions(+), 33 deletions(-) diff --git a/.github/workflows/r2r-js-sdk-integration-tests.yml b/.github/workflows/r2r-js-sdk-integration-tests.yml index 2712ab5c1..2e15fd1d9 100644 --- a/.github/workflows/r2r-js-sdk-integration-tests.yml +++ b/.github/workflows/r2r-js-sdk-integration-tests.yml @@ -1,23 +1,13 @@ name: R2R JS SDK Integration Tests + on: push: branches: - - '**' # Trigger on all branches + - '**' + jobs: - test: + setup: runs-on: ubuntu-latest - env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} - AZURE_API_BASE: ${{ secrets.AZURE_API_BASE }} - AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} - TELEMETRY_ENABLED: 'false' - R2R_POSTGRES_HOST: localhost - R2R_POSTGRES_DBNAME: postgres - R2R_POSTGRES_PORT: '5432' - R2R_POSTGRES_PASSWORD: postgres - R2R_POSTGRES_USER: postgres - R2R_PROJECT_NAME: r2r_default steps: - uses: actions/checkout@v2 - name: Set up Python and install dependencies @@ -45,6 +35,82 @@ jobs: - name: Check if R2R server is running run: | curl http://localhost:7272/v2/health || echo "Server not responding" + + v2-unit-test: + needs: setup + runs-on: ubuntu-latest + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} + AZURE_API_BASE: ${{ secrets.AZURE_API_BASE }} + AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} + TELEMETRY_ENABLED: 'false' + R2R_POSTGRES_HOST: localhost + R2R_POSTGRES_DBNAME: postgres + R2R_POSTGRES_PORT: '5432' + R2R_POSTGRES_PASSWORD: postgres + R2R_POSTGRES_USER: postgres + R2R_PROJECT_NAME: r2r_default + steps: + - name: Run r2rV2Client tests + working-directory: ./js/sdk + run: pnpm jest r2rV2Client.test.ts + + v2-integration-tests: + needs: v2-unit-test + runs-on: ubuntu-latest + strategy: + matrix: + test-group: + - r2rV2ClientIntegrationSuperUser.test.ts + - r2rV2ClientIntegrationUser.test.ts + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} + AZURE_API_BASE: ${{ secrets.AZURE_API_BASE }} + AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} + TELEMETRY_ENABLED: 'false' + R2R_POSTGRES_HOST: localhost + R2R_POSTGRES_DBNAME: postgres + R2R_POSTGRES_PORT: '5432' + R2R_POSTGRES_PASSWORD: postgres + R2R_POSTGRES_USER: postgres + R2R_PROJECT_NAME: r2r_default + steps: - name: Run integration tests working-directory: ./js/sdk - run: pnpm test + run: pnpm jest ${{ matrix.test-group }} + + v3-integration-tests: + needs: setup + runs-on: ubuntu-latest + strategy: + matrix: + test-group: + - ChunksIntegrationSuperUser.test.ts + - CollectionsIntegrationSuperUser.test.ts + - ConversationsIntegrationSuperUser.test.ts + - DocumentsAndCollectionsIntegrationUser.test.ts + - DocumentsIntegrationSuperUser.test.ts + - GraphsIntegrationSuperUser.test.ts + - PromptsIntegrationSuperUser.test.ts + - RetrievalIntegrationSuperUser.test.ts + - SystemIntegrationSuperUser.test.ts + - SystemIntegrationUser.test.ts + - UsersIntegrationSuperUser.test.ts + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} + AZURE_API_BASE: ${{ secrets.AZURE_API_BASE }} + AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} + TELEMETRY_ENABLED: 'false' + R2R_POSTGRES_HOST: localhost + R2R_POSTGRES_DBNAME: postgres + R2R_POSTGRES_PORT: '5432' + R2R_POSTGRES_PASSWORD: postgres + R2R_POSTGRES_USER: postgres + R2R_PROJECT_NAME: r2r_default + steps: + - name: Run remaining tests + working-directory: ./js/sdk + run: pnpm jest ${{ matrix.test-group }} diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index e2c8ec52a..d58182966 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -157,8 +157,6 @@ describe("r2rClient V3 Graphs Integration Tests", () => { collectionId: collectionId, }); - console.log("Communities: ", response); - expect(response.results).toBeDefined(); expect(response.total_entries).toBeGreaterThanOrEqual(1); }); diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 265f18bb9..e9a80762d 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -1277,7 +1277,7 @@ async def get_communities( communities, count = ( await self.providers.database.graph_handler.get_communities( - graph_id=collection_id, + collection_id=collection_id, offset=offset, limit=limit, ) diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 7bf478c7e..1dc8c55b1 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -2504,7 +2504,7 @@ async def add_community_info( async def get_communities( self, - graph_id: UUID, + collection_id: UUID, offset: int, limit: int, community_ids: Optional[list[UUID]] = None, @@ -2515,7 +2515,7 @@ async def get_communities( Get communities for a graph. Args: - graph_id: UUID of the graph + collection_id: UUID of the collection offset: Number of records to skip limit: Maximum number of records to return (-1 for no limit) community_ids: Optional list of community IDs to filter by @@ -2525,8 +2525,8 @@ async def get_communities( Returns: Tuple of (list of communities, total count) """ - conditions = ["graph_id = $1"] - params: list[Any] = [graph_id] + conditions = ["collection_id = $1"] + params: list[Any] = [collection_id] param_index = 2 if community_ids: @@ -2540,8 +2540,7 @@ async def get_communities( param_index += 1 select_fields = """ - id, graph_id, name, summary, findings, rating, - rating_explanation, level, metadata, created_at, updated_at + id, collection_id, name, summary, findings, rating, rating_explanation """ if include_embeddings: select_fields += ", description_embedding" @@ -2574,13 +2573,6 @@ async def get_communities( communities = [] for row in rows: community_dict = dict(row) - if isinstance(community_dict["metadata"], str): - try: - community_dict["metadata"] = json.loads( - community_dict["metadata"] - ) - except json.JSONDecodeError: - pass communities.append(Community(**community_dict)) return communities, count diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index db811638d..5bc81dc97 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -89,7 +89,7 @@ class CommunityInfo(R2RSerializable): node: str cluster: UUID - level: int + level: Optional[int] id: Optional[UUID | int] = None parent_cluster: int | None is_final_cluster: bool @@ -104,10 +104,10 @@ def __init__(self, **kwargs): @dataclass class Community(R2RSerializable): - level: int name: str = "" summary: str = "" + level: Optional[int] = None findings: list[str] = [] id: Optional[int | UUID] = None community_id: Optional[UUID] = None From 326509a8468f5708aba212595c83ceec92d159af Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 29 Nov 2024 11:48:36 -0600 Subject: [PATCH 17/28] Explicit working path --- .github/workflows/r2r-js-sdk-integration-tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/r2r-js-sdk-integration-tests.yml b/.github/workflows/r2r-js-sdk-integration-tests.yml index 2e15fd1d9..20cc1e52e 100644 --- a/.github/workflows/r2r-js-sdk-integration-tests.yml +++ b/.github/workflows/r2r-js-sdk-integration-tests.yml @@ -30,7 +30,7 @@ jobs: version: 8.x run_install: false - name: Install JS SDK dependencies - working-directory: ./js/sdk + working-directory: /home/runner/work/R2R/js/sdk run: pnpm install - name: Check if R2R server is running run: | @@ -53,7 +53,7 @@ jobs: R2R_PROJECT_NAME: r2r_default steps: - name: Run r2rV2Client tests - working-directory: ./js/sdk + working-directory: /home/runner/work/R2R/js/sdk run: pnpm jest r2rV2Client.test.ts v2-integration-tests: @@ -78,7 +78,7 @@ jobs: R2R_PROJECT_NAME: r2r_default steps: - name: Run integration tests - working-directory: ./js/sdk + working-directory: /home/runner/work/R2R/js/sdk run: pnpm jest ${{ matrix.test-group }} v3-integration-tests: @@ -112,5 +112,5 @@ jobs: R2R_PROJECT_NAME: r2r_default steps: - name: Run remaining tests - working-directory: ./js/sdk + working-directory: /home/runner/work/R2R/js/sdk run: pnpm jest ${{ matrix.test-group }} From 2d45e09c8afe969744979e6142638d49f2408534 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 29 Nov 2024 11:53:20 -0600 Subject: [PATCH 18/28] Once again --- .../r2r-js-sdk-integration-tests.yml | 77 ++++++++++++++++++- 1 file changed, 73 insertions(+), 4 deletions(-) diff --git a/.github/workflows/r2r-js-sdk-integration-tests.yml b/.github/workflows/r2r-js-sdk-integration-tests.yml index 20cc1e52e..bac99dc10 100644 --- a/.github/workflows/r2r-js-sdk-integration-tests.yml +++ b/.github/workflows/r2r-js-sdk-integration-tests.yml @@ -30,7 +30,7 @@ jobs: version: 8.x run_install: false - name: Install JS SDK dependencies - working-directory: /home/runner/work/R2R/js/sdk + working-directory: ./js/sdk run: pnpm install - name: Check if R2R server is running run: | @@ -52,8 +52,31 @@ jobs: R2R_POSTGRES_USER: postgres R2R_PROJECT_NAME: r2r_default steps: + - uses: actions/checkout@v2 + - name: Set up Python and install dependencies + uses: ./.github/actions/setup-python-light + with: + os: ubuntu-latest + - name: Setup and start PostgreSQL + uses: ./.github/actions/setup-postgres-ext + with: + os: ubuntu-latest + - name: Start R2R Light server + uses: ./.github/actions/start-r2r-light + - name: Use Node.js + uses: actions/setup-node@v2 + with: + node-version: "20.x" + - name: Install pnpm + uses: pnpm/action-setup@v2 + with: + version: 8.x + run_install: false + - name: Install JS SDK dependencies + working-directory: ./js/sdk + run: pnpm install - name: Run r2rV2Client tests - working-directory: /home/runner/work/R2R/js/sdk + working-directory: ./js/sdk run: pnpm jest r2rV2Client.test.ts v2-integration-tests: @@ -77,8 +100,31 @@ jobs: R2R_POSTGRES_USER: postgres R2R_PROJECT_NAME: r2r_default steps: + - uses: actions/checkout@v2 + - name: Set up Python and install dependencies + uses: ./.github/actions/setup-python-light + with: + os: ubuntu-latest + - name: Setup and start PostgreSQL + uses: ./.github/actions/setup-postgres-ext + with: + os: ubuntu-latest + - name: Start R2R Light server + uses: ./.github/actions/start-r2r-light + - name: Use Node.js + uses: actions/setup-node@v2 + with: + node-version: "20.x" + - name: Install pnpm + uses: pnpm/action-setup@v2 + with: + version: 8.x + run_install: false + - name: Install JS SDK dependencies + working-directory: ./js/sdk + run: pnpm install - name: Run integration tests - working-directory: /home/runner/work/R2R/js/sdk + working-directory: ./js/sdk run: pnpm jest ${{ matrix.test-group }} v3-integration-tests: @@ -111,6 +157,29 @@ jobs: R2R_POSTGRES_USER: postgres R2R_PROJECT_NAME: r2r_default steps: + - uses: actions/checkout@v2 + - name: Set up Python and install dependencies + uses: ./.github/actions/setup-python-light + with: + os: ubuntu-latest + - name: Setup and start PostgreSQL + uses: ./.github/actions/setup-postgres-ext + with: + os: ubuntu-latest + - name: Start R2R Light server + uses: ./.github/actions/start-r2r-light + - name: Use Node.js + uses: actions/setup-node@v2 + with: + node-version: "20.x" + - name: Install pnpm + uses: pnpm/action-setup@v2 + with: + version: 8.x + run_install: false + - name: Install JS SDK dependencies + working-directory: ./js/sdk + run: pnpm install - name: Run remaining tests - working-directory: /home/runner/work/R2R/js/sdk + working-directory: ./js/sdk run: pnpm jest ${{ matrix.test-group }} From 8b5dbb2ac4b246bccf4b1cdc477c4eef37aca5c0 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 29 Nov 2024 12:02:54 -0600 Subject: [PATCH 19/28] Fail fast false --- .github/workflows/r2r-js-sdk-integration-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/r2r-js-sdk-integration-tests.yml b/.github/workflows/r2r-js-sdk-integration-tests.yml index bac99dc10..9f0ae0365 100644 --- a/.github/workflows/r2r-js-sdk-integration-tests.yml +++ b/.github/workflows/r2r-js-sdk-integration-tests.yml @@ -83,6 +83,7 @@ jobs: needs: v2-unit-test runs-on: ubuntu-latest strategy: + fail-fast: false matrix: test-group: - r2rV2ClientIntegrationSuperUser.test.ts @@ -131,6 +132,7 @@ jobs: needs: setup runs-on: ubuntu-latest strategy: + fail-fast: false matrix: test-group: - ChunksIntegrationSuperUser.test.ts From 2a4c8e6e5c83f43e3d44504a3ee7be4a5e3156cf Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:08:03 -0600 Subject: [PATCH 20/28] Testing around community creation --- .../GraphsIntegrationSuperUser.test.ts | 60 ++++++++++ js/sdk/src/types.ts | 14 ++- js/sdk/src/v3/clients/graphs.ts | 56 ++++++++- py/core/base/providers/database.py | 4 +- py/core/main/api/v3/graph_router.py | 31 +++-- .../main/orchestration/simple/kg_workflow.py | 22 ++-- py/core/main/services/kg_service.py | 51 ++++----- py/core/providers/database/graph.py | 107 +++++++----------- py/shared/abstractions/graph.py | 6 + 9 files changed, 224 insertions(+), 127 deletions(-) diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index d58182966..490db1cc0 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -10,6 +10,7 @@ describe("r2rClient V3 Graphs Integration Tests", () => { let entity1Id: string; let entity2Id: string; let relationshipId: string; + let communityId: string; beforeAll(async () => { client = new r2rClient(baseUrl); @@ -250,6 +251,65 @@ describe("r2rClient V3 Graphs Integration Tests", () => { expect(response.results.predicate).toBe("falls in love with"); }); + test("Create a new community", async () => { + const response = await client.graphs.createCommunity({ + collectionId: collectionId, + name: "Raskolnikov and Dunia Community", + summary: + "Raskolnikov and Dunia are siblings, the children of Pulcheria Alexandrovna", + findings: [ + "Raskolnikov and Dunia are siblings", + "They are the children of Pulcheria Alexandrovna", + "Their family comes from a modest background", + "Dunia works as a governess to support the family", + "Raskolnikov is a former university student", + "Both siblings are intelligent and well-educated", + "They maintain a close relationship despite living apart", + "Their mother Pulcheria writes letters to keep them connected", + ], + rating: 10, + ratingExplanation: + "Raskolnikov and Dunia are central to the story and have a complex relationship", + }); + + communityId = response.results.id; + + expect(response.results).toBeDefined(); + expect(response.results.name).toBe("Raskolnikov and Dunia Community"); + expect(response.results.summary).toBe( + "Raskolnikov and Dunia are siblings, the children of Pulcheria Alexandrovna", + ); + expect(response.results.findings).toContain( + "Raskolnikov and Dunia are siblings", + ); + expect(response.results.findings).toContain( + "They are the children of Pulcheria Alexandrovna", + ); + expect(response.results.findings).toContain( + "Their family comes from a modest background", + ); + expect(response.results.findings).toContain( + "Dunia works as a governess to support the family", + ); + expect(response.results.findings).toContain( + "Raskolnikov is a former university student", + ); + expect(response.results.findings).toContain( + "Both siblings are intelligent and well-educated", + ); + expect(response.results.findings).toContain( + "They maintain a close relationship despite living apart", + ); + expect(response.results.findings).toContain( + "Their mother Pulcheria writes letters to keep them connected", + ); + expect(response.results.rating).toBe(10); + //TODO: Why is this failing? + // expect(response.results.ratingExplanation).toBe( + // "Raskolnikov and Dunia are central to the story and have a complex relationship", + // ); + }); + test("Update the entity", async () => { const response = await client.graphs.updateEntity({ collectionId: collectionId, diff --git a/js/sdk/src/types.ts b/js/sdk/src/types.ts index efa06c0df..0c72d2ff3 100644 --- a/js/sdk/src/types.ts +++ b/js/sdk/src/types.ts @@ -47,9 +47,19 @@ export interface CollectionResponse { document_count: number; } -//TODO: Sync this with the finished API response model // Community types -export interface CommunityResponse {} +export interface CommunityResponse { + id: string; + name: string; + summary: string; + findings: string[]; + communityId?: string; + graphId?: string; + collectionId?: string; + rating?: number; + ratingExplanation?: string; + descriptionEmbedding?: string; +} // Conversation types export interface ConversationResponse { diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index 10a815fd8..85734c554 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -358,7 +358,59 @@ export class GraphsClient { ); } - // TODO: Create community + /** + * Creates a new community in the graph. + * + * While communities are typically built automatically via the /graphs/{id}/communities/build endpoint, + * this endpoint allows you to manually create your own communities. + * + * This can be useful when you want to: + * - Define custom groupings of entities based on domain knowledge + * - Add communities that weren't detected by the automatic process + * - Create hierarchical organization structures + * - Tag groups of entities with specific metadata + * + * The created communities will be integrated with any existing automatically detected communities + * in the graph's community structure. + * + * @param collectionId The collection ID corresponding to the graph + * @param name Name of the community + * @param summary Summary of the community + * @param findings Findings or insights about the community + * @param rating Rating of the community + * @param ratingExplanation Explanation of the community rating + * @param attributes Additional attributes to associate with the community + * @returns WrappedCommunityResponse + */ + @feature("graphs.createCommunity") + async createCommunity(options: { + collectionId: string; + name: string; + summary: string; + findings?: string[]; + rating?: number; + ratingExplanation?: string; + attributes?: Record; + }): Promise { + const data = { + name: options.name, + ...(options.summary && { summary: options.summary }), + ...(options.findings && { findings: options.findings }), + ...(options.rating && { rating: options.rating }), + ...(options.ratingExplanation && { + rating_explanation: options.ratingExplanation, + }), + ...(options.attributes && { attributes: options.attributes }), + }; + + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/communities`, + { + data, + }, + ); + } /** * List all communities in a graph. @@ -420,7 +472,6 @@ export class GraphsClient { findings?: string[]; rating?: number; ratingExplanation?: string; - level?: number; attributes?: Record; }): Promise { const data = { @@ -431,7 +482,6 @@ export class GraphsClient { ...(options.ratingExplanation && { rating_explanation: options.ratingExplanation, }), - ...(options.level && { level: options.level }), ...(options.attributes && { attributes: options.attributes }), }; return this.client.makeRequest( diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 0f5007cf9..262c978ad 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -648,7 +648,7 @@ async def delete(self, *args: Any, **kwargs: Any) -> None: class CommunityHandler(Handler): @abstractmethod - async def create(self, *args: Any, **kwargs: Any) -> None: + async def create(self, *args: Any, **kwargs: Any) -> Community: """Create communities in storage.""" pass @@ -658,7 +658,7 @@ async def get(self, *args: Any, **kwargs: Any) -> list[Community]: pass @abstractmethod - async def update(self, *args: Any, **kwargs: Any) -> None: + async def update(self, *args: Any, **kwargs: Any) -> Community: """Update communities in storage.""" pass diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 578e03ab9..ded054e81 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -1158,7 +1158,7 @@ async def delete_relationship( }, ) @self.base_endpoint - async def create_communities( + async def create_community( collection_id: UUID = Path( ..., description="The collection ID corresponding to the graph to create the community in.", @@ -1168,12 +1168,6 @@ async def create_communities( findings: Optional[list[str]] = Body( default=[], description="Findings about the community" ), - level: Optional[int] = Body( - default=0, - ge=0, - le=100, - description="The level of the community", - ), rating: Optional[float] = Body( default=5, ge=1, le=10, description="Rating between 1 and 10" ), @@ -1184,13 +1178,14 @@ async def create_communities( default=None, description="Attributes for the community" ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): + ) -> WrappedCommunityResponse: """ Creates a new community in the graph. While communities are typically built automatically via the /graphs/{id}/communities/build endpoint, - this endpoint allows you to manually create your own communities. This can be useful when you want to: + this endpoint allows you to manually create your own communities. + This can be useful when you want to: - Define custom groupings of entities based on domain knowledge - Add communities that weren't detected by the automatic process - Create hierarchical organization structures @@ -1199,16 +1194,22 @@ async def create_communities( The created communities will be integrated with any existing automatically detected communities in the graph's community structure. """ - return await self.services["kg"].create_community_v3( - graph_id=collection_id, + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + + return await self.services["kg"].create_community( + parent_id=collection_id, name=name, summary=summary, findings=findings, rating=rating, rating_explanation=rating_explanation, - level=level, - attributes=attributes, - auth_user=auth_user, ) @self.router.get( @@ -1486,7 +1487,6 @@ async def update_community( findings: Optional[list[str]] = Body(None), rating: Optional[float] = Body(None), rating_explanation: Optional[str] = Body(None), - level: Optional[int] = Body(None), attributes: Optional[dict] = Body(None), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedCommunityResponse: @@ -1506,7 +1506,6 @@ async def update_community( findings=findings, rating=rating, rating_explanation=rating_explanation, - level=level, attributes=attributes, auth_user=auth_user, ) diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index a26b7d627..8c1a889eb 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -53,28 +53,30 @@ async def extract_triples(input_data): offset = 0 while True: # Fetch current batch - batch = (await service.providers.database.collections_handler.documents_in_collection( - collection_id=collection_id, - offset=offset, - limit=batch_size - ))["results"] - + batch = ( + await service.providers.database.collections_handler.documents_in_collection( + collection_id=collection_id, + offset=offset, + limit=batch_size, + ) + )["results"] + # If no documents returned, we've reached the end if not batch: break - + # Add current batch to results documents.extend(batch) - + # Update offset for next batch offset += batch_size - + # Optional: If batch is smaller than batch_size, we've reached the end if len(batch) < batch_size: break # documents = service.providers.database.collections_handler.documents_in_collection(input_data.get("collection_id"), offset=0, limit=1000) - print('extracting for documents = ', documents) + print("extracting for documents = ", documents) document_ids = [document.id for document in documents] logger.info( diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index e4700a872..57c5688fe 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -384,38 +384,32 @@ async def get_relationships( ################### COMMUNITIES ################### - @telemetry_event("create_community_v3") - async def create_community_v3( + @telemetry_event("create_community") + async def create_community( self, - graph_id: UUID, + parent_id: UUID, name: str, summary: str, - findings: list[str], + findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], - level: Optional[int], - attributes: Optional[dict], - auth_user: Any, - **kwargs, - ): - embedding = str( + ) -> Community: + description_embedding = str( await self.providers.embedding.async_get_embedding(summary) ) return await self.providers.database.graph_handler.communities.create( - graph_id=graph_id, + parent_id=parent_id, + store_type="graph", # type: ignore name=name, summary=summary, - embedding=embedding, + description_embedding=description_embedding, findings=findings, rating=rating, rating_explanation=rating_explanation, - level=level, - attributes=attributes, - auth_user=auth_user, ) - @telemetry_event("update_community_v3") - async def update_community_v3( + @telemetry_event("update_community") + async def update_community( self, id: UUID, community_id: UUID, @@ -424,10 +418,6 @@ async def update_community_v3( findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], - level: Optional[int], - attributes: Optional[dict], - auth_user: Any, - **kwargs, ): if summary is not None: embedding = str( @@ -445,9 +435,6 @@ async def update_community_v3( findings=findings, rating=rating, rating_explanation=rating_explanation, - level=level, - attributes=attributes, - auth_user=auth_user, ) @telemetry_event("delete_community_v3") @@ -1109,11 +1096,11 @@ async def _extract_kg( "relation_types": "\n".join(relation_types), }, ) - print('starting a job....') + print("starting a job....") for attempt in range(retries): try: - print('getting a response....') + print("getting a response....") response = await self.providers.llm.aget_completion( messages, @@ -1121,7 +1108,7 @@ async def _extract_kg( ) kg_extraction = response.choices[0].message.content - print('kg_extraction = ', kg_extraction) + print("kg_extraction = ", kg_extraction) if not kg_extraction: raise R2RException( @@ -1150,7 +1137,7 @@ async def parse_fn(response_str: str) -> Any: relationships = re.findall( relationship_pattern, response_str ) - print('found len(relationships) = ', len(relationships)) + print("found len(relationships) = ", len(relationships)) entities_arr = [] for entity in entities: @@ -1173,7 +1160,7 @@ async def parse_fn(response_str: str) -> Any: attributes={}, ) ) - print('found len(entities) = ', len(entities)) + print("found len(entities) = ", len(entities)) relations_arr = [] for relationship in relationships: @@ -1246,7 +1233,7 @@ async def store_kg_extractions( total_entities, total_relationships = 0, 0 - print('received len(kg_extractions) = ', len(kg_extractions)) + print("received len(kg_extractions) = ", len(kg_extractions)) for extraction in kg_extractions: # print("extraction = ", extraction) @@ -1254,7 +1241,9 @@ async def store_kg_extractions( # total_entities + len(extraction.entities), # total_relationships + len(extraction.relationships), # ) - print('storing len(extraction.entities) = ', len(extraction.entities)) + print( + "storing len(extraction.entities) = ", len(extraction.entities) + ) for entity in extraction.entities: await self.providers.database.graph_handler.entities.create( diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 1dc8c55b1..b613f2217 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -121,9 +121,9 @@ async def create_tables(self) -> None: async def create( self, - name: str, parent_id: UUID, store_type: StoreType, + name: str, category: Optional[str] = None, description: Optional[str] = None, description_embedding: Optional[list[float] | str] = None, @@ -817,7 +817,7 @@ async def create_tables(self) -> None: node TEXT NOT NULL, cluster UUID NOT NULL, parent_cluster INT, - level INT NOT NULL, + level INT, is_final_cluster BOOLEAN NOT NULL, relationship_ids UUID[] NOT NULL, graph_id UUID, @@ -836,12 +836,12 @@ async def create_tables(self) -> None: graph_id UUID, collection_id UUID, community_id UUID, - level INT NOT NULL, + level INT, name TEXT NOT NULL, summary TEXT NOT NULL, - findings TEXT[] NOT NULL, - rating FLOAT NOT NULL, - rating_explanation TEXT NOT NULL, + findings TEXT[], + rating FLOAT, + rating_explanation TEXT, description_embedding {vector_column_str} NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, @@ -853,47 +853,54 @@ async def create_tables(self) -> None: async def create( self, - graph_id: UUID, + parent_id: UUID, + store_type: StoreType, name: str, summary: str, - embedding: str, - findings: list[str], + findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], - level: Optional[int], - metadata: Optional[dict], - auth_user: Any, - ) -> None: + description_embedding: Optional[list[float] | str] = None, + ) -> Community: + # Do we ever want to get communities from document store? + table_name = "graph_community" - if not auth_user.is_superuser: - if not await self._check_permissions(graph_id, auth_user.id): - raise R2RException( - "You do not have permission to create this community.", - 403, - ) + if isinstance(description_embedding, list): + description_embedding = str(description_embedding) - QUERY = f""" - INSERT INTO {self._get_table_name("graph_community")} - (graph_id, name, summary, findings, rating, rating_explanation, description_embedding, level, metadata, created_by, updated_by) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - RETURNING id, graph_id, name, summary, findings, rating, rating_explanation, level, metadata, created_by, updated_by + query = f""" + INSERT INTO {self._get_table_name(table_name)} + (community_id, name, summary, findings, rating, rating_explanation, description_embedding) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at """ params = [ - graph_id, + parent_id, name, summary, findings, rating, rating_explanation, - embedding, - level, - metadata, - auth_user.id, - auth_user.id, + description_embedding, ] - return await self.connection_manager.fetchrow_query(QUERY, params) + result = await self.connection_manager.fetchrow_query( + query=query, + params=params, + ) + + return Community( + id=result["id"], + community_id=result["community_id"], + name=result["name"], + summary=result["summary"], + findings=result["findings"], + rating=result["rating"], + rating_explanation=result["rating_explanation"], + created_at=result["created_at"], + updated_at=result["updated_at"], + ) async def update( self, @@ -905,20 +912,10 @@ async def update( findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], - level: Optional[int], - metadata: Optional[dict], - auth_user: Any, - ) -> None: - - if not auth_user.is_superuser: - if not await self._check_permissions(id, auth_user.id): - raise R2RException( - "You do not have permission to update this community.", - 403, - ) + ) -> Community: update_fields = [] - params = [community_id] # type: ignore + params: list[Any] = [community_id] # type: ignore if name is not None: update_fields.append(f"name = ${len(params)+1}") params.append(name) @@ -943,17 +940,6 @@ async def update( update_fields.append(f"rating_explanation = ${len(params)+1}") params.append(rating_explanation) - if level is not None: - update_fields.append(f"level = ${len(params)+1}") - params.append(level) - - if metadata is not None: - update_fields.append(f"metadata = ${len(params)+1}") - params.append(metadata) - - update_fields.append(f"updated_by = ${len(params)+1}") - params.append(auth_user.id) - update_fields.append(f"updated_at = CURRENT_TIMESTAMP") QUERY = f""" @@ -963,16 +949,11 @@ async def update( return await self.connection_manager.fetchrow_query(QUERY, params) async def delete( - self, graph_id: UUID, community_id: UUID, auth_user: Any + self, + graph_id: UUID, + community_id: UUID, ) -> None: - if not auth_user.is_superuser: - if not await self._check_permissions(graph_id, auth_user.id): - raise R2RException( - "You do not have permission to delete this community.", - 403, - ) - QUERY = f""" DELETE FROM {self._get_table_name("graph_community")} WHERE id = $1 """ @@ -2059,7 +2040,7 @@ async def get_entities( Tuple of (list of entities, total count) """ conditions = ["parent_id = $1"] - params = [graph_id] + params: [Any] = [graph_id] param_index = 2 if entity_ids: diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 5bc81dc97..d7fe6548b 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -117,6 +117,12 @@ class Community(R2RSerializable): rating_explanation: str | None = None description_embedding: list[float] | None = None attributes: dict[str, Any] | None = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) def __init__(self, **kwargs): if isinstance(kwargs.get("attributes", None), str): From 555f7328397b4b245f5680fef334aaf5f024528d Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:25:56 -0600 Subject: [PATCH 21/28] Delete community test --- .../GraphsIntegrationSuperUser.test.ts | 18 ++++++ py/core/main/api/v3/graph_router.py | 46 ++++++++++++++-- py/core/main/services/kg_service.py | 17 ++---- py/core/providers/database/graph.py | 55 ++++++++++++------- 4 files changed, 100 insertions(+), 36 deletions(-) diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index 490db1cc0..3341ca818 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -386,6 +386,24 @@ describe("r2rClient V3 Graphs Integration Tests", () => { expect(response.results.predicate).toBe("marries"); }); + test("Delete the community", async () => { + const response = await client.graphs.deleteCommunity({ + collectionId: collectionId, + communityId: communityId, + }); + + expect(response.results).toBeDefined(); + }); + + test("Check that the community was deleted", async () => { + const response = await client.graphs.listCommunities({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + test("Reset the graph", async () => { const response = await client.graphs.reset({ collectionId: collectionId, diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index ded054e81..cdd23c724 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -1150,7 +1150,37 @@ async def delete_relationship( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.communities.create(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", communities=[community1, community2]) + result = client.graphs.create_community( + collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + name="My Community", + summary="A summary of the community", + findings=["Finding 1", "Finding 2"], + rating=5, + rating_explanation="This is a rating explanation", + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.createCommunity({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + name: "My Community", + summary: "A summary of the community", + findings: ["Finding 1", "Finding 2"], + rating: 5, + ratingExplanation: "This is a rating explanation", + }); + } + + main(); """ ), }, @@ -1414,14 +1444,18 @@ async def delete_community( ), auth_user=Depends(self.providers.auth.auth_wrapper), ): - if not auth_user.is_superuser: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): raise R2RException( - "Only superusers can access this endpoint.", 403 + "The currently authenticated user does not have access to the specified graph.", + 403, ) - await self.services["kg"].delete_community_v3( - graph_id=collection_id, + + await self.services["kg"].delete_community( + parent_id=collection_id, community_id=community_id, - auth_user=auth_user, ) return GenericBooleanResponse(success=True) # type: ignore diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 57c5688fe..b4ea2aa91 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -315,10 +315,8 @@ async def create_relationship( @telemetry_event("delete_relationship_v3") async def delete_relationship_v3( self, - level: DataLevel, id: UUID, relationship_id: UUID, - **kwargs, ): return ( await self.providers.database.graph_handler.relationships.delete( @@ -437,18 +435,15 @@ async def update_community( rating_explanation=rating_explanation, ) - @telemetry_event("delete_community_v3") - async def delete_community_v3( + @telemetry_event("delete_community") + async def delete_community( self, - graph_id: UUID, + parent_id: UUID, community_id: UUID, - auth_user: Any, - **kwargs, - ): - return await self.providers.database.graph_handler.communities.delete( - graph_id=graph_id, + ) -> None: + await self.providers.database.graph_handler.communities.delete( + parent_id=parent_id, community_id=community_id, - auth_user=auth_user, ) @telemetry_event("list_communities_v3") diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index b613f2217..4896e651a 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -885,22 +885,28 @@ async def create( description_embedding, ] - result = await self.connection_manager.fetchrow_query( - query=query, - params=params, - ) + try: + result = await self.connection_manager.fetchrow_query( + query=query, + params=params, + ) - return Community( - id=result["id"], - community_id=result["community_id"], - name=result["name"], - summary=result["summary"], - findings=result["findings"], - rating=result["rating"], - rating_explanation=result["rating_explanation"], - created_at=result["created_at"], - updated_at=result["updated_at"], - ) + return Community( + id=result["id"], + community_id=result["community_id"], + name=result["name"], + summary=result["summary"], + findings=result["findings"], + rating=result["rating"], + rating_explanation=result["rating_explanation"], + created_at=result["created_at"], + updated_at=result["updated_at"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while creating the community: {e}", + ) async def update( self, @@ -950,14 +956,25 @@ async def update( async def delete( self, - graph_id: UUID, + parent_id: UUID, community_id: UUID, ) -> None: + table_name = "graph_community" - QUERY = f""" - DELETE FROM {self._get_table_name("graph_community")} WHERE id = $1 + query = f""" + DELETE FROM {self._get_table_name(table_name)} + WHERE id = $1 AND graph_id = $2 """ - await self.connection_manager.execute_query(QUERY, [community_id]) + + params = [community_id, parent_id] + + try: + await self.connection_manager.execute_query(query, params) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while deleting the community: {e}", + ) async def get( self, From a00fab88e7a036e51f64d1c2b608dec4e49f9d91 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 29 Nov 2024 21:43:00 -0600 Subject: [PATCH 22/28] Update community tests --- .../GraphsIntegrationSuperUser.test.ts | 34 ++++++ .../r2rV2ClientIntegrationSuperUser.test.ts | 2 +- py/core/main/api/v3/graph_router.py | 53 +++++---- py/core/main/services/kg_service.py | 23 ++-- py/core/providers/database/graph.py | 106 ++++++++++-------- 5 files changed, 138 insertions(+), 80 deletions(-) diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index 3341ca818..68a510337 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -386,6 +386,40 @@ describe("r2rClient V3 Graphs Integration Tests", () => { expect(response.results.predicate).toBe("marries"); }); + test("Update the community", async () => { + const response = await client.graphs.updateCommunity({ + collectionId: collectionId, + communityId: communityId, + name: "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community", + summary: + "Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova", + }); + + expect(response.results).toBeDefined(); + expect(response.results.name).toBe( + "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community", + ); + expect(response.results.summary).toBe( + "Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova", + ); + }); + + test("Retrieve the updated community", async () => { + const response = await client.graphs.getCommunity({ + collectionId: collectionId, + communityId: communityId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(communityId); + expect(response.results.name).toBe( + "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community", + ); + expect(response.results.summary).toBe( + "Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova", + ); + }); + test("Delete the community", async () => { const response = await client.graphs.deleteCommunity({ collectionId: collectionId, diff --git a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts index 594663799..ddeb9a365 100644 --- a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts @@ -129,7 +129,7 @@ describe("r2rClient Integration Tests", () => { metadatas: [{ title: "raskolnikov.txt" }, { title: "karamozov.txt" }], }), ).resolves.not.toThrow(); - }); + }, 10000); test("Ingest files in folder", async () => { const files = ["examples/data/folder"]; diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index cdd23c724..eeb695fae 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -1204,9 +1204,6 @@ async def create_community( rating_explanation: Optional[str] = Body( default="", description="Explanation for the rating" ), - attributes: Optional[dict] = Body( - default=None, description="Attributes for the community" - ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedCommunityResponse: """ @@ -1302,9 +1299,15 @@ async def get_communities( ) -> WrappedCommunitiesResponse: """ Lists all communities in the graph with pagination support. - - By default, all attributes are returned, but this can be limited using the `attributes` parameter. """ + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) communities, count = ( await self.providers.database.graph_handler.get_communities( @@ -1371,23 +1374,29 @@ async def get_community( ) -> WrappedCommunityResponse: """ Retrieves a specific community by its ID. - - By default, all attributes are returned, but this can be limited using the `attributes` parameter. """ - if not auth_user.is_superuser: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): raise R2RException( - "Only superusers can access this endpoint.", 403 + "The currently authenticated user does not have access to the specified graph.", + 403, ) - return await self.services[ + results = await self.services[ "kg" ].providers.database.graph_handler.communities.get( - graph_id=collection_id, - community_id=community_id, - auth_user=auth_user, + parent_id=collection_id, + community_ids=[community_id], + store_type="graph", offset=0, limit=1, ) + print(f"results: {results}") + if len(results) == 0 or len(results[0]) == 0: + raise R2RException("Relationship not found", 404) + return results[0][0] @self.router.delete( "/graphs/{collection_id}/communities/{community_id}", @@ -1519,29 +1528,29 @@ async def update_community( name: Optional[str] = Body(None), summary: Optional[str] = Body(None), findings: Optional[list[str]] = Body(None), - rating: Optional[float] = Body(None), + rating: Optional[float] = Body(default=None, ge=1, le=10), rating_explanation: Optional[str] = Body(None), - attributes: Optional[dict] = Body(None), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedCommunityResponse: """ - Updates an existing community's metadata and properties. + Updates an existing community in the graph. """ - if not auth_user.is_superuser: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): raise R2RException( - "Only superusers can update communities", 403 + "The currently authenticated user does not have access to the specified graph.", + 403, ) - return await self.services["kg"].update_community_v3( - id=collection_id, + return await self.services["kg"].update_community( community_id=community_id, name=name, summary=summary, findings=findings, rating=rating, rating_explanation=rating_explanation, - attributes=attributes, - auth_user=auth_user, ) @self.router.post( diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index b4ea2aa91..f29553f32 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -409,27 +409,25 @@ async def create_community( @telemetry_event("update_community") async def update_community( self, - id: UUID, community_id: UUID, name: Optional[str], summary: Optional[str], findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], - ): + ) -> Community: + summary_embedding = None if summary is not None: - embedding = str( + summary_embedding = str( await self.providers.embedding.async_get_embedding(summary) ) - else: - embedding = None return await self.providers.database.graph_handler.communities.update( - id=id, community_id=community_id, + store_type="graph", # type: ignore name=name, summary=summary, - embedding=embedding, + summary_embedding=summary_embedding, findings=findings, rating=rating, rating_explanation=rating_explanation, @@ -446,16 +444,16 @@ async def delete_community( community_id=community_id, ) - @telemetry_event("list_communities_v3") - async def list_communities_v3( + @telemetry_event("list_communities") + async def list_communities( self, - id: UUID, + collection_id: UUID, offset: int, limit: int, - **kwargs, ): return await self.providers.database.graph_handler.communities.get( - id=id, + parent_id=collection_id, + store_type="graph", # type: ignore offset=offset, limit=limit, ) @@ -473,7 +471,6 @@ async def get_communities( ): return await self.providers.database.graph_handler.get_communities( collection_id=collection_id, - levels=levels, community_ids=community_ids, offset=offset or 0, limit=limit or -1, diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 4896e651a..6ceefe7b1 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -870,9 +870,9 @@ async def create( query = f""" INSERT INTO {self._get_table_name(table_name)} - (community_id, name, summary, findings, rating, rating_explanation, description_embedding) + (collection_id, name, summary, findings, rating, rating_explanation, description_embedding) VALUES ($1, $2, $3, $4, $5, $6, $7) - RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at + RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at """ params = [ @@ -893,7 +893,7 @@ async def create( return Community( id=result["id"], - community_id=result["community_id"], + collection_id=result["collection_id"], name=result["name"], summary=result["summary"], findings=result["findings"], @@ -910,49 +910,83 @@ async def create( async def update( self, - id: UUID, community_id: UUID, - name: Optional[str], - summary: Optional[str], - embedding: Optional[str], - findings: Optional[list[str]], - rating: Optional[float], - rating_explanation: Optional[str], + store_type: StoreType, + name: Optional[str] = None, + summary: Optional[str] = None, + summary_embedding: Optional[list[float] | str] = None, + findings: Optional[list[str]] = None, + rating: Optional[float] = None, + rating_explanation: Optional[str] = None, ) -> Community: - + table_name = "graph_community" update_fields = [] - params: list[Any] = [community_id] # type: ignore + params: list[Any] = [] + param_index = 1 + if name is not None: - update_fields.append(f"name = ${len(params)+1}") + update_fields.append(f"name = ${param_index}") params.append(name) + param_index += 1 if summary is not None: - update_fields.append(f"summary = ${len(params)+1}") + update_fields.append(f"summary = ${param_index}") params.append(summary) + param_index += 1 - if embedding is not None: - update_fields.append(f"description_embedding = ${len(params)+1}") - params.append(embedding) + if summary_embedding is not None: + update_fields.append(f"description_embedding = ${param_index}") + params.append(summary_embedding) + param_index += 1 if findings is not None: - update_fields.append(f"findings = ${len(params)+1}") + update_fields.append(f"findings = ${param_index}") params.append(findings) + param_index += 1 if rating is not None: - update_fields.append(f"rating = ${len(params)+1}") + update_fields.append(f"rating = ${param_index}") params.append(rating) + param_index += 1 if rating_explanation is not None: - update_fields.append(f"rating_explanation = ${len(params)+1}") + update_fields.append(f"rating_explanation = ${param_index}") params.append(rating_explanation) + param_index += 1 - update_fields.append(f"updated_at = CURRENT_TIMESTAMP") + if not update_fields: + raise R2RException(status_code=400, message="No fields to update") - QUERY = f""" - UPDATE {self._get_table_name("graph_community")} SET {", ".join(update_fields)} WHERE id = $1 - RETURNING id, graph_id, name, summary, findings, rating, rating_explanation, metadata, level, created_by, updated_by, updated_at + update_fields.append("updated_at = NOW()") + params.append(community_id) + + query = f""" + UPDATE {self._get_table_name(table_name)} + SET {", ".join(update_fields)} + WHERE id = ${param_index}\ + RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at """ - return await self.connection_manager.fetchrow_query(QUERY, params) + try: + result = await self.connection_manager.fetchrow_query( + query, params + ) + + return Community( + id=result["id"], + community_id=result["community_id"], + name=result["name"], + summary=result["summary"], + findings=result["findings"], + rating=result["rating"], + rating_explanation=result["rating_explanation"], + created_at=result["created_at"], + updated_at=result["updated_at"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the community: {e}", + ) async def delete( self, @@ -990,7 +1024,7 @@ async def get( # Do we ever want to get communities from document store? table_name = "graph_community" - conditions = ["graph_id = $1"] + conditions = ["collection_id = $1"] params: list[Any] = [parent_id] param_index = 2 @@ -1005,8 +1039,8 @@ async def get( param_index += 1 select_fields = """ - id, graph_id, name, summary, findings, rating, - rating_explanation, level, metadata, created_at, updated_at + id, community_id, name, summary, findings, rating, + rating_explanation, level, created_at, updated_at """ if include_embeddings: select_fields += ", description_embedding" @@ -1043,15 +1077,6 @@ async def get( for row in rows: community_dict = dict(row) - # Process metadata if it exists and is a string - if isinstance(community_dict["metadata"], str): - try: - community_dict["metadata"] = json.loads( - community_dict["metadata"] - ) - except json.JSONDecodeError: - pass - communities.append(Community(**community_dict)) return communities, count @@ -2506,7 +2531,6 @@ async def get_communities( offset: int, limit: int, community_ids: Optional[list[UUID]] = None, - levels: Optional[list[int]] = None, include_embeddings: bool = False, ) -> tuple[list[Community], int]: """ @@ -2517,7 +2541,6 @@ async def get_communities( offset: Number of records to skip limit: Maximum number of records to return (-1 for no limit) community_ids: Optional list of community IDs to filter by - levels: Optional list of levels to filter by include_embeddings: Whether to include embeddings in the response Returns: @@ -2532,11 +2555,6 @@ async def get_communities( params.append(community_ids) param_index += 1 - if levels: - conditions.append(f"level = ANY(${param_index})") - params.append(levels) - param_index += 1 - select_fields = """ id, collection_id, name, summary, findings, rating, rating_explanation """ From ed7f7344e31d84da92064fda32f750a7db68ec04 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 29 Nov 2024 22:28:00 -0600 Subject: [PATCH 23/28] Clean up type errors, cleaner code --- js/sdk/src/v3/clients/graphs.ts | 25 +++++++++++ py/core/main/api/v3/graph_router.py | 68 ++++++++++++++++++++++++----- py/core/main/services/kg_service.py | 46 +++++++------------ py/core/providers/database/graph.py | 4 +- py/shared/abstractions/graph.py | 2 +- 5 files changed, 101 insertions(+), 44 deletions(-) diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index 85734c554..351b0b2e3 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -565,9 +565,34 @@ export class GraphsClient { ); } + /** + * Creates communities in the graph by analyzing entity relationships and similarities. + * + * Communities are created through the following process: + * 1. Analyzes entity relationships and metadata to build a similarity graph + * 2. Applies advanced community detection algorithms (e.g. Leiden) to identify densely connected groups + * 3. Creates hierarchical community structure with multiple granularity levels + * 4. Generates natural language summaries and statistical insights for each community + * + * The resulting communities can be used to: + * - Understand high-level graph structure and organization + * - Identify key entity groupings and their relationships + * - Navigate and explore the graph at different levels of detail + * - Generate insights about entity clusters and their characteristics + * + * The community detection process is configurable through settings like: + * - Community detection algorithm parameters + * - Summary generation prompt + * + * @param options + * @returns + */ @feature("graphs.buildCommunities") async buildCommunities(options: { collectionId: string; + runType?: string; + kgEntichmentSettings?: Record; + runWithOrchestration?: boolean; }): Promise { return this.client.makeRequest( "POST", diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index eeb695fae..a8f189d23 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -304,7 +304,8 @@ async def build_communities( run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.providers.auth.auth_wrapper), ): # -> WrappedKGEnrichmentResponse: - """Creates communities in the graph by analyzing entity relationships and similarities. + """ + Creates communities in the graph by analyzing entity relationships and similarities. Communities are created through the following process: 1. Analyzes entity relationships and metadata to build a similarity graph @@ -323,8 +324,14 @@ async def build_communities( - Summary generation prompt """ print("collection_id = ", collection_id) - if not auth_user.is_superuser: - logger.warning("Implement permission checks here.") + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) # If no collection ID is provided, use the default user collection # id = generate_default_user_collection_id(auth_user.id) @@ -586,6 +593,14 @@ async def get_entities( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedEntitiesResponse: """Lists all entities in the graph with pagination support.""" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) # return await self.services["kg"].get_entities( # id, offset, limit, auth_user # ) @@ -626,7 +641,7 @@ async def create_entity( and collection_id not in auth_user.graph_ids ): raise R2RException( - "The currently authenticated user does not have access to this graph.", + "The currently authenticated user does not have access to the specified graph.", 403, ) @@ -680,7 +695,7 @@ async def create_relationship( and collection_id not in auth_user.graph_ids ): raise R2RException( - "The currently authenticated user does not have access to this graph.", + "The currently authenticated user does not have access to the specified graph.", 403, ) @@ -750,7 +765,15 @@ async def get_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedEntityResponse: """Retrieves a specific entity by its ID.""" - # Note: The original was missing implementation, so assuming similar pattern to relationships + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + result = await self.providers.database.graph_handler.entities.get( collection_id, "graph", entity_ids=[entity_id] ) @@ -856,6 +879,15 @@ async def delete_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedBooleanResponse: """Removes an entity from the graph.""" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + await self.providers.database.graph_handler.entities.delete( collection_id, [entity_id], "graph" ) @@ -922,13 +954,12 @@ async def get_relationships( """ Lists all relationships in the graph with pagination support. """ - # Permission check if ( not auth_user.is_superuser and collection_id not in auth_user.graph_ids ): raise R2RException( - "The currently authenticated user does not have access to this graph.", + "The currently authenticated user does not have access to the specified graph.", 403, ) @@ -1000,6 +1031,15 @@ async def get_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedRelationshipResponse: """Retrieves a specific relationship by its ID.""" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + results = ( await self.providers.database.graph_handler.relationships.get( collection_id, "graph", relationship_ids=[relationship_id] @@ -1126,6 +1166,15 @@ async def delete_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedBooleanResponse: """Removes a relationship from the graph.""" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + # return await self.services[ # "kg" # ].documents.graph_handler.relationships.remove_from_graph( @@ -1632,6 +1681,7 @@ async def pull( "The currently authenticated user does not have access to the specified graph.", 403, ) + list_graphs_response = await self.services["kg"].list_graphs( # user_ids=None, graph_ids=[collection_id], @@ -1766,7 +1816,6 @@ async def remove_document( The user must have access to both the graph and the document being removed. """ - # Check user permissions for graph if ( not auth_user.is_superuser and collection_id not in auth_user.graph_ids @@ -1776,7 +1825,6 @@ async def remove_document( 403, ) - # Check user permissions for document if ( not auth_user.is_superuser and document_id not in auth_user.document_ids diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index f29553f32..b5c446f66 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -207,12 +207,11 @@ async def update_entity( ) @telemetry_event("delete_entity") - async def delete_entity_v3( + async def delete_entity( self, id: UUID, entity_id: UUID, level: DataLevel, - **kwargs, ): return await self.providers.database.graph_handler.entities.delete( id=id, @@ -220,19 +219,6 @@ async def delete_entity_v3( level=level, ) - @telemetry_event("add_entity_to_graph") - async def add_entity_to_graph( - self, - graph_id: UUID, - entity_id: UUID, - auth_user: Optional[Any] = None, - ): - return ( - await self.providers.database.graph_handler.entities.add_to_graph( - graph_id, entity_id, auth_user - ) - ) - # TODO: deprecate this @telemetry_event("get_entities") async def get_entities( @@ -312,8 +298,8 @@ async def create_relationship( ) ) - @telemetry_event("delete_relationship_v3") - async def delete_relationship_v3( + @telemetry_event("delete_relationship") + async def delete_relationship( self, id: UUID, relationship_id: UUID, @@ -365,19 +351,19 @@ async def update_relationship( @telemetry_event("get_triples") async def get_relationships( self, - collection_id: Optional[UUID] = None, + offset: int, + limit: int, + collection_id: UUID, entity_names: Optional[list[str]] = None, - relationship_ids: Optional[list[str]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, + relationship_ids: Optional[list[UUID]] = None, ): - return await self.providers.database.graph_handler.get_relationships( - collection_id=collection_id, + return await self.providers.database.graph_handler.relationships.get( + parent_id=collection_id, + store_type="graph", # type: ignore entity_names=entity_names, relationship_ids=relationship_ids, - offset=offset or 0, - limit=limit or -1, + offset=offset, + limit=limit, ) ################### COMMUNITIES ################### @@ -458,12 +444,10 @@ async def list_communities( limit=limit, ) - # TODO: deprecate this @telemetry_event("get_communities") async def get_communities( self, - collection_id: Optional[UUID] = None, - levels: Optional[list[int]] = None, + collection_id: UUID, community_ids: Optional[list[int]] = None, offset: Optional[int] = None, limit: Optional[int] = None, @@ -472,8 +456,8 @@ async def get_communities( return await self.providers.database.graph_handler.get_communities( collection_id=collection_id, community_ids=community_ids, - offset=offset or 0, - limit=limit or -1, + offset=offset, + limit=limit, ) # @telemetry_event("create_new_graph") diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 6ceefe7b1..be559217a 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -544,7 +544,7 @@ async def get( entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, include_metadata: bool = False, - ) -> tuple[list[Relationship], int]: + ): """ Get relationships from the specified store. @@ -564,7 +564,7 @@ async def get( table_name = self._get_relationship_table_for_store(store_type) conditions = ["parent_id = $1"] - params = [parent_id] + params: list[Any] = [parent_id] param_index = 2 if relationship_ids: diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index d7fe6548b..71221546f 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -90,9 +90,9 @@ class CommunityInfo(R2RSerializable): node: str cluster: UUID level: Optional[int] - id: Optional[UUID | int] = None parent_cluster: int | None is_final_cluster: bool + id: Optional[UUID | int] = None graph_id: Optional[UUID] = None collection_id: Optional[UUID] = None # for backwards compatibility relationship_ids: Optional[list[UUID]] = None From a21003ef93a0e975027d3f38425322e215669907 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Sat, 30 Nov 2024 11:45:17 -0600 Subject: [PATCH 24/28] More cleanup --- py/core/base/abstractions/__init__.py | 2 - py/core/main/api/v3/graph_router.py | 67 +++++++------ py/core/main/services/kg_service.py | 115 +++++++--------------- py/core/pipes/kg/deduplication.py | 18 ++-- py/core/pipes/retrieval/kg_search_pipe.py | 4 +- py/core/providers/database/graph.py | 33 +++---- py/sdk/models.py | 2 - py/sdk/v3/conversations.py | 4 +- py/sdk/v3/graphs.py | 4 +- py/shared/abstractions/__init__.py | 2 - py/shared/abstractions/search.py | 4 - 11 files changed, 94 insertions(+), 161 deletions(-) diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 8b4b1a5e5..3cc40c678 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -65,7 +65,6 @@ KGEntityResult, KGGlobalResult, KGRelationshipResult, - KGSearchMethod, KGSearchResultType, SearchSettings, ) @@ -131,7 +130,6 @@ # Search abstractions "AggregateSearchResult", "GraphSearchResult", - "KGSearchMethod", "KGSearchResultType", "KGEntityResult", "KGRelationshipResult", diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index a8f189d23..3ba13b1a5 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -601,13 +601,11 @@ async def get_entities( "The currently authenticated user does not have access to the specified graph.", 403, ) - # return await self.services["kg"].get_entities( - # id, offset, limit, auth_user - # ) - entities, count = ( - await self.providers.database.graph_handler.get_entities( - collection_id, offset, limit - ) + + entities, count = await self.services["kg"].get_entities( + parent_id=collection_id, + offset=offset, + limit=limit, ) return entities, { # type: ignore @@ -775,7 +773,11 @@ async def get_entity( ) result = await self.providers.database.graph_handler.entities.get( - collection_id, "graph", entity_ids=[entity_id] + parent_id=collection_id, + store_type="graph", + offset=0, + limit=1, + entity_ids=[entity_id], ) if len(result) == 0 or len(result[0]) == 0: raise R2RException("Entity not found", 404) @@ -888,9 +890,11 @@ async def delete_entity( 403, ) - await self.providers.database.graph_handler.entities.delete( - collection_id, [entity_id], "graph" + self.services["kg"].delete_entity( + parent_id=collection_id, + entity_id=entity_id, ) + return GenericBooleanResponse(success=True) # type: ignore @self.router.get( @@ -963,13 +967,10 @@ async def get_relationships( 403, ) - relationships, count = ( - await self.providers.database.graph_handler.relationships.get( - parent_id=collection_id, - store_type="graph", - offset=offset, - limit=limit, - ) + relationships, count = await self.services["kg"].get_relationships( + parent_id=collection_id, + offset=offset, + limit=limit, ) return relationships, { # type: ignore @@ -1042,7 +1043,11 @@ async def get_relationship( results = ( await self.providers.database.graph_handler.relationships.get( - collection_id, "graph", relationship_ids=[relationship_id] + parent_id=collection_id, + store_type="graph", + offset=0, + limit=1, + relationship_ids=[relationship_id], ) ) if len(results) == 0 or len(results[0]) == 0: @@ -1175,14 +1180,11 @@ async def delete_relationship( 403, ) - # return await self.services[ - # "kg" - # ].documents.graph_handler.relationships.remove_from_graph( - # id, relationship_id, auth_user - # ) - await self.providers.database.graph_handler.relationships.delete( - collection_id, [relationship_id], "graph" + await self.services["kg"].delete_relationship( + parent_id=collection_id, + relationship_id=relationship_id, ) + return GenericBooleanResponse(success=True) # type: ignore @self.router.post( @@ -1358,12 +1360,10 @@ async def get_communities( 403, ) - communities, count = ( - await self.providers.database.graph_handler.get_communities( - collection_id=collection_id, - offset=offset, - limit=limit, - ) + communities, count = await self.services["kg"].get_communities( + parent_id=collection_id, + offset=offset, + limit=limit, ) return communities, { # type: ignore @@ -1720,7 +1720,10 @@ async def pull( ) entities = ( await self.providers.database.graph_handler.entities.get( - document.id, store_type="document" + parent_id=document.id, + store_type="document", + offset=0, + limit=100, ) ) has_document = ( diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index b5c446f66..22c90a96f 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -157,29 +157,6 @@ async def create_entity( metadata=metadata, ) - @telemetry_event("list_entities") - async def list_entities( - self, - offset: int, - limit: int, - id: Optional[UUID] = None, - graph_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - include_embeddings: Optional[bool] = False, - user_id: Optional[UUID] = None, - ): - return await self.providers.database.graph_handler.entities.get( - id=id, - graph_id=graph_id, - document_id=document_id, - entity_names=entity_names, - include_embeddings=include_embeddings, - offset=offset, - limit=limit, - user_id=user_id, - ) - @telemetry_event("update_entity") async def update_entity( self, @@ -198,69 +175,43 @@ async def update_entity( return await self.providers.database.graph_handler.entities.update( entity_id=entity_id, + store_type="graph", # type: ignore name=name, description=description, - category=category, description_embedding=description_embedding, + category=category, metadata=metadata, - store_type="graph", # type: ignore ) @telemetry_event("delete_entity") async def delete_entity( self, - id: UUID, + parent_id: UUID, entity_id: UUID, - level: DataLevel, ): return await self.providers.database.graph_handler.entities.delete( - id=id, - entity_id=entity_id, - level=level, + parent_id=parent_id, + entity_ids=[entity_id], + store_type="graph", # type: ignore ) - # TODO: deprecate this @telemetry_event("get_entities") async def get_entities( self, - collection_id: Optional[UUID] = None, - entity_ids: Optional[list[str]] = None, - entity_table_name: str = "entity", - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, - ): - return await self.providers.database.graph_handler.get_entities( - collection_id=collection_id, - entity_ids=entity_ids, - entity_table_name=entity_table_name, - offset=offset or 0, - limit=limit or -1, - ) - - ################### RELATIONSHIPS ################### - - @telemetry_event("list_relationships_v3") - async def list_relationships_v3( - self, - id: UUID, - level: DataLevel, + parent_id: UUID, offset: int, limit: int, + entity_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, - relationship_types: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - relationship_id: Optional[UUID] = None, + include_embeddings: bool = False, ): - return await self.providers.database.graph_handler.relationships.get( - id=id, - level=level, - entity_names=entity_names, - relationship_types=relationship_types, - attributes=attributes, + return await self.providers.database.graph_handler.get_entities( + parent_id=parent_id, offset=offset, limit=limit, - relationship_id=relationship_id, + entity_ids=entity_ids, + entity_names=entity_names, + include_embeddings=include_embeddings, ) @telemetry_event("create_relationship") @@ -301,13 +252,14 @@ async def create_relationship( @telemetry_event("delete_relationship") async def delete_relationship( self, - id: UUID, + parent_id: UUID, relationship_id: UUID, ): return ( await self.providers.database.graph_handler.relationships.delete( - id=id, - relationship_id=relationship_id, + parent_id=parent_id, + relationship_ids=[relationship_id], + store_type="graph", # type: ignore ) ) @@ -347,27 +299,24 @@ async def update_relationship( ) ) - # TODO: deprecate this - @telemetry_event("get_triples") + @telemetry_event("get_relationships") async def get_relationships( self, + parent_id: UUID, offset: int, limit: int, - collection_id: UUID, - entity_names: Optional[list[str]] = None, relationship_ids: Optional[list[UUID]] = None, + entity_names: Optional[list[str]] = None, ): return await self.providers.database.graph_handler.relationships.get( - parent_id=collection_id, + parent_id=parent_id, store_type="graph", # type: ignore - entity_names=entity_names, - relationship_ids=relationship_ids, offset=offset, limit=limit, + relationship_ids=relationship_ids, + entity_names=entity_names, ) - ################### COMMUNITIES ################### - @telemetry_event("create_community") async def create_community( self, @@ -447,17 +396,19 @@ async def list_communities( @telemetry_event("get_communities") async def get_communities( self, - collection_id: UUID, - community_ids: Optional[list[int]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, + parent_id: UUID, + offset: int, + limit: int, + community_ids: Optional[list[UUID]] = None, + community_names: Optional[list[str]] = None, + include_embeddings: bool = False, ): return await self.providers.database.graph_handler.get_communities( - collection_id=collection_id, - community_ids=community_ids, + parent_id=parent_id, offset=offset, limit=limit, + community_ids=community_ids, + include_embeddings=include_embeddings, ) # @telemetry_event("create_new_graph") diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index 5838d1b7c..9ce3f62ba 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -1,10 +1,8 @@ import json import logging -from typing import Any, Union +from typing import Any from uuid import UUID -from fastapi import HTTPException - from core.base import AsyncState from core.base.abstractions import DataLevel, Entity, KGEntityDeduplicationType from core.base.pipes import AsyncPipe @@ -26,14 +24,12 @@ def __init__( self, config: AsyncPipe.PipeConfig, database_provider: PostgresDBProvider, - llm_provider: Union[ - OpenAICompletionProvider, LiteLLMCompletionProvider - ], - embedding_provider: Union[ - LiteLLMEmbeddingProvider, - OpenAIEmbeddingProvider, - OllamaEmbeddingProvider, - ], + llm_provider: OpenAICompletionProvider | LiteLLMCompletionProvider, + embedding_provider: ( + LiteLLMEmbeddingProvider + | OpenAIEmbeddingProvider + | OllamaEmbeddingProvider + ), logging_provider: SqlitePersistentLoggingProvider, **kwargs, ): diff --git a/py/core/pipes/retrieval/kg_search_pipe.py b/py/core/pipes/retrieval/kg_search_pipe.py index a8ba17177..ff6625351 100644 --- a/py/core/pipes/retrieval/kg_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_pipe.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, AsyncGenerator, Optional +from typing import Any, AsyncGenerator from uuid import UUID from core.base import ( @@ -15,7 +15,6 @@ KGCommunityResult, KGEntityResult, KGRelationshipResult, - KGSearchMethod, KGSearchResultType, SearchSettings, ) @@ -234,7 +233,6 @@ async def search( rating_explanation=search_result["rating_explanation"], findings=search_result["findings"], ), - # method=KGSearchMethod.LOCAL, result_type=KGSearchResultType.COMMUNITY, metadata=( { diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index be559217a..c1fa9bac4 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -14,7 +14,6 @@ from core.base.abstractions import ( Community, CommunityInfo, - DataLevel, Entity, Graph, KGCreationSettings, @@ -178,8 +177,8 @@ async def get( self, parent_id: UUID, store_type: StoreType, - offset: int = 0, - limit: int = 100, + offset: int, + limit: int, entity_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, include_embeddings: bool = False, @@ -268,7 +267,7 @@ async def update( """Update an entity in the specified store.""" table_name = self._get_entity_table_for_store(store_type) update_fields = [] - params: list = [] + params: list[Any] = [] param_index = 1 if isinstance(metadata, str): @@ -340,7 +339,7 @@ async def delete( parent_id: UUID, entity_ids: Optional[list[UUID]] = None, store_type: StoreType = StoreType.GRAPH, - ) -> list[UUID]: + ) -> None: """ Delete entities from the specified store. If entity_ids is not provided, deletes all entities for the given parent_id. @@ -387,8 +386,6 @@ async def delete( 404, ) - return [row["id"] for row in results] - class PostgresRelationshipHandler(RelationshipHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -538,8 +535,8 @@ async def get( self, parent_id: UUID, store_type: StoreType, - offset: int = 0, - limit: int = 100, + offset: int, + limit: int, relationship_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, @@ -747,7 +744,7 @@ async def delete( parent_id: UUID, relationship_ids: Optional[list[UUID]] = None, store_type: StoreType = StoreType.GRAPH, - ) -> list[UUID]: + ) -> None: """ Delete relationships from the specified store. If relationship_ids is not provided, deletes all relationships for the given parent_id. @@ -791,8 +788,6 @@ async def delete( 404, ) - return [row["id"] for row in results] - class PostgresCommunityHandler(CommunityHandler): @@ -2060,7 +2055,7 @@ async def get_deduplication_estimate( async def get_entities( self, - graph_id: UUID, + parent_id: UUID, offset: int, limit: int, entity_ids: Optional[list[UUID]] = None, @@ -2073,7 +2068,7 @@ async def get_entities( Args: offset: Number of records to skip limit: Maximum number of records to return (-1 for no limit) - graph_id: UUID of the graph + parent_id: UUID of the collection entity_ids: Optional list of entity IDs to filter by entity_names: Optional list of entity names to filter by include_embeddings: Whether to include embeddings in the response @@ -2082,7 +2077,7 @@ async def get_entities( Tuple of (list of entities, total count) """ conditions = ["parent_id = $1"] - params: [Any] = [graph_id] + params: list[Any] = [parent_id] param_index = 2 if entity_ids: @@ -2320,8 +2315,8 @@ async def get( self, parent_id: UUID, store_type: StoreType, - offset: int = 0, - limit: int = 100, + offset: int, + limit: int, entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, relationship_id: Optional[UUID] = None, @@ -2527,7 +2522,7 @@ async def add_community_info( async def get_communities( self, - collection_id: UUID, + parent_id: UUID, offset: int, limit: int, community_ids: Optional[list[UUID]] = None, @@ -2547,7 +2542,7 @@ async def get_communities( Tuple of (list of communities, total count) """ conditions = ["collection_id = $1"] - params: list[Any] = [collection_id] + params: list[Any] = [parent_id] param_index = 2 if community_ids: diff --git a/py/sdk/models.py b/py/sdk/models.py index 987a8479d..2891c4216 100644 --- a/py/sdk/models.py +++ b/py/sdk/models.py @@ -12,7 +12,6 @@ KGGlobalResult, KGRelationshipResult, KGRunType, - KGSearchMethod, KGSearchResultType, Message, MessageType, @@ -38,7 +37,6 @@ "KGGlobalResult", "KGRelationshipResult", "KGRunType", - "KGSearchMethod", "GraphSearchResult", "KGSearchResultType", "GraphSearchSettings", diff --git a/py/sdk/v3/conversations.py b/py/sdk/v3/conversations.py index bc6b0aad6..4fa3887a8 100644 --- a/py/sdk/v3/conversations.py +++ b/py/sdk/v3/conversations.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from uuid import UUID from shared.api.models.base import WrappedBooleanResponse @@ -125,7 +125,7 @@ async def add_message( Returns: dict: Result of the operation, including the new message ID """ - data = { + data: dict[str, Any] = { "content": content, "role": role, } diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index b7954180f..ef819502e 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from uuid import UUID from shared.api.models.base import ( @@ -387,7 +387,7 @@ async def update_community( Returns: dict: Updated community information """ - data = {} + data: dict[str, Any] = {} if name is not None: data["name"] = name if summary is not None: diff --git a/py/shared/abstractions/__init__.py b/py/shared/abstractions/__init__.py index 8c55df1b8..8e70a0008 100644 --- a/py/shared/abstractions/__init__.py +++ b/py/shared/abstractions/__init__.py @@ -50,7 +50,6 @@ KGEntityResult, KGGlobalResult, KGRelationshipResult, - KGSearchMethod, KGSearchResultType, SearchSettings, ) @@ -110,7 +109,6 @@ # Search abstractions "AggregateSearchResult", "GraphSearchResult", - "KGSearchMethod", "KGSearchResultType", "KGEntityResult", "KGRelationshipResult", diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index 991588282..fd3502691 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -62,10 +62,6 @@ class KGSearchResultType(str, Enum): COMMUNITY = "community" -class KGSearchMethod(str, Enum): - LOCAL = "local" - - class KGEntityResult(R2RSerializable): name: str description: str From 7620d9f221655b6cb9327014e91f34b2f351f675 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Sat, 30 Nov 2024 14:30:38 -0600 Subject: [PATCH 25/28] More --- py/core/pipes/kg/storage.py | 1 - py/core/providers/database/graph.py | 60 +---------------------------- 2 files changed, 1 insertion(+), 60 deletions(-) diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index 53a0f132e..510cd5ca6 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -5,7 +5,6 @@ from core.base import AsyncState, KGExtraction, R2RDocumentProcessingError from core.base.pipes.base_pipe import AsyncPipe -from core.providers.database.graph import DataLevel from core.providers.database.postgres import PostgresDBProvider from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index c1fa9bac4..105eb409b 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -2829,7 +2829,7 @@ async def perform_graph_clustering( # relationship_ids_cache, leiden_params, collection_id # ) # else: - num_communities = await self._cluster_and_add_community_info( + return await self._cluster_and_add_community_info( relationships=relationships, relationship_ids_cache=relationship_ids_cache, leiden_params=leiden_params, @@ -2837,10 +2837,6 @@ async def perform_graph_clustering( # graph_id=collection_id, ) - return num_communities - - ####################### MANAGEMENT METHODS ####################### - async def get_entity_map( self, offset: int, limit: int, document_id: UUID ) -> dict[str, dict[str, list[dict[str, Any]]]]: @@ -3448,8 +3444,6 @@ async def _compute_leiden_communities( except ImportError as e: raise ImportError("Please install the graspologic package.") from e - ####################### UTILITY METHODS ####################### - async def get_existing_document_entity_chunk_ids( self, document_id: UUID ) -> list[str]: @@ -3463,23 +3457,6 @@ async def get_existing_document_entity_chunk_ids( ) ] - async def create_vector_index(self): - # need to implement this. Just call vector db provider's create_vector_index method. - # this needs to be run periodically for every collection. - raise NotImplementedError - - async def structured_query(self): - raise NotImplementedError - - async def update_extraction_prompt(self): - raise NotImplementedError - - async def update_kg_search_prompt(self): - raise NotImplementedError - - async def upsert_relationships(self): - raise NotImplementedError - async def get_entity_count( self, collection_id: Optional[UUID] = None, @@ -3528,41 +3505,6 @@ async def get_entity_count( "count" ] - async def get_relationship_count( - self, - collection_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - ) -> int: - if collection_id is None and document_id is None: - raise ValueError( - "Either collection_id or document_id must be provided." - ) - - conditions = [] - params = [] - - if collection_id: - conditions.append( - f""" - document_id = ANY( - SELECT document_id FROM {self._get_table_name("document_info")} - WHERE $1 = ANY(collection_ids) - ) - """ - ) - params.append(str(collection_id)) - else: - conditions.append("document_id = $1") - params.append(str(document_id)) - - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("relationship")} - WHERE {" AND ".join(conditions)} - """ - return (await self.connection_manager.fetch_query(QUERY, params))[0][ - "count" - ] - async def update_entity_descriptions(self, entities: list[Entity]): query = f""" From 6d395a1f4c35c576b8d32accdbc08a6de33956da Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Sat, 30 Nov 2024 14:41:16 -0600 Subject: [PATCH 26/28] remove chunk_entity --- py/core/main/services/kg_service.py | 8 ---- py/core/providers/database/graph.py | 60 ----------------------------- 2 files changed, 68 deletions(-) diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 22c90a96f..6334be62f 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -566,14 +566,6 @@ async def kg_entity_description( return all_results - @telemetry_event("get_graph_status") - async def get_graph_status( - self, - collection_id: UUID, - **kwargs, - ): - raise NotImplementedError("Not implemented") - @telemetry_event("kg_clustering") async def kg_clustering( self, diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 105eb409b..58673a7d6 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -2199,7 +2199,6 @@ async def delete_node_via_document_id( # Execute separate DELETE queries delete_queries = [ - f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = $1", f"DELETE FROM {self._get_table_name('relationship')} WHERE document_id = $1", f"DELETE FROM {self._get_table_name('entity')} WHERE document_id = $1", ] @@ -2716,7 +2715,6 @@ async def delete_graph_for_collection( # TODO: make these queries more efficient. Pass the document_ids as params. if cascade: DELETE_QUERIES += [ - f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1::uuid[]);", f"DELETE FROM {self._get_table_name('relationship')} WHERE document_id = ANY($1::uuid[]);", f"DELETE FROM {self._get_table_name('entity')} WHERE document_id = ANY($1::uuid[]);", f"DELETE FROM {self._get_table_name('graph_entity')} WHERE collection_id = $1;", @@ -2905,64 +2903,6 @@ async def get_entity_map( return entity_map - async def get_graph_status(self, collection_id: UUID) -> dict: - # check document_info table for the documents in the collection and return the status of each document - kg_extraction_statuses = await self.connection_manager.fetch_query( - f"SELECT document_id, extraction_status FROM {self._get_table_name('document_info')} WHERE collection_id = $1", - [collection_id], - ) - - document_ids = [ - doc_id["document_id"] for doc_id in kg_extraction_statuses - ] - - graph_cluster_statuses = await self.connection_manager.fetch_query( - f"SELECT enrichment_status FROM {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} WHERE id = $1", - [collection_id], - ) - - # entity and relationship counts - chunk_entity_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1)", - [document_ids], - ) - - relationship_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('relationship')} WHERE document_id = ANY($1)", - [document_ids], - ) - - entity_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE document_id = ANY($1)", - [document_ids], - ) - - graph_entity_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('graph_entity')} WHERE collection_id = $1", - [collection_id], - ) - - community_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('community')} WHERE collection_id = $1", - [collection_id], - ) - - return { - "kg_extraction_statuses": kg_extraction_statuses, - "graph_cluster_status": graph_cluster_statuses[0][ - "enrichment_status" - ], - "chunk_entity_count": chunk_entity_count[0]["count"], - "relationship_count": relationship_count[0]["count"], - "entity_count": entity_count[0]["count"], - "graph_entity_count": graph_entity_count[0]["count"], - "community_count": community_count[0]["count"], - } - - ####################### ESTIMATION METHODS ####################### - - ####################### GRAPH SEARCH METHODS ####################### - def _build_filters( self, filters: dict, parameters: list[Union[str, int, bytes]] ) -> str: From ff8c4999abe95099f64dd84e1ec0de657dca3359 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Sat, 30 Nov 2024 14:54:33 -0600 Subject: [PATCH 27/28] Delete bad, unused methods --- py/core/pipes/kg/clustering.py | 3 +- py/core/providers/database/graph.py | 96 ++--------------------------- 2 files changed, 5 insertions(+), 94 deletions(-) diff --git a/py/core/pipes/kg/clustering.py b/py/core/pipes/kg/clustering.py index 1103c09b5..c9275c240 100644 --- a/py/core/pipes/kg/clustering.py +++ b/py/core/pipes/kg/clustering.py @@ -54,9 +54,8 @@ async def cluster_kg( num_communities = await self.database_provider.graph_handler.perform_graph_clustering( collection_id=collection_id, - # graph_id=graph_id, leiden_params=leiden_params, - ) # type: ignore + ) logger.info( f"Clustering completed. Generated {num_communities} communities." diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 58673a7d6..cd8e5532d 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -1710,58 +1710,6 @@ async def add_entities_v3( # return True - async def add_relationships_v3( - self, id: UUID, relationship_ids: list[UUID], copy_data: bool = True - ) -> bool: - """ - Add relationships to the graph. - """ - QUERY = f""" - UPDATE {self._get_table_name("relationship")} - SET graph_ids = array_append(graph_ids, $1) - WHERE id = ANY($2) - """ - await self.connection_manager.execute_query( - QUERY, [id, relationship_ids] - ) - - if copy_data: - QUERY = f""" - INSERT INTO {self._get_table_name("graph_relationship")} - SELECT * FROM {self._get_table_name("relationship")} - WHERE id = ANY($1) - """ - await self.connection_manager.execute_query( - QUERY, [relationship_ids] - ) - - return True - - async def remove_relationships( - self, id: UUID, relationship_ids: list[UUID], delete_data: bool = True - ) -> bool: - """ - Remove relationships from the graph. - """ - QUERY = f""" - UPDATE {self._get_table_name("relationship")} - SET graph_ids = array_remove(graph_ids, $1) - WHERE id = ANY($2) - """ - await self.connection_manager.execute_query( - QUERY, [id, relationship_ids] - ) - - if delete_data: - QUERY = f""" - DELETE FROM {self._get_table_name("graph_relationship")} WHERE id = ANY($1) - """ - await self.connection_manager.execute_query( - QUERY, [relationship_ids] - ) - - return True - async def update( self, graph_id: UUID, @@ -2236,30 +2184,6 @@ async def delete_node_via_document_id( return None return None - ##################### RELATIONSHIP METHODS ##################### - - # DEPRECATED - async def add_relationships( - self, - relationships: list[Relationship], - table_name: str = "relationship", - ): # type: ignore - """ - Upsert relationships into the relationship table. These are raw relationships extracted from the document. - - Args: - relationships: list[Relationship]: list of relationships to upsert - table_name: str: name of the table to upsert into - - Returns: - result: asyncpg.Record: result of the upsert operation - """ - return await _add_objects( - objects=[ele.to_dict() for ele in relationships], - full_table_name=self._get_table_name(table_name), - connection_manager=self.connection_manager, - ) - async def get_all_relationships( self, collection_id: UUID | None, @@ -2360,7 +2284,7 @@ async def get( # Build conditions and parameters for listing relationships conditions = ["parent_id = $1"] - params = [parent_id] + params: list[Any] = [parent_id] param_index = 2 if entity_names: @@ -2748,10 +2672,8 @@ async def delete_graph_for_collection( async def perform_graph_clustering( self, - collection_id: UUID | None, - # graph_id: UUID | None, + collection_id: UUID, leiden_params: dict[str, Any], - use_community_cache: bool = False, ) -> int: """ Leiden clustering algorithm to cluster the knowledge graph relationships into communities. @@ -2770,8 +2692,6 @@ async def perform_graph_clustering( check_directed: bool = True, """ - start_time = time.time() - # # relationships = await self.get_all_relationships( # # collection_id, collection_id # , graph_id # # ) @@ -3079,8 +2999,6 @@ async def graph_search( print("output = ", output) yield output - ####################### GRAPH CLUSTERING METHODS ####################### - async def _create_graph_and_cluster( self, relationships: list[Relationship], leiden_params: dict[str, Any] ) -> Any: @@ -3096,11 +3014,7 @@ async def _create_graph_and_cluster( logger.info(f"Graph has {len(G.nodes)} nodes and {len(G.edges)} edges") - hierarchical_communities = await self._compute_leiden_communities( - G, leiden_params - ) - - return hierarchical_communities + return await self._compute_leiden_communities(G, leiden_params) async def _cluster_and_add_community_info( self, @@ -3456,7 +3370,7 @@ async def update_entity_descriptions(self, entities: list[Entity]): inputs = [ ( entity.name, - entity.graph_id, + entity.parent_id, entity.description, entity.description_embedding, ) @@ -3465,8 +3379,6 @@ async def update_entity_descriptions(self, entities: list[Entity]): await self.connection_manager.execute_many(query, inputs) # type: ignore - ####################### PRIVATE METHODS ########################## - def _json_serialize(obj): if isinstance(obj, UUID): From d4bfba78e74e0a13a69d16aee593f7bf66f55cee Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:08:01 -0600 Subject: [PATCH 28/28] More --- py/core/main/api/v3/documents_router.py | 97 ++++++++++++------------- 1 file changed, 46 insertions(+), 51 deletions(-) diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index e9410210f..1756ca748 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -13,20 +13,19 @@ from core.base import R2RException, RunType, generate_document_id from core.base.abstractions import ( - Entity, KGCreationSettings, KGRunType, - Relationship, ) from core.base.api.models import ( GenericBooleanResponse, - PaginatedResultsWrapper, WrappedBooleanResponse, WrappedChunksResponse, WrappedCollectionsResponse, WrappedDocumentResponse, WrappedDocumentsResponse, + WrappedEntitiesResponse, WrappedIngestionResponse, + WrappedRelationshipsResponse, ) from core.providers import ( HatchetOrchestrationProvider, @@ -761,16 +760,16 @@ async def list_chunks( ..., description="The ID of the document to retrieve chunks for.", ), - offset: Optional[int] = Query( + offset: int = Query( 0, ge=0, - description="The offset of the first chunk to retrieve.", + description="Specifies the number of objects to skip. Defaults to 0.", ), - limit: Optional[int] = Query( + limit: int = Query( 100, - ge=0, - le=20_000, - description="The maximum number of chunks to retrieve, up to 20,000.", + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), include_vectors: Optional[bool] = Query( False, @@ -1263,7 +1262,7 @@ async def extract( description="Whether to run the entities and relationships extraction process with orchestration.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): + ) -> WrappedIngestionResponse: """ Extracts entities and relationships from a document. The entities and relationships extraction process involves: @@ -1290,7 +1289,6 @@ async def extract( settings_dict=settings, # type: ignore ) - # If the run type is estimate, return an estimate of the creation cost if run_type is KGRunType.ESTIMATE: return { # type: ignore "message": "Estimate retrieved successfully", @@ -1303,30 +1301,27 @@ async def extract( kg_creation_settings=server_kg_creation_settings, ), } - else: - # Otherwise, create the graph - if run_with_orchestration: - workflow_input = { - "document_id": str(id), - "kg_creation_settings": server_kg_creation_settings.model_dump_json(), - "user": auth_user.json(), - } - return await self.orchestration_provider.run_workflow( # type: ignore - "extract-triples", {"request": workflow_input}, {} - ) - else: - from core.main.orchestration import simple_kg_factory + if run_with_orchestration: + workflow_input = { + "document_id": str(id), + "kg_creation_settings": server_kg_creation_settings.model_dump_json(), + "user": auth_user.json(), + } - logger.info( - "Running extract-triples without orchestration." - ) - simple_kg = simple_kg_factory(self.services["kg"]) - await simple_kg["extract-triples"](workflow_input) # type: ignore - return { # type: ignore - "message": "Graph created successfully.", - "task_id": None, - } + return await self.orchestration_provider.run_workflow( + "extract-triples", {"request": workflow_input}, {} + ) + else: + from core.main.orchestration import simple_kg_factory + + logger.info("Running extract-triples without orchestration.") + simple_kg = simple_kg_factory(self.services["kg"]) + await simple_kg["extract-triples"](workflow_input) + return { # type: ignore + "message": "Graph created successfully.", + "task_id": None, + } @self.router.get( "/documents/{id}/entities", @@ -1357,23 +1352,23 @@ async def get_entities( ..., description="The ID of the document to retrieve entities from.", ), - offset: Optional[int] = Query( + offset: int = Query( 0, ge=0, - description="The offset of the first entity to retrieve.", + description="Specifies the number of objects to skip. Defaults to 0.", ), - limit: Optional[int] = Query( + limit: int = Query( 100, - ge=0, - le=20_000, - description="The maximum number of entities to retrieve, up to 20,000.", + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), include_embeddings: Optional[bool] = Query( False, description="Whether to include vector embeddings in the response.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Entity]]: + ) -> WrappedEntitiesResponse: """ Retrieves the entities that were extracted from a document. These represent important semantic elements like people, places, organizations, concepts, etc. @@ -1402,7 +1397,7 @@ async def get_entities( raise R2RException("Document not found.", 404) # Get all entities for this document from the document_entity table - entities, total_count = ( + entities, count = ( await self.providers.database.graph_handler.entities.get( parent_id=id, store_type="document", @@ -1412,7 +1407,7 @@ async def get_entities( ) ) - return entities, {"total_entries": total_count} + return entities, {"total_entries": count} # type: ignore @self.router.get( "/documents/{id}/relationships", @@ -1482,16 +1477,16 @@ async def list_relationships( ..., description="The ID of the document to retrieve relationships for.", ), - offset: Optional[int] = Query( + offset: int = Query( 0, ge=0, - description="The offset of the first relationship to retrieve.", + description="Specifies the number of objects to skip. Defaults to 0.", ), - limit: Optional[int] = Query( + limit: int = Query( 100, - ge=0, - le=20_000, - description="The maximum number of relationships to retrieve, up to 20,000.", + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), entity_names: Optional[list[str]] = Query( None, @@ -1502,7 +1497,7 @@ async def list_relationships( description="Filter relationships by specific relationship types.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Relationship]]: + ) -> WrappedRelationshipsResponse: """ Retrieves the relationships between entities that were extracted from a document. These represent connections and interactions between entities found in the text. @@ -1531,7 +1526,7 @@ async def list_relationships( raise R2RException("Document not found.", 404) # Get relationships for this document - relationships, total_count = ( + relationships, count = ( await self.providers.database.graph_handler.relationships.get( parent_id=id, store_type="document", @@ -1542,7 +1537,7 @@ async def list_relationships( ) ) - return relationships, {"total_entries": total_count} + return relationships, {"total_entries": count} # type: ignore @staticmethod async def _process_file(file):