Skip to content

Commit

Permalink
added Select object for transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
zilto authored and zilto committed Sep 9, 2024
1 parent e380a1a commit 398f0aa
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 13 deletions.
3 changes: 2 additions & 1 deletion burr/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from burr.core.action import Action, Condition, Result, action, default, expr, when
from burr.core.action import Action, Condition, Result, Select, action, default, expr, when
from burr.core.application import (
Application,
ApplicationBuilder,
Expand All @@ -18,6 +18,7 @@
"default",
"expr",
"Result",
"Select",
"State",
"when",
]
45 changes: 45 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
List,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -378,6 +379,50 @@ def __invert__(self):
# exists = Condition.exists


# TODO type `resolver` to prevent user-facing type-mismatch
# e.g., a user provided `def foo(state: State, actions: list)`
# would be too restrictive for a `Sequence` type
class Select(Function):
def __init__(
self,
keys: List[str],
resolver: Callable[[State, Sequence[Action]], str],
name: str = None,
):
self._keys = keys
self._resolver = resolver
self._name = name
# TODO add a `default` kwarg;
# could an Action, action_name: str, or action_idx: int
# `default` value could be returned if `_resolver` returns None

@property
def name(self) -> str:
return self._name

@property
def reads(self) -> list[str]:
return self._keys

@property
def resolver(self) -> Callable[[State, Sequence[Action]], str]:
return self._resolver

def __repr__(self) -> str:
return f"select: {self._name}"

def _validate(self, state: State):
missing_keys = set(self._keys) - set(state.keys())
if missing_keys:
raise ValueError(
f"Missing keys in state required by condition: {self} {', '.join(missing_keys)}"
)

def run(self, state: State, possible_actions: Sequence[Action]) -> str:
self._validate(state)
return self._resolver(state, possible_actions)


class Result(Action):
def __init__(self, *fields: str):
"""Represents a result action. This is purely a convenience class to
Expand Down
5 changes: 4 additions & 1 deletion burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Expand All @@ -33,6 +34,7 @@
Condition,
Function,
Reducer,
Select,
SingleStepAction,
SingleStepStreamingAction,
StreamingAction,
Expand Down Expand Up @@ -2005,7 +2007,8 @@ def with_actions(
def with_transitions(
self,
*transitions: Union[
Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition]
Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]]],
Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]], Union[Condition, Select]],
],
) -> "ApplicationBuilder":
"""Adds transitions to the application. Transitions are specified as tuples of either:
Expand Down
41 changes: 30 additions & 11 deletions burr/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import inspect
import logging
import pathlib
from typing import Any, Callable, List, Literal, Optional, Set, Tuple, Union
from typing import Any, Callable, List, Literal, Optional, Sequence, Set, Tuple, Union

from burr import telemetry
from burr.core.action import Action, Condition, create_action, default
from burr.core.action import Action, Condition, Select, create_action, default
from burr.core.state import State
from burr.core.validation import BASE_ERROR_MESSAGE, assert_set

Expand Down Expand Up @@ -118,6 +118,13 @@ def get_next_node(
return self._action_map[entrypoint]
possibilities = self._adjacency_map[prior_step]
for next_action, condition in possibilities:
# When `Select` is used, all possibilities have the same `condition` attached.
# Hitting a `Select` will necessarily exit the for loop
if isinstance(condition, Select):
possible_actions = [self._action_map[p[0]] for p in possibilities]
selected_action = condition.run(state, possible_actions)
return self._action_map[selected_action]

if condition.run(state)[Condition.KEY]:
return self._action_map[next_action]
return None
Expand Down Expand Up @@ -235,7 +242,7 @@ class GraphBuilder:

def __init__(self):
"""Initializes the graph builder."""
self.transitions: Optional[List[Tuple[str, str, Condition]]] = None
self.transitions: Optional[List[Tuple[str, str, Union[Condition, Select]]]] = None
self.actions: Optional[List[Action]] = None

def with_actions(
Expand Down Expand Up @@ -269,7 +276,8 @@ def with_actions(
def with_transitions(
self,
*transitions: Union[
Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition]
Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]]],
Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]], Union[Condition, Select]],
],
) -> "GraphBuilder":
"""Adds transitions to the graph. Transitions are specified as tuples of either:
Expand All @@ -291,14 +299,25 @@ def with_transitions(
condition = conditions[0]
else:
condition = default
if not isinstance(from_, list):
# check required because issubclass(str, Sequence) == True
if isinstance(from_, Sequence) and not isinstance(from_, str):
from_ = [*from_]
else:
from_ = [from_]
for action in from_:
if not isinstance(action, str):
raise ValueError(f"Transition source must be a string, not {action}")
if not isinstance(to_, str):
raise ValueError(f"Transition target must be a string, not {to_}")
self.transitions.append((action, to_, condition))
if isinstance(to_, Sequence) and not isinstance(to_, str):
if not isinstance(condition, Select):
raise ValueError(
"Transition with multiple targets require a `Select` condition."
)
else:
to_ = [to_]
for source in from_:
for target in to_:
if not isinstance(source, str):
raise ValueError(f"Transition source must be a string, not {source}")
if not isinstance(target, str):
raise ValueError(f"Transition target must be a string, not {to_}")
self.transitions.append((source, target, condition))
return self

def build(self) -> Graph:
Expand Down
34 changes: 34 additions & 0 deletions tests/core/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Function,
Input,
Result,
Select,
SingleStepAction,
SingleStepStreamingAction,
StreamingAction,
Expand Down Expand Up @@ -200,6 +201,39 @@ def test_condition_lmda():
# assert cond.run(State({"foo" : "bar"})) == {Condition.KEY: False}


def test_select_constant():
select = Select([], resolver=lambda *args: "foo")
selected_action = select.run(State(), [])

assert selected_action == "foo"


def test_select_determistic():
@action(reads=[], writes=[])
def bar(state):
return state

@action(reads=[], writes=[])
def baz(state):
return state

def length_resolver(state: State, actions: list[Action]) -> str:
foo = state["foo"]
action_idx = len(foo) % len(actions)
return actions[action_idx].name

foo1 = "len=3" # % 2 = 1
foo2 = "len_is_8" # % 2 = 0
actions = [create_action(bar, "bar"), create_action(baz, "baz")]
select = Select(["foo"], resolver=length_resolver)

selected_1 = select.run(State({"foo": foo1}), possible_actions=actions)
assert selected_1 == actions[len(foo1) % len(actions)].name

selected_2 = select.run(State({"foo": foo2}), possible_actions=actions)
assert selected_2 == actions[len(foo2) % len(actions)].name


def test_result():
result = Result("foo", "bar")
assert result.run(State({"foo": "baz", "bar": "qux", "baz": "quux"})) == {
Expand Down

0 comments on commit 398f0aa

Please sign in to comment.