Skip to content

Commit

Permalink
Adds basic TQDM progress bar to KG creation pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Aug 20, 2024
1 parent 77db5e2 commit f162016
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
4 changes: 3 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pinecone-client = {version = "^4.1.0", optional = true}
types-mock = "^5.1.0.20240425"
eval-type-backport = "^0.2.0"
pypdf = "^4.3.1"
tqdm = "^4.66.5"

[tool.poetry.group.dev.dependencies]
pylint = "^3.1.0"
Expand Down
8 changes: 6 additions & 2 deletions src/neo4j_genai/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any, AsyncGenerator, Awaitable, Callable, Optional

from pydantic import BaseModel, Field
from tqdm.asyncio import tqdm

from neo4j_genai.experimental.pipeline.component import Component, DataModel
from neo4j_genai.experimental.pipeline.exceptions import (
Expand Down Expand Up @@ -410,6 +411,7 @@ def on_task_complete(self, node: TaskPipelineNode, result: RunResult) -> None:
if result.result:
res_to_save = result.result.model_dump()
self.add_result_for_component(node.name, res_to_save, is_final=node.is_leaf())
self.pbar.update(1)

def add_result_for_component(
self, name: str, result: dict[str, Any] | None, is_final: bool = False
Expand Down Expand Up @@ -471,12 +473,14 @@ def validate_inputs_config(self, data: dict[str, Any]) -> None:
task.validate_inputs_config(data)

async def run(self, data: dict[str, Any]) -> dict[str, Any]:
logging.info("Starting pipeline")
logger.debug("Starting pipeline")
self.pbar = tqdm(total=len(self._nodes), desc="Creating knowledge graph")
start_time = default_timer()
self.validate_inputs_config(data)
self.reinitialize()
orchestrator = Orchestrator(self)
await orchestrator.run(data)
end_time = default_timer()
logging.info(f"Pipeline finished in {end_time - start_time}s")
self.pbar.close()
logger.debug(f"Pipeline finished in {end_time - start_time}s")
return self._final_results.all()

0 comments on commit f162016

Please sign in to comment.