Skip to content

Commit

Permalink
Removed async driver support from Neo4jWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Oct 23, 2024
1 parent 9391662 commit 8ac25bc
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


async def define_and_run_pipeline(
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> PipelineResult:
from neo4j_graphrag.experimental.pipeline import Pipeline

Expand Down Expand Up @@ -131,11 +131,11 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.AsyncGraphDatabase.driver(
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
res = await define_and_run_pipeline(driver, llm)
await driver.close()
driver.close()
await llm.async_client.close()
return res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


async def define_and_run_pipeline(
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> PipelineResult:
"""This is where we define and run the KG builder pipeline, instantiating a few
components:
Expand Down Expand Up @@ -148,11 +148,11 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.AsyncGraphDatabase.driver(
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
res = await define_and_run_pipeline(driver, llm)
await driver.close()
driver.close()
await llm.async_client.close()
return res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


async def define_and_run_pipeline(
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> None:
"""This is where we define and run the KG builder pipeline, instantiating a few
components:
Expand Down Expand Up @@ -144,11 +144,11 @@ async def main() -> None:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.AsyncGraphDatabase.driver(
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
await define_and_run_pipeline(driver, llm)
await driver.close()
driver.close()
await llm.async_client.close()


Expand Down
6 changes: 3 additions & 3 deletions examples/kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@


async def define_and_run_pipeline(
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> PipelineResult:
from neo4j_graphrag.experimental.pipeline import Pipeline

Expand Down Expand Up @@ -137,11 +137,11 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.AsyncGraphDatabase.driver(
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
res = await define_and_run_pipeline(driver, llm)
await driver.close()
driver.close()
await llm.async_client.close()
return res

Expand Down
85 changes: 11 additions & 74 deletions src/neo4j_graphrag/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
from __future__ import annotations

import asyncio
import inspect
import logging
from abc import abstractmethod
from typing import Any, Generator, Literal, Optional
Expand Down Expand Up @@ -87,21 +85,21 @@ class Neo4jWriter(KGWriter):
Args:
driver (neo4j.driver): The Neo4j driver to connect to the database.
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000.
Example:
.. code-block:: python
from neo4j import AsyncGraphDatabase
from neo4j import GraphDatabase
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
from neo4j_graphrag.experimental.pipeline import Pipeline
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
pipeline = Pipeline()
Expand All @@ -111,15 +109,13 @@ class Neo4jWriter(KGWriter):

def __init__(
self,
driver: neo4j.driver,
driver: neo4j.Driver,
neo4j_database: Optional[str] = None,
batch_size: int = 1000,
max_concurrency: int = 5,
):
self.driver = driver
self.neo4j_database = neo4j_database
self.batch_size = batch_size
self.max_concurrency = max_concurrency
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()

def _db_setup(self) -> None:
Expand All @@ -129,13 +125,6 @@ def _db_setup(self) -> None:
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
)

async def _async_db_setup(self) -> None:
# create index on __Entity__.id
# used when creating the relationships
await self.driver.execute_query(
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
)

@staticmethod
def _nodes_to_rows(
nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig
Expand Down Expand Up @@ -166,23 +155,6 @@ def _upsert_nodes(
else:
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)

async def _async_upsert_nodes(
self,
nodes: list[Neo4jNode],
lexical_graph_config: LexicalGraphConfig,
sem: asyncio.Semaphore,
) -> None:
"""Asynchronously upserts a single node into the Neo4j database."
Args:
nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
"""
async with sem:
parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)}
await self.driver.execute_query(
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
)

def _get_version(self) -> tuple[int, ...]:
records, _, _ = self.driver.execute_query(
"CALL dbms.components()", database_=self.neo4j_database
Expand Down Expand Up @@ -220,26 +192,6 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
else:
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)

async def _async_upsert_relationships(
self, rels: list[Neo4jRelationship], sem: asyncio.Semaphore
) -> None:
"""Asynchronously upserts a single relationship into the Neo4j database.
Args:
rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
"""
async with sem:
parameters = {"rows": [rel.model_dump() for rel in rels]}
if self.is_version_5_23_or_above:
await self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
parameters_=parameters,
)
else:
await self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY, parameters_=parameters
)

@validate_call
async def run(
self,
Expand All @@ -253,28 +205,13 @@ async def run(
lexical_graph_config (LexicalGraphConfig):
"""
try:
if inspect.iscoroutinefunction(self.driver.execute_query):
await self._async_db_setup()
sem = asyncio.Semaphore(self.max_concurrency)
node_tasks = [
self._async_upsert_nodes(batch, lexical_graph_config, sem)
for batch in batched(graph.nodes, self.batch_size)
]
await asyncio.gather(*node_tasks)

rel_tasks = [
self._async_upsert_relationships(batch, sem)
for batch in batched(graph.relationships, self.batch_size)
]
await asyncio.gather(*rel_tasks)
else:
self._db_setup()

for batch in batched(graph.nodes, self.batch_size):
self._upsert_nodes(batch, lexical_graph_config)

for batch in batched(graph.relationships, self.batch_size):
self._upsert_relationships(batch)
self._db_setup()

for batch in batched(graph.nodes, self.batch_size):
self._upsert_nodes(batch, lexical_graph_config)

for batch in batched(graph.relationships, self.batch_size):
self._upsert_relationships(batch)

return KGWriterModel(
status="SUCCESS",
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_graphrag/experimental/components/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ class SinglePropertyExactMatchResolver(EntityResolver):
.. code-block:: python
from neo4j import AsyncGraphDatabase
from neo4j import GraphDatabase
from neo4j_graphrag.experimental.components.resolver import SinglePropertyExactMatchResolver
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
resolver = SinglePropertyExactMatchResolver(driver=driver, neo4j_database=DATABASE)
await resolver.run() # no expected parameters
Expand Down
Loading

0 comments on commit 8ac25bc

Please sign in to comment.