Skip to content

Commit

Permalink
Forming Interruptable and Proceedable protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 2, 2024
1 parent 6daa938 commit af11b1f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 16 deletions.
20 changes: 15 additions & 5 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,28 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
@runtime_checkable
class State(Protocol):
LABEL: ClassVar[LABEL_TYPE]
is_terminal: ClassVar[bool]

async def execute(self) -> State | None:
def enter(self) -> None: ...

def exit(self) -> None: ...


@runtime_checkable
class Interruptable(Protocol):
def interrupt(self, reason: Exception) -> None: ...


@runtime_checkable
class Proceedable(Protocol):

def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...

def enter(self) -> None: ...

def exit(self) -> None: ...


class StateEventHook(enum.Enum):
"""
Expand Down
12 changes: 2 additions & 10 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def exit(self) -> None:

self.in_state = False


@final
@auto_persist('msg', 'data', 'in_state')
class Waiting(state_machine.State, persistence.Savable):
Expand Down Expand Up @@ -342,7 +343,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
self.done_callback = None
self._waiting_future = futures.Future()

def interrupt(self, reason: Any) -> None:
def interrupt(self, reason: Exception) -> None:
# This will cause the future in execute() to raise the exception
self._waiting_future.set_exception(reason)

Expand Down Expand Up @@ -448,9 +449,6 @@ def get_exc_info(
self.traceback,
)

async def execute(self) -> state_machine.State: # type: ignore
...

def enter(self) -> None:
self.in_state = True

Expand Down Expand Up @@ -486,9 +484,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
def enter(self) -> None:
self.in_state = True

async def execute(self) -> state_machine.State: # type: ignore
...

def exit(self) -> None:
if self.is_terminal:
raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
Expand Down Expand Up @@ -519,9 +514,6 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]):
self.process = process
self.msg = msg

async def execute(self) -> state_machine.State: # type: ignore
...

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process
Expand Down
22 changes: 21 additions & 1 deletion src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,15 @@
utils,
)
from .base import state_machine
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.state_machine import (
Interruptable,
Proceedable,
StateEntryFailed,
StateMachine,
StateMachineError,
TransitionFailed,
event,
)
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_comms import MESSAGE_KEY, KillMessage, MessageType
Expand Down Expand Up @@ -1088,6 +1096,11 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable
return self._pausing

if self._stepping:
if not isinstance(self._state, Interruptable):
raise exceptions.InvalidStateError(
f'cannot interrupt {self._state.__class__}, method `interrupt` not implement'
)

# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.PauseInterruption(msg)
Expand All @@ -1099,6 +1112,10 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable

return self._do_pause(msg)

@staticmethod
def _interrupt(state: Interruptable, reason: Exception) -> None:
state.interrupt(reason)

def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool:
"""Carry out the pause procedure, optionally transitioning to the next state first"""
try:
Expand Down Expand Up @@ -1275,6 +1292,9 @@ async def step(self) -> None:
if self.paused and self._paused is not None:
await self._paused

if not isinstance(self._state, Proceedable):
raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented')

try:
self._stepping = True
next_state = None
Expand Down

0 comments on commit af11b1f

Please sign in to comment.