From 531c6960baedf7803dc6bb9d9cca5a6c8ef50052 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 3 Dec 2024 00:48:53 +0100 Subject: [PATCH] Just the interface duck typing --- src/plumpy/base/state_machine.py | 1 + src/plumpy/process_states.py | 102 +++++++++++-------------------- 2 files changed, 37 insertions(+), 66 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index c36f04ea..5b0841b6 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -132,6 +132,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): LABEL: ClassVar[LABEL_TYPE] + ALLOWED: ClassVar[set[str]] is_terminal: ClassVar[bool] def enter(self) -> None: ... diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index cc9169c7..4348803e 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -138,17 +138,17 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ - CREATED: str = 'created' - RUNNING: str = 'running' - WAITING: str = 'waiting' - FINISHED: str = 'finished' - EXCEPTED: str = 'excepted' - KILLED: str = 'killed' + CREATED = 'created' + RUNNING = 'running' + WAITING = 'waiting' + FINISHED = 'finished' + EXCEPTED = 'excepted' + KILLED = 'killed' @final -@auto_persist('args', 'kwargs', 'in_state') -class Created(state_machine.State, persistence.Savable): +@auto_persist('args', 'kwargs') +class Created(persistence.Savable): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} @@ -161,7 +161,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.run_fn = run_fn self.args = args self.kwargs = kwargs - self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -176,19 +175,14 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi async def execute(self) -> state_machine.State: return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) - def enter(self) -> None: - self.in_state = True - - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... @final -@auto_persist('args', 'kwargs', 'in_state') -class Running(state_machine.State, persistence.Savable): +@auto_persist('args', 'kwargs') +class Running(persistence.Savable): LABEL = ProcessState.RUNNING ALLOWED = { ProcessState.RUNNING, @@ -215,7 +209,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.args = args self.kwargs = kwargs self._run_handle = None - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -258,7 +251,7 @@ async def execute(self) -> state_machine.State: # Got passed a basic return type result = Stop(result, True) - command = result + command = cast(Stop, result) next_state = self._action_command(command) return next_state @@ -279,19 +272,14 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_m return cast(state_machine.State, state) # casting from base.State to process.State - def enter(self) -> None: - self.in_state = True - - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... @final -@auto_persist('msg', 'data', 'in_state') -class Waiting(state_machine.State, persistence.Savable): +@auto_persist('msg', 'data') +class Waiting(persistence.Savable): LABEL = ProcessState.WAITING ALLOWED = { ProcessState.RUNNING, @@ -325,7 +313,6 @@ def __init__( self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -372,18 +359,13 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) - def enter(self) -> None: - self.in_state = True - - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... -@auto_persist('in_state') -class Excepted(state_machine.State, persistence.Savable): +@final +class Excepted(persistence.Savable): """ Excepted state, can optionally provide exception and trace_back @@ -392,6 +374,7 @@ class Excepted(state_machine.State, persistence.Savable): """ LABEL = ProcessState.EXCEPTED + ALLOWED: set[str] = set() EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' @@ -412,7 +395,6 @@ def __init__( self.process = process self.exception = exception self.traceback = trace_back - self.in_state = False def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -449,18 +431,14 @@ def get_exc_info( self.traceback, ) - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('result', 'successful', 'in_state') -class Finished(state_machine.State, persistence.Savable): +@final +@auto_persist('result', 'successful') +class Finished(persistence.Savable): """State for process is finished. :param result: The result of process @@ -468,6 +446,7 @@ class Finished(state_machine.State, persistence.Savable): """ LABEL = ProcessState.FINISHED + ALLOWED: set[str] = set() is_terminal = True @@ -475,24 +454,19 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None: self.process = process self.result = result self.successful = successful - self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process - def enter(self) -> None: - self.in_state = True - - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... -@auto_persist('msg', 'in_state') -class Killed(state_machine.State, persistence.Savable): +@final +@auto_persist('msg') +class Killed(persistence.Savable): """ Represents a state where a process has been killed. @@ -503,6 +477,7 @@ class Killed(state_machine.State, persistence.Savable): """ LABEL = ProcessState.KILLED + ALLOWED: set[str] = set() is_terminal = True @@ -518,14 +493,9 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi super().load_instance_state(saved_state, load_context) self.process = load_context.process - def enter(self) -> None: - self.in_state = True - - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... # endregion