diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 52316857..bc36bb54 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -20,7 +20,6 @@ Set, Type, Union, - cast, ) from plumpy.futures import Future @@ -44,7 +43,7 @@ class StateEntryFailed(Exception): Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg + def __init__(self, state: type["State"] = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg super().__init__("failed to enter state") self.state = state self.args = args @@ -330,12 +329,12 @@ def on_terminated(self) -> None: """Called when a terminal state is entered""" def transition_to( - self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any + self, new_state: Union[State, Type[State]], *args: Any, **kwargs: Any ) -> None: """Transite to the new state. The new target state will be create lazily when the state - is not yet instantiated, which will happened for states not in the expect path such as + is not yet instantiated, which will happened for states not in the expect path such as pause and kill. """ assert ( @@ -403,6 +402,10 @@ def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic + # because the label is defined after the state and required to be know before calling this function. + # This method should be replaced by `_create_state_instance`. + # aiida-core using this method for its Waiting state override. try: return self.get_states_map()[state_label](self, *args, **kwargs) # pylint: disable=unsubscriptable-object except KeyError: @@ -436,15 +439,9 @@ def _enter_next_state(self, next_state: State) -> None: self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) def _create_state_instance( - self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any + self, state_cls: type[State], *args: Any, **kwargs: Any ) -> State: - # build from state class - if inspect.isclass(state) and issubclass(state, State): - state_cls = state - else: - try: - state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object - except KeyError: - raise ValueError(f"{state} is not a valid state") + if state_cls.LABEL not in self.get_states_map(): + raise ValueError(f"{state_cls.LABEL} is not a valid state") return state_cls(self, *args, **kwargs) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index e1c3e5ee..bf8dae8c 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -833,7 +833,7 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False) + raise StateEntryFailed(process_states.Finished, result, False) self.future().set_result(self.outputs) @@ -1017,7 +1017,7 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace) + self.transition_to(process_states.Excepted, exception, trace) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1082,7 +1082,7 @@ def do_kill(_next_state: process_states.State) -> Any: try: # Ignore the next state # __import__('ipdb').set_trace() - self.transition_to(process_states.ProcessState.KILLED, exception) + self.transition_to(process_states.Killed, exception) return True finally: self._killing = None @@ -1134,7 +1134,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) + self.transition_to(process_states.Excepted, exception, trace_back) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1162,7 +1162,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.ProcessState.KILLED, msg) + self.transition_to(process_states.Killed, msg) return True @property