From 8389da97fbbe1126bb64f6dadbf1ab5717138e25 Mon Sep 17 00:00:00 2001 From: jernejfrank Date: Mon, 30 Dec 2024 20:57:24 +0800 Subject: [PATCH] Add async persistance interface Adding async support for persistance and refactoring builder: - classes for building async persistance adapters / hooks - builder extended to include async initializer/persister, async build - builder refactored - application added async validation, warning/error when async hooks not invoked - automatic built for app in parallelism made backward compatible --- burr/core/application.py | 302 ++++++++++++++++++++++++++++++++------- burr/core/parallelism.py | 27 +++- burr/core/persistence.py | 117 ++++++++++++++- 3 files changed, 389 insertions(+), 57 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index 6c4eb054..b131fb3f 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -44,7 +44,12 @@ StreamingResultContainer, ) from burr.core.graph import Graph, GraphBuilder -from burr.core.persistence import BaseStateLoader, BaseStateSaver +from burr.core.persistence import ( + AsyncBaseStateLoader, + AsyncBaseStateSaver, + BaseStateLoader, + BaseStateSaver, +) from burr.core.state import State from burr.core.typing import ActionSchema, DictBasedTypingSystem, TypingSystem from burr.core.validation import BASE_ERROR_MESSAGE @@ -87,7 +92,9 @@ def _raise_fn_return_validation_error(output: Any, action_name: str): def _adjust_single_step_output( - output: Union[State, Tuple[dict, State]], action_name: str, action_schema: ActionSchema + output: Union[State, Tuple[dict, State]], + action_name: str, + action_schema: ActionSchema, ): """Adjusts the output of a single step action to be a tuple of (result, state) or just state""" @@ -839,6 +846,7 @@ def step(self, inputs: Optional[Dict[str, Any]] = None) -> Optional[Tuple[Action """ # we need to increment the sequence before we start computing # that way if we're replaying from state, we don't get stuck + self.validate_correct_async_use() self._increment_sequence_id() out = self._step(inputs=inputs, _run_hooks=True) return out @@ -1037,7 +1045,10 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True ) else: result = await _arun_function( - next_action, self._state, inputs=action_inputs, name=next_action.name + next_action, + self._state, + inputs=action_inputs, + name=next_action.name, ) new_state = _run_reducer(next_action, self._state, result, next_action.name) new_state = self._update_internal_state_value(new_state, next_action) @@ -1174,6 +1185,7 @@ def iterate( :return: Each iteration returns the result of running `step`. This generator also returns a tuple of [action, result, current state] """ + self.validate_correct_async_use() halt_before, halt_after, inputs = self._clean_iterate_params( halt_before, halt_after, inputs ) @@ -1237,6 +1249,7 @@ def run( Note that this is only used for the first iteration -- subsequent iterations will not use this. :return: The final state, and the results of running the actions in the order that they were specified. """ + self.validate_correct_async_use() gen = self.iterate(halt_before=halt_before, halt_after=halt_after, inputs=inputs) while True: try: @@ -1382,6 +1395,7 @@ def streaming_response(state: State, prompt: str) -> Generator[dict, None, Tuple result, state = output.get() print(format(result['response'], color)) """ + self.validate_correct_async_use() call_execute_method_wrapper = _call_execute_method_pre_post(ExecuteMethod.stream_result) call_execute_method_wrapper.call_pre(self) halt_before, halt_after, inputs = self._clean_iterate_params( @@ -1957,6 +1971,29 @@ def _repr_mimebundle_(self, include=None, exclude=None, **kwargs): dot = self.visualize(include_conditions=True, include_state=False) return dot._repr_mimebundle_(include=include, exclude=exclude, **kwargs) + def validate_correct_async_use(self): + """Validates that the application is meant to run async. + This validation is performed in synchronous application methods.""" + + # This is a gentle warning for existing users to use the async application + if self._adapter_set.async_hooks: + logger.warning( + "There are asynchronous hooks present in the application that will be ignored. " + "Please use async methods to run the application and have them executed. " + f"The application has following asynchronous hooks: {self._adapter_set.async_hooks} " + ) + + # We check that if: + # - we have build the application using .abuild() + # - we have async hooks present + # this application is meant to be run in async mode. + if self._builder and self._builder.is_async: + raise ValueError( + "The application was build with async hooks " + "which need to be executed in an asynchronous run. " + "Please use the async run methods to run the application." + ) + def _validate_app_id(app_id: Optional[str]): if app_id is None: @@ -1999,6 +2036,7 @@ def __init__(self): self.typing_system = None self.parallel_executor_factory = None self.state_persister = None + self._is_async: bool = False def with_identifiers( self, app_id: str = None, partition_key: str = None, sequence_id: int = None @@ -2150,7 +2188,9 @@ def with_entrypoint(self, action: str) -> "ApplicationBuilder[StateType]": return self def with_actions( - self, *action_list: Union[Action, Callable], **action_dict: Union[Action, Callable] + self, + *action_list: Union[Action, Callable], + **action_dict: Union[Action, Callable], ) -> "ApplicationBuilder[StateType]": """Adds an action to the application. The actions are granted names (using the with_name) method post-adding, using the kw argument. If it already has a name (or you wish to use the function name, raw, and @@ -2168,7 +2208,8 @@ def with_actions( def with_transitions( self, *transitions: Union[ - Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition] + Tuple[Union[str, list[str]], str], + Tuple[Union[str, list[str]], str, Condition], ], ) -> "ApplicationBuilder[StateType]": """Adds transitions to the application. Transitions are specified as tuples of either: @@ -2249,7 +2290,7 @@ def with_tracker( def initialize_from( self, - initializer: BaseStateLoader, + initializer: Union[BaseStateLoader, AsyncBaseStateLoader], resume_at_next_action: bool, default_state: dict, default_entrypoint: str, @@ -2295,7 +2336,9 @@ def initialize_from( return self def with_state_persister( - self, persister: Union[BaseStateSaver, LifecycleAdapter], on_every: str = "step" + self, + persister: Union[BaseStateSaver, AsyncBaseStateSaver, LifecycleAdapter], + on_every: str = "step", ) -> "ApplicationBuilder[StateType]": """Adds a state persister to the application. This is a way to persist state out to a database, file, etc... at the specified interval. This is one of two options: @@ -2312,20 +2355,7 @@ def with_state_persister( if on_every != "step": raise ValueError(f"on_every {on_every} not supported") - if not isinstance(persister, persistence.BaseStateSaver): - self.lifecycle_adapters.append(persister) - else: - # Check if 'is_initialized' exists and is False; raise RuntimeError, else continue if not implemented - try: - if not persister.is_initialized(): - raise RuntimeError( - "RuntimeError: Uninitialized persister. Make sure to call .initialize() before passing it to " - "the ApplicationBuilder." - ) - except NotImplementedError: - pass - self.lifecycle_adapters.append(persistence.PersisterHook(persister)) - self.state_persister = persister # track for later + self.state_persister = persister # tracks for later; validates in build / abuild return self def with_spawning_parent( @@ -2349,15 +2379,58 @@ def with_spawning_parent( self.spawn_from_partition_key = partition_key return self - def _load_from_persister(self): - """Loads from the set persister and into this current object. + def _set_sync_state_persister(self): + """Inits the synchronous with_state_persister to save the state (local/DB/custom implementations). + Moved here to mimic the async case. + """ + if self.state_persister.is_async(): + raise ValueError( + "You are building the sync application, but have used an " + "async persister. Please use a sync persister or use the " + ".abuild() method to build an async application." + ) - Mutates: - - self.state - - self.sequence_id - - maybe self.start + if not isinstance(self.state_persister, persistence.BaseStateSaver): + self.lifecycle_adapters.append(self.state_persister) + else: + # Check if 'is_initialized' exists and is False; raise RuntimeError, else continue if not implemented + try: + if not self.state_persister.is_initialized(): + raise RuntimeError( + "RuntimeError: Uninitialized persister. Make sure to call .initialize() before passing it to " + "the ApplicationBuilder." + ) + except NotImplementedError: + pass + self.lifecycle_adapters.append(persistence.PersisterHook(self.state_persister)) + async def _set_async_state_persister(self): + """Inits the asynchronous with_state_persister to save the state (local/DB/custom implementations). + Moved here to be able to chain methods and delay the execution until we can chain coroutines in abuild(). """ + if not self.state_persister.is_async(): + raise ValueError( + "You are building the async application, but have used an " + "sync persister. Please use an async persister or use the " + ".build() method to build an sync application." + ) + + if not isinstance(self.state_persister, persistence.AsyncBaseStateSaver): + self.lifecycle_adapters.append(self.state_persister) + else: + # Check if 'is_initialized' exists and is False; raise RuntimeError, else continue if not implemented + try: + if not await self.state_persister.is_initialized(): + raise RuntimeError( + "RuntimeError: Uninitialized persister. Make sure to call .initialize() before passing it to " + "the ApplicationBuilder." + ) + except NotImplementedError: + pass + self.lifecycle_adapters.append(persistence.PersisterHookAsync(self.state_persister)) + + def _identify_state_to_load(self): + """Helper to determine which state to load.""" if self.fork_from_app_id is not None: if self.app_id == self.fork_from_app_id: raise ValueError( @@ -2373,14 +2446,26 @@ def _load_from_persister(self): _partition_key = self.partition_key _app_id = self.app_id _sequence_id = self.sequence_id - # load state from persister - load_result = self.state_initializer.load(_partition_key, _app_id, _sequence_id) + + return _partition_key, _app_id, _sequence_id + + def _init_state_from_persister( + self, + load_result: Optional[persistence.PersistedStateData], + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int], + ): + """Initializes the state of the application. + + Either there is a loaded configuration provided or we use the default state to initialize. + """ if load_result is None: if self.fork_from_app_id is not None: logger.warning( f"{self.state_initializer.__class__.__name__} returned None while trying to fork from: " - f"partition_key:{_partition_key}, app_id:{_app_id}, " - f"sequence_id:{_sequence_id}. " + f"partition_key:{partition_key}, app_id:{app_id}, " + f"sequence_id:{sequence_id}. " "You explicitly requested to fork from a prior application run, but it does not exist. " "Defaulting to state defaults instead." ) @@ -2393,8 +2478,8 @@ def _load_from_persister(self): raise ValueError( BASE_ERROR_MESSAGE + f"Error: {self.state_initializer.__class__.__name__} returned {load_result} for " - f"partition_key:{_partition_key}, app_id:{_app_id}, " - f"sequence_id:{_sequence_id}, " + f"partition_key:{partition_key}, app_id:{app_id}, " + f"sequence_id:{sequence_id}, " "but value for state was None! This is not allowed. Please return just None in this case, " "or double check that persisted state can never be a None value." ) @@ -2417,6 +2502,50 @@ def _load_from_persister(self): # self.start is already set to the default. We don't need to do anything. pass + def _load_from_sync_persister(self): + """Loads from the set sync persister and into this current object. + + Mutates: + - self.state + - self.sequence_id + - maybe self.start + + """ + if self.state_initializer.is_async(): + raise ValueError( + "You are building the sync application, but have used an " + "async initializer. Please use a sync initializer or use the " + ".abuild() method to build an async application." + ) + + _partition_key, _app_id, _sequence_id = self._identify_state_to_load() + + # load state from persister + load_result = self.state_initializer.load(_partition_key, _app_id, _sequence_id) + self._init_state_from_persister(load_result, _partition_key, _app_id, _sequence_id) + + async def _load_from_async_persister(self): + """Loads from the set async persister and into this current object. + + Mutates: + - self.state + - self.sequence_id + - maybe self.start + + """ + if not self.state_initializer.is_async(): + raise ValueError( + "You are building the async application, but have used an " + "sync initializer. Please use an async initializer or use the " + ".build() method to build an sync application." + ) + + _partition_key, _app_id, _sequence_id = self._identify_state_to_load() + + # load state from persister + load_result = await self.state_initializer.load(_partition_key, _app_id, _sequence_id) + self._init_state_from_persister(load_result, _partition_key, _app_id, _sequence_id) + def reset_to_entrypoint(self): self.state = self.state.wipe(delete=[PRIOR_STEP]) @@ -2431,21 +2560,7 @@ def _get_built_graph(self) -> Graph: return self.graph_builder.build() return self.prebuilt_graph - @telemetry.capture_function_usage - def build(self) -> Application[StateType]: - """Builds the application. - - This function is a bit messy as we iron out the exact logic and rigor we want around things. - - :return: The application object - """ - - _validate_app_id(self.app_id) - if self.state is None: - self.state = State() - if self.state_initializer: - # sets state, sequence_id, and maybe start - self._load_from_persister() + def _build_common(self) -> Application: graph = self._get_built_graph() _validate_start(self.start, {action.name for action in graph.actions}) typing_system: TypingSystem[StateType] = ( @@ -2484,3 +2599,94 @@ def build(self) -> Application[StateType]: state_persister=self.state_persister, state_initializer=self.state_initializer, ) + + @telemetry.capture_function_usage + def build(self) -> Application[StateType]: + """Builds the application for synchronous runs. + + We support both synchronous and asynchronous applications. In case you are using state initializers + and persisters, the synchronous application should be used in the following cases: + + .. table:: When to use .build() + :widths: auto + + +-----------------------------------------+----------+------------------------+ + | Cases (persister and app methods) | Status | Remarks | + +=========================================+==========+========================+ + | Sync and Sync | ✅ | | + +-----------------------------------------+----------+------------------------+ + | Sync and Async | ✅ ⚠️ | Will be deprecated | + +-----------------------------------------+----------+------------------------+ + | Async and Sync | ❌ | | + +-----------------------------------------+----------+------------------------+ + | Async and Async | ❌ | | + +-----------------------------------------+----------+------------------------+ + + We originally only had sync persistence and as such this still can be used when running the + app async. However, we strongly encourage to switch to async persisters if you are running + an async application. + + :return: The application object. + """ + _validate_app_id(self.app_id) + if self.state is None: + self.state = State() + + if self.state_persister: + self._set_sync_state_persister() # this is used to save the state during application run + if self.state_initializer: + # sets state, sequence_id, and maybe start + self._load_from_sync_persister() # this is used to load application from a previously saved state + + return self._build_common() + + @telemetry.capture_function_usage + async def abuild(self) -> Application[StateType]: + """Builds the application for asynchronous runs. + + We support both synchronous and asynchronous applications. To save/load in an asynchronous + manner add the async persister versions and use this method to initialize them. This will + also enforce the application to be run with async methods as an additional safety check that + you are not introducing unnecessary bottlenecks. + + Note: When you run an async application you can still use the normal sync functionalities, i.e. + sync hooks of other adapters, but they will block the async event loop until finished. + + In case you are using state initializers and persisters, the asynchronous application should + be used in the following cases: + + .. table:: When to use .abuild() + :widths: auto + + +-----------------------------------------+----------+ + | Cases (persister and app methods) | Status | + +=========================================+==========+ + | Sync and Sync | ❌ | + +-----------------------------------------+----------+ + | Sync and Async | ❌ | + +-----------------------------------------+----------+ + | Async and Sync | ❌ | + +-----------------------------------------+----------+ + | Async and Async | ✅ | + +-----------------------------------------+----------+ + + :return: The application object. + """ + self._is_async = True + + _validate_app_id(self.app_id) + if self.state is None: + self.state = State() + + if self.state_persister: + await self._set_async_state_persister() # this is used to save the state during application run + + if self.state_initializer: + # sets state, sequence_id, and maybe start + await self._load_from_async_persister() # this is used to load application from a previously saved state + + return self._build_common() + + @property + def is_async(self) -> bool: + return self._is_async diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index 6c83326d..0c5c2376 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -3,6 +3,7 @@ import dataclasses import hashlib import inspect +import logging from typing import ( Any, AsyncGenerator, @@ -19,7 +20,7 @@ from burr.common import async_utils from burr.common.async_utils import SyncOrAsyncGenerator, SyncOrAsyncGeneratorOrItemOrList -from burr.core import Action, Application, ApplicationBuilder, ApplicationContext, Graph, State +from burr.core import Action, ApplicationBuilder, ApplicationContext, Graph, State from burr.core.action import SingleStepAction from burr.core.application import ApplicationIdentifiers from burr.core.graph import GraphBuilder @@ -28,6 +29,7 @@ from burr.tracking.base import TrackingClient SubgraphType = Union[Action, Callable, "RunnableGraph"] +logger = logging.getLogger(__name__) @dataclasses.dataclass @@ -70,7 +72,7 @@ def create(from_: SubgraphType) -> "RunnableGraph": @dataclasses.dataclass class SubGraphTask: - """Task to run a subgraph. Has runtime-spefici information, like inputs, state, and + """Task to run a subgraph. Has runtime-specific information, like inputs, state, and the application ID. This is the lower-level component -- the user will only directly interact with this if they use the TaskBasedParallelAction interface, which produces a generator of these. """ @@ -84,7 +86,7 @@ class SubGraphTask: state_persister: Optional[BaseStateSaver] = None state_initializer: Optional[BaseStateLoader] = None - def _create_app(self, parent_context: ApplicationIdentifiers) -> Application: + def _create_app_builder(self, parent_context: ApplicationIdentifiers) -> ApplicationBuilder: builder = ( ApplicationBuilder() .with_graph(self.graph.graph) @@ -101,6 +103,7 @@ def _create_app(self, parent_context: ApplicationIdentifiers) -> Application: ) if self.tracker is not None: builder = builder.with_tracker(self.tracker) # TODO -- move this into the adapter + # In this case we want to persist the state for the app if self.state_persister is not None: builder = builder.with_state_persister(self.state_persister) @@ -119,14 +122,14 @@ def _create_app(self, parent_context: ApplicationIdentifiers) -> Application: else: builder = builder.with_entrypoint(self.graph.entrypoint).with_state(self.state) - return builder.build() + return builder def run( self, parent_context: ApplicationContext, ) -> State: - """Runs the task -- this simply executes it b y instantiating a sub-application""" - app = self._create_app(parent_context) + """Runs the task -- this simply executes it by instantiating a sub-application""" + app = self._create_app_builder(parent_context).build() action, result, state = app.run( halt_after=self.graph.halt_after, inputs={key: value for key, value in self.inputs.items() if not key.startswith("__")}, @@ -134,7 +137,17 @@ def run( return state async def arun(self, parent_context: ApplicationContext): - app = self._create_app(parent_context) + # Here for backwards compatibility, not ideal + if (self.state_initializer is not None and not self.state_initializer.is_async()) or ( + self.state_persister is not None and not self.state_persister.is_async() + ): + logger.warning( + "You are using sync persisters for an async application which is not optimal. " + "Consider switching to an async persister implementation. We will make this an error soon." + ) + app = self._create_app_builder(parent_context).build() + else: + app = await self._create_app_builder(parent_context).abuild() action, result, state = await app.arun( halt_after=self.graph.halt_after, inputs={key: value for key, value in self.inputs.items() if not key.startswith("__")}, diff --git a/burr/core/persistence.py b/burr/core/persistence.py index f2788a2c..7e3450ed 100644 --- a/burr/core/persistence.py +++ b/burr/core/persistence.py @@ -9,7 +9,7 @@ from burr.common.types import BaseCopyable from burr.core import Action from burr.core.state import State, logger -from burr.lifecycle import PostRunStepHook +from burr.lifecycle import PostRunStepHook, PostRunStepHookAsync try: from typing import Self @@ -52,9 +52,41 @@ def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: """Returns list of app IDs for a given primary key""" pass + def is_async(self) -> bool: + return False + + +class AsyncBaseStateLoader(abc.ABC): + """Asynchronous base class for state initialization. This goes together with a AsyncBaseStateSaver + to form the database for your application.""" + + @abc.abstractmethod + async def load( + self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs + ) -> Optional[PersistedStateData]: + """Loads the state for a given app_id + + :param partition_key: the partition key. Note this could be None, but it's up to the persistor to whether + that is a valid value it can handle. + :param app_id: the identifier for the app instance being recorded. + :param sequence_id: optional, the state corresponding to a specific point in time. Specifically state at the + end of the action with this sequence_id. If sequence_id is not provided, persistor should return the state + from the latest fully completed action. + :return: PersistedStateData or None + """ + pass + + @abc.abstractmethod + async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: + """Returns list of app IDs for a given primary key""" + pass + + def is_async(self) -> bool: + return True + class BaseStateSaver(abc.ABC): - """Basic Interface for state writing. This goes together with a BaseStateLoader to form the + """Base class for state writing. This goes together with a BaseStateLoader to form the database for your application. """ @@ -94,6 +126,54 @@ def save( """ pass + def is_async(self) -> bool: + return False + + +class AsyncBaseStateSaver(abc.ABC): + """Asynchronous base class for state writing. This goes together with a AsyncBaseStateLoader + to form the database for your application. + """ + + async def initialize(self): + """Initializes the app for saving, set up any databases, etc.. you want to here.""" + pass + + async def is_initialized(self) -> bool: + """Check if the persister has been initialized appropriately.""" + raise NotImplementedError("Implement this method in your subclass if you need to.") + + @abc.abstractmethod + async def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: State, + status: Literal["completed", "failed"], + **kwargs, + ): + """Saves the state for a given app_id, sequence_id, position + + (PK, App_id, sequence_id, position) are a unique identifier for the state. Why not just + (PK, App_id, sequence_id)? Because we're over-engineering this here. We're always going to have + a position so might as well make it a quadruple. + + :param partition_key: the partition key. Note this could be None, but it's up to the persistor to whether + that is a valid value it can handle. + :param app_id: Appliaction UID to write with + :param sequence_id: Sequence ID of the last executed step + :param position: The action name that was implemented + :param state: The current state of the application + :param status: The status of this state, either "completed" or "failed". If "failed" the state is what it was + before the action was applied. + """ + pass + + def is_async(self) -> bool: + return True + class BaseStatePersister(BaseStateLoader, BaseStateSaver, metaclass=ABCMeta): """Utility interface for a state reader/writer. This both persists and initializes state. @@ -101,6 +181,12 @@ class BaseStatePersister(BaseStateLoader, BaseStateSaver, metaclass=ABCMeta): """ +class AsyncBaseStatePersister(AsyncBaseStateLoader, AsyncBaseStateSaver, metaclass=ABCMeta): + """Utility interface for an asynchronous state reader/writer. This both persists and initializes state. + Extend this class if you want an easy way to implement custom state storage. + """ + + class PersisterHook(PostRunStepHook): """Wrapper class for bridging the persistence interface with lifecycle hooks. This is used internally.""" @@ -124,6 +210,33 @@ def post_run_step( self.persister.save(partition_key, app_id, sequence_id, action.name, state, "failed") +class PersisterHookAsync(PostRunStepHookAsync): + """Wrapper class for bridging the persistence interface with asynchronous lifecycle hooks. This is used internally.""" + + def __init__(self, persister: AsyncBaseStateSaver): + self.persister = persister + + async def post_run_step( + self, + app_id: str, + partition_key: str, + sequence_id: int, + state: "State", + action: "Action", + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + if exception is None: + await self.persister.save( + partition_key, app_id, sequence_id, action.name, state, "completed" + ) + else: + await self.persister.save( + partition_key, app_id, sequence_id, action.name, state, "failed" + ) + + class DevNullPersister(BaseStatePersister): """Does nothing, do not use this. This is for testing only."""