Skip to content

Commit

Permalink
Remove async driver support from the KG creation pipeline (neo4j#201)
Browse files Browse the repository at this point in the history
* Removed async driver support from Neo4jWriter

* Removes neo4j_graphrag.utils.execute_query

* Actually removes neo4j_graphrag.utils.execute_query

* Updated CHANGELOG

* Removed references to max_concurrency
  • Loading branch information
alexthomas93 authored Oct 23, 2024
1 parent 9391662 commit 57529d4
Show file tree
Hide file tree
Showing 13 changed files with 39 additions and 271 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
### Changed
- Vector and Hybrid retrievers used with `return_properties` now also return the node labels (`nodeLabels`) and the node's element ID (`id`).
- `HybridRetriever` now filters out the embedding property index in `self.vector_index_name` from the retriever result by default.
- Removed support for neo4j.AsyncDriver in the KG creation pipeline, affecting Neo4jWriter and related components.
- Updated examples and unit tests to reflect the removal of async driver support.


## 1.1.0
Expand Down
6 changes: 2 additions & 4 deletions docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,8 @@ to a Neo4j database:
graph = Neo4jGraph(nodes=[], relationships=[])
await writer.run(graph)
To improve insert performances, it is possible to act on two parameters:

- `batch_size`: the number of nodes/relationships to be processed in each batch (default is 1000).
- `max_concurrency`: the max number of concurrent queries (default is 5).
Adjust the batch_size parameter of `Neo4jWriter` to optimize insert performance.
This parameter controls the number of nodes or relationships inserted per batch, with a default value of 1000.

See :ref:`neo4jgraph`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
a specific signature for the run method, which makes it very flexible.
"""

from typing import Any, Optional, Union
from typing import Any, Optional

import neo4j
from neo4j_graphrag.experimental.components.resolver import EntityResolver
Expand All @@ -12,7 +12,7 @@
class MyEntityResolver(EntityResolver):
def __init__(
self,
driver: Union[neo4j.Driver, neo4j.AsyncDriver],
driver: neo4j.Driver,
filter_query: Optional[str] = None,
) -> None:
super().__init__(driver, filter_query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ async def main(driver: neo4j.Driver, graph: Neo4jGraph) -> KGWriterModel:
driver,
# optionally, configure the neo4j database
# neo4j_database="neo4j",
# you can tune batch_size and max_concurrency to
# you can tune batch_size to
# improve speed
# batch_size=1000,
# max_concurrency=5,
)
result = await writer.run(graph=graph)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


async def define_and_run_pipeline(
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> PipelineResult:
from neo4j_graphrag.experimental.pipeline import Pipeline

Expand Down Expand Up @@ -131,11 +131,11 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.AsyncGraphDatabase.driver(
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
res = await define_and_run_pipeline(driver, llm)
await driver.close()
driver.close()
await llm.async_client.close()
return res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


async def define_and_run_pipeline(
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> PipelineResult:
"""This is where we define and run the KG builder pipeline, instantiating a few
components:
Expand Down Expand Up @@ -148,11 +148,11 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.AsyncGraphDatabase.driver(
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
res = await define_and_run_pipeline(driver, llm)
await driver.close()
driver.close()
await llm.async_client.close()
return res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


async def define_and_run_pipeline(
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> None:
"""This is where we define and run the KG builder pipeline, instantiating a few
components:
Expand Down Expand Up @@ -144,11 +144,11 @@ async def main() -> None:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.AsyncGraphDatabase.driver(
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
await define_and_run_pipeline(driver, llm)
await driver.close()
driver.close()
await llm.async_client.close()


Expand Down
6 changes: 3 additions & 3 deletions examples/kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@


async def define_and_run_pipeline(
neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> PipelineResult:
from neo4j_graphrag.experimental.pipeline import Pipeline

Expand Down Expand Up @@ -137,11 +137,11 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.AsyncGraphDatabase.driver(
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
res = await define_and_run_pipeline(driver, llm)
await driver.close()
driver.close()
await llm.async_client.close()
return res

Expand Down
85 changes: 11 additions & 74 deletions src/neo4j_graphrag/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
from __future__ import annotations

import asyncio
import inspect
import logging
from abc import abstractmethod
from typing import Any, Generator, Literal, Optional
Expand Down Expand Up @@ -87,21 +85,21 @@ class Neo4jWriter(KGWriter):
Args:
driver (neo4j.driver): The Neo4j driver to connect to the database.
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000.
Example:
.. code-block:: python
from neo4j import AsyncGraphDatabase
from neo4j import GraphDatabase
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
from neo4j_graphrag.experimental.pipeline import Pipeline
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
pipeline = Pipeline()
Expand All @@ -111,15 +109,13 @@ class Neo4jWriter(KGWriter):

def __init__(
self,
driver: neo4j.driver,
driver: neo4j.Driver,
neo4j_database: Optional[str] = None,
batch_size: int = 1000,
max_concurrency: int = 5,
):
self.driver = driver
self.neo4j_database = neo4j_database
self.batch_size = batch_size
self.max_concurrency = max_concurrency
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()

def _db_setup(self) -> None:
Expand All @@ -129,13 +125,6 @@ def _db_setup(self) -> None:
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
)

async def _async_db_setup(self) -> None:
# create index on __Entity__.id
# used when creating the relationships
await self.driver.execute_query(
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
)

@staticmethod
def _nodes_to_rows(
nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig
Expand Down Expand Up @@ -166,23 +155,6 @@ def _upsert_nodes(
else:
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)

async def _async_upsert_nodes(
self,
nodes: list[Neo4jNode],
lexical_graph_config: LexicalGraphConfig,
sem: asyncio.Semaphore,
) -> None:
"""Asynchronously upserts a single node into the Neo4j database."
Args:
nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
"""
async with sem:
parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)}
await self.driver.execute_query(
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
)

def _get_version(self) -> tuple[int, ...]:
records, _, _ = self.driver.execute_query(
"CALL dbms.components()", database_=self.neo4j_database
Expand Down Expand Up @@ -220,26 +192,6 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
else:
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)

async def _async_upsert_relationships(
self, rels: list[Neo4jRelationship], sem: asyncio.Semaphore
) -> None:
"""Asynchronously upserts a single relationship into the Neo4j database.
Args:
rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
"""
async with sem:
parameters = {"rows": [rel.model_dump() for rel in rels]}
if self.is_version_5_23_or_above:
await self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
parameters_=parameters,
)
else:
await self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY, parameters_=parameters
)

@validate_call
async def run(
self,
Expand All @@ -253,28 +205,13 @@ async def run(
lexical_graph_config (LexicalGraphConfig):
"""
try:
if inspect.iscoroutinefunction(self.driver.execute_query):
await self._async_db_setup()
sem = asyncio.Semaphore(self.max_concurrency)
node_tasks = [
self._async_upsert_nodes(batch, lexical_graph_config, sem)
for batch in batched(graph.nodes, self.batch_size)
]
await asyncio.gather(*node_tasks)

rel_tasks = [
self._async_upsert_relationships(batch, sem)
for batch in batched(graph.relationships, self.batch_size)
]
await asyncio.gather(*rel_tasks)
else:
self._db_setup()

for batch in batched(graph.nodes, self.batch_size):
self._upsert_nodes(batch, lexical_graph_config)

for batch in batched(graph.relationships, self.batch_size):
self._upsert_relationships(batch)
self._db_setup()

for batch in batched(graph.nodes, self.batch_size):
self._upsert_nodes(batch, lexical_graph_config)

for batch in batched(graph.relationships, self.batch_size):
self._upsert_relationships(batch)

return KGWriterModel(
status="SUCCESS",
Expand Down
23 changes: 8 additions & 15 deletions src/neo4j_graphrag/experimental/components/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any, Optional, Union
from typing import Any, Optional

import neo4j

from neo4j_graphrag.experimental.components.types import ResolutionStats
from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.utils import execute_query


class EntityResolver(Component, abc.ABC):
Expand All @@ -32,7 +31,7 @@ class EntityResolver(Component, abc.ABC):

def __init__(
self,
driver: Union[neo4j.Driver, neo4j.AsyncDriver],
driver: neo4j.Driver,
filter_query: Optional[str] = None,
) -> None:
self.driver = driver
Expand All @@ -56,22 +55,22 @@ class SinglePropertyExactMatchResolver(EntityResolver):
.. code-block:: python
from neo4j import AsyncGraphDatabase
from neo4j import GraphDatabase
from neo4j_graphrag.experimental.components.resolver import SinglePropertyExactMatchResolver
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
resolver = SinglePropertyExactMatchResolver(driver=driver, neo4j_database=DATABASE)
await resolver.run() # no expected parameters
"""

def __init__(
self,
driver: Union[neo4j.Driver, neo4j.AsyncDriver],
driver: neo4j.Driver,
filter_query: Optional[str] = None,
resolve_property: str = "name",
neo4j_database: Optional[str] = None,
Expand All @@ -94,11 +93,7 @@ async def run(self) -> ResolutionStats:
if self.filter_query:
match_query += self.filter_query
stat_query = f"{match_query} RETURN count(entity) as c"
records, _, _ = await execute_query(
self.driver,
stat_query,
database_=self.database,
)
records, _, _ = self.driver.execute_query(stat_query, database_=self.database)
number_of_nodes_to_resolve = records[0].get("c")
if number_of_nodes_to_resolve == 0:
return ResolutionStats(
Expand Down Expand Up @@ -130,10 +125,8 @@ async def run(self) -> ResolutionStats:
"YIELD node "
"RETURN count(node) as c "
)
records, _, _ = await execute_query(
self.driver,
merge_nodes_query,
database_=self.database,
records, _, _ = self.driver.execute_query(
merge_nodes_query, database_=self.database
)
number_of_created_nodes = records[0].get("c")
return ResolutionStats(
Expand Down
14 changes: 1 addition & 13 deletions src/neo4j_graphrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,11 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Optional

import neo4j
from typing import Optional


def validate_search_query_input(
query_text: Optional[str] = None, query_vector: Optional[list[float]] = None
) -> None:
if not (bool(query_vector) ^ bool(query_text)):
raise ValueError("You must provide exactly one of query_vector or query_text.")


async def execute_query(
driver: neo4j.Driver | neo4j.AsyncDriver, query: str, **kwargs: Any
) -> Any:
if isinstance(driver, neo4j.AsyncDriver):
result = await driver.execute_query(query, **kwargs)
else:
result = driver.execute_query(query, **kwargs)
return result
Loading

0 comments on commit 57529d4

Please sign in to comment.