From 57529d4bb6972a5e96ece7d51255c46e568d1ed4 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 23 Oct 2024 17:34:48 +0100 Subject: [PATCH] Remove async driver support from the KG creation pipeline (#201) * Removed async driver support from Neo4jWriter * Removes neo4j_graphrag.utils.execute_query * Actually removes neo4j_graphrag.utils.execute_query * Updated CHANGELOG * Removed references to max_concurrency --- CHANGELOG.md | 2 + docs/source/user_guide_kg_builder.rst | 6 +- .../components/resolvers/custom_resolver.py | 4 +- .../components/writers/neo4j_writer.py | 3 +- .../pipeline/kg_builder_from_pdf.py | 6 +- .../pipeline/kg_builder_from_text.py | 6 +- ...builder_two_documents_entity_resolution.py | 6 +- examples/kg_builder.py | 6 +- .../experimental/components/kg_writer.py | 85 ++--------- .../experimental/components/resolver.py | 23 +-- src/neo4j_graphrag/utils.py | 14 +- tests/unit/conftest.py | 5 - .../experimental/components/test_kg_writer.py | 144 ------------------ 13 files changed, 39 insertions(+), 271 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 515d467f..42aaa9eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ ### Changed - Vector and Hybrid retrievers used with `return_properties` now also return the node labels (`nodeLabels`) and the node's element ID (`id`). - `HybridRetriever` now filters out the embedding property index in `self.vector_index_name` from the retriever result by default. +- Removed support for neo4j.AsyncDriver in the KG creation pipeline, affecting Neo4jWriter and related components. +- Updated examples and unit tests to reflect the removal of async driver support. ## 1.1.0 diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index ab90dba1..206ef226 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -433,10 +433,8 @@ to a Neo4j database: graph = Neo4jGraph(nodes=[], relationships=[]) await writer.run(graph) -To improve insert performances, it is possible to act on two parameters: - -- `batch_size`: the number of nodes/relationships to be processed in each batch (default is 1000). -- `max_concurrency`: the max number of concurrent queries (default is 5). +Adjust the batch_size parameter of `Neo4jWriter` to optimize insert performance. +This parameter controls the number of nodes or relationships inserted per batch, with a default value of 1000. See :ref:`neo4jgraph`. diff --git a/examples/customize/build_graph/components/resolvers/custom_resolver.py b/examples/customize/build_graph/components/resolvers/custom_resolver.py index b7f1bf4f..26375f18 100644 --- a/examples/customize/build_graph/components/resolvers/custom_resolver.py +++ b/examples/customize/build_graph/components/resolvers/custom_resolver.py @@ -2,7 +2,7 @@ a specific signature for the run method, which makes it very flexible. """ -from typing import Any, Optional, Union +from typing import Any, Optional import neo4j from neo4j_graphrag.experimental.components.resolver import EntityResolver @@ -12,7 +12,7 @@ class MyEntityResolver(EntityResolver): def __init__( self, - driver: Union[neo4j.Driver, neo4j.AsyncDriver], + driver: neo4j.Driver, filter_query: Optional[str] = None, ) -> None: super().__init__(driver, filter_query) diff --git a/examples/customize/build_graph/components/writers/neo4j_writer.py b/examples/customize/build_graph/components/writers/neo4j_writer.py index f85acc20..60dbcea9 100644 --- a/examples/customize/build_graph/components/writers/neo4j_writer.py +++ b/examples/customize/build_graph/components/writers/neo4j_writer.py @@ -11,10 +11,9 @@ async def main(driver: neo4j.Driver, graph: Neo4jGraph) -> KGWriterModel: driver, # optionally, configure the neo4j database # neo4j_database="neo4j", - # you can tune batch_size and max_concurrency to + # you can tune batch_size to # improve speed # batch_size=1000, - # max_concurrency=5, ) result = await writer.run(graph=graph) return result diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py index a841b92b..e81d482c 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py @@ -39,7 +39,7 @@ async def define_and_run_pipeline( - neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface + neo4j_driver: neo4j.Driver, llm: LLMInterface ) -> PipelineResult: from neo4j_graphrag.experimental.pipeline import Pipeline @@ -131,11 +131,11 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.AsyncGraphDatabase.driver( + driver = neo4j.GraphDatabase.driver( "bolt://localhost:7687", auth=("neo4j", "password") ) res = await define_and_run_pipeline(driver, llm) - await driver.close() + driver.close() await llm.async_client.close() return res diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_text.py b/examples/customize/build_graph/pipeline/kg_builder_from_text.py index a3acd4e2..2beed124 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_text.py @@ -39,7 +39,7 @@ async def define_and_run_pipeline( - neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface + neo4j_driver: neo4j.Driver, llm: LLMInterface ) -> PipelineResult: """This is where we define and run the KG builder pipeline, instantiating a few components: @@ -148,11 +148,11 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.AsyncGraphDatabase.driver( + driver = neo4j.GraphDatabase.driver( "bolt://localhost:7687", auth=("neo4j", "password") ) res = await define_and_run_pipeline(driver, llm) - await driver.close() + driver.close() await llm.async_client.close() return res diff --git a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py index eb68b022..d6f5e9ae 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py +++ b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py @@ -40,7 +40,7 @@ async def define_and_run_pipeline( - neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface + neo4j_driver: neo4j.Driver, llm: LLMInterface ) -> None: """This is where we define and run the KG builder pipeline, instantiating a few components: @@ -144,11 +144,11 @@ async def main() -> None: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.AsyncGraphDatabase.driver( + driver = neo4j.GraphDatabase.driver( "bolt://localhost:7687", auth=("neo4j", "password") ) await define_and_run_pipeline(driver, llm) - await driver.close() + driver.close() await llm.async_client.close() diff --git a/examples/kg_builder.py b/examples/kg_builder.py index 25917101..650473e4 100644 --- a/examples/kg_builder.py +++ b/examples/kg_builder.py @@ -45,7 +45,7 @@ async def define_and_run_pipeline( - neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface + neo4j_driver: neo4j.Driver, llm: LLMInterface ) -> PipelineResult: from neo4j_graphrag.experimental.pipeline import Pipeline @@ -137,11 +137,11 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.AsyncGraphDatabase.driver( + driver = neo4j.GraphDatabase.driver( "bolt://localhost:7687", auth=("neo4j", "password") ) res = await define_and_run_pipeline(driver, llm) - await driver.close() + driver.close() await llm.async_client.close() return res diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index b27ae0c8..e1316f11 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -14,8 +14,6 @@ # limitations under the License. from __future__ import annotations -import asyncio -import inspect import logging from abc import abstractmethod from typing import Any, Generator, Literal, Optional @@ -87,13 +85,13 @@ class Neo4jWriter(KGWriter): Args: driver (neo4j.driver): The Neo4j driver to connect to the database. neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided. - max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM. + batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000. Example: .. code-block:: python - from neo4j import AsyncGraphDatabase + from neo4j import GraphDatabase from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter from neo4j_graphrag.experimental.pipeline import Pipeline @@ -101,7 +99,7 @@ class Neo4jWriter(KGWriter): AUTH = ("neo4j", "password") DATABASE = "neo4j" - driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE) + driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE) pipeline = Pipeline() @@ -111,15 +109,13 @@ class Neo4jWriter(KGWriter): def __init__( self, - driver: neo4j.driver, + driver: neo4j.Driver, neo4j_database: Optional[str] = None, batch_size: int = 1000, - max_concurrency: int = 5, ): self.driver = driver self.neo4j_database = neo4j_database self.batch_size = batch_size - self.max_concurrency = max_concurrency self.is_version_5_23_or_above = self._check_if_version_5_23_or_above() def _db_setup(self) -> None: @@ -129,13 +125,6 @@ def _db_setup(self) -> None: "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)" ) - async def _async_db_setup(self) -> None: - # create index on __Entity__.id - # used when creating the relationships - await self.driver.execute_query( - "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)" - ) - @staticmethod def _nodes_to_rows( nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig @@ -166,23 +155,6 @@ def _upsert_nodes( else: self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters) - async def _async_upsert_nodes( - self, - nodes: list[Neo4jNode], - lexical_graph_config: LexicalGraphConfig, - sem: asyncio.Semaphore, - ) -> None: - """Asynchronously upserts a single node into the Neo4j database." - - Args: - nodes (list[Neo4jNode]): The nodes batch to upsert into the database. - """ - async with sem: - parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)} - await self.driver.execute_query( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters - ) - def _get_version(self) -> tuple[int, ...]: records, _, _ = self.driver.execute_query( "CALL dbms.components()", database_=self.neo4j_database @@ -220,26 +192,6 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None: else: self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters) - async def _async_upsert_relationships( - self, rels: list[Neo4jRelationship], sem: asyncio.Semaphore - ) -> None: - """Asynchronously upserts a single relationship into the Neo4j database. - - Args: - rels (list[Neo4jRelationship]): The relationships batch to upsert into the database. - """ - async with sem: - parameters = {"rows": [rel.model_dump() for rel in rels]} - if self.is_version_5_23_or_above: - await self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_=parameters, - ) - else: - await self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY, parameters_=parameters - ) - @validate_call async def run( self, @@ -253,28 +205,13 @@ async def run( lexical_graph_config (LexicalGraphConfig): """ try: - if inspect.iscoroutinefunction(self.driver.execute_query): - await self._async_db_setup() - sem = asyncio.Semaphore(self.max_concurrency) - node_tasks = [ - self._async_upsert_nodes(batch, lexical_graph_config, sem) - for batch in batched(graph.nodes, self.batch_size) - ] - await asyncio.gather(*node_tasks) - - rel_tasks = [ - self._async_upsert_relationships(batch, sem) - for batch in batched(graph.relationships, self.batch_size) - ] - await asyncio.gather(*rel_tasks) - else: - self._db_setup() - - for batch in batched(graph.nodes, self.batch_size): - self._upsert_nodes(batch, lexical_graph_config) - - for batch in batched(graph.relationships, self.batch_size): - self._upsert_relationships(batch) + self._db_setup() + + for batch in batched(graph.nodes, self.batch_size): + self._upsert_nodes(batch, lexical_graph_config) + + for batch in batched(graph.relationships, self.batch_size): + self._upsert_relationships(batch) return KGWriterModel( status="SUCCESS", diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index 1dee9e10..d9fa8f99 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import Any, Optional, Union +from typing import Any, Optional import neo4j from neo4j_graphrag.experimental.components.types import ResolutionStats from neo4j_graphrag.experimental.pipeline import Component -from neo4j_graphrag.utils import execute_query class EntityResolver(Component, abc.ABC): @@ -32,7 +31,7 @@ class EntityResolver(Component, abc.ABC): def __init__( self, - driver: Union[neo4j.Driver, neo4j.AsyncDriver], + driver: neo4j.Driver, filter_query: Optional[str] = None, ) -> None: self.driver = driver @@ -56,14 +55,14 @@ class SinglePropertyExactMatchResolver(EntityResolver): .. code-block:: python - from neo4j import AsyncGraphDatabase + from neo4j import GraphDatabase from neo4j_graphrag.experimental.components.resolver import SinglePropertyExactMatchResolver URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") DATABASE = "neo4j" - driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE) + driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) resolver = SinglePropertyExactMatchResolver(driver=driver, neo4j_database=DATABASE) await resolver.run() # no expected parameters @@ -71,7 +70,7 @@ class SinglePropertyExactMatchResolver(EntityResolver): def __init__( self, - driver: Union[neo4j.Driver, neo4j.AsyncDriver], + driver: neo4j.Driver, filter_query: Optional[str] = None, resolve_property: str = "name", neo4j_database: Optional[str] = None, @@ -94,11 +93,7 @@ async def run(self) -> ResolutionStats: if self.filter_query: match_query += self.filter_query stat_query = f"{match_query} RETURN count(entity) as c" - records, _, _ = await execute_query( - self.driver, - stat_query, - database_=self.database, - ) + records, _, _ = self.driver.execute_query(stat_query, database_=self.database) number_of_nodes_to_resolve = records[0].get("c") if number_of_nodes_to_resolve == 0: return ResolutionStats( @@ -130,10 +125,8 @@ async def run(self) -> ResolutionStats: "YIELD node " "RETURN count(node) as c " ) - records, _, _ = await execute_query( - self.driver, - merge_nodes_query, - database_=self.database, + records, _, _ = self.driver.execute_query( + merge_nodes_query, database_=self.database ) number_of_created_nodes = records[0].get("c") return ResolutionStats( diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index 77109046..e86f7588 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -14,9 +14,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional - -import neo4j +from typing import Optional def validate_search_query_input( @@ -24,13 +22,3 @@ def validate_search_query_input( ) -> None: if not (bool(query_vector) ^ bool(query_text)): raise ValueError("You must provide exactly one of query_vector or query_text.") - - -async def execute_query( - driver: neo4j.Driver | neo4j.AsyncDriver, query: str, **kwargs: Any -) -> Any: - if isinstance(driver, neo4j.AsyncDriver): - result = await driver.execute_query(query, **kwargs) - else: - result = driver.execute_query(query, **kwargs) - return result diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index f9f8ef09..829cad23 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -34,11 +34,6 @@ def driver() -> MagicMock: return MagicMock(spec=neo4j.Driver) -@pytest.fixture(scope="function") -def async_driver() -> MagicMock: - return MagicMock(spec=neo4j.AsyncDriver) - - @pytest.fixture(scope="function") def embedder() -> MagicMock: return MagicMock(spec=Embedder) diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index cdedb818..0e8fb0aa 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -228,52 +228,6 @@ async def test_run(_: Mock, driver: MagicMock) -> None: ) -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 22, 0), -) -@pytest.mark.asyncio -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._async_db_setup", - return_value=None, -) -async def test_run_async_driver(_: Mock, async_driver: MagicMock) -> None: - neo4j_writer = Neo4jWriter(driver=async_driver) - node = Neo4jNode(id="1", label="Label") - rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") - graph = Neo4jGraph(nodes=[node], relationships=[rel]) - await neo4j_writer.run(graph=graph) - async_driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, - parameters_={ - "rows": [ - { - "label": "Label", - "labels": ["Label", "__Entity__"], - "id": "1", - "properties": {}, - "embedding_properties": None, - } - ] - }, - ) - parameters_ = { - "rows": [ - { - "type": "RELATIONSHIP", - "start_node_id": "1", - "end_node_id": "2", - "properties": {}, - "embedding_properties": None, - } - ] - } - async_driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, - parameters_=parameters_, - ) - - @pytest.mark.asyncio @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._db_setup", @@ -367,101 +321,3 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters_, ) - - -@pytest.mark.asyncio -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._async_db_setup", - return_value=None, -) -async def test_run_async_driver_is_version_below_5_23(_: Mock) -> None: - async_driver = MagicMock() - async_driver.execute_query = Mock( - return_value=([{"versions": ["5.22.0"]}], None, None) - ) - - neo4j_writer = Neo4jWriter(driver=async_driver) - - node = Neo4jNode(id="1", label="Label") - rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") - graph = Neo4jGraph(nodes=[node], relationships=[rel]) - await neo4j_writer.run(graph=graph) - - async_driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, - parameters_={ - "rows": [ - { - "label": "Label", - "labels": ["Label", "__Entity__"], - "id": "1", - "properties": {}, - "embedding_properties": None, - } - ] - }, - ) - parameters_ = { - "rows": [ - { - "type": "RELATIONSHIP", - "start_node_id": "1", - "end_node_id": "2", - "properties": {}, - "embedding_properties": None, - } - ] - } - async_driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, - parameters_=parameters_, - ) - - -@pytest.mark.asyncio -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._async_db_setup", - return_value=None, -) -async def test_run_async_driver_is_version_5_23_or_above(_: Mock) -> None: - async_driver = MagicMock() - async_driver.execute_query = Mock( - return_value=([{"versions": ["5.23.0"]}], None, None) - ) - - neo4j_writer = Neo4jWriter(driver=async_driver) - - node = Neo4jNode(id="1", label="Label") - rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") - graph = Neo4jGraph(nodes=[node], relationships=[rel]) - await neo4j_writer.run(graph=graph) - - async_driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_={ - "rows": [ - { - "label": "Label", - "labels": ["Label", "__Entity__"], - "id": "1", - "properties": {}, - "embedding_properties": None, - } - ] - }, - ) - parameters_ = { - "rows": [ - { - "type": "RELATIONSHIP", - "start_node_id": "1", - "end_node_id": "2", - "properties": {}, - "embedding_properties": None, - } - ] - } - async_driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_=parameters_, - )