diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index c36f04ea..5b0841b6 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -132,6 +132,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): LABEL: ClassVar[LABEL_TYPE] + ALLOWED: ClassVar[set[str]] is_terminal: ClassVar[bool] def enter(self) -> None: ... diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index cc9169c7..a58287ef 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -147,8 +147,8 @@ class ProcessState(Enum): @final -@auto_persist('args', 'kwargs', 'in_state') -class Created(state_machine.State, persistence.Savable): +@auto_persist('args', 'kwargs') +class Created(persistence.Savable): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} @@ -161,7 +161,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.run_fn = run_fn self.args = args self.kwargs = kwargs - self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -177,18 +176,15 @@ async def execute(self) -> state_machine.State: return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) def enter(self) -> None: - self.in_state = True + ... def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - self.in_state = False + ... @final -@auto_persist('args', 'kwargs', 'in_state') -class Running(state_machine.State, persistence.Savable): +@auto_persist('args', 'kwargs') +class Running(persistence.Savable): LABEL = ProcessState.RUNNING ALLOWED = { ProcessState.RUNNING, @@ -215,7 +211,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.args = args self.kwargs = kwargs self._run_handle = None - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -280,18 +275,15 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_m return cast(state_machine.State, state) # casting from base.State to process.State def enter(self) -> None: - self.in_state = True + ... def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - self.in_state = False + ... @final -@auto_persist('msg', 'data', 'in_state') -class Waiting(state_machine.State, persistence.Savable): +@auto_persist('msg', 'data') +class Waiting(persistence.Savable): LABEL = ProcessState.WAITING ALLOWED = { ProcessState.RUNNING, @@ -325,7 +317,6 @@ def __init__( self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -373,25 +364,21 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) def enter(self) -> None: - self.in_state = True + ... def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - self.in_state = False + ... - -@auto_persist('in_state') -class Excepted(state_machine.State, persistence.Savable): +@final +class Excepted(persistence.Savable): """ Excepted state, can optionally provide exception and trace_back :param exception: The exception instance :param trace_back: An optional exception traceback """ - LABEL = ProcessState.EXCEPTED + ALLOWED: set[str] = set() EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' @@ -412,7 +399,6 @@ def __init__( self.process = process self.exception = exception self.traceback = trace_back - self.in_state = False def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -450,24 +436,22 @@ def get_exc_info( ) def enter(self) -> None: - self.in_state = True + ... def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - self.in_state = False + ... -@auto_persist('result', 'successful', 'in_state') -class Finished(state_machine.State, persistence.Savable): +@final +@auto_persist('result', 'successful') +class Finished(persistence.Savable): """State for process is finished. :param result: The result of process :param successful: Boolean for the exit code is ``0`` the process is successful. """ - LABEL = ProcessState.FINISHED + ALLOWED: set[str] = set() is_terminal = True @@ -475,24 +459,21 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None: self.process = process self.result = result self.successful = successful - self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process def enter(self) -> None: - self.in_state = True + ... def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + ... - self.in_state = False - -@auto_persist('msg', 'in_state') -class Killed(state_machine.State, persistence.Savable): +@final +@auto_persist('msg') +class Killed(persistence.Savable): """ Represents a state where a process has been killed. @@ -503,6 +484,7 @@ class Killed(state_machine.State, persistence.Savable): """ LABEL = ProcessState.KILLED + ALLOWED: set[str] = set() is_terminal = True @@ -519,13 +501,11 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.process = load_context.process def enter(self) -> None: - self.in_state = True + ... def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + ... - self.in_state = False # endregion