diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 620b0d3a..c36f04ea 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -132,18 +132,28 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): LABEL: ClassVar[LABEL_TYPE] + is_terminal: ClassVar[bool] - async def execute(self) -> State | None: + def enter(self) -> None: ... + + def exit(self) -> None: ... + + +@runtime_checkable +class Interruptable(Protocol): + def interrupt(self, reason: Exception) -> None: ... + + +@runtime_checkable +class Proceedable(Protocol): + + def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ ... - def enter(self) -> None: ... - - def exit(self) -> None: ... - class StateEventHook(enum.Enum): """ diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 88cab660..cc9169c7 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -288,6 +288,7 @@ def exit(self) -> None: self.in_state = False + @final @auto_persist('msg', 'data', 'in_state') class Waiting(state_machine.State, persistence.Savable): @@ -342,7 +343,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.done_callback = None self._waiting_future = futures.Future() - def interrupt(self, reason: Any) -> None: + def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) @@ -448,9 +449,6 @@ def get_exc_info( self.traceback, ) - async def execute(self) -> state_machine.State: # type: ignore - ... - def enter(self) -> None: self.in_state = True @@ -486,9 +484,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def enter(self) -> None: self.in_state = True - async def execute(self) -> state_machine.State: # type: ignore - ... - def exit(self) -> None: if self.is_terminal: raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') @@ -519,9 +514,6 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): self.process = process self.msg = msg - async def execute(self) -> state_machine.State: # type: ignore - ... - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 6bc88bac..08899f9e 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -51,7 +51,15 @@ utils, ) from .base import state_machine -from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event +from .base.state_machine import ( + Interruptable, + Proceedable, + StateEntryFailed, + StateMachine, + StateMachineError, + TransitionFailed, + event, +) from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper from .process_comms import MESSAGE_KEY, KillMessage, MessageType @@ -1088,6 +1096,11 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._pausing if self._stepping: + if not isinstance(self._state, Interruptable): + raise exceptions.InvalidStateError( + f'cannot interrupt {self._state.__class__}, method `interrupt` not implement' + ) + # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.PauseInterruption(msg) @@ -1099,6 +1112,10 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._do_pause(msg) + @staticmethod + def _interrupt(state: Interruptable, reason: Exception) -> None: + state.interrupt(reason) + def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: @@ -1275,6 +1292,9 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused + if not isinstance(self._state, Proceedable): + raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented') + try: self._stepping = True next_state = None