From 4e04d0489d1a78072333b8a23b308e0cbefcdf6f Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 6 Dec 2024 10:06:38 +0100 Subject: [PATCH] Close instantiated drivers --- .../pipeline/config/pipeline_config.py | 4 ++++ .../experimental/pipeline/config/runner.py | 17 +++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py index 83703cdd..6eb01048 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py @@ -144,6 +144,10 @@ def parse( def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: return user_input + def close(self) -> None: + for driver in self._global_data.get("neo4j_config", {}).values(): + driver.close() + def get_neo4j_driver_by_name(self, name: str) -> neo4j.Driver: drivers: dict[str, neo4j.Driver] = self._global_data.get("neo4j_config", {}) return drivers[name] diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index fdb01edc..be7bff6d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -89,22 +89,24 @@ def __init__( self, pipeline_definition: PipelineDefinition, config: AbstractPipelineConfig | None = None, + do_cleaning: bool = False, ) -> None: self.config = config self.pipeline = Pipeline.from_definition(pipeline_definition) self.run_params = pipeline_definition.get_run_params() + self.do_cleaning = do_cleaning @classmethod - def from_config(cls, config: AbstractPipelineConfig | dict[str, Any]) -> Self: + def from_config(cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False) -> Self: wrapper = PipelineConfigWrapper.model_validate({"config": config}) - return cls(wrapper.parse(), config=wrapper.config) + return cls(wrapper.parse(), config=wrapper.config, do_cleaning=do_cleaning) @classmethod def from_config_file(cls, file_path: Union[str, Path]) -> Self: if not isinstance(file_path, str): file_path = str(file_path) data = ConfigReader().read(file_path) - return cls.from_config(data) + return cls.from_config(data, do_cleaning=True) async def run(self, user_input: dict[str, Any]) -> PipelineResult: # pipeline_conditional_run_params = self. @@ -114,4 +116,11 @@ async def run(self, user_input: dict[str, Any]) -> PipelineResult: ) else: run_param = deep_update(self.run_params, user_input) - return await self.pipeline.run(data=run_param) + result = await self.pipeline.run(data=run_param) + if self.do_cleaning: + self.close() + return result + + def close(self) -> None: + if self.config: + self.config.close()