diff --git a/examples/pipeline/kg_builder_example.py b/examples/pipeline/kg_builder_example.py index 03acb1ac..8dc28a2f 100644 --- a/examples/pipeline/kg_builder_example.py +++ b/examples/pipeline/kg_builder_example.py @@ -79,7 +79,8 @@ async def main(neo4j_driver: neo4j.Driver) -> None: # Run the knowledge graph building process with text input text_input = "John Doe lives in New York City." - text_result = await kg_builder_text.run_async(text=text_input) + # text_result = await kg_builder_text.run_async(text=text_input) + text_result = kg_builder_text.run(text=text_input) print(f"Text Processing Result: {text_result}") await llm.async_client.close() diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index 4e07a1d6..999b3084 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -19,7 +19,7 @@ from neo4j_graphrag.experimental.components.types import ResolutionStats from neo4j_graphrag.experimental.pipeline import Component -from neo4j_graphrag.utils import execute_query +from neo4j_graphrag.utils import execute_query, async_to_sync class EntityResolver(Component, abc.ABC): @@ -140,3 +140,5 @@ async def run(self) -> ResolutionStats: number_of_nodes_to_resolve=number_of_nodes_to_resolve, number_of_created_nodes=number_of_created_nodes, ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 84cd5bc0..efbd9bcb 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.utils import async_to_sync class DataModel(BaseModel): @@ -63,6 +64,8 @@ def __new__( } for f, field in return_model.model_fields.items() } + # create sync method: + attrs["run_sync"] = async_to_sync(run_method) return type.__new__(meta, name, bases, attrs) diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 9f7ff488..5d7d6bf0 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -43,6 +43,7 @@ from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.utils import run_sync class SimpleKGPipelineConfig(BaseModel): @@ -225,6 +226,10 @@ async def run_async( pipe_inputs = self._prepare_inputs(file_path=file_path, text=text) return await self.pipeline.run(pipe_inputs) + def run(self, file_path: Optional[str] = None, text: Optional[str] = None) -> PipelineResult: + """Run pipeline synchronously""" + return run_sync(self, file_path=file_path, text=text) + def _prepare_inputs( self, file_path: Optional[str], text: Optional[str] ) -> dict[str, Any]: diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 3d004eb8..a8a6efed 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -24,6 +24,8 @@ from timeit import default_timer from typing import Any, AsyncGenerator, Optional +from neo4j_graphrag.utils import async_to_sync + try: import pygraphviz as pgv except ImportError: @@ -105,6 +107,8 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None: logger.debug(f"TASK RESULT {self.name=} {res=}") return res + run_sync = async_to_sync(run) + class Orchestrator: """Orchestrate a pipeline. @@ -618,3 +622,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult: run_id=orchestrator.run_id, result=await self.final_results.get(orchestrator.run_id), ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index 5f1b322f..3a0f7419 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -15,7 +15,10 @@ from __future__ import annotations import inspect +from functools import wraps from typing import Any, Optional, Union +import asyncio +import concurrent.futures import neo4j @@ -37,3 +40,17 @@ async def execute_query( # but we're sure at this stage we do not have a coroutine anymore records, _, _ = driver.execute_query(query, **kwargs) # type: ignore[misc] return records # type: ignore[no-any-return] + + +def run_sync(function, *args, **kwargs): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(lambda: asyncio.run(function(*args, **kwargs))) + return_value = future.result() + return return_value + + +def async_to_sync(func): + @wraps(func) + def wrapper(*args, **kwargs): + return run_sync(func, *args, **kwargs) + return wrapper