Skip to content

Commit

Permalink
Neo4jWriter improvements (#151)
Browse files Browse the repository at this point in the history
* Update Cypher queries for nodes

* Mypy

* Set embeddings in same query (for nodes)

* Fix e2e + mypy

* Merge queries for relationships

* Ruff

* Unused imports

* CHANGELOG update + elementId instead of elementID (seems to be the convention)
  • Loading branch information
stellasia authored Sep 25, 2024
1 parent 89411ca commit ff1c6ee
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 119 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

### Fixed
- Resolved import issue with the Vertex AI Embeddings class.
- Resolved issue where Neo4jWriter component would raise an error if the start or end node ID was not defined properly in the input.

### Changed
- Moved the Embedder class to the neo4j_graphrag.embeddings directory for better organization alongside other custom embedders.
- Neo4jWriter component now runs a single query to merge node and set its embeddings if any.

## 0.6.3
### Changed
Expand Down
99 changes: 29 additions & 70 deletions src/neo4j_graphrag/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import asyncio
import inspect
import logging
from abc import abstractmethod
from typing import Any, Dict, Literal, Optional, Tuple
Expand All @@ -28,12 +29,6 @@
Neo4jRelationship,
)
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
from neo4j_graphrag.indexes import (
async_upsert_vector,
async_upsert_vector_on_relationship,
upsert_vector,
upsert_vector_on_relationship,
)
from neo4j_graphrag.neo4j_queries import UPSERT_NODE_QUERY, UPSERT_RELATIONSHIP_QUERY

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,15 +97,26 @@ def __init__(
self.neo4j_database = neo4j_database
self.max_concurrency = max_concurrency

def _db_setup(self) -> None:
# create index on __Entity__.id
self.driver.execute_query(
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__Entity__) ON (n.id)"
)

async def _async_db_setup(self) -> None:
# create index on __Entity__.id
await self.driver.execute_query(
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__Entity__) ON (n.id)"
)

def _get_node_query(self, node: Neo4jNode) -> Tuple[str, Dict[str, Any]]:
# Create the initial node
parameters = {"id": node.id}
if node.properties:
parameters.update(node.properties)
properties = (
"{" + ", ".join(f"{key}: ${key}" for key in parameters.keys()) + "}"
)
query = UPSERT_NODE_QUERY.format(label=node.label, properties=properties)
parameters = {
"id": node.id,
"properties": node.properties or {},
"embeddings": node.embedding_properties,
}
query = UPSERT_NODE_QUERY.format(label=node.label)
return query, parameters

def _upsert_node(self, node: Neo4jNode) -> None:
Expand All @@ -120,18 +126,7 @@ def _upsert_node(self, node: Neo4jNode) -> None:
node (Neo4jNode): The node to upsert into the database.
"""
query, parameters = self._get_node_query(node)
result = self.driver.execute_query(query, parameters_=parameters)
node_id = result.records[0]["elementID(n)"]
# Add the embedding properties to the node
if node.embedding_properties:
for prop, vector in node.embedding_properties.items():
upsert_vector(
driver=self.driver,
node_id=node_id,
embedding_property=prop,
vector=vector,
neo4j_database=self.neo4j_database,
)
self.driver.execute_query(query, parameters_=parameters)

async def _async_upsert_node(
self,
Expand All @@ -145,35 +140,18 @@ async def _async_upsert_node(
"""
async with sem:
query, parameters = self._get_node_query(node)
result = await self.driver.execute_query(query, parameters_=parameters)
node_id = result.records[0]["elementID(n)"]
# Add the embedding properties to the node
if node.embedding_properties:
for prop, vector in node.embedding_properties.items():
await async_upsert_vector(
driver=self.driver,
node_id=node_id,
embedding_property=prop,
vector=vector,
neo4j_database=self.neo4j_database,
)
await self.driver.execute_query(query, parameters_=parameters)

def _get_rel_query(self, rel: Neo4jRelationship) -> Tuple[str, Dict[str, Any]]:
# Create the initial relationship
parameters = {
"start_node_id": rel.start_node_id,
"end_node_id": rel.end_node_id,
"properties": rel.properties or {},
"embeddings": rel.embedding_properties,
}
if rel.properties:
properties = (
"{" + ", ".join(f"{key}: ${key}" for key in rel.properties.keys()) + "}"
)
parameters.update(rel.properties)
else:
properties = "{}"
query = UPSERT_RELATIONSHIP_QUERY.format(
type=rel.type,
properties=properties,
)
return query, parameters

Expand All @@ -184,18 +162,7 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
rel (Neo4jRelationship): The relationship to upsert into the database.
"""
query, parameters = self._get_rel_query(rel)
result = self.driver.execute_query(query, parameters_=parameters)
rel_id = result.records[0]["elementID(r)"]
# Add the embedding properties to the relationship
if rel.embedding_properties:
for prop, vector in rel.embedding_properties.items():
upsert_vector_on_relationship(
driver=self.driver,
rel_id=rel_id,
embedding_property=prop,
vector=vector,
neo4j_database=self.neo4j_database,
)
self.driver.execute_query(query, parameters_=parameters)

async def _async_upsert_relationship(
self, rel: Neo4jRelationship, sem: asyncio.Semaphore
Expand All @@ -207,18 +174,7 @@ async def _async_upsert_relationship(
"""
async with sem:
query, parameters = self._get_rel_query(rel)
result = await self.driver.execute_query(query, parameters_=parameters)
rel_id = result.records[0]["elementID(r)"]
# Add the embedding properties to the relationship
if rel.embedding_properties:
for prop, vector in rel.embedding_properties.items():
await async_upsert_vector_on_relationship(
driver=self.driver,
rel_id=rel_id,
embedding_property=prop,
vector=vector,
neo4j_database=self.neo4j_database,
)
await self.driver.execute_query(query, parameters_=parameters)

@validate_call
async def run(self, graph: Neo4jGraph) -> KGWriterModel:
Expand All @@ -228,7 +184,8 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
graph (Neo4jGraph): The knowledge graph to upsert into the database.
"""
try:
if isinstance(self.driver, neo4j.AsyncDriver):
if inspect.iscoroutinefunction(self.driver.execute_query):
await self._async_db_setup()
sem = asyncio.Semaphore(self.max_concurrency)
node_tasks = [
self._async_upsert_node(node, sem) for node in graph.nodes
Expand All @@ -241,6 +198,8 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
]
await asyncio.gather(*rel_tasks)
else:
self._db_setup()

for node in graph.nodes:
self._upsert_node(node)

Expand Down
24 changes: 20 additions & 4 deletions src/neo4j_graphrag/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,28 @@
"YIELD node, score"
)

UPSERT_NODE_QUERY = "MERGE (n:`{label}` {properties}) RETURN elementID(n)"
UPSERT_NODE_QUERY = (
"MERGE (n:__Entity__ {{id: $id}}) "
"WITH n SET n:`{label}`, n += $properties "
"WITH n CALL {{ "
"WITH n WITH n WHERE $embeddings IS NOT NULL "
"UNWIND keys($embeddings) as emb "
"CALL db.create.setNodeVectorProperty(n, emb, $embeddings[emb]) "
"}} "
"RETURN elementId(n)"
)

UPSERT_RELATIONSHIP_QUERY = (
"MATCH (start {{ id: $start_node_id }}), (end {{ id: $end_node_id }}) "
"MERGE (start)-[r:{type} {properties}]->(end) "
"RETURN elementID(r)"
"MATCH (start {{ id: $start_node_id }}) "
"MATCH (end {{ id: $end_node_id }}) "
"MERGE (start)-[r:{type}]->(end) "
"WITH r SET r += $properties "
"WITH r CALL {{ "
"WITH r WITH r WHERE $embeddings IS NOT NULL "
"UNWIND keys($embeddings) as emb "
"CALL db.create.setRelationshipVectorProperty(r, emb, $embeddings[emb]) "
"}} "
"RETURN elementId(r)"
)

UPSERT_VECTOR_ON_NODE_QUERY = (
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/test_kg_writer_component_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
assert "a" and "b" and "r" in record.keys()

node_a = record["a"]
assert start_node.label == list(node_a.labels)[0]
assert start_node.label in list(node_a.labels)
assert start_node.id == str(node_a.get("id"))
if start_node.properties:
for key, val in start_node.properties.items():
Expand All @@ -66,7 +66,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
assert node_a.get(key) == [1.0, 2.0, 3.0]

node_b = record["b"]
assert end_node.label == list(node_b.labels)[0]
assert end_node.label in list(node_b.labels)
assert end_node.id == str(node_b.get("id"))
if end_node.properties:
for key, val in end_node.properties.items():
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def driver() -> MagicMock:
return MagicMock(spec=neo4j.Driver)


@pytest.fixture(scope="function")
def async_driver() -> MagicMock:
return MagicMock(spec=neo4j.AsyncDriver)


@pytest.fixture(scope="function")
def embedder() -> MagicMock:
return MagicMock(spec=Embedder)
Expand Down
Loading

0 comments on commit ff1c6ee

Please sign in to comment.