From 103193ca90af62822e07e6bd18c69815d4eb41a4 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Thu, 28 Nov 2024 22:14:56 +0100 Subject: [PATCH] Finished state de-abstraction --- src/plumpy/base/state_machine.py | 23 +++++++++++++++- src/plumpy/process_states.py | 45 +++++++++++++++++++++++++++++--- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 36ca73ba..ec19e2d5 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -8,7 +8,7 @@ import os import sys from types import TracebackType -from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast +from typing import Any, Callable, ClassVar, Dict, Hashable, Iterable, List, Optional, Protocol, Sequence, Set, Type, Union, cast, runtime_checkable from plumpy.futures import Future @@ -165,6 +165,27 @@ def do_exit(self) -> None: self.in_state = False +@runtime_checkable +class StateP(Protocol): + LABEL: ClassVar[str] + + # FIXME: fix the LABEL_TYPE + ALLOWED: ClassVar[set[LABEL_TYPE]] + + def do_enter(self) -> None: + ... + + def do_exit(self) -> None: + ... + + def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': + ... + + @classmethod + def is_terminal(cls) -> bool: + ... + + class StateEventHook(enum.Enum): """ Hooks that can be used to register callback at various points in the state transition diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 2446b044..41655678 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -267,7 +267,6 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: return cast(State, state) # casting from base.State to process.State -# class Waiting(state_machine.State): class Waiting(state_machine.State, persistence.Savable): """The basic waiting state.""" @@ -292,6 +291,7 @@ def __str__(self) -> str: state_info += f' ({self.msg})' return state_info + # FIXME: fully get rid of state_machine.State as parent class (as a protocol with contract) def __init__( self, process: 'Process', @@ -314,6 +314,7 @@ def process(self) -> state_machine.StateMachine: """ return self._process + # FIXME: this is shared by all states def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'state_machine.State': return self._process.create_state(state_label, *args, **kwargs) @@ -424,15 +425,51 @@ def get_exc_info(self) -> Tuple[Optional[Type[BaseException]], Optional[BaseExce return type(self.exception) if self.exception else None, self.exception, self.traceback -@auto_persist('result', 'successful') -class Finished(State): +class Finished(state_machine.State, persistence.Savable): LABEL = ProcessState.FINISHED + is_terminal_state = True + + _auto_persist = {'result', 'successful'} def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - super().__init__(process) + self._process = process + self.in_state: bool = False self.result = result self.successful = successful + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self._process + + # FIXME: this is shared by all states + def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'state_machine.State': + return self._process.create_state(state_label, *args, **kwargs) + + # FIXME: this is shared + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + super().save_instance_state(out_state, save_context) + + # FIXME: this is shared + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + super().load_instance_state(saved_state, load_context) + + def do_enter(self) -> None: + self.in_state = True + + def do_exit(self) -> None: + if self.is_terminal(): + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + + @classmethod + def is_terminal(cls) -> bool: + # deprecated using class attribute `is_terminal_state` directly. + return cls.is_terminal_state + @auto_persist('msg') class Killed(State):