From 080d0364ce160f82ed0477081a9f73c6bc5ec7de Mon Sep 17 00:00:00 2001
From: Jusong Yu <jusong.yeu@gmail.com>
Date: Mon, 2 Dec 2024 23:20:45 +0100
Subject: [PATCH] Forming Interruptable and Proceedable protocol

---
 src/plumpy/base/state_machine.py | 20 +++++++++++++++-----
 src/plumpy/process_states.py     | 12 ++----------
 src/plumpy/processes.py          | 22 +++++++++++++++++++++-
 3 files changed, 38 insertions(+), 16 deletions(-)

diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py
index 2164737f..27b1e5f8 100644
--- a/src/plumpy/base/state_machine.py
+++ b/src/plumpy/base/state_machine.py
@@ -132,18 +132,28 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
 @runtime_checkable
 class State(Protocol):
     LABEL: ClassVar[LABEL_TYPE]
+    is_terminal: ClassVar[bool]
 
-    async def execute(self) -> State | None:
+    def enter(self) -> None: ...
+
+    def exit(self) -> None: ...
+
+
+@runtime_checkable
+class Interruptable(Protocol):
+    def interrupt(self, reason: Exception) -> None: ...
+
+
+@runtime_checkable
+class Proceedable(Protocol):
+
+    def execute(self) -> State | None:
         """
         Execute the state, performing the actions that this state is responsible for.
         :returns: a state to transition to or None if finished.
         """
         ...
 
-    def enter(self) -> None: ...
-
-    def exit(self) -> None: ...
-
 
 class StateEventHook(enum.Enum):
     """
diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py
index 88cab660..cc9169c7 100644
--- a/src/plumpy/process_states.py
+++ b/src/plumpy/process_states.py
@@ -288,6 +288,7 @@ def exit(self) -> None:
 
         self.in_state = False
 
+
 @final
 @auto_persist('msg', 'data', 'in_state')
 class Waiting(state_machine.State, persistence.Savable):
@@ -342,7 +343,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
             self.done_callback = None
         self._waiting_future = futures.Future()
 
-    def interrupt(self, reason: Any) -> None:
+    def interrupt(self, reason: Exception) -> None:
         # This will cause the future in execute() to raise the exception
         self._waiting_future.set_exception(reason)
 
@@ -448,9 +449,6 @@ def get_exc_info(
             self.traceback,
         )
 
-    async def execute(self) -> state_machine.State:  # type: ignore
-        ...
-
     def enter(self) -> None:
         self.in_state = True
 
@@ -486,9 +484,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
     def enter(self) -> None:
         self.in_state = True
 
-    async def execute(self) -> state_machine.State:  # type: ignore
-        ...
-
     def exit(self) -> None:
         if self.is_terminal:
             raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
@@ -519,9 +514,6 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]):
         self.process = process
         self.msg = msg
 
-    async def execute(self) -> state_machine.State:  # type: ignore
-        ...
-
     def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
         super().load_instance_state(saved_state, load_context)
         self.process = load_context.process
diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py
index 1e745437..74808291 100644
--- a/src/plumpy/processes.py
+++ b/src/plumpy/processes.py
@@ -51,7 +51,15 @@
     utils,
 )
 from .base import state_machine
-from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
+from .base.state_machine import (
+    Interruptable,
+    Proceedable,
+    StateEntryFailed,
+    StateMachine,
+    StateMachineError,
+    TransitionFailed,
+    event,
+)
 from .base.utils import call_with_super_check, super_check
 from .event_helper import EventHelper
 from .process_comms import MESSAGE_KEY, KillMessage, MessageType
@@ -1092,6 +1100,11 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable
             return self._pausing
 
         if self._stepping:
+            if not isinstance(self._state, Interruptable):
+                raise exceptions.InvalidStateError(
+                    f'cannot interrupt {self._state.__class__}, method `interrupt` not implement'
+                )
+
             # Ask the step function to pause by setting this flag and giving the
             # caller back a future
             interrupt_exception = process_states.PauseInterruption(msg)
@@ -1103,6 +1116,10 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable
 
         return self._do_pause(msg)
 
+    @staticmethod
+    def _interrupt(state: Interruptable, reason: Exception) -> None:
+        state.interrupt(reason)
+
     def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool:
         """Carry out the pause procedure, optionally transitioning to the next state first"""
         try:
@@ -1285,6 +1302,9 @@ async def step(self) -> None:
         if self.paused and self._paused is not None:
             await self._paused
 
+        if not isinstance(self._state, Proceedable):
+            raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented')
+
         try:
             self._stepping = True
             next_state = None