diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index d99d0705..556760c0 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -8,7 +8,20 @@ import os import sys from types import TracebackType -from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Optional, + Sequence, + Set, + Type, + Union, + cast, +) from plumpy.futures import Future @@ -60,10 +73,10 @@ def __init__( super().__init__(self._format_msg()) def _format_msg(self) -> str: - msg = [f'{self.initial_state} -> {self.final_state}'] + msg = [f"{self.initial_state} -> {self.final_state}"] if self.traceback_str is not None: msg.append(self.traceback_str) - return '\n'.join(msg) + return "\n".join(msg) def event( @@ -71,16 +84,16 @@ def event( to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" - if from_states != '*': + if from_states != "*": if inspect.isclass(from_states): from_states = (from_states,) if not all(issubclass(state, State) for state in from_states): # type: ignore - raise TypeError(f'from_states: {from_states}') - if to_states != '*': + raise TypeError(f"from_states: {from_states}") + if to_states != "*": if inspect.isclass(to_states): to_states = (to_states,) if not all(issubclass(state, State) for state in to_states): # type: ignore - raise TypeError(f'to_states: {to_states}') + raise TypeError(f"to_states: {to_states}") def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: evt_label = wrapped.__name__ @@ -89,14 +102,20 @@ def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: def transition(self: Any, *a: Any, **kw: Any) -> Any: initial = self._state - if from_states != '*' and not any(isinstance(self._state, state) for state in from_states): # type: ignore - raise EventError(evt_label, f'Event {evt_label} invalid in state {initial.LABEL}') + if from_states != "*" and not any( + isinstance(self._state, state) for state in from_states + ): # type: ignore + raise EventError( + evt_label, f"Event {evt_label} invalid in state {initial.LABEL}" + ) result = wrapped(self, *a, **kw) if not (result is False or isinstance(result, Future)): - if to_states != '*' and not any(isinstance(self._state, state) for state in to_states): # type: ignore + if to_states != "*" and not any( + isinstance(self._state, state) for state in to_states + ): # type: ignore if self._state == initial: - raise EventError(evt_label, 'Machine did not transition') + raise EventError(evt_label, "Machine did not transition") raise EventError( evt_label, @@ -142,7 +161,7 @@ def label(self) -> LABEL_TYPE: def enter(self) -> None: """Entering the state""" - def execute(self) -> Optional['State']: + def execute(self) -> Optional["State"]: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. @@ -152,9 +171,9 @@ def execute(self) -> Optional['State']: def exit(self) -> None: """Exiting the state""" if self.is_terminal(): - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + raise InvalidStateError(f"Cannot exit a terminal state {self.LABEL}") - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': + def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> "State": return self.state_machine.create_state(state_label, *args, **kwargs) def do_enter(self) -> None: @@ -211,7 +230,7 @@ def get_states(cls) -> Sequence[Type[State]]: if cls.STATES is not None: return cls.STATES - raise RuntimeError('States not defined') + raise RuntimeError("States not defined") @classmethod def initial_state_label(cls) -> LABEL_TYPE: @@ -229,7 +248,7 @@ def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: def __ensure_built(cls) -> None: try: # Check if it's already been built (and therefore sealed) - if cls.__getattribute__(cls, 'sealed'): + if cls.__getattribute__(cls, "sealed"): return except AttributeError: pass @@ -253,7 +272,9 @@ def __init__(self) -> None: self.__ensure_built() self._state: Optional[State] = None self._exception_handler = None # Note this appears to never be used - self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG')))) + self.set_debug( + (not sys.flags.ignore_environment and bool(os.environ.get("PYTHONSMDEBUG"))) + ) self._transitioning = False self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {} @@ -262,7 +283,7 @@ def init(self) -> None: """Called after entering initial state in `__call__` method of `StateMachineMeta`""" def __str__(self) -> str: - return f'<{self.__class__.__name__}> ({self.state})' + return f"<{self.__class__.__name__}> ({self.state})" def create_initial_state(self) -> State: return self.get_state_class(self.initial_state_label())(self) @@ -273,7 +294,9 @@ def state(self) -> Optional[LABEL_TYPE]: return None return self._state.LABEL - def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: + def add_state_event_callback( + self, hook: Hashable, callback: EVENT_CALLBACK_TYPE + ) -> None: """ Add a callback to be called on a particular state event hook. The callback should have form fn(state_machine, hook, state) @@ -283,8 +306,10 @@ def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE """ self._event_callbacks.setdefault(hook, []).append(callback) - def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: - if getattr(self, '_closed', False): + def remove_state_event_callback( + self, hook: Hashable, callback: EVENT_CALLBACK_TYPE + ) -> None: + if getattr(self, "_closed", False): # if the process is closed, then all callbacks have already been removed return None try: @@ -308,8 +333,10 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A try: self._transitioning = True - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, *args, **kwargs) + if not isinstance(new_state, State): + # Make sure we have a state instance + new_state = self._create_state_instance(new_state, *args, **kwargs) + label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -320,7 +347,10 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._enter_next_state(new_state) except StateEntryFailed as exception: # Make sure we have a state instance - new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs) + if not isinstance(exception.state, State): + new_state = self._create_state_instance( + exception.state, *exception.args, **exception.kwargs + ) label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) @@ -338,7 +368,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._transitioning = False def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: """Called when a state transitions fails. @@ -358,7 +392,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat try: return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: - raise ValueError(f'{state_label} is not a valid state') + raise ValueError(f"{state_label} is not a valid state") def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" @@ -367,11 +401,15 @@ def _exit_current_state(self, next_state: State) -> None: # in which case check the new state is the initial state if self._state is None: if next_state.label != self.initial_state_label(): - raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") + raise RuntimeError( + f"Cannot enter state '{next_state}' as the initial state" + ) return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError( + f"Cannot transition from {self._state.LABEL} to {next_state.label}" + ) self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.do_exit() @@ -383,20 +421,16 @@ def _enter_next_state(self, next_state: State) -> None: self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State: - if isinstance(state, State): - # It's already a state instance - return state - - # OK, have to create it - state_cls = self._ensure_state_class(state) - return state_cls(self, *args, **kwargs) - - def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]: + def _create_state_instance( + self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any + ) -> State: + # build from state class if inspect.isclass(state) and issubclass(state, State): - return state + state_cls = state + else: + try: + state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object + except KeyError: + raise ValueError(f"{state} is not a valid state") - try: - return self.get_states_map()[cast(Hashable, state)] - except KeyError: - raise ValueError(f'{state} is not a valid state') + return state_cls(self, *args, **kwargs) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 25a8f78e..6eef55af 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -865,7 +865,8 @@ def on_kill(self, msg: Optional[MessageType]) -> None: if msg is None: msg_txt = '' else: - msg_txt = msg[MESSAGE_KEY] or '' + # msg_txt = msg[MESSAGE_KEY] or '' + msg_txt = msg self.set_status(msg_txt) self.future().set_exception(exceptions.KilledError(msg_txt)) @@ -1079,7 +1080,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: # Ignore the next state - __import__('ipdb').set_trace() + # __import__('ipdb').set_trace() self.transition_to(process_states.ProcessState.KILLED, exception) return True finally: