diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 556760c0..3397c40d 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -20,7 +20,6 @@ Set, Type, Union, - cast, ) from plumpy.futures import Future @@ -325,8 +324,18 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: - assert not self._transitioning, 'Cannot call transition_to when already transitioning state' + def transition_to( + self, new_state: Union[State, Type[State]], *args: Any, **kwargs: Any + ) -> None: + """Transite to the new state. + + The new target state will be create lazily when the state + is not yet instantiated, which will happened for states not in the expect path such as + pause and kill. + """ + assert ( + not self._transitioning + ), "Cannot call transition_to when already transitioning state" initial_state_label = self._state.LABEL if self._state is not None else None label = None @@ -389,6 +398,10 @@ def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic + # because the label is defined after the state and required to be know before calling this function. + # This method should be replaced by `_create_state_instance`. + # aiida-core using this method for its Waiting state override. try: return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: @@ -422,15 +435,9 @@ def _enter_next_state(self, next_state: State) -> None: self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) def _create_state_instance( - self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any + self, state_cls: type[State], *args: Any, **kwargs: Any ) -> State: - # build from state class - if inspect.isclass(state) and issubclass(state, State): - state_cls = state - else: - try: - state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object - except KeyError: - raise ValueError(f"{state} is not a valid state") + if state_cls.LABEL not in self.get_states_map(): + raise ValueError(f"{state_cls.LABEL} is not a valid state") return state_cls(self, *args, **kwargs) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 6eef55af..a4b3b017 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -831,7 +831,7 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False) + raise StateEntryFailed(process_states.Finished, result, False) self.future().set_result(self.outputs) @@ -1016,7 +1016,7 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace) + self.transition_to(process_states.Excepted, exception, trace) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1081,7 +1081,7 @@ def do_kill(_next_state: process_states.State) -> Any: try: # Ignore the next state # __import__('ipdb').set_trace() - self.transition_to(process_states.ProcessState.KILLED, exception) + self.transition_to(process_states.Killed, exception) return True finally: self._killing = None @@ -1133,7 +1133,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) + self.transition_to(process_states.Excepted, exception, trace_back) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1161,7 +1161,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.ProcessState.KILLED, msg) + self.transition_to(process_states.Killed, msg) return True @property