diff --git a/src/neo4j_genai/experimental/components/kg_writer.py b/src/neo4j_genai/experimental/components/kg_writer.py index f510b868..8d833b78 100644 --- a/src/neo4j_genai/experimental/components/kg_writer.py +++ b/src/neo4j_genai/experimental/components/kg_writer.py @@ -17,7 +17,7 @@ import asyncio import logging from abc import abstractmethod -from typing import Literal, Optional +from typing import Any, Dict, Literal, Optional, Tuple import neo4j from pydantic import validate_call @@ -102,12 +102,7 @@ def __init__( self.neo4j_database = neo4j_database self.max_concurrency = max_concurrency - def _upsert_node(self, node: Neo4jNode) -> None: - """Upserts a single node into the Neo4j database." - - Args: - node (Neo4jNode): The node to upsert into the database. - """ + def _get_node_query(self, node: Neo4jNode) -> Tuple[str, Dict[str, Any]]: # Create the initial node parameters = {"id": node.id} if node.properties: @@ -116,6 +111,15 @@ def _upsert_node(self, node: Neo4jNode) -> None: "{" + ", ".join(f"{key}: ${key}" for key in parameters.keys()) + "}" ) query = UPSERT_NODE_QUERY.format(label=node.label, properties=properties) + return query, parameters + + def _upsert_node(self, node: Neo4jNode) -> None: + """Upserts a single node into the Neo4j database." + + Args: + node (Neo4jNode): The node to upsert into the database. + """ + query, parameters = self._get_node_query(node) result = self.driver.execute_query(query, parameters_=parameters) node_id = result.records[0]["elementID(n)"] # Add the embedding properties to the node @@ -140,14 +144,7 @@ async def _async_upsert_node( node (Neo4jNode): The node to upsert into the database. """ async with sem: - # Create the initial node - parameters = {"id": node.id} - if node.properties: - parameters.update(node.properties) - properties = ( - "{" + ", ".join(f"{key}: ${key}" for key in parameters.keys()) + "}" - ) - query = UPSERT_NODE_QUERY.format(label=node.label, properties=properties) + query, parameters = self._get_node_query(node) result = await self.driver.execute_query(query, parameters_=parameters) node_id = result.records[0]["elementID(n)"] # Add the embedding properties to the node @@ -161,12 +158,7 @@ async def _async_upsert_node( neo4j_database=self.neo4j_database, ) - def _upsert_relationship(self, rel: Neo4jRelationship) -> None: - """Upserts a single relationship into the Neo4j database. - - Args: - rel (Neo4jRelationship): The relationship to upsert into the database. - """ + def _get_rel_query(self, rel: Neo4jRelationship) -> Tuple[str, Dict[str, Any]]: # Create the initial relationship parameters = { "start_node_id": rel.start_node_id, @@ -183,6 +175,15 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None: type=rel.type, properties=properties, ) + return query, parameters + + def _upsert_relationship(self, rel: Neo4jRelationship) -> None: + """Upserts a single relationship into the Neo4j database. + + Args: + rel (Neo4jRelationship): The relationship to upsert into the database. + """ + query, parameters = self._get_rel_query(rel) result = self.driver.execute_query(query, parameters_=parameters) rel_id = result.records[0]["elementID(r)"] # Add the embedding properties to the relationship @@ -205,24 +206,7 @@ async def _async_upsert_relationship( rel (Neo4jRelationship): The relationship to upsert into the database. """ async with sem: - # Create the initial relationship - parameters = { - "start_node_id": rel.start_node_id, - "end_node_id": rel.end_node_id, - } - if rel.properties: - properties = ( - "{" - + ", ".join(f"{key}: ${key}" for key in rel.properties.keys()) - + "}" - ) - parameters.update(rel.properties) - else: - properties = "{}" - query = UPSERT_RELATIONSHIP_QUERY.format( - type=rel.type, - properties=properties, - ) + query, parameters = self._get_rel_query(rel) result = await self.driver.execute_query(query, parameters_=parameters) rel_id = result.records[0]["elementID(r)"] # Add the embedding properties to the relationship