diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py index a841b92b..e81d482c 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py @@ -39,7 +39,7 @@ async def define_and_run_pipeline( - neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface + neo4j_driver: neo4j.Driver, llm: LLMInterface ) -> PipelineResult: from neo4j_graphrag.experimental.pipeline import Pipeline @@ -131,11 +131,11 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.AsyncGraphDatabase.driver( + driver = neo4j.GraphDatabase.driver( "bolt://localhost:7687", auth=("neo4j", "password") ) res = await define_and_run_pipeline(driver, llm) - await driver.close() + driver.close() await llm.async_client.close() return res diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_text.py b/examples/customize/build_graph/pipeline/kg_builder_from_text.py index a3acd4e2..2beed124 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_text.py @@ -39,7 +39,7 @@ async def define_and_run_pipeline( - neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface + neo4j_driver: neo4j.Driver, llm: LLMInterface ) -> PipelineResult: """This is where we define and run the KG builder pipeline, instantiating a few components: @@ -148,11 +148,11 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.AsyncGraphDatabase.driver( + driver = neo4j.GraphDatabase.driver( "bolt://localhost:7687", auth=("neo4j", "password") ) res = await define_and_run_pipeline(driver, llm) - await driver.close() + driver.close() await llm.async_client.close() return res diff --git a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py index eb68b022..d6f5e9ae 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py +++ b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py @@ -40,7 +40,7 @@ async def define_and_run_pipeline( - neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface + neo4j_driver: neo4j.Driver, llm: LLMInterface ) -> None: """This is where we define and run the KG builder pipeline, instantiating a few components: @@ -144,11 +144,11 @@ async def main() -> None: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.AsyncGraphDatabase.driver( + driver = neo4j.GraphDatabase.driver( "bolt://localhost:7687", auth=("neo4j", "password") ) await define_and_run_pipeline(driver, llm) - await driver.close() + driver.close() await llm.async_client.close() diff --git a/examples/kg_builder.py b/examples/kg_builder.py index 25917101..650473e4 100644 --- a/examples/kg_builder.py +++ b/examples/kg_builder.py @@ -45,7 +45,7 @@ async def define_and_run_pipeline( - neo4j_driver: neo4j.AsyncDriver, llm: LLMInterface + neo4j_driver: neo4j.Driver, llm: LLMInterface ) -> PipelineResult: from neo4j_graphrag.experimental.pipeline import Pipeline @@ -137,11 +137,11 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - driver = neo4j.AsyncGraphDatabase.driver( + driver = neo4j.GraphDatabase.driver( "bolt://localhost:7687", auth=("neo4j", "password") ) res = await define_and_run_pipeline(driver, llm) - await driver.close() + driver.close() await llm.async_client.close() return res diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index b27ae0c8..e1316f11 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -14,8 +14,6 @@ # limitations under the License. from __future__ import annotations -import asyncio -import inspect import logging from abc import abstractmethod from typing import Any, Generator, Literal, Optional @@ -87,13 +85,13 @@ class Neo4jWriter(KGWriter): Args: driver (neo4j.driver): The Neo4j driver to connect to the database. neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided. - max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM. + batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000. Example: .. code-block:: python - from neo4j import AsyncGraphDatabase + from neo4j import GraphDatabase from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter from neo4j_graphrag.experimental.pipeline import Pipeline @@ -101,7 +99,7 @@ class Neo4jWriter(KGWriter): AUTH = ("neo4j", "password") DATABASE = "neo4j" - driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE) + driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE) pipeline = Pipeline() @@ -111,15 +109,13 @@ class Neo4jWriter(KGWriter): def __init__( self, - driver: neo4j.driver, + driver: neo4j.Driver, neo4j_database: Optional[str] = None, batch_size: int = 1000, - max_concurrency: int = 5, ): self.driver = driver self.neo4j_database = neo4j_database self.batch_size = batch_size - self.max_concurrency = max_concurrency self.is_version_5_23_or_above = self._check_if_version_5_23_or_above() def _db_setup(self) -> None: @@ -129,13 +125,6 @@ def _db_setup(self) -> None: "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)" ) - async def _async_db_setup(self) -> None: - # create index on __Entity__.id - # used when creating the relationships - await self.driver.execute_query( - "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)" - ) - @staticmethod def _nodes_to_rows( nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig @@ -166,23 +155,6 @@ def _upsert_nodes( else: self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters) - async def _async_upsert_nodes( - self, - nodes: list[Neo4jNode], - lexical_graph_config: LexicalGraphConfig, - sem: asyncio.Semaphore, - ) -> None: - """Asynchronously upserts a single node into the Neo4j database." - - Args: - nodes (list[Neo4jNode]): The nodes batch to upsert into the database. - """ - async with sem: - parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)} - await self.driver.execute_query( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters - ) - def _get_version(self) -> tuple[int, ...]: records, _, _ = self.driver.execute_query( "CALL dbms.components()", database_=self.neo4j_database @@ -220,26 +192,6 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None: else: self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters) - async def _async_upsert_relationships( - self, rels: list[Neo4jRelationship], sem: asyncio.Semaphore - ) -> None: - """Asynchronously upserts a single relationship into the Neo4j database. - - Args: - rels (list[Neo4jRelationship]): The relationships batch to upsert into the database. - """ - async with sem: - parameters = {"rows": [rel.model_dump() for rel in rels]} - if self.is_version_5_23_or_above: - await self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_=parameters, - ) - else: - await self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY, parameters_=parameters - ) - @validate_call async def run( self, @@ -253,28 +205,13 @@ async def run( lexical_graph_config (LexicalGraphConfig): """ try: - if inspect.iscoroutinefunction(self.driver.execute_query): - await self._async_db_setup() - sem = asyncio.Semaphore(self.max_concurrency) - node_tasks = [ - self._async_upsert_nodes(batch, lexical_graph_config, sem) - for batch in batched(graph.nodes, self.batch_size) - ] - await asyncio.gather(*node_tasks) - - rel_tasks = [ - self._async_upsert_relationships(batch, sem) - for batch in batched(graph.relationships, self.batch_size) - ] - await asyncio.gather(*rel_tasks) - else: - self._db_setup() - - for batch in batched(graph.nodes, self.batch_size): - self._upsert_nodes(batch, lexical_graph_config) - - for batch in batched(graph.relationships, self.batch_size): - self._upsert_relationships(batch) + self._db_setup() + + for batch in batched(graph.nodes, self.batch_size): + self._upsert_nodes(batch, lexical_graph_config) + + for batch in batched(graph.relationships, self.batch_size): + self._upsert_relationships(batch) return KGWriterModel( status="SUCCESS", diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index 1dee9e10..ce1e9eaa 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -56,14 +56,14 @@ class SinglePropertyExactMatchResolver(EntityResolver): .. code-block:: python - from neo4j import AsyncGraphDatabase + from neo4j import GraphDatabase from neo4j_graphrag.experimental.components.resolver import SinglePropertyExactMatchResolver URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") DATABASE = "neo4j" - driver = AsyncGraphDatabase.driver(URI, auth=AUTH, database=DATABASE) + driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) resolver = SinglePropertyExactMatchResolver(driver=driver, neo4j_database=DATABASE) await resolver.run() # no expected parameters diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index cdedb818..0e8fb0aa 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -228,52 +228,6 @@ async def test_run(_: Mock, driver: MagicMock) -> None: ) -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 22, 0), -) -@pytest.mark.asyncio -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._async_db_setup", - return_value=None, -) -async def test_run_async_driver(_: Mock, async_driver: MagicMock) -> None: - neo4j_writer = Neo4jWriter(driver=async_driver) - node = Neo4jNode(id="1", label="Label") - rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") - graph = Neo4jGraph(nodes=[node], relationships=[rel]) - await neo4j_writer.run(graph=graph) - async_driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, - parameters_={ - "rows": [ - { - "label": "Label", - "labels": ["Label", "__Entity__"], - "id": "1", - "properties": {}, - "embedding_properties": None, - } - ] - }, - ) - parameters_ = { - "rows": [ - { - "type": "RELATIONSHIP", - "start_node_id": "1", - "end_node_id": "2", - "properties": {}, - "embedding_properties": None, - } - ] - } - async_driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, - parameters_=parameters_, - ) - - @pytest.mark.asyncio @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._db_setup", @@ -367,101 +321,3 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters_, ) - - -@pytest.mark.asyncio -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._async_db_setup", - return_value=None, -) -async def test_run_async_driver_is_version_below_5_23(_: Mock) -> None: - async_driver = MagicMock() - async_driver.execute_query = Mock( - return_value=([{"versions": ["5.22.0"]}], None, None) - ) - - neo4j_writer = Neo4jWriter(driver=async_driver) - - node = Neo4jNode(id="1", label="Label") - rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") - graph = Neo4jGraph(nodes=[node], relationships=[rel]) - await neo4j_writer.run(graph=graph) - - async_driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, - parameters_={ - "rows": [ - { - "label": "Label", - "labels": ["Label", "__Entity__"], - "id": "1", - "properties": {}, - "embedding_properties": None, - } - ] - }, - ) - parameters_ = { - "rows": [ - { - "type": "RELATIONSHIP", - "start_node_id": "1", - "end_node_id": "2", - "properties": {}, - "embedding_properties": None, - } - ] - } - async_driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, - parameters_=parameters_, - ) - - -@pytest.mark.asyncio -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._async_db_setup", - return_value=None, -) -async def test_run_async_driver_is_version_5_23_or_above(_: Mock) -> None: - async_driver = MagicMock() - async_driver.execute_query = Mock( - return_value=([{"versions": ["5.23.0"]}], None, None) - ) - - neo4j_writer = Neo4jWriter(driver=async_driver) - - node = Neo4jNode(id="1", label="Label") - rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") - graph = Neo4jGraph(nodes=[node], relationships=[rel]) - await neo4j_writer.run(graph=graph) - - async_driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_={ - "rows": [ - { - "label": "Label", - "labels": ["Label", "__Entity__"], - "id": "1", - "properties": {}, - "embedding_properties": None, - } - ] - }, - ) - parameters_ = { - "rows": [ - { - "type": "RELATIONSHIP", - "start_node_id": "1", - "end_node_id": "2", - "properties": {}, - "embedding_properties": None, - } - ] - } - async_driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_=parameters_, - )