Skip to content

Commit

Permalink
Adds transaction logging to state to ensure we only remove the items
Browse files Browse the repository at this point in the history
functions write

Beforehand we had trouble with state manipulation. If we wanted to do a
default write for a single-step action, it would not have any way to
know what that wrote, versus what was in the state. This adds a simple
state log that has a single "flush" operation -- it just keeps track of
all operations since the last "flush" call, and returns those.

This way, all we have to do is flush before the operation, flush after,
and use the "after" results to filter writes so we know which default to
apply.

This also cleans up a bit of the immutability guarentees -- we were
doing a deepcopy on every state update, which has the potential to slow
applications down.
  • Loading branch information
elijahbenizzy committed Jul 17, 2024
1 parent 7a3e145 commit ba5b4bd
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 22 deletions.
59 changes: 49 additions & 10 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from burr.core.graph import Graph, GraphBuilder
from burr.core.persistence import BaseStateLoader, BaseStateSaver
from burr.core.state import State
from burr.core.state import State, StateDelta
from burr.core.validation import BASE_ERROR_MESSAGE
from burr.lifecycle.base import LifecycleAdapter
from burr.lifecycle.internal import LifecycleAdapterSet
Expand Down Expand Up @@ -83,15 +83,44 @@ def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_
_raise_fn_return_validation_error(output, action_name)


def _apply_defaults(state: State, defaults: Dict[str, Any]) -> State:
def _apply_defaults(
state: State,
defaults: Dict[str, Any],
op_list_to_restrict_writes: Optional[List[StateDelta]] = None,
) -> State:
"""Applies default values to the state. This is useful for the cases in which one applies a default value.
:param state: The state object to apply to.
:param defaults: Default values (key/value) to use
:param op_list_to_restrict_writes: The list of operations to restrict writes to, optional.
If this is specified, then it will only apply the defaults to the keys that were written by ops in the op list.
This allows us to track what it has written, and use that to apply defaults.
:return: The state object with the defaults applied.
"""
state_update = {}
state_to_use = state
op_list_writes = None
# In this case we want to restrict to the written sets
if op_list_to_restrict_writes is not None:
op_list_writes = set()
for op in op_list_to_restrict_writes:
op_list_writes.update(op.writes())

# We really don't need to short-circuit but I want to avoid the update function
# So we might as well
if len(defaults) > 0:
for key, value in defaults.items():
if key not in state:
state_update[key] = value
# if we're tracking the op list
# Then we only want to apply deafults
# to keys that have *not* been written to
# This is more restrictive than the next condition
if op_list_writes is not None:
if key not in op_list_writes:
state_update[key] = value
# Otherwise we just apply the defaults to the state itself
else:
if key not in state:
state_update[key] = value
state_to_use = state.update(**state_update)
return state_to_use

Expand Down Expand Up @@ -244,12 +273,13 @@ def _run_single_step_action(
:return: The result of running the action, and the new state
"""
# TODO -- guard all reads/writes with a subset of the state
state.flush_op_list()
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
result, new_state = _adjust_single_step_output(
action.run_and_update(state, **inputs), action.name
)
new_state = _apply_defaults(new_state, action.default_writes)
new_state = _apply_defaults(new_state, action.default_writes, state.flush_op_list())
_validate_result(result, action.name)
out = result, _state_update(state, new_state)
_validate_result(result, action.name)
Expand All @@ -262,6 +292,7 @@ def _run_single_step_streaming_action(
) -> Generator[Tuple[dict, Optional[State]], None, None]:
"""Runs a single step streaming action. This API is internal-facing.
This normalizes + validates the output."""
state.flush_op_list()
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
generator = action.stream_run_and_update(state, **inputs)
Expand All @@ -284,7 +315,9 @@ def _run_single_step_streaming_action(
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
)
_validate_result(result, action.name)
state_update = _apply_defaults(state_update, action.default_writes)
state_update = _apply_defaults(
state_update, action.default_writes, state_update.flush_op_list()
)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update

Expand All @@ -293,13 +326,14 @@ async def _arun_single_step_action(
action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]]
) -> Tuple[dict, State]:
"""Runs a single step action in async. See the synchronous version for more details."""
state.flush_op_list()
state_to_use = state
state_to_use = _apply_defaults(state_to_use, action.default_reads)
action.validate_inputs(inputs)
result, new_state = _adjust_single_step_output(
await action.run_and_update(state_to_use, **inputs), action.name
)
new_state = _apply_defaults(new_state, action.default_writes)
new_state = _apply_defaults(new_state, action.default_writes, state.flush_op_list())
_validate_result(result, action.name)
_validate_reducer_writes(action, new_state, action.name)
return result, _state_update(state, new_state)
Expand All @@ -309,6 +343,7 @@ async def _arun_single_step_streaming_action(
action: SingleStepStreamingAction, state: State, inputs: Optional[Dict[str, Any]]
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
"""Runs a single step streaming action in async. See the synchronous version for more details."""
state.flush_op_list()
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
generator = action.stream_run_and_update(state, **inputs)
Expand All @@ -331,7 +366,7 @@ async def _arun_single_step_streaming_action(
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
)
_validate_result(result, action.name)
state_update = _apply_defaults(state_update, action.default_writes)
state_update = _apply_defaults(state_update, action.default_writes, state.flush_op_list())
_validate_reducer_writes(action, state_update, action.name)
# TODO -- add guard against zero-length stream
yield result, state_update
Expand All @@ -347,6 +382,7 @@ def _run_multi_step_streaming_action(
This peeks ahead by one so we know when this is done (and when to validate).
"""
state.flush_op_list()
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
generator = action.stream_run(state, **inputs)
Expand All @@ -361,7 +397,7 @@ def _run_multi_step_streaming_action(
yield next_result, None
_validate_result(result, action.name)
state_update = _run_reducer(action, state, result, action.name)
state_update = _apply_defaults(state_update, action.default_writes)
state_update = _apply_defaults(state_update, action.default_writes, state.flush_op_list())
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update

Expand All @@ -370,6 +406,7 @@ async def _arun_multi_step_streaming_action(
action: AsyncStreamingAction, state: State, inputs: Optional[Dict[str, Any]]
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
"""Runs a multi-step streaming action in async. See the synchronous version for more details."""
state.flush_op_list()
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
generator = action.stream_run(state, **inputs)
Expand All @@ -384,7 +421,7 @@ async def _arun_multi_step_streaming_action(
yield next_result, None
_validate_result(result, action.name)
state_update = _run_reducer(action, state, result, action.name)
state_update = _apply_defaults(state_update, action.default_writes)
state_update = _apply_defaults(state_update, action.default_writes, state.flush_op_list())
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update

Expand Down Expand Up @@ -537,6 +574,7 @@ def _step(
) -> Optional[Tuple[Action, dict, State]]:
"""Internal-facing version of step. This is the same as step, but with an additional
parameter to hide hook execution so async can leverage it."""
self._state.flush_op_list() # Just to be sure, this is internal but we don't want to carry too many around
with self.context:
next_action = self.get_next_action()
if next_action is None:
Expand Down Expand Up @@ -668,6 +706,7 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d

async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True):
# we want to increment regardless of failure
self.state.flush_op_list()
with self.context:
next_action = self.get_next_action()
if next_action is None:
Expand Down
58 changes: 50 additions & 8 deletions burr/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import importlib
import inspect
import logging
from typing import Any, Callable, Dict, Iterator, Mapping, Union
from typing import Any, Callable, Dict, Iterator, List, Mapping, Union

from burr.core import serde

Expand Down Expand Up @@ -95,6 +95,11 @@ def writes(self) -> list[str]:
"""Returns the keys that this state delta writes"""
pass

@abc.abstractmethod
def deletes(self) -> list[str]:
"""Returns the keys that this state delta deletes"""
pass

@abc.abstractmethod
def apply_mutate(self, inputs: dict):
"""Applies the state delta to the inputs"""
Expand All @@ -117,6 +122,9 @@ def reads(self) -> list[str]:
def writes(self) -> list[str]:
return list(self.values.keys())

def deletes(self) -> list[str]:
return []

def apply_mutate(self, inputs: dict):
inputs.update(self.values)

Expand All @@ -137,13 +145,21 @@ def reads(self) -> list[str]:
def writes(self) -> list[str]:
return list(self.values.keys())

def deletes(self) -> list[str]:
return []

def apply_mutate(self, inputs: dict):
for key, value in self.values.items():
if key not in inputs:
inputs[key] = []
if not isinstance(inputs[key], list):
raise ValueError(f"Cannot append to non-list value {key}={inputs[self.key]}")
inputs[key].append(value)
inputs[key] = [
*inputs[key],
value,
] # Not as efficient but safer, so we don't mutate the original list
# we're doing this to avoid a copy.deepcopy() call, so it is already more efficient than it was before
# That said, if one modifies prior values in the list, it is on them, and undefined behavior

def validate(self, input_state: Dict[str, Any]):
incorrect_types = {}
Expand Down Expand Up @@ -171,6 +187,9 @@ def reads(self) -> list[str]:
def writes(self) -> list[str]:
return list(self.values.keys())

def deletes(self) -> list[str]:
return []

def validate(self, input_state: Dict[str, Any]):
incorrect_types = {}
for write_key in self.writes():
Expand Down Expand Up @@ -201,11 +220,14 @@ def name(cls) -> str:
return "delete"

def reads(self) -> list[str]:
return list(self.keys)
return []

def writes(self) -> list[str]:
return []

def deletes(self) -> list[str]:
return list(self.keys)

def apply_mutate(self, inputs: dict):
for key in self.keys:
inputs.pop(key, None)
Expand All @@ -214,19 +236,36 @@ def apply_mutate(self, inputs: dict):
class State(Mapping):
"""An immutable state object. This is the only way to interact with state in Burr."""

def __init__(self, initial_values: Dict[str, Any] = None):
def __init__(self, initial_values: Dict[str, Any] = None, _op_list: list[StateDelta] = None):
if initial_values is None:
initial_values = dict()
self._state = initial_values
self._op_list = _op_list if _op_list is not None else []
self._internal_sequence_id = 0

def flush_op_list(self) -> List[StateDelta]:
"""Flushes the operation list, returning it and clearing it. This is an internal method,
do not use, as it may change."""
op_list = self._op_list
self._op_list = []
return op_list

@property
def op_list(self) -> list[StateDelta]:
"""The list of operations since this was last flushed.
Also an internal property -- do not use, the implementation might change."""
return self._op_list

def apply_operation(self, operation: StateDelta) -> "State":
"""Applies a given operation to the state, returning a new state"""
new_state = copy.deepcopy(self._state) # TODO -- restrict to just the read keys
new_state = copy.copy(self._state) # TODO -- restrict to just the read keys
operation.validate(new_state)
operation.apply_mutate(
new_state
) # todo -- validate that the write keys are the only different ones
return State(new_state)
self._op_list.append(operation)
# we want to carry this on for now
return State(new_state, _op_list=self._op_list)

def get_all(self) -> Dict[str, Any]:
"""Returns the entire state, realize as a dictionary. This is a copy."""
Expand Down Expand Up @@ -327,11 +366,14 @@ def wipe(self, delete: list[str] = None, keep: list[str] = None):
def merge(self, other: "State") -> "State":
"""Merges two states together, overwriting the values in self
with those in other."""
return State({**self.get_all(), **other.get_all()})
return State({**self.get_all(), **other.get_all()}, _op_list=self._op_list)

def subset(self, *keys: str, ignore_missing: bool = True) -> "State":
"""Returns a subset of the state, with only the given keys"""
return State({key: self[key] for key in keys if key in self or not ignore_missing})
return State(
{key: self[key] for key in keys if key in self or not ignore_missing},
_op_list=self._op_list,
)

def __getitem__(self, __k: str) -> Any:
return self._state[__k]
Expand Down
5 changes: 3 additions & 2 deletions docs/reference/state.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
=================
=====
State
=================
=====

Use the state API to manipulate the state of the application.

.. autoclass:: burr.core.state.State
:members:
:exclude-members: op_list, flush_op_list

.. automethod:: __init__

Expand Down
Loading

0 comments on commit ba5b4bd

Please sign in to comment.