Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar committed Nov 20, 2024
1 parent 47a4223 commit 6ee5254
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 8 deletions.
14 changes: 6 additions & 8 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,9 +669,7 @@ async def add_entity_to_graph(
"""
Adds a list of entities to the graph by their IDs.
"""
raise NotImplementedError("Not implemented")


return await self.services["kg"].documents.graph_handler.entities.add_to_graph(id, entity_ids)

@self.router.delete(
"/graphs/{id}/entities/{entity_id}",
Expand Down Expand Up @@ -710,7 +708,7 @@ async def remove_entity_from_graph(
"""
Removes an entity from the graph by its ID.
"""
raise NotImplementedError("Not implemented")
return await self.services["kg"].documents.graph_handler.entities.remove_from_graph(id, [entity_id])


@self.router.post(
Expand Down Expand Up @@ -750,7 +748,7 @@ async def add_relationship_to_graph(
"""
Adds a list of relationships to the graph by their IDs.
"""
raise NotImplementedError("Not implemented")
return await self.services["kg"].documents.graph_handler.relationships.add_to_graph(id, relationship_ids)



Expand Down Expand Up @@ -791,7 +789,7 @@ async def remove_relationship_from_graph(
"""
Removes a relationship from the graph by its ID.
"""
raise NotImplementedError("Not implemented")
return await self.services["kg"].documents.graph_handler.relationships.remove_from_graph(id, [relationship_id])



Expand Down Expand Up @@ -832,7 +830,7 @@ async def add_communities_to_graph(
"""
Adds a list of communities to the graph by their IDs.
"""
raise NotImplementedError("Not implemented")
return await self.services["kg"].documents.graph_handler.communities.add_to_graph(id, community_ids)



Expand Down Expand Up @@ -873,5 +871,5 @@ async def remove_community_from_graph(
"""
Removes a community from the graph by its ID.
"""
raise NotImplementedError("Not implemented")
return await self.services["kg"].documents.graph_handler.communities.remove_from_graph(id, [community_id])

50 changes: 50 additions & 0 deletions py/core/providers/database/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,25 @@ async def delete(
connection_manager=self.connection_manager,
)

async def add_to_graph(self, graph_id: UUID, entity_ids: list[UUID]) -> None:
QUERY = f"""
UPDATE {self._get_table_name("graph_entity")}
SET graph_ids = CASE
WHEN graph_ids IS NULL THEN ARRAY[$1]
WHEN NOT ($1 = ANY(graph_ids)) THEN array_append(graph_ids, $1)
ELSE graph_ids
END
WHERE id = ANY($2)
"""
return await self.connection_manager.execute_query(QUERY, [graph_id, entity_ids])

async def remove_from_graph(self, graph_id: UUID, entity_ids: list[UUID]) -> None:
QUERY = f"""
UPDATE {self._get_table_name("graph_entity")}
SET graph_ids = array_remove(graph_ids, $1)
WHERE id = ANY($2)
"""
return await self.connection_manager.execute_query(QUERY, [graph_id, entity_ids])

class PostgresRelationshipHandler(RelationshipHandler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -488,6 +507,26 @@ async def delete(
return await self.connection_manager.fetchrow_query(
QUERY, [relationship_id]
)

async def add_to_graph(self, graph_id: UUID, relationship_ids: list[UUID]) -> None:
QUERY = f"""
UPDATE {self._get_table_name("graph_relationship")}
SET graph_ids = CASE
WHEN graph_ids IS NULL THEN ARRAY[$1]
WHEN NOT ($1 = ANY(graph_ids)) THEN array_append(graph_ids, $1)
ELSE graph_ids
END
WHERE id = ANY($2)
"""
return await self.connection_manager.execute_query(QUERY, [graph_id, relationship_ids])

async def remove_from_graph(self, graph_id: UUID, relationship_ids: list[UUID]) -> None:
QUERY = f"""
UPDATE {self._get_table_name("graph_relationship")}
SET graph_ids = array_remove(graph_ids, $1)
WHERE id = ANY($2)
"""
return await self.connection_manager.execute_query(QUERY, [graph_id, relationship_ids])


class PostgresCommunityHandler(CommunityHandler):
Expand Down Expand Up @@ -607,6 +646,17 @@ async def get(
)
]

async def add_to_graph(self, graph_id: UUID, community_ids: list[UUID]) -> None:
QUERY = f"""
UPDATE {self._get_table_name("graph_community")} SET graph_id = $1 WHERE id = ANY($2)
"""
return await self.connection_manager.execute_query(QUERY, [graph_id, community_ids])

async def remove_from_graph(self, graph_id: UUID, community_ids: list[UUID]) -> None:
QUERY = f"""
UPDATE {self._get_table_name("graph_community")} SET graph_id = NULL WHERE id = ANY($1)
"""
return await self.connection_manager.execute_query(QUERY, [community_ids])

class PostgresGraphHandler(GraphHandler):
"""Handler for Knowledge Graph METHODS in PostgreSQL."""
Expand Down

0 comments on commit 6ee5254

Please sign in to comment.