From 45fe4ee8ee1044d6681f6fffa58efac5db6644f3 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 22 Jan 2025 10:24:36 +0100 Subject: [PATCH] Move types to types.py --- examples/README.md | 1 + .../pipeline/pipeline_with_notifications.py | 2 +- .../experimental/pipeline/notification.py | 41 +----------------- .../experimental/pipeline/pipeline.py | 2 +- .../experimental/pipeline/types.py | 42 ++++++++++++++++++- .../experimental/pipeline/test_pipeline.py | 4 +- 6 files changed, 48 insertions(+), 44 deletions(-) diff --git a/examples/README.md b/examples/README.md index d3e4ef54..7a53aaec 100644 --- a/examples/README.md +++ b/examples/README.md @@ -101,6 +101,7 @@ are listed in [the last section of this file](#customize). - [Process multiple documents](./customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py) - [Export lexical graph creation into another pipeline](./customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py) - [Build pipeline from config file](customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py) +- [Add event listener to get notification about Pipeline progress](./customize/build_graph/pipeline/pipeline_with_notifications.py) #### Components diff --git a/examples/customize/build_graph/pipeline/pipeline_with_notifications.py b/examples/customize/build_graph/pipeline/pipeline_with_notifications.py index 441fa696..db30ddd6 100644 --- a/examples/customize/build_graph/pipeline/pipeline_with_notifications.py +++ b/examples/customize/build_graph/pipeline/pipeline_with_notifications.py @@ -14,8 +14,8 @@ FixedSizeSplitter, ) from neo4j_graphrag.experimental.pipeline import Pipeline -from neo4j_graphrag.experimental.pipeline.notification import Event from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult +from neo4j_graphrag.experimental.pipeline.types import Event logger = logging.getLogger(__name__) logging.basicConfig() diff --git a/src/neo4j_graphrag/experimental/pipeline/notification.py b/src/neo4j_graphrag/experimental/pipeline/notification.py index 9fdbba6f..4e8a4a56 100644 --- a/src/neo4j_graphrag/experimental/pipeline/notification.py +++ b/src/neo4j_graphrag/experimental/pipeline/notification.py @@ -15,48 +15,11 @@ from __future__ import annotations import datetime -import enum -from collections.abc import Awaitable -from typing import Any, Optional, Protocol +from typing import Any, Optional from pydantic import BaseModel -from neo4j_graphrag.experimental.pipeline.types import RunResult - - -class EventType(enum.Enum): - PIPELINE_STARTED = "PIPELINE_STARTED" - TASK_STARTED = "TASK_STARTED" - TASK_FINISHED = "TASK_FINISHED" - PIPELINE_FINISHED = "PIPELINE_FINISHED" - - @property - def is_pipeline_event(self) -> bool: - return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED] - - @property - def is_task_event(self) -> bool: - return self in [EventType.TASK_STARTED, EventType.TASK_FINISHED] - - -class Event(BaseModel): - event_type: EventType - run_id: str - timestamp: datetime.datetime - message: Optional[str] = None - payload: Optional[dict[str, Any]] = None - - -class PipelineEvent(Event): - pass - - -class TaskEvent(Event): - task_name: str - - -class EventCallbackProtocol(Protocol): - def __call__(self, event: Event) -> Awaitable[None]: ... +from neo4j_graphrag.experimental.pipeline.types import RunResult, EventCallbackProtocol, Event, PipelineEvent, TaskEvent, EventType class EventNotifier: diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index a522ccc0..77526785 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -33,7 +33,6 @@ from neo4j_graphrag.experimental.pipeline.exceptions import ( PipelineDefinitionError, ) -from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator from neo4j_graphrag.experimental.pipeline.pipeline_graph import ( PipelineEdge, @@ -44,6 +43,7 @@ from neo4j_graphrag.experimental.pipeline.types import ( ComponentDefinition, ConnectionDefinition, + EventCallbackProtocol, PipelineDefinition, RunResult, ) diff --git a/src/neo4j_graphrag/experimental/pipeline/types.py b/src/neo4j_graphrag/experimental/pipeline/types.py index a63eb7e6..bb9d88a8 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types.py +++ b/src/neo4j_graphrag/experimental/pipeline/types.py @@ -17,7 +17,8 @@ import datetime import enum from collections import defaultdict -from typing import Any, Optional, Union +from collections.abc import Awaitable +from typing import Any, Optional, Protocol, Union from pydantic import BaseModel, ConfigDict, Field @@ -62,6 +63,45 @@ class RunResult(BaseModel): ) +class EventType(enum.Enum): + PIPELINE_STARTED = "PIPELINE_STARTED" + TASK_STARTED = "TASK_STARTED" + TASK_FINISHED = "TASK_FINISHED" + PIPELINE_FINISHED = "PIPELINE_FINISHED" + + @property + def is_pipeline_event(self) -> bool: + return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED] + + @property + def is_task_event(self) -> bool: + return self in [EventType.TASK_STARTED, EventType.TASK_FINISHED] + + +class Event(BaseModel): + event_type: EventType + run_id: str + """Pipeline unique run_id, same as the one returned in PipelineResult after pipeline.run""" + timestamp: datetime.datetime + message: Optional[str] = None + """Optional information about the status""" + payload: Optional[dict[str, Any]] = None + """Input or output data depending on the type of event""" + + +class PipelineEvent(Event): + pass + + +class TaskEvent(Event): + task_name: str + """Name of the task as defined in pipeline.add_component""" + + +class EventCallbackProtocol(Protocol): + def __call__(self, event: Event) -> Awaitable[None]: ... + + EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] RelationInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] """Types derived from the SchemaEntity and SchemaRelation types, diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py index eaec0d64..c149872f 100644 --- a/tests/unit/experimental/pipeline/test_pipeline.py +++ b/tests/unit/experimental/pipeline/test_pipeline.py @@ -24,13 +24,13 @@ import pytest from neo4j_graphrag.experimental.pipeline import Component, Pipeline from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError -from neo4j_graphrag.experimental.pipeline.notification import ( +from neo4j_graphrag.experimental.pipeline.types import ( EventCallbackProtocol, EventType, PipelineEvent, + RunResult, TaskEvent, ) -from neo4j_graphrag.experimental.pipeline.types import RunResult from .components import ( ComponentAdd,