-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
226 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from haystack.preview.components.routers.file_type_router import FileTypeRouter | ||
from haystack.preview.components.routers.metadata_router import MetadataRouter | ||
from haystack.preview.components.routers.router import Router | ||
from haystack.preview.components.routers.conditional_router import ConditionalRouter | ||
|
||
|
||
__all__ = ["FileTypeRouter", "MetadataRouter", "Router"] | ||
__all__ = ["FileTypeRouter", "MetadataRouter", "Router", "ConditionalRouter"] |
128 changes: 128 additions & 0 deletions
128
haystack/preview/components/routers/conditional_router.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import logging | ||
from typing import List, Dict, Any, Set | ||
|
||
from jinja2.nativetypes import NativeEnvironment | ||
from jinja2 import meta | ||
|
||
from haystack.preview import component | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class NoRouteSelectedException(Exception): | ||
"""Exception raised when no route is selected in Router.""" | ||
|
||
|
||
class RouteConditionException(Exception): | ||
"""Exception raised when there is an error parsing or evaluating the condition expression in Router.""" | ||
|
||
|
||
@component | ||
class ConditionalRouter: | ||
""" | ||
The Router class orchestrates the flow of data by evaluating specified route conditions | ||
to determine the appropriate route among a set of provided route alternatives. | ||
To use a Router in Haystack 2.x pipelines we first define a list called routes, where each element | ||
is a dictionary representing a route. | ||
Each route dictionary contains three keys: condition, output, and output_type. | ||
The condition is a string containing a Jinja2 boolean expression that will be evaluated to determine if | ||
this route should be selected. The output is a string specifying the name of the output slot for | ||
this route, and output_type is a string representation of the expected type of the output. | ||
Example: | ||
In this example, we create a `Router` instance with two routes. | ||
The first route will be selected if the number of streams is less than 2, | ||
and will output the `query` variable. The second route will be selected | ||
if the number of streams is 2 or more, and will output the `streams` variable. | ||
These variables need to be provided in the pipeline `run()` method. | ||
```python | ||
routes = [ | ||
{"condition": "{{streams|length < 2}}", "output": "query", "output_type": "str"}, | ||
{"condition": "{{streams|length < 2}}", "output": "streams", "output_type": "List[ByteStream]"} | ||
] | ||
router = Router(routes=routes) | ||
``` | ||
""" | ||
|
||
def __init__(self, routes: List[Dict]): | ||
""" | ||
Initialize the Router with a list of routes and the routing variables. | ||
:param routes: A list of dictionaries, each representing a route with a | ||
boolean condition expression (`condition`), an output slot (`output`), | ||
and the output type as a string representation (`output_type`). | ||
:param routing_variables: A list of additional pipeline variables that are | ||
used in the boolean condition expressions or as outputs of the router. | ||
These variables should be provided by either the pipeline `run()` | ||
method or by a previous component to the router in the pipeline. | ||
""" | ||
self._validate_routes(routes) | ||
self.routes: List[dict] = routes | ||
|
||
# Create a Jinja native environment to extract variables from the condition templates | ||
env = NativeEnvironment() | ||
|
||
# Inspect the routes to determine input and output types. | ||
|
||
input_names: Set[str] = set() # let's just store the name, type will always be Any | ||
output_types: Dict[str, str] = {} | ||
for route in routes: | ||
# Input types must include any variable that needs to be sent in output | ||
input_names.add(route["output"]) | ||
# Also add any additional variable that might be used within a "condition" expression. | ||
ast = env.parse(route["condition"]) | ||
input_names.update(meta.find_undeclared_variables(ast)) | ||
|
||
output_types.update({route["output"]: route["output_type"]}) | ||
|
||
component.set_input_types(self, **{var: Any for var in input_names}) | ||
component.set_output_types(self, **output_types) | ||
|
||
def run(self, **kwargs): | ||
""" | ||
Executes the routing logic by evaluating the specified boolean condition expressions | ||
for each route in the order they are listed. The method directs the flow | ||
of data to the output slot specified in the first route whose expression | ||
evaluates to True. If no route's expression evaluates to True, an exception | ||
is raised. | ||
:param kwargs: A dictionary containing the pipeline variables, which should | ||
include all variables used in the "condition" templates. | ||
:return: A dictionary containing the output slot and the corresponding result, | ||
based on the first route whose expression evaluates to True. | ||
:raises NoRouteSelectedException: If no route's expression evaluates to True. | ||
""" | ||
# Create a Jinja native environment evaluate the condition templates | ||
env = NativeEnvironment() | ||
|
||
for route in self.routes: | ||
try: | ||
t = env.from_string(route["condition"]) | ||
if t.render(**kwargs): | ||
output_slot = route["output"] | ||
return {output_slot: kwargs[output_slot]} | ||
except Exception as e: | ||
raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e | ||
|
||
raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}") | ||
|
||
def _validate_routes(self, routes: List[Dict]): | ||
for route in routes: | ||
try: | ||
keys = set(route.keys()) | ||
except AttributeError: | ||
raise ValueError(f"Route must be a dictionary, got: {route}") | ||
|
||
if not {"condition", "output", "output_type"}.issubset(keys): | ||
raise ValueError("Each route must contain 'condition', 'output', and 'output_type' keys.") |
96 changes: 96 additions & 0 deletions
96
test/preview/components/routers/test_conditional_router.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from unittest import mock | ||
from typing import List | ||
|
||
import pytest | ||
|
||
from haystack.preview.components.routers import ConditionalRouter | ||
from haystack.preview.components.routers.conditional_router import NoRouteSelectedException | ||
|
||
|
||
class TestRouter: | ||
@pytest.fixture | ||
def routes(self): | ||
return [ | ||
{"condition": "{{streams|length < 2}}", "output": "query", "output_type": str}, | ||
{"condition": "{{streams|length >= 2}}", "output": "streams", "output_type": List[int]}, | ||
] | ||
|
||
@pytest.fixture | ||
def router(self, routes): | ||
return ConditionalRouter(routes) | ||
|
||
@pytest.mark.unit | ||
def test_router_initialized(self, routes): | ||
router = ConditionalRouter(routes) | ||
|
||
assert router.routes == routes | ||
assert set(router.__canals_input__.keys()) == {"query", "streams"} | ||
assert set(router.__canals_output__.keys()) == {"query", "streams"} | ||
|
||
@pytest.mark.unit | ||
def test_router_evaluate_condition_expressions(self, router): | ||
# first route should be selected | ||
kwargs = {"streams": [1, 2, 3], "query": "test"} | ||
result = router.run(**kwargs) | ||
assert result == {"streams": [1, 2, 3]} | ||
|
||
# second route should be selected | ||
kwargs = {"streams": [1], "query": "test"} | ||
result = router.run(**kwargs) | ||
assert result == {"query": "test"} | ||
|
||
@pytest.mark.unit | ||
def test_complex_condition(self): | ||
routes = [ | ||
{ | ||
"condition": "{{messages[-1].metadata.finish_reason == 'function_call'}}", | ||
"output": "streams", | ||
"output_type": List[int], | ||
}, | ||
{"condition": "{{True}}", "output": "query", "output_type": str}, # catch-all condition | ||
] | ||
router = ConditionalRouter(routes) | ||
message = mock.MagicMock() | ||
message.metadata.finish_reason = "function_call" | ||
result = router.run(messages=[message], streams=[1, 2, 3], query="my query") | ||
assert result == {"streams": [1, 2, 3]} | ||
|
||
@pytest.mark.unit | ||
def test_router_no_route(self, router): | ||
# should raise an exception | ||
router = ConditionalRouter( | ||
[ | ||
{"condition": "{{streams|length < 2}}", "output": "query", "output_type": str}, | ||
{"condition": "{{streams|length >= 5}}", "output": "streams", "output_type": List[int]}, | ||
] | ||
) | ||
|
||
kwargs = {"streams": [1, 2, 3], "query": "test"} | ||
with pytest.raises(NoRouteSelectedException): | ||
router.run(**kwargs) | ||
|
||
@pytest.mark.unit | ||
def test_router_raises_value_error_if_route_not_dictionary(self): | ||
""" | ||
Router raises a ValueError if each route is not a dictionary | ||
""" | ||
routes = [ | ||
{"condition": "{{streams|length < 2}}", "output": "query", "output_type": str}, | ||
["{{streams|length >= 2}}", "streams", List[int]], | ||
] | ||
|
||
with pytest.raises(ValueError): | ||
ConditionalRouter(routes) | ||
|
||
@pytest.mark.unit | ||
def test_router_raises_value_error_if_route_missing_keys(self): | ||
""" | ||
Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys | ||
""" | ||
routes = [ | ||
{"condition": "{{streams|length < 2}}", "output": "query"}, | ||
{"condition": "{{streams|length < 2}}", "output_type": str}, | ||
] | ||
|
||
with pytest.raises(ValueError): | ||
ConditionalRouter(routes) |