Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for returning intermediate outputs of pipeline components #7504

Merged
merged 6 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def _validate_input(self, data: Dict[str, Any]):

# TODO: We're ignoring these linting rules for the time being, after we properly optimize this function we'll remove the noqa
def run( # noqa: C901, PLR0912, PLR0915 pylint: disable=too-many-branches
self, data: Dict[str, Any], debug: bool = False
self, data: Dict[str, Any], debug: bool = False, include_outputs_from: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""
Runs the pipeline with given input data.
Expand All @@ -623,8 +623,16 @@ def run( # noqa: C901, PLR0912, PLR0915 pylint: disable=too-many-branches
and its value is a dictionary of that component's input parameters.
:param debug:
Set to True to collect and return debug information.
:param include_outputs_from:
Set of component names whose individual outputs are to be
included in the pipeline's output. For components that are
invoked multiple times (in a loop), only the last-produced
output is included.
:returns:
A dictionary containing the pipeline's output.
A dictionary where each entry corresponds to a component name
and its output. If `include_outputs_from` is `None`, this dictionary
will only contain the outputs of leaf components, i.e., components
without outgoing connections.

:raises PipelineRuntimeError:
If a component fails or returns unexpected output.
Expand Down Expand Up @@ -756,6 +764,8 @@ def run(self, word: str):
# The waiting_for_input list is used to keep track of components that are waiting for input.
waiting_for_input: List[Tuple[str, Component]] = []

include_outputs_from = set() if include_outputs_from is None else include_outputs_from

with tracing.tracer.trace(
"haystack.pipeline.run",
tags={
Expand All @@ -765,7 +775,11 @@ def run(self, word: str):
},
):
# This is what we'll return at the end
final_outputs = {}
final_outputs: Dict[Any, Any] = {}

# Cache for extra outputs, if enabled.
extra_outputs: Dict[Any, Any] = {}

while len(to_run) > 0:
name, comp = to_run.pop(0)

Expand Down Expand Up @@ -826,6 +840,11 @@ def run(self, word: str):
span.set_tags(tags={"haystack.component.visits": self.graph.nodes[name]["visits"]})
span.set_content_tag("haystack.component.output", res)

if name in include_outputs_from:
# Deepcopy the outputs to prevent downstream nodes from modifying them
# We don't care about loops - Always store the last output.
extra_outputs[name] = deepcopy(res)

# Reset the waiting for input previous states, we managed to run a component
before_last_waiting_for_input = None
last_waiting_for_input = None
Expand Down Expand Up @@ -988,6 +1007,11 @@ def run(self, word: str):
waiting_for_input.remove((name, comp))
to_run.append((name, comp))

if len(include_outputs_from) > 0:
for name, output in extra_outputs.items():
if name not in final_outputs:
final_outputs[name] = output

return final_outputs

def _prepare_component_input_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
`pipeline.run` accepts a set of component names whose intermediate outputs are returned in the final
pipeline output dictionary.
61 changes: 61 additions & 0 deletions test/core/pipeline/test_intermediate_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import logging

from haystack.components.others import Multiplexer
from haystack.core.pipeline import Pipeline
from haystack.testing.sample_components import Accumulate, AddFixedValue, Double, Threshold

logging.basicConfig(level=logging.DEBUG)


def test_pipeline_intermediate_outputs():
pipeline = Pipeline()
pipeline.add_component("first_addition", AddFixedValue(add=2))
pipeline.add_component("second_addition", AddFixedValue())
pipeline.add_component("double", Double())
pipeline.connect("first_addition", "double")
pipeline.connect("double", "second_addition")

results = pipeline.run(
{"first_addition": {"value": 1}}, include_outputs_from={"first_addition", "second_addition", "double"}
)
assert results == {"second_addition": {"result": 7}, "first_addition": {"result": 3}, "double": {"value": 6}}

results = pipeline.run({"first_addition": {"value": 1}}, include_outputs_from={"double"})
assert results == {"second_addition": {"result": 7}, "double": {"value": 6}}


def test_pipeline_with_loops_intermediate_outputs():
accumulator = Accumulate()

pipeline = Pipeline(max_loops_allowed=10)
pipeline.add_component("add_one", AddFixedValue(add=1))
pipeline.add_component("multiplexer", Multiplexer(type_=int))
pipeline.add_component("below_10", Threshold(threshold=10))
pipeline.add_component("below_5", Threshold(threshold=5))
pipeline.add_component("add_three", AddFixedValue(add=3))
pipeline.add_component("accumulator", accumulator)
pipeline.add_component("add_two", AddFixedValue(add=2))

pipeline.connect("add_one.result", "multiplexer")
pipeline.connect("multiplexer.value", "below_10.value")
pipeline.connect("below_10.below", "accumulator.value")
pipeline.connect("accumulator.value", "below_5.value")
pipeline.connect("below_5.above", "add_three.value")
pipeline.connect("below_5.below", "multiplexer")
pipeline.connect("add_three.result", "multiplexer")
pipeline.connect("below_10.above", "add_two.value")

results = pipeline.run(
{"add_one": {"value": 3}},
include_outputs_from={"add_two", "add_one", "multiplexer", "below_10", "accumulator", "below_5", "add_three"},
)

assert results == {
"add_two": {"result": 13},
"add_one": {"result": 4},
"multiplexer": {"value": 11},
"below_10": {"above": 11},
"accumulator": {"value": 8},
"below_5": {"above": 8},
"add_three": {"result": 11},
}