Skip to content

Commit

Permalink
Raise error on wrong parameter mapping during pipeline definition (ne…
Browse files Browse the repository at this point in the history
…o4j#124)

* Raise error when param is mapped twice or is not a valid input

* Pipeline invalidation method, to reinitialize param mapping and missing params dict
  • Loading branch information
stellasia authored Sep 17, 2024
1 parent 75138c1 commit b90003b
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

### Changed
- Pipeline run method now return a PipelineResult object.
- Improved parameter validation for pipelines (#124). Pipeline now raise an error before a run starts if:
- the same parameter is mapped twice
- or a parameter is defined in the mapping but is not a valid component input


## 0.5.0
Expand Down
19 changes: 17 additions & 2 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def add_component(self, component: Component, name: str) -> None:
task = TaskPipelineNode(name, component)
self.add_node(task)
# invalidate the pipeline if it was already validated
self.is_validated = False
self.invalidate()

def set_component(self, name: str, component: Component) -> None:
"""Replace a component with another. If 'name' is not yet in the pipeline,
Expand All @@ -439,7 +439,7 @@ def set_component(self, name: str, component: Component) -> None:
task = TaskPipelineNode(name, component)
self.set_node(task)
# invalidate the pipeline if it was already validated
self.is_validated = False
self.invalidate()

def connect(
self,
Expand Down Expand Up @@ -475,7 +475,12 @@ def connect(
if self.is_cyclic():
raise PipelineDefinitionError("Cyclic graph are not allowed")
# invalidate the pipeline if it was already validated
self.invalidate()

def invalidate(self) -> None:
self.is_validated = False
self.param_mapping = defaultdict(dict)
self.missing_inputs = defaultdict()

def validate_parameter_mapping(self) -> None:
"""Go through the graph and make sure parameter mapping is valid
Expand Down Expand Up @@ -520,6 +525,8 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
Considering the naming {param => target (component, [output_parameter]) },
the mapping is valid if:
- 'param' is a valid input for task
- 'param' has not already been mapped
- The target component exists in the pipeline and, if specified, the
target output parameter is a valid field in the target component's
result model.
Expand All @@ -543,6 +550,14 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
# check that the previous component is actually returning
# the mapped parameter
for param, path in edge_inputs.items():
if param in self.param_mapping[task.name]:
raise PipelineDefinitionError(
f"Parameter '{param}' already mapped to {self.param_mapping[task.name][param]}"
)
if param not in task.component.component_inputs:
raise PipelineDefinitionError(
f"Parameter '{param}' is not a valid input for component '{task.name}' of type '{task.component.__class__.__name__}'"
)
try:
source_component_name, param_name = path.split(".")
except ValueError:
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/experimental/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,33 @@ def test_pipeline_parameter_validation_one_component_all_good() -> None:
assert is_valid is True


def test_pipeline_invalidate() -> None:
pipe = Pipeline()
pipe.is_validated = True
pipe.param_mapping = {"a": {"key": {"component": "component", "param": "param"}}}
pipe.missing_inputs = {"a": ["other_key"]}
pipe.invalidate()
assert pipe.is_validated is False
assert len(pipe.param_mapping) == 0
assert len(pipe.missing_inputs) == 0


def test_pipeline_parameter_validation_called_twice() -> None:
pipe = Pipeline()
component_a = ComponentPassThrough()
component_b = ComponentPassThrough()
pipe.add_component(component_a, "a")
pipe.add_component(component_b, "b")
pipe.connect("a", "b", {"value": "a.result"})
is_valid = pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b"))
assert is_valid is True
with pytest.raises(PipelineDefinitionError):
pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b"))
pipe.invalidate()
is_valid = pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b"))
assert is_valid is True


def test_pipeline_parameter_validation_one_component_input_param_missing() -> None:
pipe = Pipeline()
component_a = ComponentPassThrough()
Expand All @@ -105,6 +132,39 @@ def test_pipeline_parameter_validation_one_component_input_param_missing() -> No
assert pipe.missing_inputs["a"] == ["value"]


def test_pipeline_parameter_validation_param_mapped_twice() -> None:
pipe = Pipeline()
component_a = ComponentPassThrough()
component_b = ComponentPassThrough()
component_c = ComponentPassThrough()
pipe.add_component(component_a, "a")
pipe.add_component(component_b, "b")
pipe.add_component(component_c, "c")
pipe.connect("a", "c", {"value": "a.result"})
pipe.connect("b", "c", {"value": "b.result"})
with pytest.raises(PipelineDefinitionError) as excinfo:
pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("c"))
assert (
"Parameter 'value' already mapped to {'component': 'a', 'param': 'result'}"
in str(excinfo)
)


def test_pipeline_parameter_validation_unexpected_input() -> None:
pipe = Pipeline()
component_a = ComponentPassThrough()
component_b = ComponentPassThrough()
pipe.add_component(component_a, "a")
pipe.add_component(component_b, "b")
pipe.connect("a", "b", {"unexpected_input_name": "a.result"})
with pytest.raises(PipelineDefinitionError) as excinfo:
pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b"))
assert (
"Parameter 'unexpected_input_name' is not a valid input for component 'b' of type 'ComponentPassThrough'"
in str(excinfo)
)


def test_pipeline_parameter_validation_connected_components_input() -> None:
"""Parameter for component 'b' comes from the pipeline inputs"""
pipe = Pipeline()
Expand Down

0 comments on commit b90003b

Please sign in to comment.