From 17a541a94937b5eb0be979fce8753664e56b0985 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 3 Dec 2024 02:25:27 +0100 Subject: [PATCH 1/6] Mapping states from state name --- src/plumpy/base/state_machine.py | 13 +++---------- src/plumpy/processes.py | 20 +++++++++++++++----- tests/base/test_statemachine.py | 9 +++++---- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 2035d4ab..be27e0cd 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -45,7 +45,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None: + def __init__(self, state: State, *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -314,7 +314,7 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None: + def transition_to(self, new_state: State | None, **kwargs: Any) -> None: """Transite to the new state. The new target state will be create lazily when the state is not yet instantiated, @@ -331,11 +331,6 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> label = None try: self._transitioning = True - - if not isinstance(new_state, State): - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, **kwargs) - label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -345,9 +340,7 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> try: self._enter_next_state(new_state) except StateEntryFailed as exception: - # Make sure we have a state instance - if not isinstance(exception.state, State): - new_state = self._create_state_instance(exception.state, **exception.kwargs) + new_state = exception.state label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ef558fa1..5e2f4cbd 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -867,7 +867,9 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.Finished, result=result, successful=False) + state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] + finished_state = state_cls(self, result=result, successful=False) + raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -1064,7 +1066,9 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace) + state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace) + self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1127,7 +1131,9 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: - self.transition_to(process_states.Killed, msg=exception.msg) + state_class = self.get_states_map()[process_states.ProcessState.KILLED] + new_state = self._create_state_instance(state_class, msg=exception.msg) + self.transition_to(new_state) return True finally: self._killing = None @@ -1179,7 +1185,9 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace_back) + state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back) + self.transition_to(new_state) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1207,7 +1215,9 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.Killed, msg=msg) + state_class = self.get_states_map()[process_states.ProcessState.KILLED] + new_state = self._create_state_instance(state_class, msg=msg) + self.transition_to(new_state) return True @property diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 5b4b73d8..3a1621a2 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -57,6 +57,7 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) + self._player = player self.playing_state = playing_state def __str__(self): @@ -64,7 +65,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) else: self.state_machine.transition_to(self.playing_state) @@ -80,7 +81,7 @@ def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(self.state_machine, track=track)) class CdPlayer(state_machine.StateMachine): @@ -107,12 +108,12 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, playing_state=self._state) + self.transition_to(Paused(self, playing_state=self._state)) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) def stop(self): - self.transition_to(Stopped) + self.transition_to(Stopped(self)) class TestStateMachine(unittest.TestCase): From f44835c76c438abd80f0437c926eabc6082333a8 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 15:53:13 +0100 Subject: [PATCH 2/6] Remove the middle layer of statemachine.State + Savable abstraction --- docs/source/nitpick-exceptions | 2 +- src/plumpy/process_states.py | 111 +++++++++++++++++++++++---------- src/plumpy/processes.py | 26 ++++---- src/plumpy/workchains.py | 4 +- 4 files changed, 94 insertions(+), 49 deletions(-) diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 2f354987..e1d6d969 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator # unavailable forward references py:class plumpy.process_states.Command -py:class plumpy.process_states.State +py:class plumpy.state_machine.State py:class plumpy.base.state_machine.State py:class State py:class Process diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index dbbb7bef..44a916e9 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -120,6 +120,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process try: self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) except ValueError: @@ -145,25 +146,8 @@ class ProcessState(Enum): KILLED: str = 'killed' -@auto_persist('in_state') -class State(state_machine.State, persistence.Savable): - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process - - def interrupt(self, reason: Any) -> None: - pass - - -@auto_persist('args', 'kwargs') -class Created(State): +@auto_persist('args', 'kwargs', 'in_state') +class Created(state_machine.State, persistence.Savable): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} @@ -182,14 +166,23 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) def execute(self) -> state_machine.State: return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine -@auto_persist('args', 'kwargs') -class Running(State): + +@auto_persist('args', 'kwargs', 'in_state') +class Running(state_machine.State, persistence.Savable): LABEL = ProcessState.RUNNING ALLOWED = { ProcessState.RUNNING, @@ -223,6 +216,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore @@ -230,7 +225,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> State: # type: ignore + async def execute(self) -> state_machine.State: # type: ignore if self._command is not None: command = self._command else: @@ -245,7 +240,7 @@ async def execute(self) -> State: # type: ignore raise except Exception: excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(State, excepted) + return cast(state_machine.State, excepted) else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -259,7 +254,7 @@ async def execute(self) -> State: # type: ignore next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: if isinstance(command, Kill): state = self.create_state(ProcessState.KILLED, command.msg) # elif isinstance(command, Pause): @@ -273,11 +268,18 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: else: raise ValueError('Unrecognised command') - return cast(State, state) # casting from base.State to process.State + return cast(state_machine.State, state) # casting from base.State to process.State + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine -@auto_persist('msg', 'data') -class Waiting(State): + +@auto_persist('msg', 'data', 'in_state') +class Waiting(state_machine.State, persistence.Savable): LABEL = ProcessState.WAITING ALLOWED = { ProcessState.RUNNING, @@ -317,6 +319,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: self.done_callback = getattr(self.process, callback_name) @@ -328,7 +332,7 @@ def interrupt(self, reason: Any) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> State: # type: ignore + async def execute(self) -> state_machine.State: # type: ignore try: result = await self._waiting_future except Interruption: @@ -343,7 +347,7 @@ async def execute(self) -> State: # type: ignore else: next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) - return cast(State, next_state) # casting from base.State to process.State + return cast(state_machine.State, next_state) # casting from base.State to process.State def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -353,8 +357,16 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + -class Excepted(State): +@auto_persist('in_state') +class Excepted(state_machine.State, persistence.Savable): """ Excepted state, can optionally provide exception and trace_back @@ -394,6 +406,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: @@ -415,9 +429,16 @@ def get_exc_info( self.traceback, ) + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + -@auto_persist('result', 'successful') -class Finished(State): +@auto_persist('result', 'successful', 'in_state') +class Finished(state_machine.State, persistence.Savable): """State for process is finished. :param result: The result of process @@ -431,9 +452,20 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None: self.result = result self.successful = successful + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + -@auto_persist('msg') -class Killed(State): +@auto_persist('msg', 'in_state') +class Killed(state_machine.State, persistence.Savable): """ Represents a state where a process has been killed. @@ -453,5 +485,16 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): super().__init__(process) self.msg = msg + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + # endregion diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 5e2f4cbd..1fe05470 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -177,7 +177,7 @@ def current(cls) -> Optional['Process']: return None @classmethod - def get_states(cls) -> Sequence[Type[process_states.State]]: + def get_states(cls) -> Sequence[Type[state_machine.State]]: """Return all allowed states of the process.""" state_classes = cls.get_state_classes() return ( @@ -186,7 +186,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -357,10 +357,10 @@ def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering( - cast(process_states.State, state) + cast(state_machine.State, state) ), state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( - cast(Optional[process_states.State], from_state) + cast(Optional[state_machine.State], from_state) ), state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } @@ -661,7 +661,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self._loop = asyncio.get_event_loop() - self._state: process_states.State = self.recreate_state(saved_state['_state']) + self._state: state_machine.State = self.recreate_state(saved_state['_state']) if 'communicator' in load_context: self._communicator = load_context.communicator @@ -719,7 +719,7 @@ def log_with_pid(self, level: int, msg: str) -> None: # region Events - def on_entering(self, state: process_states.State) -> None: + def on_entering(self, state: state_machine.State) -> None: # Map these onto direct functions that the subclass can implement state_label = state.LABEL if state_label == process_states.ProcessState.CREATED: @@ -735,7 +735,7 @@ def on_entering(self, state: process_states.State) -> None: elif state_label == process_states.ProcessState.EXCEPTED: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore - def on_entered(self, from_state: Optional[process_states.State]) -> None: + def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement state_label = self._state.LABEL if state_label == process_states.ProcessState.RUNNING: @@ -1103,7 +1103,7 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._do_pause(msg) - def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: + def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: @@ -1129,7 +1129,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu if isinstance(exception, process_states.KillInterruption): - def do_kill(_next_state: process_states.State) -> Any: + def do_kill(_next_state: state_machine.State) -> Any: try: state_class = self.get_states_map()[process_states.ProcessState.KILLED] new_state = self._create_state_instance(state_class, msg=exception.msg) @@ -1227,7 +1227,7 @@ def is_killing(self) -> bool: # endregion - def create_initial_state(self) -> process_states.State: + def create_initial_state(self) -> state_machine.State: """This method is here to override its superclass. Automatically enter the CREATED state when the process is created. @@ -1235,11 +1235,11 @@ def create_initial_state(self) -> process_states.State: :return: A Created state """ return cast( - process_states.State, + state_machine.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), ) - def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: + def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ Create a state object from a saved state @@ -1247,7 +1247,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) + return cast(state_machine.State, persistence.Savable.load(saved_state, load_context)) # endregion diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 748a44d7..9eafcb50 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -25,6 +25,8 @@ import kiwipy +from plumpy.base import state_machine + from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -117,7 +119,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map From 985c2ed6455891f62f6987bf1f168f077c7fb3c5 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 16:06:39 +0100 Subject: [PATCH 3/6] Move is_terminal as class attribute required --- src/plumpy/base/state_machine.py | 8 ++------ src/plumpy/process_states.py | 11 +++++++++++ src/plumpy/processes.py | 4 ++-- tests/base/test_statemachine.py | 6 ++++++ 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index be27e0cd..380f4610 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -133,10 +133,6 @@ class State: # from this one ALLOWED: Set[LABEL_TYPE] = set() - @classmethod - def is_terminal(cls) -> bool: - return not cls.ALLOWED - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): """ :param state_machine: The process this state belongs to @@ -165,7 +161,7 @@ def execute(self) -> Optional['State']: @super_check def exit(self) -> None: """Exiting the state""" - if self.is_terminal(): + if self.is_terminal: raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': @@ -345,7 +341,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: self._exit_current_state(new_state) self._enter_next_state(new_state) - if self._state is not None and self._state.is_terminal(): + if self._state is not None and self._state.is_terminal: call_with_super_check(self.on_terminated) except Exception: self._transitioning = False diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 44a916e9..91959c4d 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -152,6 +152,7 @@ class Created(state_machine.State, persistence.Savable): ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' + is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: super().__init__(process) @@ -200,6 +201,8 @@ class Running(state_machine.State, persistence.Savable): _running: bool = False _run_handle = None + is_terminal = False + def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: super().__init__(process) assert run_fn is not None @@ -293,6 +296,8 @@ class Waiting(state_machine.State, persistence.Savable): _interruption = None + is_terminal = False + def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: @@ -379,6 +384,8 @@ class Excepted(state_machine.State, persistence.Savable): EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' + is_terminal = True + def __init__( self, process: 'Process', @@ -447,6 +454,8 @@ class Finished(state_machine.State, persistence.Savable): LABEL = ProcessState.FINISHED + is_terminal = True + def __init__(self, process: 'Process', result: Any, successful: bool) -> None: super().__init__(process) self.result = result @@ -477,6 +486,8 @@ class Killed(state_machine.State, persistence.Savable): LABEL = ProcessState.KILLED + is_terminal = True + def __init__(self, process: 'Process', msg: Optional[MessageType]): """ :param process: The associated process diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 1fe05470..1e745437 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -467,7 +467,7 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal() + return self._state.is_terminal def result(self) -> Any: """ @@ -540,7 +540,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal() + return self._state.is_terminal # endregion diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 3a1621a2..b6d7e2d3 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -20,6 +20,8 @@ class Playing(state_machine.State): ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, track): assert track is not None, 'Must provide a track name' super().__init__(player) @@ -54,6 +56,8 @@ class Paused(state_machine.State): ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) @@ -77,6 +81,8 @@ class Stopped(state_machine.State): } TRANSITIONS = {PLAY: PLAYING} + is_terminal = False + def __str__(self): return '[]' From e4e0df59c908161a4c7bfdc6feb751eada58444c Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 17:07:29 +0100 Subject: [PATCH 4/6] forming the enter/exit for State protocol --- src/plumpy/base/state_machine.py | 66 ++++---------- src/plumpy/process_states.py | 148 ++++++++++++++++++------------- src/plumpy/workchains.py | 25 +++--- tests/base/test_statemachine.py | 44 +++++++-- tests/test_processes.py | 2 +- 5 files changed, 157 insertions(+), 128 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 380f4610..2164737f 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -13,15 +13,17 @@ from typing import ( Any, Callable, + ClassVar, Dict, Hashable, Iterable, List, Optional, + Protocol, Sequence, - Set, Type, Union, + runtime_checkable, ) from plumpy.futures import Future @@ -88,12 +90,12 @@ def event( if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) - if not all(issubclass(state, State) for state in from_states): # type: ignore + if not all(isinstance(state, State) for state in from_states): # type: ignore raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) - if not all(issubclass(state, State) for state in to_states): # type: ignore + if not all(isinstance(state, State) for state in to_states): # type: ignore raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: @@ -127,53 +129,20 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -class State: - LABEL: LABEL_TYPE = None - # A set containing the labels of states that can be entered - # from this one - ALLOWED: Set[LABEL_TYPE] = set() +@runtime_checkable +class State(Protocol): + LABEL: ClassVar[LABEL_TYPE] - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): - """ - :param state_machine: The process this state belongs to - """ - self.state_machine = state_machine - self.in_state: bool = False - - def __str__(self) -> str: - return str(self.LABEL) - - @property - def label(self) -> LABEL_TYPE: - """Convenience property to get the state label""" - return self.LABEL - - @super_check - def enter(self) -> None: - """Entering the state""" - - def execute(self) -> Optional['State']: + async def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ + ... - @super_check - def exit(self) -> None: - """Exiting the state""" - if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': - return self.state_machine.create_state(state_label, *args, **kwargs) - - def do_enter(self) -> None: - call_with_super_check(self.enter) - self.in_state = True + def enter(self) -> None: ... - def do_exit(self) -> None: - call_with_super_check(self.exit) - self.in_state = False + def exit(self) -> None: ... class StateEventHook(enum.Enum): @@ -250,7 +219,7 @@ def __ensure_built(cls) -> None: # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: - assert issubclass(state_cls, State) + assert isinstance(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" cls._STATES_MAP[label] = state_cls @@ -380,7 +349,8 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat # This method should be replaced by `_create_state_instance`. # aiida-core using this method for its Waiting state override. try: - return self.get_states_map()[state_label](self, *args, **kwargs) + state_cls = self.get_states_map()[state_label] + return state_cls(self, *args, **kwargs) except KeyError: raise ValueError(f'{state_label} is not a valid state') @@ -390,20 +360,20 @@ def _exit_current_state(self, next_state: State) -> None: # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: - if next_state.label != self.initial_state_label(): + if next_state.LABEL != self.initial_state_label(): raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) - self._state.do_exit() + self._state.exit() def _enter_next_state(self, next_state: State) -> None: last_state = self._state self._fire_state_event(StateEventHook.ENTERING_STATE, next_state) # Enter the new state - next_state.do_enter() + next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 91959c4d..88cab660 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,7 +5,7 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final import yaml from yaml.loader import Loader @@ -146,6 +146,7 @@ class ProcessState(Enum): KILLED: str = 'killed' +@final @auto_persist('args', 'kwargs', 'in_state') class Created(state_machine.State, persistence.Savable): LABEL = ProcessState.CREATED @@ -155,11 +156,12 @@ class Created(state_machine.State, persistence.Savable): is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs + self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -167,21 +169,24 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + async def execute(self) -> state_machine.State: + return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False +@final @auto_persist('args', 'kwargs', 'in_state') class Running(state_machine.State, persistence.Savable): LABEL = ProcessState.RUNNING @@ -204,12 +209,13 @@ class Running(state_machine.State, persistence.Savable): is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs self._run_handle = None + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -219,7 +225,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: @@ -228,7 +234,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> state_machine.State: # type: ignore + async def execute(self) -> state_machine.State: if self._command is not None: command = self._command else: @@ -242,7 +248,7 @@ async def execute(self) -> state_machine.State: # type: ignore # Let this bubble up to the caller raise except Exception: - excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) + excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) return cast(state_machine.State, excepted) else: if not isinstance(result, Command): @@ -259,28 +265,30 @@ async def execute(self) -> state_machine.State: # type: ignore def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: if isinstance(command, Kill): - state = self.create_state(ProcessState.KILLED, command.msg) + state = self.process.create_state(ProcessState.KILLED, command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) else: raise ValueError('Unrecognised command') return cast(state_machine.State, state) # casting from base.State to process.State - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False +@final @auto_persist('msg', 'data', 'in_state') class Waiting(state_machine.State, persistence.Savable): LABEL = ProcessState.WAITING @@ -311,11 +319,12 @@ def __init__( msg: Optional[str] = None, data: Optional[Any] = None, ) -> None: - super().__init__(process) + self.process = process self.done_callback = done_callback self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -324,7 +333,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: @@ -348,9 +357,9 @@ async def execute(self) -> state_machine.State: # type: ignore raise if result == NULL: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result) return cast(state_machine.State, next_state) # casting from base.State to process.State @@ -362,12 +371,14 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('in_state') @@ -397,9 +408,10 @@ def __init__( :param exception: The exception instance :param trace_back: An optional exception traceback """ - super().__init__(process) + self.process = process self.exception = exception self.traceback = trace_back + self.in_state = False def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -413,7 +425,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -436,12 +448,17 @@ def get_exc_info( self.traceback, ) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + async def execute(self) -> state_machine.State: # type: ignore + ... + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('result', 'successful', 'in_state') @@ -457,20 +474,26 @@ class Finished(state_machine.State, persistence.Savable): is_terminal = True def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - super().__init__(process) + self.process = process self.result = result self.successful = successful - - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process + + def enter(self) -> None: + self.in_state = True + + async def execute(self) -> state_machine.State: # type: ignore + ... + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('msg', 'in_state') @@ -493,19 +516,24 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): :param process: The associated process :param msg: Optional kill message """ - super().__init__(process) + self.process = process self.msg = msg - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + async def execute(self) -> state_machine.State: # type: ignore + ... def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False # endregion diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 9eafcb50..eefd57f1 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -26,6 +26,7 @@ import kiwipy from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -87,16 +88,6 @@ def __init__( resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key - def enter(self) -> None: - super().enter() - for awaitable in self._awaiting: - awaitable.add_done_callback(self._awaitable_done) - - def exit(self) -> None: - super().exit() - for awaitable in self._awaiting: - awaitable.remove_done_callback(self._awaitable_done) - def _awaitable_done(self, awaitable: asyncio.Future) -> None: key = self._awaiting.pop(awaitable) try: @@ -107,6 +98,20 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: if not self._awaiting: self._waiting_future.set_result(lang.NULL) + def enter(self) -> None: + for awaitable in self._awaiting: + awaitable.add_done_callback(self._awaitable_done) + + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False + + for awaitable in self._awaiting: + awaitable.remove_done_callback(self._awaitable_done) + class WorkChain(mixins.ContextMixin, processes.Process): """ diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index b6d7e2d3..b6100614 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- import time +from typing import final import unittest from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError # Events PLAY = 'Play' @@ -24,24 +26,16 @@ class Playing(state_machine.State): def __init__(self, player, track): assert track is not None, 'Must provide a track name' - super().__init__(player) self.track = track self._last_time = None self._played = 0.0 + self.in_state = False def __str__(self): if self.in_state: self._update_time() return f'> {self.track} ({self._played}s)' - def enter(self): - super().enter() - self._last_time = time.time() - - def exit(self): - super().exit() - self._update_time() - def play(self, track=None): # pylint: disable=no-self-use, unused-argument return False @@ -50,6 +44,17 @@ def _update_time(self): self._played += current_time - self._last_time self._last_time = current_time + def enter(self) -> None: + self._last_time = time.time() + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self._update_time() + self.in_state = False + class Paused(state_machine.State): LABEL = PAUSED @@ -73,6 +78,15 @@ def play(self, track=None): else: self.state_machine.transition_to(self.playing_state) + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + class Stopped(state_machine.State): LABEL = STOPPED @@ -83,12 +97,24 @@ class Stopped(state_machine.State): is_terminal = False + def __init__(self, player): + self.state_machine = player + def __str__(self): return '[]' def play(self, track): self.state_machine.transition_to(Playing(self.state_machine, track=track)) + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + class CdPlayer(state_machine.StateMachine): STATES = (Stopped, Playing, Paused) diff --git a/tests/test_processes.py b/tests/test_processes.py index eb5bf599..4b8cc606 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -653,7 +653,7 @@ def test_exception_during_on_entered(self): class RaisingProcess(Process): def on_entered(self, from_state): - if from_state is not None and from_state.label == ProcessState.RUNNING: + if from_state is not None and from_state.LABEL == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') super().on_entered(from_state) From 50a2804a5e0ed93a7ec908a61821cba73991092b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 23:20:45 +0100 Subject: [PATCH 5/6] Forming Interruptable and Proceedable protocol --- src/plumpy/base/state_machine.py | 20 +++++++++++++++----- src/plumpy/process_states.py | 12 ++---------- src/plumpy/processes.py | 22 +++++++++++++++++++++- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 2164737f..27b1e5f8 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -132,18 +132,28 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): LABEL: ClassVar[LABEL_TYPE] + is_terminal: ClassVar[bool] - async def execute(self) -> State | None: + def enter(self) -> None: ... + + def exit(self) -> None: ... + + +@runtime_checkable +class Interruptable(Protocol): + def interrupt(self, reason: Exception) -> None: ... + + +@runtime_checkable +class Proceedable(Protocol): + + def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ ... - def enter(self) -> None: ... - - def exit(self) -> None: ... - class StateEventHook(enum.Enum): """ diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 88cab660..cc9169c7 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -288,6 +288,7 @@ def exit(self) -> None: self.in_state = False + @final @auto_persist('msg', 'data', 'in_state') class Waiting(state_machine.State, persistence.Savable): @@ -342,7 +343,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.done_callback = None self._waiting_future = futures.Future() - def interrupt(self, reason: Any) -> None: + def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) @@ -448,9 +449,6 @@ def get_exc_info( self.traceback, ) - async def execute(self) -> state_machine.State: # type: ignore - ... - def enter(self) -> None: self.in_state = True @@ -486,9 +484,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def enter(self) -> None: self.in_state = True - async def execute(self) -> state_machine.State: # type: ignore - ... - def exit(self) -> None: if self.is_terminal: raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') @@ -519,9 +514,6 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): self.process = process self.msg = msg - async def execute(self) -> state_machine.State: # type: ignore - ... - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 1e745437..74808291 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -51,7 +51,15 @@ utils, ) from .base import state_machine -from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event +from .base.state_machine import ( + Interruptable, + Proceedable, + StateEntryFailed, + StateMachine, + StateMachineError, + TransitionFailed, + event, +) from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper from .process_comms import MESSAGE_KEY, KillMessage, MessageType @@ -1092,6 +1100,11 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._pausing if self._stepping: + if not isinstance(self._state, Interruptable): + raise exceptions.InvalidStateError( + f'cannot interrupt {self._state.__class__}, method `interrupt` not implement' + ) + # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.PauseInterruption(msg) @@ -1103,6 +1116,10 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._do_pause(msg) + @staticmethod + def _interrupt(state: Interruptable, reason: Exception) -> None: + state.interrupt(reason) + def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: @@ -1285,6 +1302,9 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused + if not isinstance(self._state, Proceedable): + raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented') + try: self._stepping = True next_state = None From cacf0d92eeb3107984aa6ebc229d3ebe4e5b923b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 3 Dec 2024 00:48:53 +0100 Subject: [PATCH 6/6] Refactoring create_state as static function initialize state from label create_state refact Hashable initialized + parameters passed to Hashable Fix pre-commit errors --- src/plumpy/base/state_machine.py | 45 +++--- src/plumpy/process_states.py | 235 ++++++++++++++++--------------- src/plumpy/processes.py | 42 +++--- src/plumpy/workchains.py | 10 +- tests/base/test_statemachine.py | 15 +- 5 files changed, 173 insertions(+), 174 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 27b1e5f8..fc926008 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -34,7 +34,6 @@ _LOGGER = logging.getLogger(__name__) -LABEL_TYPE = Union[None, enum.Enum, str] EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] @@ -131,9 +130,12 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): - LABEL: ClassVar[LABEL_TYPE] + LABEL: ClassVar[Any] + ALLOWED: ClassVar[set[Any]] is_terminal: ClassVar[bool] + def __init__(self, *args: Any, **kwargs: Any): ... + def enter(self) -> None: ... def exit(self) -> None: ... @@ -146,7 +148,6 @@ def interrupt(self, reason: Exception) -> None: ... @runtime_checkable class Proceedable(Protocol): - def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. @@ -155,6 +156,14 @@ def execute(self) -> State | None: ... +def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + if state_label not in st.get_states_map(): + raise ValueError(f'{state_label} is not a valid state') + + state_cls = st.get_states_map()[state_label] + return state_cls(*args, **kwargs) + + class StateEventHook(enum.Enum): """ Hooks that can be used to register callback at various points in the state transition @@ -203,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]: raise RuntimeError('States not defined') @classmethod - def initial_state_label(cls) -> LABEL_TYPE: + def initial_state_label(cls) -> Any: cls.__ensure_built() assert cls.STATES is not None return cls.STATES[0].LABEL @classmethod - def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: + def get_state_class(cls, label: Any) -> Type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None return cls._STATES_MAP[label] @@ -253,11 +262,11 @@ def init(self) -> None: def __str__(self) -> str: return f'<{self.__class__.__name__}> ({self.state})' - def create_initial_state(self) -> State: - return self.get_state_class(self.initial_state_label())(self) + def create_initial_state(self, *args: Any, **kwargs: Any) -> State: + return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Optional[LABEL_TYPE]: + def state(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -297,6 +306,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: The arguments are passed to the state class to create state instance. (process arg does not need to pass since it will always call with 'self' as process) """ + print(f'try: {self._state} -> {new_state}') assert not self._transitioning, 'Cannot call transition_to when already transitioning state' if new_state is None: @@ -353,17 +363,6 @@ def get_debug(self) -> bool: def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: - # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic - # because the label is defined after the state and required to be know before calling this function. - # This method should be replaced by `_create_state_instance`. - # aiida-core using this method for its Waiting state override. - try: - state_cls = self.get_states_map()[state_label] - return state_cls(self, *args, **kwargs) - except KeyError: - raise ValueError(f'{state_label} is not a valid state') - def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" @@ -375,7 +374,7 @@ def _exit_current_state(self, next_state: State) -> None: return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.exit() @@ -386,9 +385,3 @@ def _enter_next_state(self, next_state: State) -> None: next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - - def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State: - if state_cls.LABEL not in self.get_states_map(): - raise ValueError(f'{state_cls.LABEL} is not a valid state') - - return state_cls(self, **kwargs) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index cc9169c7..5f3e8237 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,7 +5,20 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Optional, + Protocol, + Tuple, + Type, + Union, + cast, + final, + runtime_checkable, +) import yaml from yaml.loader import Loader @@ -20,9 +33,9 @@ _HAS_TBLIB = False from . import exceptions, futures, persistence, utils -from .base import state_machine +from .base import state_machine as st from .lang import NULL -from .persistence import auto_persist +from .persistence import LoadSaveContext, auto_persist from .utils import SAVED_STATE_TYPE __all__ = [ @@ -138,22 +151,28 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ - CREATED: str = 'created' - RUNNING: str = 'running' - WAITING: str = 'waiting' - FINISHED: str = 'finished' - EXCEPTED: str = 'excepted' - KILLED: str = 'killed' + # FIXME: see LSP error of return a exception, the type is Literal[str] which is invariant, tricky + CREATED = 'created' + RUNNING = 'running' + WAITING = 'waiting' + FINISHED = 'finished' + EXCEPTED = 'excepted' + KILLED = 'killed' + + +@runtime_checkable +class Savable(Protocol): + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... @final -@auto_persist('args', 'kwargs', 'in_state') -class Created(state_machine.State, persistence.Savable): - LABEL = ProcessState.CREATED - ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} +@auto_persist('args', 'kwargs') +class Created(persistence.Savable): + LABEL: ClassVar = ProcessState.CREATED + ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' - is_terminal = False + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: assert run_fn is not None @@ -161,7 +180,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.run_fn = run_fn self.args = args self.kwargs = kwargs - self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -173,24 +191,21 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - async def execute(self) -> state_machine.State: - return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) - - def enter(self) -> None: - self.in_state = True + def execute(self) -> st.State: + return st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.run_fn, *self.args, **self.kwargs + ) - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... @final -@auto_persist('args', 'kwargs', 'in_state') -class Running(state_machine.State, persistence.Savable): - LABEL = ProcessState.RUNNING - ALLOWED = { +@auto_persist('args', 'kwargs') +class Running(persistence.Savable): + LABEL: ClassVar = ProcessState.RUNNING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, @@ -206,7 +221,7 @@ class Running(state_machine.State, persistence.Savable): _running: bool = False _run_handle = None - is_terminal = False + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: assert run_fn is not None @@ -215,7 +230,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.args = args self.kwargs = kwargs self._run_handle = None - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -234,7 +248,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> state_machine.State: + def execute(self) -> st.State: if self._command is not None: command = self._command else: @@ -248,8 +262,10 @@ async def execute(self) -> state_machine.State: # Let this bubble up to the caller raise except Exception: - excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(state_machine.State, excepted) + _, exception, traceback = sys.exc_info() + # excepted = state_cls(exception=exception, traceback=traceback) + excepted = Excepted(exception=exception, traceback=traceback) + return excepted else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -258,42 +274,52 @@ async def execute(self) -> state_machine.State: # Got passed a basic return type result = Stop(result, True) - command = result + command = cast(Stop, result) next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State: if isinstance(command, Kill): - state = self.process.create_state(ProcessState.KILLED, command.msg) + state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful) + state = st.create_state( + self.process, ProcessState.FINISHED, result=command.result, successful=command.successful + ) elif isinstance(command, Wait): - state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = st.create_state( + self.process, + ProcessState.WAITING, + process=self.process, + done_callback=command.continue_fn, + msg=command.msg, + data=command.data, + ) elif isinstance(command, Continue): - state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = st.create_state( + self.process, + ProcessState.RUNNING, + process=self.process, + run_fn=command.continue_fn, + *command.args, + **command.kwargs, + ) else: raise ValueError('Unrecognised command') - return cast(state_machine.State, state) # casting from base.State to process.State - - def enter(self) -> None: - self.in_state = True + return state - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... -@final -@auto_persist('msg', 'data', 'in_state') -class Waiting(state_machine.State, persistence.Savable): - LABEL = ProcessState.WAITING - ALLOWED = { +@auto_persist('msg', 'data') +class Waiting(persistence.Savable): + LABEL: ClassVar = ProcessState.WAITING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, @@ -305,7 +331,7 @@ class Waiting(state_machine.State, persistence.Savable): _interruption = None - is_terminal = False + is_terminal: ClassVar[bool] = False def __str__(self) -> str: state_info = super().__str__() @@ -325,7 +351,6 @@ def __init__( self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -347,7 +372,7 @@ def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> state_machine.State: # type: ignore + async def execute(self) -> st.State: try: result = await self._waiting_future except Interruption: @@ -358,11 +383,15 @@ async def execute(self) -> state_machine.State: # type: ignore raise if result == NULL: - next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.done_callback + ) else: - next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result + ) - return cast(state_machine.State, next_state) # casting from base.State to process.State + return next_state def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -372,47 +401,39 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('in_state') -class Excepted(state_machine.State, persistence.Savable): +@final +class Excepted(persistence.Savable): """ - Excepted state, can optionally provide exception and trace_back + Excepted state, can optionally provide exception and traceback :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - LABEL = ProcessState.EXCEPTED + LABEL: ClassVar = ProcessState.EXCEPTED + ALLOWED: ClassVar[set[str]] = set() EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' - is_terminal = True + is_terminal: ClassVar = True def __init__( self, - process: 'Process', exception: Optional[BaseException], - trace_back: Optional[TracebackType] = None, + traceback: Optional[TracebackType] = None, ): """ - :param process: The associated process :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - self.process = process self.exception = exception - self.traceback = trace_back - self.in_state = False + self.traceback = traceback def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -426,7 +447,6 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -449,50 +469,40 @@ def get_exc_info( self.traceback, ) - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('result', 'successful', 'in_state') -class Finished(state_machine.State, persistence.Savable): +@final +@auto_persist('result', 'successful') +class Finished(persistence.Savable): """State for process is finished. :param result: The result of process :param successful: Boolean for the exit code is ``0`` the process is successful. """ - LABEL = ProcessState.FINISHED + LABEL: ClassVar = ProcessState.FINISHED + ALLOWED: ClassVar[set[str]] = set() - is_terminal = True + is_terminal: ClassVar[bool] = True - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - self.process = process + def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful - self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('msg', 'in_state') -class Killed(state_machine.State, persistence.Savable): +@final +@auto_persist('msg') +class Killed(persistence.Savable): """ Represents a state where a process has been killed. @@ -502,30 +512,23 @@ class Killed(state_machine.State, persistence.Savable): :param msg: An optional message explaining the reason for the process termination. """ - LABEL = ProcessState.KILLED + LABEL: ClassVar = ProcessState.KILLED + ALLOWED: ClassVar[set[str]] = set() - is_terminal = True + is_terminal: ClassVar[bool] = True - def __init__(self, process: 'Process', msg: Optional[MessageType]): + def __init__(self, msg: Optional[MessageType]): """ - :param process: The associated process :param msg: Optional kill message """ - self.process = process self.msg = msg def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process - - def enter(self) -> None: - self.in_state = True - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... # endregion diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 74808291..bae08dd4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The main Process module""" +from __future__ import annotations + import abc import asyncio import contextlib @@ -58,6 +60,7 @@ StateMachine, StateMachineError, TransitionFailed, + create_state, event, ) from .base.utils import call_with_super_check, super_check @@ -194,7 +197,7 @@ def get_states(cls) -> Sequence[Type[state_machine.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: + def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -633,7 +636,9 @@ def save_instance_state( """ super().save_instance_state(out_state, save_context) - out_state['_state'] = self._state.save() + # FIXME: the combined ProcessState protocol should cover the case + if isinstance(self._state, process_states.Savable): + out_state['_state'] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -876,7 +881,7 @@ def on_finish(self, result: Any, successful: bool) -> None: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] - finished_state = state_cls(self, result=result, successful=False) + finished_state = state_cls(result=result, successful=False) raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -1074,8 +1079,8 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] - new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: @@ -1148,10 +1153,11 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: state_machine.State) -> Any: try: - state_class = self.get_states_map()[process_states.ProcessState.KILLED] - new_state = self._create_state_instance(state_class, msg=exception.msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=exception.msg) self.transition_to(new_state) return True + # FIXME: if try block except, will hit deadlock in event loop + # need to know how to debug it, and where to set a timeout. finally: self._killing = None @@ -1196,14 +1202,14 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure - :param trace_back: Optional exception traceback + :param traceback: Optional exception traceback """ - state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] - new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback) self.transition_to(new_state) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: @@ -1223,7 +1229,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] # Already killing return self._killing - if self._stepping: + if self._stepping and isinstance(self._state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg) @@ -1232,8 +1238,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - state_class = self.get_states_map()[process_states.ProcessState.KILLED] - new_state = self._create_state_instance(state_class, msg=msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True @@ -1251,10 +1256,7 @@ def create_initial_state(self) -> state_machine.State: :return: A Created state """ - return cast( - state_machine.State, - self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), - ) + return self.get_state_class(process_states.ProcessState.CREATED)(self, self.run) def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ @@ -1325,7 +1327,9 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + next_state = create_state( + self, process_states.ProcessState.EXCEPTED, exception=sys.exc_info()[1], traceback=sys.exc_info()[2] + ) self._set_interrupt_action(None) if self._interrupt_action: diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index eefd57f1..865a5b61 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -11,7 +11,6 @@ Any, Callable, Dict, - Hashable, List, Mapping, MutableSequence, @@ -71,6 +70,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: return self._outline +# FIXME: better use composition here @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): """Overwrite the waiting state""" @@ -80,11 +80,11 @@ def __init__( process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, + data: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: - super().__init__(process, done_callback, msg, awaiting) + super().__init__(process, done_callback, msg, data) self._awaiting: Dict[asyncio.Future, str] = {} - for awaitable, key in (awaiting or {}).items(): + for awaitable, key in (data or {}).items(): resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key @@ -124,7 +124,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: + def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index b6100614..6a61fe00 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -17,7 +17,7 @@ STOPPED = 'Stopped' -class Playing(state_machine.State): +class Playing: LABEL = PLAYING ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} @@ -56,7 +56,7 @@ def exit(self) -> None: self.in_state = False -class Paused(state_machine.State): +class Paused: LABEL = PAUSED ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} @@ -65,7 +65,6 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' - super().__init__(player) self._player = player self.playing_state = playing_state @@ -74,9 +73,9 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) + self._player.transition_to(Playing(player=self.state_machine, track=track)) else: - self.state_machine.transition_to(self.playing_state) + self._player.transition_to(self.playing_state) def enter(self) -> None: self.in_state = True @@ -88,7 +87,7 @@ def exit(self) -> None: self.in_state = False -class Stopped(state_machine.State): +class Stopped: LABEL = STOPPED ALLOWED = { PLAYING, @@ -98,13 +97,13 @@ class Stopped(state_machine.State): is_terminal = False def __init__(self, player): - self.state_machine = player + self._player = player def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing(self.state_machine, track=track)) + self._player.transition_to(Playing(self._player, track=track)) def enter(self) -> None: self.in_state = True