Skip to content

Commit

Permalink
Close instantiated drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Dec 6, 2024
1 parent ce21353 commit 4e04d04
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def parse(
def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
return user_input

def close(self) -> None:
for driver in self._global_data.get("neo4j_config", {}).values():
driver.close()

def get_neo4j_driver_by_name(self, name: str) -> neo4j.Driver:
drivers: dict[str, neo4j.Driver] = self._global_data.get("neo4j_config", {})
return drivers[name]
Expand Down
17 changes: 13 additions & 4 deletions src/neo4j_graphrag/experimental/pipeline/config/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,24 @@ def __init__(
self,
pipeline_definition: PipelineDefinition,
config: AbstractPipelineConfig | None = None,
do_cleaning: bool = False,
) -> None:
self.config = config
self.pipeline = Pipeline.from_definition(pipeline_definition)
self.run_params = pipeline_definition.get_run_params()
self.do_cleaning = do_cleaning

@classmethod
def from_config(cls, config: AbstractPipelineConfig | dict[str, Any]) -> Self:
def from_config(cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False) -> Self:
wrapper = PipelineConfigWrapper.model_validate({"config": config})
return cls(wrapper.parse(), config=wrapper.config)
return cls(wrapper.parse(), config=wrapper.config, do_cleaning=do_cleaning)

@classmethod
def from_config_file(cls, file_path: Union[str, Path]) -> Self:
if not isinstance(file_path, str):
file_path = str(file_path)
data = ConfigReader().read(file_path)
return cls.from_config(data)
return cls.from_config(data, do_cleaning=True)

async def run(self, user_input: dict[str, Any]) -> PipelineResult:
# pipeline_conditional_run_params = self.
Expand All @@ -114,4 +116,11 @@ async def run(self, user_input: dict[str, Any]) -> PipelineResult:
)
else:
run_param = deep_update(self.run_params, user_input)
return await self.pipeline.run(data=run_param)
result = await self.pipeline.run(data=run_param)
if self.do_cleaning:
self.close()
return result

def close(self) -> None:
if self.config:
self.config.close()

0 comments on commit 4e04d04

Please sign in to comment.