From aa4e5ee0f0a8ff6c21f2297b0ee46344665356a2 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 27 Nov 2024 13:55:11 +0100 Subject: [PATCH] Waiting state de-abstraction --- src/plumpy/process_states.py | 228 +++++++++++++++++++++++------------ 1 file changed, 148 insertions(+), 80 deletions(-) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 3407412d..264a8be9 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -22,21 +22,21 @@ from .utils import SAVED_STATE_TYPE __all__ = [ - 'ProcessState', - 'Created', - 'Running', - 'Waiting', - 'Finished', - 'Excepted', - 'Killed', + "ProcessState", + "Created", + "Running", + "Waiting", + "Finished", + "Excepted", + "Killed", # Commands - 'Kill', - 'Stop', - 'Wait', - 'Continue', - 'Interruption', - 'KillInterruption', - 'PauseInterruption', + "Kill", + "Stop", + "Wait", + "Continue", + "Interruption", + "KillInterruption", + "PauseInterruption", ] if TYPE_CHECKING: @@ -62,9 +62,8 @@ class Command(persistence.Savable): pass -@auto_persist('msg') +@auto_persist("msg") class Kill(Command): - def __init__(self, msg: Optional[Any] = None): super().__init__() self.msg = msg @@ -74,11 +73,13 @@ class Pause(Command): pass -@auto_persist('msg', 'data') +@auto_persist("msg", "data") class Wait(Command): - def __init__( - self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None + self, + continue_fn: Optional[Callable[..., Any]] = None, + msg: Optional[Any] = None, + data: Optional[Any] = None, ): super().__init__() self.continue_fn = continue_fn @@ -86,18 +87,17 @@ def __init__( self.data = data -@auto_persist('result') +@auto_persist("result") class Stop(Command): - def __init__(self, result: Any, successful: bool) -> None: super().__init__() self.result = result self.successful = successful -@auto_persist('args', 'kwargs') +@auto_persist("args", "kwargs") class Continue(Command): - CONTINUE_FN = 'continue_fn' + CONTINUE_FN = "continue_fn" def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): super().__init__() @@ -105,11 +105,15 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) try: self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) @@ -127,17 +131,17 @@ class ProcessState(Enum): """ The possible states that a :class:`~plumpy.processes.Process` can be in. """ - CREATED: str = 'created' - RUNNING: str = 'running' - WAITING: str = 'waiting' - FINISHED: str = 'finished' - EXCEPTED: str = 'excepted' - KILLED: str = 'killed' + CREATED: str = "created" + RUNNING: str = "running" + WAITING: str = "waiting" + FINISHED: str = "finished" + EXCEPTED: str = "excepted" + KILLED: str = "killed" -@auto_persist('in_state') -class State(state_machine.State, persistence.Savable): +@auto_persist("in_state") +class State(state_machine.State, persistence.Savable): @property def process(self) -> state_machine.StateMachine: """ @@ -145,7 +149,9 @@ def process(self) -> state_machine.StateMachine: """ return self.state_machine - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process @@ -153,48 +159,62 @@ def interrupt(self, reason: Any) -> None: # pylint: disable=unused-argument pass -@auto_persist('args', 'kwargs') +@auto_persist("args", "kwargs") class Created(State): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} - RUN_FN = 'run_fn' + RUN_FN = "run_fn" - def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + return self.create_state( + ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs + ) -@auto_persist('args', 'kwargs') +@auto_persist("args", "kwargs") class Running(State): LABEL = ProcessState.RUNNING ALLOWED = { - ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED + ProcessState.RUNNING, + ProcessState.WAITING, + ProcessState.FINISHED, + ProcessState.KILLED, + ProcessState.EXCEPTED, } - RUN_FN = 'run_fn' # The key used to store the function to run - COMMAND = 'command' # The key used to store an upcoming command + RUN_FN = "run_fn" # The key used to store the function to run + COMMAND = "command" # The key used to store an upcoming command # Class level defaults _command: Union[None, Kill, Stop, Wait, Continue] = None _running: bool = False _run_handle = None - def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn @@ -202,17 +222,23 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.kwargs = kwargs self._run_handle = None - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: - self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore + self._command = persistence.Savable.load( + saved_state[self.COMMAND], load_context + ) # type: ignore def interrupt(self, reason: Any) -> None: pass @@ -252,40 +278,50 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = self.create_state( + ProcessState.FINISHED, command.result, command.successful + ) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = self.create_state( + ProcessState.WAITING, command.continue_fn, command.msg, command.data + ) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = self.create_state( + ProcessState.RUNNING, command.continue_fn, *command.args + ) else: - raise ValueError('Unrecognised command') + raise ValueError("Unrecognised command") return cast(State, state) # casting from base.State to process.State -@auto_persist('msg', 'data') -class Waiting(State): +@auto_persist("msg", "data", "in_state") +class Waiting(state_machine.State, persistence.Savable): LABEL = ProcessState.WAITING ALLOWED = { - ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, ProcessState.EXCEPTED, ProcessState.FINISHED + ProcessState.RUNNING, + ProcessState.WAITING, + ProcessState.KILLED, + ProcessState.EXCEPTED, + ProcessState.FINISHED, } - DONE_CALLBACK = 'DONE_CALLBACK' + DONE_CALLBACK = "DONE_CALLBACK" _interruption = None def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: - state_info += f' ({self.msg})' + state_info += f" ({self.msg})" return state_info def __init__( self, - process: 'Process', + process: "Process", done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - data: Optional[Any] = None + data: Optional[Any] = None, ) -> None: super().__init__(process) self.done_callback = done_callback @@ -293,13 +329,26 @@ def __init__( self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: self.done_callback = getattr(self.process, callback_name) @@ -324,12 +373,14 @@ async def execute(self) -> State: # type: ignore # pylint: disable=invalid-over if result == NULL: next_state = self.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = self.create_state( + ProcessState.RUNNING, self.done_callback, result + ) return cast(State, next_state) # casting from base.State to process.State def resume(self, value: Any = NULL) -> None: - assert self._waiting_future is not None, 'Not yet waiting' + assert self._waiting_future is not None, "Not yet waiting" if self._waiting_future.done(): return @@ -340,11 +391,14 @@ def resume(self, value: Any = NULL) -> None: class Excepted(State): LABEL = ProcessState.EXCEPTED - EXC_VALUE = 'ex_value' - TRACEBACK = 'traceback' + EXC_VALUE = "ex_value" + TRACEBACK = "traceback" def __init__( - self, process: 'Process', exception: Optional[BaseException], trace_back: Optional[TracebackType] = None + self, + process: "Process", + exception: Optional[BaseException], + trace_back: Optional[TracebackType] = None, ): """ :param process: The associated process @@ -356,50 +410,64 @@ def __init__( self.traceback = trace_back def __str__(self) -> str: - exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] - return super().__str__() + f'({exception})' + exception = traceback.format_exception_only( + type(self.exception) if self.exception else None, self.exception + )[0] + return super().__str__() + f"({exception})" - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: - out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) + out_state[self.TRACEBACK] = "".join(traceback.format_tb(self.traceback)) - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: - self.traceback = \ - tblib.Traceback.from_string(saved_state[self.TRACEBACK], - strict=False) + self.traceback = tblib.Traceback.from_string( + saved_state[self.TRACEBACK], strict=False + ) except KeyError: self.traceback = None else: self.traceback = None - def get_exc_info(self) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: + def get_exc_info( + self, + ) -> Tuple[ + Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType] + ]: """ Recreate the exc_info tuple and return it """ - return type(self.exception) if self.exception else None, self.exception, self.traceback + return ( + type(self.exception) if self.exception else None, + self.exception, + self.traceback, + ) -@auto_persist('result', 'successful') +@auto_persist("result", "successful") class Finished(State): LABEL = ProcessState.FINISHED - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: + def __init__(self, process: "Process", result: Any, successful: bool) -> None: super().__init__(process) self.result = result self.successful = successful -@auto_persist('msg') +@auto_persist("msg") class Killed(State): LABEL = ProcessState.KILLED - def __init__(self, process: 'Process', msg: Optional[str]): + def __init__(self, process: "Process", msg: Optional[str]): """ :param process: The associated process :param msg: Optional kill message