diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 2035d4ab..be27e0cd 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -45,7 +45,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None: + def __init__(self, state: State, *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -314,7 +314,7 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None: + def transition_to(self, new_state: State | None, **kwargs: Any) -> None: """Transite to the new state. The new target state will be create lazily when the state is not yet instantiated, @@ -331,11 +331,6 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> label = None try: self._transitioning = True - - if not isinstance(new_state, State): - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, **kwargs) - label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -345,9 +340,7 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> try: self._enter_next_state(new_state) except StateEntryFailed as exception: - # Make sure we have a state instance - if not isinstance(exception.state, State): - new_state = self._create_state_instance(exception.state, **exception.kwargs) + new_state = exception.state label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ef558fa1..5e2f4cbd 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -867,7 +867,9 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.Finished, result=result, successful=False) + state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] + finished_state = state_cls(self, result=result, successful=False) + raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -1064,7 +1066,9 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace) + state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace) + self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1127,7 +1131,9 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: - self.transition_to(process_states.Killed, msg=exception.msg) + state_class = self.get_states_map()[process_states.ProcessState.KILLED] + new_state = self._create_state_instance(state_class, msg=exception.msg) + self.transition_to(new_state) return True finally: self._killing = None @@ -1179,7 +1185,9 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace_back) + state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back) + self.transition_to(new_state) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1207,7 +1215,9 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.Killed, msg=msg) + state_class = self.get_states_map()[process_states.ProcessState.KILLED] + new_state = self._create_state_instance(state_class, msg=msg) + self.transition_to(new_state) return True @property diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 5b4b73d8..3a1621a2 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -57,6 +57,7 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) + self._player = player self.playing_state = playing_state def __str__(self): @@ -64,7 +65,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) else: self.state_machine.transition_to(self.playing_state) @@ -80,7 +81,7 @@ def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(self.state_machine, track=track)) class CdPlayer(state_machine.StateMachine): @@ -107,12 +108,12 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, playing_state=self._state) + self.transition_to(Paused(self, playing_state=self._state)) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) def stop(self): - self.transition_to(Stopped) + self.transition_to(Stopped(self)) class TestStateMachine(unittest.TestCase):