Skip to content

Commit

Permalink
Added an asynchronous Neo4jWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Sep 2, 2024
1 parent 81245dd commit b063107
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 16 deletions.
3 changes: 3 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ Database Interaction

.. autofunction:: neo4j_genai.indexes.upsert_vector_on_relationship

.. autofunction:: neo4j_genai.indexes.async_upsert_vector

.. autofunction:: neo4j_genai.indexes.async_upsert_vector_on_relationship

******
Errors
Expand Down
141 changes: 140 additions & 1 deletion src/neo4j_genai/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

import asyncio
import logging
from abc import abstractmethod
from typing import Literal, Optional
Expand All @@ -27,7 +28,12 @@
Neo4jRelationship,
)
from neo4j_genai.experimental.pipeline.component import Component, DataModel
from neo4j_genai.indexes import upsert_vector, upsert_vector_on_relationship
from neo4j_genai.indexes import (
async_upsert_vector,
async_upsert_vector_on_relationship,
upsert_vector,
upsert_vector_on_relationship,
)
from neo4j_genai.neo4j_queries import UPSERT_NODE_QUERY, UPSERT_RELATIONSHIP_QUERY

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -173,3 +179,136 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
except neo4j.exceptions.ClientError as e:
logger.exception(e)
return KGWriterModel(status="FAILURE")


class AsyncNeo4jWriter(KGWriter):
"""Asynchronously Writes a knowledge graph to a Neo4j database.
Args:
driver (neo4j.AsyncDriver): The asynchronous Neo4j driver to connect to the database.
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
Example:
.. code-block:: python
from neo4j import AsyncGraphDatabase
from neo4j_genai.experimental.components.kg_writer import Neo4jWriter
from neo4j_genai.experimental.pipeline import Pipeline
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
writer = AsyncNeo4jWriter(driver=driver, neo4j_database=DATABASE)
pipeline = Pipeline()
pipeline.add_component("writer", writer)
"""

def __init__(
self,
driver: neo4j.AsyncDriver,
neo4j_database: Optional[str] = None,
max_concurrency: int = 5,
):
self.driver = driver
self.neo4j_database = neo4j_database
self.max_concurrency = max_concurrency

async def _upsert_node(
self,
node: Neo4jNode,
sem: asyncio.Semaphore,
) -> None:
"""Upserts a single node into the Neo4j database."
Args:
node (Neo4jNode): The node to upsert into the database.
"""
async with sem:
# Create the initial node
parameters = {"id": node.id}
if node.properties:
parameters.update(node.properties)
properties = (
"{" + ", ".join(f"{key}: ${key}" for key in parameters.keys()) + "}"
)
query = UPSERT_NODE_QUERY.format(label=node.label, properties=properties)
result = await self.driver.execute_query(query, parameters_=parameters)
node_id = result.records[0]["elementID(n)"]
# Add the embedding properties to the node
if node.embedding_properties:
for prop, vector in node.embedding_properties.items():
await async_upsert_vector(
driver=self.driver,
node_id=node_id,
embedding_property=prop,
vector=vector,
neo4j_database=self.neo4j_database,
)

async def _upsert_relationship(
self, rel: Neo4jRelationship, sem: asyncio.Semaphore
) -> None:
"""Upserts a single relationship into the Neo4j database.
Args:
rel (Neo4jRelationship): The relationship to upsert into the database.
"""
async with sem:
# Create the initial relationship
parameters = {
"start_node_id": rel.start_node_id,
"end_node_id": rel.end_node_id,
}
if rel.properties:
properties = (
"{"
+ ", ".join(f"{key}: ${key}" for key in rel.properties.keys())
+ "}"
)
parameters.update(rel.properties)
else:
properties = "{}"
query = UPSERT_RELATIONSHIP_QUERY.format(
type=rel.type,
properties=properties,
)
result = await self.driver.execute_query(query, parameters_=parameters)
rel_id = result.records[0]["elementID(r)"]
# Add the embedding properties to the relationship
if rel.embedding_properties:
for prop, vector in rel.embedding_properties.items():
await async_upsert_vector_on_relationship(
driver=self.driver,
rel_id=rel_id,
embedding_property=prop,
vector=vector,
neo4j_database=self.neo4j_database,
)

@validate_call
async def run(self, graph: Neo4jGraph) -> KGWriterModel:
"""Upserts a knowledge graph into a Neo4j database.
Args:
graph (Neo4jGraph): The knowledge graph to upsert into the database.
"""
try:
sem = asyncio.Semaphore(self.max_concurrency)
node_tasks = [self._upsert_node(node, sem) for node in graph.nodes]
await asyncio.gather(*node_tasks)

rel_tasks = [
self._upsert_relationship(rel, sem) for rel in graph.relationships
]
await asyncio.gather(*rel_tasks)

return KGWriterModel(status="SUCCESS")
except neo4j.exceptions.ClientError as e:
logger.exception(e)
return KGWriterModel(status="FAILURE")
139 changes: 124 additions & 15 deletions src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
import neo4j
from pydantic import ValidationError

from neo4j_genai.neo4j_queries import (
UPSERT_VECTOR_ON_NODE_QUERY,
UPSERT_VECTOR_ON_RELATIONSHIP_QUERY,
)

from .exceptions import Neo4jIndexError, Neo4jInsertionError
from .types import FulltextIndexModel, VectorIndexModel

Expand Down Expand Up @@ -278,19 +283,14 @@ def upsert_vector(
Neo4jInsertionError: If upserting of the vector fails.
"""
try:
query = (
"MATCH (n) "
"WHERE elementId(n) = $id "
"WITH n "
"CALL db.create.setNodeVectorProperty(n, $embedding_property, $vector) "
"RETURN n"
)
parameters = {
"id": node_id,
"embedding_property": embedding_property,
"vector": vector,
}
driver.execute_query(query, parameters, database_=neo4j_database)
driver.execute_query(
UPSERT_VECTOR_ON_NODE_QUERY, parameters, database_=neo4j_database
)
except neo4j.exceptions.ClientError as e:
raise Neo4jInsertionError(
f"Upserting vector to Neo4j failed: {e.message}"
Expand Down Expand Up @@ -339,19 +339,128 @@ def upsert_vector_on_relationship(
Neo4jInsertionError: If upserting of the vector fails.
"""
try:
query = (
"MATCH ()-[r]->() "
"WHERE elementId(r) = $id "
"WITH r "
"CALL db.create.setRelationshipVectorProperty(r, $embedding_property, $vector) "
"RETURN r"
parameters = {
"id": rel_id,
"embedding_property": embedding_property,
"vector": vector,
}
driver.execute_query(
UPSERT_VECTOR_ON_RELATIONSHIP_QUERY, parameters, database_=neo4j_database
)
except neo4j.exceptions.ClientError as e:
raise Neo4jInsertionError(
f"Upserting vector to Neo4j failed: {e.message}"
) from e


async def async_upsert_vector(
driver: neo4j.AsyncDriver,
node_id: int,
embedding_property: str,
vector: list[float],
neo4j_database: Optional[str] = None,
) -> None:
"""
This method constructs a Cypher query and asynchronously executes it
to upsert (insert or update) a vector property on a specific node.
Example:
.. code-block:: python
from neo4j import AsyncGraphDatabase
from neo4j_genai.indexes import upsert_vector
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
# Connect to Neo4j database
driver = AsyncGraphDatabase.driver(URI, auth=AUTH)
# Upsert the vector data
async_upsert_vector(
driver,
node_id="nodeId",
embedding_property="vectorProperty",
vector=...,
)
Args:
driver (neo4j.AsyncDriver): Neo4j Python asynchronous driver instance.
node_id (int): The id of the node.
embedding_property (str): The name of the property to store the vector in.
vector (list[float]): The vector to store.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Raises:
Neo4jInsertionError: If upserting of the vector fails.
"""
try:
parameters = {
"id": node_id,
"embedding_property": embedding_property,
"vector": vector,
}
await driver.execute_query(
UPSERT_VECTOR_ON_NODE_QUERY, parameters, database_=neo4j_database
)
except neo4j.exceptions.ClientError as e:
raise Neo4jInsertionError(
f"Upserting vector to Neo4j failed: {e.message}"
) from e


async def async_upsert_vector_on_relationship(
driver: neo4j.AsyncDriver,
rel_id: int,
embedding_property: str,
vector: list[float],
neo4j_database: Optional[str] = None,
) -> None:
"""
This method constructs a Cypher query and asynchronously executes it
to upsert (insert or update) a vector property on a specific relationship.
Example:
.. code-block:: python
from neo4j import AsyncGraphDatabase
from neo4j_genai.indexes import upsert_vector_on_relationship
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
# Connect to Neo4j database
driver = AsyncGraphDatabase.driver(URI, auth=AUTH)
# Upsert the vector data
async_upsert_vector_on_relationship(
driver,
node_id="nodeId",
embedding_property="vectorProperty",
vector=...,
)
Args:
driver (neo4j.AsyncDriver): Neo4j Python asynchronous driver instance.
rel_id (int): The id of the relationship.
embedding_property (str): The name of the property to store the vector in.
vector (list[float]): The vector to store.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Raises:
Neo4jInsertionError: If upserting of the vector fails.
"""
try:
parameters = {
"id": rel_id,
"embedding_property": embedding_property,
"vector": vector,
}
driver.execute_query(query, parameters, database_=neo4j_database)
await driver.execute_query(
UPSERT_VECTOR_ON_RELATIONSHIP_QUERY, parameters, database_=neo4j_database
)
except neo4j.exceptions.ClientError as e:
raise Neo4jInsertionError(
f"Upserting vector to Neo4j failed: {e.message}"
Expand Down
16 changes: 16 additions & 0 deletions src/neo4j_genai/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@
"RETURN elementID(r)"
)

UPSERT_VECTOR_ON_NODE_QUERY = (
"MATCH (n) "
"WHERE elementId(n) = $id "
"WITH n "
"CALL db.create.setNodeVectorProperty(n, $embedding_property, $vector) "
"RETURN n"
)

UPSERT_VECTOR_ON_RELATIONSHIP_QUERY = (
"MATCH ()-[r]->() "
"WHERE elementId(r) = $id "
"WITH r "
"CALL db.create.setRelationshipVectorProperty(r, $embedding_property, $vector) "
"RETURN r"
)


def _get_hybrid_query() -> str:
return (
Expand Down

0 comments on commit b063107

Please sign in to comment.