Skip to content

Commit

Permalink
add alternative implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Oct 29, 2023
1 parent 4186bbf commit fbbb2ff
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 1 deletion.
3 changes: 2 additions & 1 deletion haystack/preview/components/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from haystack.preview.components.routers.file_type_router import FileTypeRouter
from haystack.preview.components.routers.metadata_router import MetadataRouter
from haystack.preview.components.routers.router import Router
from haystack.preview.components.routers.conditional_router import ConditionalRouter


__all__ = ["FileTypeRouter", "MetadataRouter", "Router"]
__all__ = ["FileTypeRouter", "MetadataRouter", "Router", "ConditionalRouter"]
128 changes: 128 additions & 0 deletions haystack/preview/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import logging
from typing import List, Dict, Any, Set

from jinja2.nativetypes import NativeEnvironment
from jinja2 import meta

from haystack.preview import component

logger = logging.getLogger(__name__)


class NoRouteSelectedException(Exception):
"""Exception raised when no route is selected in Router."""


class RouteConditionException(Exception):
"""Exception raised when there is an error parsing or evaluating the condition expression in Router."""


@component
class ConditionalRouter:
"""
The Router class orchestrates the flow of data by evaluating specified route conditions
to determine the appropriate route among a set of provided route alternatives.
To use a Router in Haystack 2.x pipelines we first define a list called routes, where each element
is a dictionary representing a route.
Each route dictionary contains three keys: condition, output, and output_type.
The condition is a string containing a Jinja2 boolean expression that will be evaluated to determine if
this route should be selected. The output is a string specifying the name of the output slot for
this route, and output_type is a string representation of the expected type of the output.
Example:
In this example, we create a `Router` instance with two routes.
The first route will be selected if the number of streams is less than 2,
and will output the `query` variable. The second route will be selected
if the number of streams is 2 or more, and will output the `streams` variable.
These variables need to be provided in the pipeline `run()` method.
```python
routes = [
{"condition": "{{streams|length < 2}}", "output": "query", "output_type": "str"},
{"condition": "{{streams|length < 2}}", "output": "streams", "output_type": "List[ByteStream]"}
]
router = Router(routes=routes)
```
"""

def __init__(self, routes: List[Dict]):
"""
Initialize the Router with a list of routes and the routing variables.
:param routes: A list of dictionaries, each representing a route with a
boolean condition expression (`condition`), an output slot (`output`),
and the output type as a string representation (`output_type`).
:param routing_variables: A list of additional pipeline variables that are
used in the boolean condition expressions or as outputs of the router.
These variables should be provided by either the pipeline `run()`
method or by a previous component to the router in the pipeline.
"""
self._validate_routes(routes)
self.routes: List[dict] = routes

# Create a Jinja native environment to extract variables from the condition templates
env = NativeEnvironment()

# Inspect the routes to determine input and output types.

input_names: Set[str] = set() # let's just store the name, type will always be Any
output_types: Dict[str, str] = {}
for route in routes:
# Input types must include any variable that needs to be sent in output
input_names.add(route["output"])
# Also add any additional variable that might be used within a "condition" expression.
ast = env.parse(route["condition"])
input_names.update(meta.find_undeclared_variables(ast))

output_types.update({route["output"]: route["output_type"]})

component.set_input_types(self, **{var: Any for var in input_names})
component.set_output_types(self, **output_types)

def run(self, **kwargs):
"""
Executes the routing logic by evaluating the specified boolean condition expressions
for each route in the order they are listed. The method directs the flow
of data to the output slot specified in the first route whose expression
evaluates to True. If no route's expression evaluates to True, an exception
is raised.
:param kwargs: A dictionary containing the pipeline variables, which should
include all variables used in the "condition" templates.
:return: A dictionary containing the output slot and the corresponding result,
based on the first route whose expression evaluates to True.
:raises NoRouteSelectedException: If no route's expression evaluates to True.
"""
# Create a Jinja native environment evaluate the condition templates
env = NativeEnvironment()

for route in self.routes:
try:
t = env.from_string(route["condition"])
if t.render(**kwargs):
output_slot = route["output"]
return {output_slot: kwargs[output_slot]}
except Exception as e:
raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e

raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")

def _validate_routes(self, routes: List[Dict]):
for route in routes:
try:
keys = set(route.keys())
except AttributeError:
raise ValueError(f"Route must be a dictionary, got: {route}")

if not {"condition", "output", "output_type"}.issubset(keys):
raise ValueError("Each route must contain 'condition', 'output', and 'output_type' keys.")
96 changes: 96 additions & 0 deletions test/preview/components/routers/test_conditional_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from unittest import mock
from typing import List

import pytest

from haystack.preview.components.routers import ConditionalRouter
from haystack.preview.components.routers.conditional_router import NoRouteSelectedException


class TestRouter:
@pytest.fixture
def routes(self):
return [
{"condition": "{{streams|length < 2}}", "output": "query", "output_type": str},
{"condition": "{{streams|length >= 2}}", "output": "streams", "output_type": List[int]},
]

@pytest.fixture
def router(self, routes):
return ConditionalRouter(routes)

@pytest.mark.unit
def test_router_initialized(self, routes):
router = ConditionalRouter(routes)

assert router.routes == routes
assert set(router.__canals_input__.keys()) == {"query", "streams"}
assert set(router.__canals_output__.keys()) == {"query", "streams"}

@pytest.mark.unit
def test_router_evaluate_condition_expressions(self, router):
# first route should be selected
kwargs = {"streams": [1, 2, 3], "query": "test"}
result = router.run(**kwargs)
assert result == {"streams": [1, 2, 3]}

# second route should be selected
kwargs = {"streams": [1], "query": "test"}
result = router.run(**kwargs)
assert result == {"query": "test"}

@pytest.mark.unit
def test_complex_condition(self):
routes = [
{
"condition": "{{messages[-1].metadata.finish_reason == 'function_call'}}",
"output": "streams",
"output_type": List[int],
},
{"condition": "{{True}}", "output": "query", "output_type": str}, # catch-all condition
]
router = ConditionalRouter(routes)
message = mock.MagicMock()
message.metadata.finish_reason = "function_call"
result = router.run(messages=[message], streams=[1, 2, 3], query="my query")
assert result == {"streams": [1, 2, 3]}

@pytest.mark.unit
def test_router_no_route(self, router):
# should raise an exception
router = ConditionalRouter(
[
{"condition": "{{streams|length < 2}}", "output": "query", "output_type": str},
{"condition": "{{streams|length >= 5}}", "output": "streams", "output_type": List[int]},
]
)

kwargs = {"streams": [1, 2, 3], "query": "test"}
with pytest.raises(NoRouteSelectedException):
router.run(**kwargs)

@pytest.mark.unit
def test_router_raises_value_error_if_route_not_dictionary(self):
"""
Router raises a ValueError if each route is not a dictionary
"""
routes = [
{"condition": "{{streams|length < 2}}", "output": "query", "output_type": str},
["{{streams|length >= 2}}", "streams", List[int]],
]

with pytest.raises(ValueError):
ConditionalRouter(routes)

@pytest.mark.unit
def test_router_raises_value_error_if_route_missing_keys(self):
"""
Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys
"""
routes = [
{"condition": "{{streams|length < 2}}", "output": "query"},
{"condition": "{{streams|length < 2}}", "output_type": str},
]

with pytest.raises(ValueError):
ConditionalRouter(routes)

0 comments on commit fbbb2ff

Please sign in to comment.