Skip to content

Commit

Permalink
Simplify _create_state_instance so it only need to do real create
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 1, 2024
1 parent ea1d655 commit 28db202
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 45 deletions.
120 changes: 77 additions & 43 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
import os
import sys
from types import TracebackType
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast
from typing import (
Any,
Callable,
Dict,
Hashable,
Iterable,
List,
Optional,
Sequence,
Set,
Type,
Union,
cast,
)

from plumpy.futures import Future

Expand Down Expand Up @@ -60,27 +73,27 @@ def __init__(
super().__init__(self._format_msg())

def _format_msg(self) -> str:
msg = [f'{self.initial_state} -> {self.final_state}']
msg = [f"{self.initial_state} -> {self.final_state}"]
if self.traceback_str is not None:
msg.append(self.traceback_str)
return '\n'.join(msg)
return "\n".join(msg)


def event(
from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*',
to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*',
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""A decorator to check for correct transitions, raising ``EventError`` on invalid transitions."""
if from_states != '*':
if from_states != "*":
if inspect.isclass(from_states):
from_states = (from_states,)
if not all(issubclass(state, State) for state in from_states): # type: ignore
raise TypeError(f'from_states: {from_states}')
if to_states != '*':
raise TypeError(f"from_states: {from_states}")
if to_states != "*":
if inspect.isclass(to_states):
to_states = (to_states,)
if not all(issubclass(state, State) for state in to_states): # type: ignore
raise TypeError(f'to_states: {to_states}')
raise TypeError(f"to_states: {to_states}")

def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
evt_label = wrapped.__name__
Expand All @@ -89,14 +102,20 @@ def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
def transition(self: Any, *a: Any, **kw: Any) -> Any:
initial = self._state

if from_states != '*' and not any(isinstance(self._state, state) for state in from_states): # type: ignore
raise EventError(evt_label, f'Event {evt_label} invalid in state {initial.LABEL}')
if from_states != "*" and not any(
isinstance(self._state, state) for state in from_states
): # type: ignore
raise EventError(
evt_label, f"Event {evt_label} invalid in state {initial.LABEL}"
)

result = wrapped(self, *a, **kw)
if not (result is False or isinstance(result, Future)):
if to_states != '*' and not any(isinstance(self._state, state) for state in to_states): # type: ignore
if to_states != "*" and not any(
isinstance(self._state, state) for state in to_states
): # type: ignore
if self._state == initial:
raise EventError(evt_label, 'Machine did not transition')
raise EventError(evt_label, "Machine did not transition")

raise EventError(
evt_label,
Expand Down Expand Up @@ -142,7 +161,7 @@ def label(self) -> LABEL_TYPE:
def enter(self) -> None:
"""Entering the state"""

def execute(self) -> Optional['State']:
def execute(self) -> Optional["State"]:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
Expand All @@ -152,9 +171,9 @@ def execute(self) -> Optional['State']:
def exit(self) -> None:
"""Exiting the state"""
if self.is_terminal():
raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
raise InvalidStateError(f"Cannot exit a terminal state {self.LABEL}")

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> "State":
return self.state_machine.create_state(state_label, *args, **kwargs)

def do_enter(self) -> None:
Expand Down Expand Up @@ -211,7 +230,7 @@ def get_states(cls) -> Sequence[Type[State]]:
if cls.STATES is not None:
return cls.STATES

raise RuntimeError('States not defined')
raise RuntimeError("States not defined")

@classmethod
def initial_state_label(cls) -> LABEL_TYPE:
Expand All @@ -229,7 +248,7 @@ def get_state_class(cls, label: LABEL_TYPE) -> Type[State]:
def __ensure_built(cls) -> None:
try:
# Check if it's already been built (and therefore sealed)
if cls.__getattribute__(cls, 'sealed'):
if cls.__getattribute__(cls, "sealed"):
return
except AttributeError:
pass
Expand All @@ -253,7 +272,9 @@ def __init__(self) -> None:
self.__ensure_built()
self._state: Optional[State] = None
self._exception_handler = None # Note this appears to never be used
self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG'))))
self.set_debug(
(not sys.flags.ignore_environment and bool(os.environ.get("PYTHONSMDEBUG")))
)
self._transitioning = False
self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {}

Expand All @@ -262,7 +283,7 @@ def init(self) -> None:
"""Called after entering initial state in `__call__` method of `StateMachineMeta`"""

def __str__(self) -> str:
return f'<{self.__class__.__name__}> ({self.state})'
return f"<{self.__class__.__name__}> ({self.state})"

def create_initial_state(self) -> State:
return self.get_state_class(self.initial_state_label())(self)
Expand All @@ -273,7 +294,9 @@ def state(self) -> Optional[LABEL_TYPE]:
return None
return self._state.LABEL

def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None:
def add_state_event_callback(
self, hook: Hashable, callback: EVENT_CALLBACK_TYPE
) -> None:
"""
Add a callback to be called on a particular state event hook.
The callback should have form fn(state_machine, hook, state)
Expand All @@ -283,8 +306,10 @@ def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE
"""
self._event_callbacks.setdefault(hook, []).append(callback)

def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None:
if getattr(self, '_closed', False):
def remove_state_event_callback(
self, hook: Hashable, callback: EVENT_CALLBACK_TYPE
) -> None:
if getattr(self, "_closed", False):
# if the process is closed, then all callbacks have already been removed
return None
try:
Expand All @@ -308,8 +333,10 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
try:
self._transitioning = True

# Make sure we have a state instance
new_state = self._create_state_instance(new_state, *args, **kwargs)
if not isinstance(new_state, State):
# Make sure we have a state instance
new_state = self._create_state_instance(new_state, *args, **kwargs)

label = new_state.LABEL

# If the previous transition failed, do not try to exit it but go straight to next state
Expand All @@ -320,7 +347,10 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
self._enter_next_state(new_state)
except StateEntryFailed as exception:
# Make sure we have a state instance
new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs)
if not isinstance(exception.state, State):
new_state = self._create_state_instance(
exception.state, *exception.args, **exception.kwargs
)
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)
Expand All @@ -338,7 +368,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
self._transitioning = False

def transition_failed(
self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType
self,
initial_state: Hashable,
final_state: Hashable,
exception: Exception,
trace: TracebackType,
) -> None:
"""Called when a state transitions fails.
Expand All @@ -358,7 +392,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
except KeyError:
raise ValueError(f'{state_label} is not a valid state')
raise ValueError(f"{state_label} is not a valid state")

def _exit_current_state(self, next_state: State) -> None:
"""Exit the given state"""
Expand All @@ -367,11 +401,15 @@ def _exit_current_state(self, next_state: State) -> None:
# in which case check the new state is the initial state
if self._state is None:
if next_state.label != self.initial_state_label():
raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state")
raise RuntimeError(
f"Cannot enter state '{next_state}' as the initial state"
)
return # Nothing to exit

if next_state.LABEL not in self._state.ALLOWED:
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
raise RuntimeError(
f"Cannot transition from {self._state.LABEL} to {next_state.label}"
)
self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
self._state.do_exit()

Expand All @@ -383,20 +421,16 @@ def _enter_next_state(self, next_state: State) -> None:
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State:
if isinstance(state, State):
# It's already a state instance
return state

# OK, have to create it
state_cls = self._ensure_state_class(state)
return state_cls(self, *args, **kwargs)

def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]:
def _create_state_instance(
self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any
) -> State:
# build from state class
if inspect.isclass(state) and issubclass(state, State):
return state
state_cls = state
else:
try:
state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object
except KeyError:
raise ValueError(f"{state} is not a valid state")

try:
return self.get_states_map()[cast(Hashable, state)]
except KeyError:
raise ValueError(f'{state} is not a valid state')
return state_cls(self, *args, **kwargs)
5 changes: 3 additions & 2 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,8 @@ def on_kill(self, msg: Optional[MessageType]) -> None:
if msg is None:
msg_txt = ''
else:
msg_txt = msg[MESSAGE_KEY] or ''
# msg_txt = msg[MESSAGE_KEY] or ''
msg_txt = msg

self.set_status(msg_txt)
self.future().set_exception(exceptions.KilledError(msg_txt))
Expand Down Expand Up @@ -1079,7 +1080,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu
def do_kill(_next_state: process_states.State) -> Any:
try:
# Ignore the next state
__import__('ipdb').set_trace()
# __import__('ipdb').set_trace()
self.transition_to(process_states.ProcessState.KILLED, exception)
return True
finally:
Expand Down

0 comments on commit 28db202

Please sign in to comment.