From 989f995764a8d88c9fabbf345e70e2ee146d04c1 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sun, 1 Dec 2024 00:10:00 +0100 Subject: [PATCH] Simplify _create_state_instance so it only need to do real create --- src/plumpy/base/state_machine.py | 174 +++++++++++++++++++------------ src/plumpy/processes.py | 5 +- 2 files changed, 112 insertions(+), 67 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index b62825e1..52316857 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """The state machine for processes""" + import enum import functools import inspect @@ -7,18 +8,31 @@ import os import sys from types import TracebackType -from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Optional, + Sequence, + Set, + Type, + Union, + cast, +) from plumpy.futures import Future from .utils import call_with_super_check, super_check -__all__ = ['StateMachine', 'StateMachineMeta', 'event', 'TransitionFailed'] +__all__ = ["StateMachine", "StateMachineMeta", "event", "TransitionFailed"] _LOGGER = logging.getLogger(__name__) LABEL_TYPE = Union[None, enum.Enum, str] # pylint: disable=invalid-name -EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] # pylint: disable=invalid-name +EVENT_CALLBACK_TYPE = Callable[["StateMachine", Hashable, Optional["State"]], None] # pylint: disable=invalid-name class StateMachineError(Exception): @@ -31,7 +45,7 @@ class StateEntryFailed(Exception): """ def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg - super().__init__('failed to enter state') + super().__init__("failed to enter state") self.state = state self.args = args self.kwargs = kwargs @@ -42,7 +56,6 @@ class InvalidStateError(Exception): class EventError(StateMachineError): - def __init__(self, evt: str, msg: str): super().__init__(msg) self.event = evt @@ -53,9 +66,9 @@ class TransitionFailed(Exception): def __init__( self, - initial_state: 'State', - final_state: Optional['State'] = None, - traceback_str: Optional[str] = None + initial_state: "State", + final_state: Optional["State"] = None, + traceback_str: Optional[str] = None, ) -> None: self.initial_state = initial_state self.final_state = final_state @@ -63,27 +76,27 @@ def __init__( super().__init__(self._format_msg()) def _format_msg(self) -> str: - msg = [f'{self.initial_state} -> {self.final_state}'] + msg = [f"{self.initial_state} -> {self.final_state}"] if self.traceback_str is not None: msg.append(self.traceback_str) - return '\n'.join(msg) + return "\n".join(msg) def event( - from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', - to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*' + from_states: Union[str, Type["State"], Iterable[Type["State"]]] = "*", + to_states: Union[str, Type["State"], Iterable[Type["State"]]] = "*", ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" - if from_states != '*': + if from_states != "*": if inspect.isclass(from_states): from_states = (from_states,) if not all(issubclass(state, State) for state in from_states): # type: ignore - raise TypeError(f'from_states: {from_states}') - if to_states != '*': + raise TypeError(f"from_states: {from_states}") + if to_states != "*": if inspect.isclass(to_states): to_states = (to_states,) if not all(issubclass(state, State) for state in to_states): # type: ignore - raise TypeError(f'to_states: {to_states}') + raise TypeError(f"to_states: {to_states}") def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: evt_label = wrapped.__name__ @@ -92,18 +105,25 @@ def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: def transition(self: Any, *a: Any, **kw: Any) -> Any: initial = self._state - if from_states != '*' and not any(isinstance(self._state, state) for state in from_states): # type: ignore - raise EventError(evt_label, f'Event {evt_label} invalid in state {initial.LABEL}') + if from_states != "*" and not any( + isinstance(self._state, state) for state in from_states + ): # type: ignore + raise EventError( + evt_label, f"Event {evt_label} invalid in state {initial.LABEL}" + ) result = wrapped(self, *a, **kw) if not (result is False or isinstance(result, Future)): - if to_states != '*' and not any(isinstance(self._state, state) for state in to_states): # type: ignore + if to_states != "*" and not any( + isinstance(self._state, state) for state in to_states + ): # type: ignore if self._state == initial: - raise EventError(evt_label, 'Machine did not transition') + raise EventError(evt_label, "Machine did not transition") raise EventError( - evt_label, 'Event produced invalid state transition from ' - f'{initial.LABEL} to {self._state.LABEL}' + evt_label, + "Event produced invalid state transition from " + f"{initial.LABEL} to {self._state.LABEL}", ) return result @@ -126,7 +146,7 @@ class State: def is_terminal(cls) -> bool: return not cls.ALLOWED - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): # pylint: disable=unused-argument + def __init__(self, state_machine: "StateMachine", *args: Any, **kwargs: Any): # pylint: disable=unused-argument """ :param state_machine: The process this state belongs to """ @@ -138,14 +158,14 @@ def __str__(self) -> str: @property def label(self) -> LABEL_TYPE: - """ Convenience property to get the state label """ + """Convenience property to get the state label""" return self.LABEL @super_check def enter(self) -> None: - """ Entering the state """ + """Entering the state""" - def execute(self) -> Optional['State']: + def execute(self) -> Optional["State"]: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. @@ -153,11 +173,11 @@ def execute(self) -> Optional['State']: @super_check def exit(self) -> None: - """ Exiting the state """ + """Exiting the state""" if self.is_terminal(): - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + raise InvalidStateError(f"Cannot exit a terminal state {self.LABEL}") - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': + def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> "State": return self.state_machine.create_state(state_label, *args, **kwargs) def do_enter(self) -> None: @@ -175,14 +195,14 @@ class StateEventHook(enum.Enum): procedure. The callback will be passed a state instance whose meaning will differ depending on the hook as commented below. """ + ENTERING_STATE: int = 0 # State passed will be the state that is being entered ENTERED_STATE: int = 1 # State passed will be the last state that we entered from EXITING_STATE: int = 2 # State passed will be the next state that will be entered (or None for terminal) class StateMachineMeta(type): - - def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': + def __call__(cls, *args: Any, **kwargs: Any) -> "StateMachine": """ Create the state machine and enter the initial state. @@ -214,7 +234,7 @@ def get_states(cls) -> Sequence[Type[State]]: if cls.STATES is not None: return cls.STATES - raise RuntimeError('States not defined') + raise RuntimeError("States not defined") @classmethod def initial_state_label(cls) -> LABEL_TYPE: @@ -232,7 +252,7 @@ def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: def __ensure_built(cls) -> None: try: # Check if it's already been built (and therefore sealed) - if cls.__getattribute__(cls, 'sealed'): + if cls.__getattribute__(cls, "sealed"): return except AttributeError: pass @@ -256,7 +276,9 @@ def __init__(self) -> None: self.__ensure_built() self._state: Optional[State] = None self._exception_handler = None # Note this appears to never be used - self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG')))) + self.set_debug( + (not sys.flags.ignore_environment and bool(os.environ.get("PYTHONSMDEBUG"))) + ) self._transitioning = False self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {} @@ -265,7 +287,7 @@ def init(self) -> None: """Called after entering initial state in `__call__` method of `StateMachineMeta`""" def __str__(self) -> str: - return f'<{self.__class__.__name__}> ({self.state})' + return f"<{self.__class__.__name__}> ({self.state})" def create_initial_state(self) -> State: return self.get_state_class(self.initial_state_label())(self) @@ -276,7 +298,9 @@ def state(self) -> Optional[LABEL_TYPE]: return None return self._state.LABEL - def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: + def add_state_event_callback( + self, hook: Hashable, callback: EVENT_CALLBACK_TYPE + ) -> None: """ Add a callback to be called on a particular state event hook. The callback should have form fn(state_machine, hook, state) @@ -286,8 +310,10 @@ def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE """ self._event_callbacks.setdefault(hook, []).append(callback) - def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: - if getattr(self, '_closed', False): + def remove_state_event_callback( + self, hook: Hashable, callback: EVENT_CALLBACK_TYPE + ) -> None: + if getattr(self, "_closed", False): # if the process is closed, then all callbacks have already been removed return None try: @@ -301,19 +327,30 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: @super_check def on_terminated(self) -> None: - """ Called when a terminal state is entered """ + """Called when a terminal state is entered""" - def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: - assert not self._transitioning, \ - 'Cannot call transition_to when already transitioning state' + def transition_to( + self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any + ) -> None: + """Transite to the new state. + + The new target state will be create lazily when the state + is not yet instantiated, which will happened for states not in the expect path such as + pause and kill. + """ + assert ( + not self._transitioning + ), "Cannot call transition_to when already transitioning state" initial_state_label = self._state.LABEL if self._state is not None else None label = None try: self._transitioning = True - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, *args, **kwargs) + if not isinstance(new_state, State): + # Make sure we have a state instance + new_state = self._create_state_instance(new_state, *args, **kwargs) + label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -324,7 +361,10 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._enter_next_state(new_state) except StateEntryFailed as exception: # Make sure we have a state instance - new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs) + if not isinstance(exception.state, State): + new_state = self._create_state_instance( + exception.state, *exception.args, **exception.kwargs + ) label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) @@ -342,7 +382,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._transitioning = False def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: """Called when a state transitions fails. @@ -362,20 +406,24 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat try: return self.get_states_map()[state_label](self, *args, **kwargs) # pylint: disable=unsubscriptable-object except KeyError: - raise ValueError(f'{state_label} is not a valid state') + raise ValueError(f"{state_label} is not a valid state") def _exit_current_state(self, next_state: State) -> None: - """ Exit the given state """ + """Exit the given state""" # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: if next_state.label != self.initial_state_label(): - raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") + raise RuntimeError( + f"Cannot enter state '{next_state}' as the initial state" + ) return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError( + f"Cannot transition from {self._state.LABEL} to {next_state.label}" + ) self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.do_exit() @@ -387,20 +435,16 @@ def _enter_next_state(self, next_state: State) -> None: self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State: - if isinstance(state, State): - # It's already a state instance - return state - - # OK, have to create it - state_cls = self._ensure_state_class(state) - return state_cls(self, *args, **kwargs) - - def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]: + def _create_state_instance( + self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any + ) -> State: + # build from state class if inspect.isclass(state) and issubclass(state, State): - return state + state_cls = state + else: + try: + state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object + except KeyError: + raise ValueError(f"{state} is not a valid state") - try: - return self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object - except KeyError: - raise ValueError(f'{state} is not a valid state') + return state_cls(self, *args, **kwargs) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 7ebe9586..e1c3e5ee 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -867,7 +867,8 @@ def on_kill(self, msg: Optional[MessageType]) -> None: if msg is None: msg_txt = '' else: - msg_txt = msg[MESSAGE_KEY] or '' + # msg_txt = msg[MESSAGE_KEY] or '' + msg_txt = msg self.set_status(msg_txt) self.future().set_exception(exceptions.KilledError(msg_txt)) @@ -1080,7 +1081,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: # Ignore the next state - __import__('ipdb').set_trace() + # __import__('ipdb').set_trace() self.transition_to(process_states.ProcessState.KILLED, exception) return True finally: