diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 6f1690339..866923094 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -669,9 +669,7 @@ async def add_entity_to_graph( """ Adds a list of entities to the graph by their IDs. """ - raise NotImplementedError("Not implemented") - - + return await self.services["kg"].documents.graph_handler.entities.add_to_graph(id, entity_ids) @self.router.delete( "/graphs/{id}/entities/{entity_id}", @@ -710,7 +708,7 @@ async def remove_entity_from_graph( """ Removes an entity from the graph by its ID. """ - raise NotImplementedError("Not implemented") + return await self.services["kg"].documents.graph_handler.entities.remove_from_graph(id, [entity_id]) @self.router.post( @@ -750,7 +748,7 @@ async def add_relationship_to_graph( """ Adds a list of relationships to the graph by their IDs. """ - raise NotImplementedError("Not implemented") + return await self.services["kg"].documents.graph_handler.relationships.add_to_graph(id, relationship_ids) @@ -791,7 +789,7 @@ async def remove_relationship_from_graph( """ Removes a relationship from the graph by its ID. """ - raise NotImplementedError("Not implemented") + return await self.services["kg"].documents.graph_handler.relationships.remove_from_graph(id, [relationship_id]) @@ -832,7 +830,7 @@ async def add_communities_to_graph( """ Adds a list of communities to the graph by their IDs. """ - raise NotImplementedError("Not implemented") + return await self.services["kg"].documents.graph_handler.communities.add_to_graph(id, community_ids) @@ -873,5 +871,5 @@ async def remove_community_from_graph( """ Removes a community from the graph by its ID. """ - raise NotImplementedError("Not implemented") + return await self.services["kg"].documents.graph_handler.communities.remove_from_graph(id, [community_id]) diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 48933766a..bd352973e 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -319,6 +319,25 @@ async def delete( connection_manager=self.connection_manager, ) + async def add_to_graph(self, graph_id: UUID, entity_ids: list[UUID]) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph_entity")} + SET graph_ids = CASE + WHEN graph_ids IS NULL THEN ARRAY[$1] + WHEN NOT ($1 = ANY(graph_ids)) THEN array_append(graph_ids, $1) + ELSE graph_ids + END + WHERE id = ANY($2) + """ + return await self.connection_manager.execute_query(QUERY, [graph_id, entity_ids]) + + async def remove_from_graph(self, graph_id: UUID, entity_ids: list[UUID]) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph_entity")} + SET graph_ids = array_remove(graph_ids, $1) + WHERE id = ANY($2) + """ + return await self.connection_manager.execute_query(QUERY, [graph_id, entity_ids]) class PostgresRelationshipHandler(RelationshipHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -488,6 +507,26 @@ async def delete( return await self.connection_manager.fetchrow_query( QUERY, [relationship_id] ) + + async def add_to_graph(self, graph_id: UUID, relationship_ids: list[UUID]) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph_relationship")} + SET graph_ids = CASE + WHEN graph_ids IS NULL THEN ARRAY[$1] + WHEN NOT ($1 = ANY(graph_ids)) THEN array_append(graph_ids, $1) + ELSE graph_ids + END + WHERE id = ANY($2) + """ + return await self.connection_manager.execute_query(QUERY, [graph_id, relationship_ids]) + + async def remove_from_graph(self, graph_id: UUID, relationship_ids: list[UUID]) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph_relationship")} + SET graph_ids = array_remove(graph_ids, $1) + WHERE id = ANY($2) + """ + return await self.connection_manager.execute_query(QUERY, [graph_id, relationship_ids]) class PostgresCommunityHandler(CommunityHandler): @@ -607,6 +646,17 @@ async def get( ) ] + async def add_to_graph(self, graph_id: UUID, community_ids: list[UUID]) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph_community")} SET graph_id = $1 WHERE id = ANY($2) + """ + return await self.connection_manager.execute_query(QUERY, [graph_id, community_ids]) + + async def remove_from_graph(self, graph_id: UUID, community_ids: list[UUID]) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph_community")} SET graph_id = NULL WHERE id = ANY($1) + """ + return await self.connection_manager.execute_query(QUERY, [community_ids]) class PostgresGraphHandler(GraphHandler): """Handler for Knowledge Graph METHODS in PostgreSQL."""