Skip to content

Commit

Permalink
Move types to types.py
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Jan 22, 2025
1 parent a51a4bc commit 45fe4ee
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 44 deletions.
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ are listed in [the last section of this file](#customize).
- [Process multiple documents](./customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py)
- [Export lexical graph creation into another pipeline](./customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py)
- [Build pipeline from config file](customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py)
- [Add event listener to get notification about Pipeline progress](./customize/build_graph/pipeline/pipeline_with_notifications.py)


#### Components
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
FixedSizeSplitter,
)
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.notification import Event
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import Event

logger = logging.getLogger(__name__)
logging.basicConfig()
Expand Down
41 changes: 2 additions & 39 deletions src/neo4j_graphrag/experimental/pipeline/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,48 +15,11 @@
from __future__ import annotations

import datetime
import enum
from collections.abc import Awaitable
from typing import Any, Optional, Protocol
from typing import Any, Optional

from pydantic import BaseModel

from neo4j_graphrag.experimental.pipeline.types import RunResult


class EventType(enum.Enum):
PIPELINE_STARTED = "PIPELINE_STARTED"
TASK_STARTED = "TASK_STARTED"
TASK_FINISHED = "TASK_FINISHED"
PIPELINE_FINISHED = "PIPELINE_FINISHED"

@property
def is_pipeline_event(self) -> bool:
return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED]

@property
def is_task_event(self) -> bool:
return self in [EventType.TASK_STARTED, EventType.TASK_FINISHED]


class Event(BaseModel):
event_type: EventType
run_id: str
timestamp: datetime.datetime
message: Optional[str] = None
payload: Optional[dict[str, Any]] = None


class PipelineEvent(Event):
pass


class TaskEvent(Event):
task_name: str


class EventCallbackProtocol(Protocol):
def __call__(self, event: Event) -> Awaitable[None]: ...
from neo4j_graphrag.experimental.pipeline.types import RunResult, EventCallbackProtocol, Event, PipelineEvent, TaskEvent, EventType


class EventNotifier:
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from neo4j_graphrag.experimental.pipeline.exceptions import (
PipelineDefinitionError,
)
from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol
from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator
from neo4j_graphrag.experimental.pipeline.pipeline_graph import (
PipelineEdge,
Expand All @@ -44,6 +43,7 @@
from neo4j_graphrag.experimental.pipeline.types import (
ComponentDefinition,
ConnectionDefinition,
EventCallbackProtocol,
PipelineDefinition,
RunResult,
)
Expand Down
42 changes: 41 additions & 1 deletion src/neo4j_graphrag/experimental/pipeline/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import datetime
import enum
from collections import defaultdict
from typing import Any, Optional, Union
from collections.abc import Awaitable
from typing import Any, Optional, Protocol, Union

from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -62,6 +63,45 @@ class RunResult(BaseModel):
)


class EventType(enum.Enum):
PIPELINE_STARTED = "PIPELINE_STARTED"
TASK_STARTED = "TASK_STARTED"
TASK_FINISHED = "TASK_FINISHED"
PIPELINE_FINISHED = "PIPELINE_FINISHED"

@property
def is_pipeline_event(self) -> bool:
return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED]

@property
def is_task_event(self) -> bool:
return self in [EventType.TASK_STARTED, EventType.TASK_FINISHED]


class Event(BaseModel):
event_type: EventType
run_id: str
"""Pipeline unique run_id, same as the one returned in PipelineResult after pipeline.run"""
timestamp: datetime.datetime
message: Optional[str] = None
"""Optional information about the status"""
payload: Optional[dict[str, Any]] = None
"""Input or output data depending on the type of event"""


class PipelineEvent(Event):
pass


class TaskEvent(Event):
task_name: str
"""Name of the task as defined in pipeline.add_component"""


class EventCallbackProtocol(Protocol):
def __call__(self, event: Event) -> Awaitable[None]: ...


EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]]
RelationInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]]
"""Types derived from the SchemaEntity and SchemaRelation types,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/experimental/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import pytest
from neo4j_graphrag.experimental.pipeline import Component, Pipeline
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.notification import (
from neo4j_graphrag.experimental.pipeline.types import (
EventCallbackProtocol,
EventType,
PipelineEvent,
RunResult,
TaskEvent,
)
from neo4j_graphrag.experimental.pipeline.types import RunResult

from .components import (
ComponentAdd,
Expand Down

0 comments on commit 45fe4ee

Please sign in to comment.