diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 498daa5c90..4edd65e3d9 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -613,7 +613,7 @@ def _validate_input(self, data: Dict[str, Any]): # TODO: We're ignoring these linting rules for the time being, after we properly optimize this function we'll remove the noqa def run( # noqa: C901, PLR0912, PLR0915 pylint: disable=too-many-branches - self, data: Dict[str, Any], debug: bool = False + self, data: Dict[str, Any], debug: bool = False, include_outputs_from: Optional[Set[str]] = None ) -> Dict[str, Any]: """ Runs the pipeline with given input data. @@ -623,8 +623,16 @@ def run( # noqa: C901, PLR0912, PLR0915 pylint: disable=too-many-branches and its value is a dictionary of that component's input parameters. :param debug: Set to True to collect and return debug information. + :param include_outputs_from: + Set of component names whose individual outputs are to be + included in the pipeline's output. For components that are + invoked multiple times (in a loop), only the last-produced + output is included. :returns: - A dictionary containing the pipeline's output. + A dictionary where each entry corresponds to a component name + and its output. If `include_outputs_from` is `None`, this dictionary + will only contain the outputs of leaf components, i.e., components + without outgoing connections. :raises PipelineRuntimeError: If a component fails or returns unexpected output. @@ -756,6 +764,8 @@ def run(self, word: str): # The waiting_for_input list is used to keep track of components that are waiting for input. waiting_for_input: List[Tuple[str, Component]] = [] + include_outputs_from = set() if include_outputs_from is None else include_outputs_from + with tracing.tracer.trace( "haystack.pipeline.run", tags={ @@ -765,7 +775,11 @@ def run(self, word: str): }, ): # This is what we'll return at the end - final_outputs = {} + final_outputs: Dict[Any, Any] = {} + + # Cache for extra outputs, if enabled. + extra_outputs: Dict[Any, Any] = {} + while len(to_run) > 0: name, comp = to_run.pop(0) @@ -826,6 +840,11 @@ def run(self, word: str): span.set_tags(tags={"haystack.component.visits": self.graph.nodes[name]["visits"]}) span.set_content_tag("haystack.component.output", res) + if name in include_outputs_from: + # Deepcopy the outputs to prevent downstream nodes from modifying them + # We don't care about loops - Always store the last output. + extra_outputs[name] = deepcopy(res) + # Reset the waiting for input previous states, we managed to run a component before_last_waiting_for_input = None last_waiting_for_input = None @@ -988,6 +1007,11 @@ def run(self, word: str): waiting_for_input.remove((name, comp)) to_run.append((name, comp)) + if len(include_outputs_from) > 0: + for name, output in extra_outputs.items(): + if name not in final_outputs: + final_outputs[name] = output + return final_outputs def _prepare_component_input_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]: diff --git a/releasenotes/notes/pipeline-intermediate-outputs-7cb8e71f79532ec1.yaml b/releasenotes/notes/pipeline-intermediate-outputs-7cb8e71f79532ec1.yaml new file mode 100644 index 0000000000..38a0a9ab8f --- /dev/null +++ b/releasenotes/notes/pipeline-intermediate-outputs-7cb8e71f79532ec1.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + `pipeline.run` accepts a set of component names whose intermediate outputs are returned in the final + pipeline output dictionary. diff --git a/test/core/pipeline/test_intermediate_outputs.py b/test/core/pipeline/test_intermediate_outputs.py new file mode 100644 index 0000000000..c4624233e4 --- /dev/null +++ b/test/core/pipeline/test_intermediate_outputs.py @@ -0,0 +1,61 @@ +import logging + +from haystack.components.others import Multiplexer +from haystack.core.pipeline import Pipeline +from haystack.testing.sample_components import Accumulate, AddFixedValue, Double, Threshold + +logging.basicConfig(level=logging.DEBUG) + + +def test_pipeline_intermediate_outputs(): + pipeline = Pipeline() + pipeline.add_component("first_addition", AddFixedValue(add=2)) + pipeline.add_component("second_addition", AddFixedValue()) + pipeline.add_component("double", Double()) + pipeline.connect("first_addition", "double") + pipeline.connect("double", "second_addition") + + results = pipeline.run( + {"first_addition": {"value": 1}}, include_outputs_from={"first_addition", "second_addition", "double"} + ) + assert results == {"second_addition": {"result": 7}, "first_addition": {"result": 3}, "double": {"value": 6}} + + results = pipeline.run({"first_addition": {"value": 1}}, include_outputs_from={"double"}) + assert results == {"second_addition": {"result": 7}, "double": {"value": 6}} + + +def test_pipeline_with_loops_intermediate_outputs(): + accumulator = Accumulate() + + pipeline = Pipeline(max_loops_allowed=10) + pipeline.add_component("add_one", AddFixedValue(add=1)) + pipeline.add_component("multiplexer", Multiplexer(type_=int)) + pipeline.add_component("below_10", Threshold(threshold=10)) + pipeline.add_component("below_5", Threshold(threshold=5)) + pipeline.add_component("add_three", AddFixedValue(add=3)) + pipeline.add_component("accumulator", accumulator) + pipeline.add_component("add_two", AddFixedValue(add=2)) + + pipeline.connect("add_one.result", "multiplexer") + pipeline.connect("multiplexer.value", "below_10.value") + pipeline.connect("below_10.below", "accumulator.value") + pipeline.connect("accumulator.value", "below_5.value") + pipeline.connect("below_5.above", "add_three.value") + pipeline.connect("below_5.below", "multiplexer") + pipeline.connect("add_three.result", "multiplexer") + pipeline.connect("below_10.above", "add_two.value") + + results = pipeline.run( + {"add_one": {"value": 3}}, + include_outputs_from={"add_two", "add_one", "multiplexer", "below_10", "accumulator", "below_5", "add_three"}, + ) + + assert results == { + "add_two": {"result": 13}, + "add_one": {"result": 4}, + "multiplexer": {"value": 11}, + "below_10": {"above": 11}, + "accumulator": {"value": 8}, + "below_5": {"above": 8}, + "add_three": {"result": 11}, + }