diff --git a/CHANGELOG.md b/CHANGELOG.md index 15d750d4..655d7608 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ ### Changed - Pipeline run method now return a PipelineResult object. +- Improved parameter validation for pipelines (#124). Pipeline now raise an error before a run starts if: + - the same parameter is mapped twice + - or a parameter is defined in the mapping but is not a valid component input ## 0.5.0 diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 31088c8b..97b4b53f 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -430,7 +430,7 @@ def add_component(self, component: Component, name: str) -> None: task = TaskPipelineNode(name, component) self.add_node(task) # invalidate the pipeline if it was already validated - self.is_validated = False + self.invalidate() def set_component(self, name: str, component: Component) -> None: """Replace a component with another. If 'name' is not yet in the pipeline, @@ -439,7 +439,7 @@ def set_component(self, name: str, component: Component) -> None: task = TaskPipelineNode(name, component) self.set_node(task) # invalidate the pipeline if it was already validated - self.is_validated = False + self.invalidate() def connect( self, @@ -475,7 +475,12 @@ def connect( if self.is_cyclic(): raise PipelineDefinitionError("Cyclic graph are not allowed") # invalidate the pipeline if it was already validated + self.invalidate() + + def invalidate(self) -> None: self.is_validated = False + self.param_mapping = defaultdict(dict) + self.missing_inputs = defaultdict() def validate_parameter_mapping(self) -> None: """Go through the graph and make sure parameter mapping is valid @@ -520,6 +525,8 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: Considering the naming {param => target (component, [output_parameter]) }, the mapping is valid if: + - 'param' is a valid input for task + - 'param' has not already been mapped - The target component exists in the pipeline and, if specified, the target output parameter is a valid field in the target component's result model. @@ -543,6 +550,14 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: # check that the previous component is actually returning # the mapped parameter for param, path in edge_inputs.items(): + if param in self.param_mapping[task.name]: + raise PipelineDefinitionError( + f"Parameter '{param}' already mapped to {self.param_mapping[task.name][param]}" + ) + if param not in task.component.component_inputs: + raise PipelineDefinitionError( + f"Parameter '{param}' is not a valid input for component '{task.name}' of type '{task.component.__class__.__name__}'" + ) try: source_component_name, param_name = path.split(".") except ValueError: diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py index 89602503..97c8b481 100644 --- a/tests/unit/experimental/pipeline/test_pipeline.py +++ b/tests/unit/experimental/pipeline/test_pipeline.py @@ -97,6 +97,33 @@ def test_pipeline_parameter_validation_one_component_all_good() -> None: assert is_valid is True +def test_pipeline_invalidate() -> None: + pipe = Pipeline() + pipe.is_validated = True + pipe.param_mapping = {"a": {"key": {"component": "component", "param": "param"}}} + pipe.missing_inputs = {"a": ["other_key"]} + pipe.invalidate() + assert pipe.is_validated is False + assert len(pipe.param_mapping) == 0 + assert len(pipe.missing_inputs) == 0 + + +def test_pipeline_parameter_validation_called_twice() -> None: + pipe = Pipeline() + component_a = ComponentPassThrough() + component_b = ComponentPassThrough() + pipe.add_component(component_a, "a") + pipe.add_component(component_b, "b") + pipe.connect("a", "b", {"value": "a.result"}) + is_valid = pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b")) + assert is_valid is True + with pytest.raises(PipelineDefinitionError): + pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b")) + pipe.invalidate() + is_valid = pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b")) + assert is_valid is True + + def test_pipeline_parameter_validation_one_component_input_param_missing() -> None: pipe = Pipeline() component_a = ComponentPassThrough() @@ -105,6 +132,39 @@ def test_pipeline_parameter_validation_one_component_input_param_missing() -> No assert pipe.missing_inputs["a"] == ["value"] +def test_pipeline_parameter_validation_param_mapped_twice() -> None: + pipe = Pipeline() + component_a = ComponentPassThrough() + component_b = ComponentPassThrough() + component_c = ComponentPassThrough() + pipe.add_component(component_a, "a") + pipe.add_component(component_b, "b") + pipe.add_component(component_c, "c") + pipe.connect("a", "c", {"value": "a.result"}) + pipe.connect("b", "c", {"value": "b.result"}) + with pytest.raises(PipelineDefinitionError) as excinfo: + pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("c")) + assert ( + "Parameter 'value' already mapped to {'component': 'a', 'param': 'result'}" + in str(excinfo) + ) + + +def test_pipeline_parameter_validation_unexpected_input() -> None: + pipe = Pipeline() + component_a = ComponentPassThrough() + component_b = ComponentPassThrough() + pipe.add_component(component_a, "a") + pipe.add_component(component_b, "b") + pipe.connect("a", "b", {"unexpected_input_name": "a.result"}) + with pytest.raises(PipelineDefinitionError) as excinfo: + pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b")) + assert ( + "Parameter 'unexpected_input_name' is not a valid input for component 'b' of type 'ComponentPassThrough'" + in str(excinfo) + ) + + def test_pipeline_parameter_validation_connected_components_input() -> None: """Parameter for component 'b' comes from the pipeline inputs""" pipe = Pipeline()