From 884157412a8d4407f80d4f44cf65f052adf72633 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 22:38:21 +0100 Subject: [PATCH] Forming a State protocol --- src/plumpy/base/state_machine.py | 45 ++++--------- src/plumpy/process_states.py | 109 +++++++++++-------------------- src/plumpy/processes.py | 10 +-- tests/base/test_statemachine.py | 16 +++-- 4 files changed, 66 insertions(+), 114 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index a650d0ed..3822a745 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -3,7 +3,7 @@ from __future__ import annotations -import abc +from abc import abstractmethod, ABCMeta import enum import functools import inspect @@ -22,7 +22,6 @@ Optional, Protocol, Sequence, - Set, Type, Union, runtime_checkable, @@ -49,7 +48,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None: + def __init__(self, state: 'State', *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -92,12 +91,12 @@ def event( if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) - if not all(issubclass(state, State) for state in from_states): # type: ignore + if not all(isinstance(state, State) for state in from_states): # type: ignore raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) - if not all(issubclass(state, State) for state in to_states): # type: ignore + if not all(isinstance(state, State) for state in to_states): # type: ignore raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: @@ -131,32 +130,20 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -class State(abc.ABC): +@runtime_checkable +class State(Protocol): LABEL: ClassVar[LABEL_TYPE] - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): - """ - :param state_machine: The process this state belongs to - """ - self.state_machine = state_machine - - def execute(self) -> Optional['State']: + async def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ ... - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': - return self.state_machine.create_state(state_label, *args, **kwargs) - - @abc.abstractmethod - def enter(self) -> None: - ... + def enter(self) -> None: ... - @abc.abstractmethod - def exit(self) -> None: - ... + def exit(self) -> None: ... class StateEventHook(enum.Enum): @@ -233,7 +220,7 @@ def __ensure_built(cls) -> None: # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: - assert issubclass(state_cls, State) + assert isinstance(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" cls._STATES_MAP[label] = state_cls @@ -293,7 +280,7 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None: + def transition_to(self, new_state: State | None, **kwargs: Any) -> None: """Transite to the new state. The new target state will be create lazily when the state is not yet instantiated, @@ -311,10 +298,6 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> try: self._transitioning = True - if not isinstance(new_state, State): - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, **kwargs) - label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -325,8 +308,7 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> self._enter_next_state(new_state) except StateEntryFailed as exception: # Make sure we have a state instance - if not isinstance(exception.state, State): - new_state = self._create_state_instance(exception.state, **exception.kwargs) + new_state = exception.state label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) @@ -370,7 +352,8 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat # This method should be replaced by `_create_state_instance`. # aiida-core using this method for its Waiting state override. try: - return self.get_states_map()[state_label](self, *args, **kwargs) + state_cls = self.get_states_map()[state_label] + return state_cls(self, *args, **kwargs) except KeyError: raise ValueError(f'{state_label} is not a valid state') diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 450fd90b..e2540c27 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,7 +5,7 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final import yaml from yaml.loader import Loader @@ -146,6 +146,7 @@ class ProcessState(Enum): KILLED: str = 'killed' +@final @auto_persist('args', 'kwargs', 'in_state') class Created(state_machine.State, persistence.Savable): LABEL = ProcessState.CREATED @@ -155,11 +156,12 @@ class Created(state_machine.State, persistence.Savable): is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs + self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -167,30 +169,24 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) - - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + async def execute(self) -> state_machine.State: + return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) def enter(self) -> None: self.in_state = True def exit(self) -> None: if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') self.in_state = False +@final @auto_persist('args', 'kwargs', 'in_state') class Running(state_machine.State, persistence.Savable): LABEL = ProcessState.RUNNING @@ -213,12 +209,13 @@ class Running(state_machine.State, persistence.Savable): is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs self._run_handle = None + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -228,7 +225,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: @@ -237,7 +234,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> state_machine.State: # type: ignore + async def execute(self) -> state_machine.State: if self._command is not None: command = self._command else: @@ -251,7 +248,7 @@ async def execute(self) -> state_machine.State: # type: ignore # Let this bubble up to the caller raise except Exception: - excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) + excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) return cast(state_machine.State, excepted) else: if not isinstance(result, Command): @@ -268,37 +265,31 @@ async def execute(self) -> state_machine.State: # type: ignore def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: if isinstance(command, Kill): - state = self.create_state(ProcessState.KILLED, command.msg) + state = self.process.create_state(ProcessState.KILLED, command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) else: raise ValueError('Unrecognised command') return cast(state_machine.State, state) # casting from base.State to process.State - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - def enter(self) -> None: self.in_state = True def exit(self) -> None: if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') self.in_state = False +@final @auto_persist('msg', 'data', 'in_state') class Waiting(state_machine.State, persistence.Savable): LABEL = ProcessState.WAITING @@ -329,11 +320,12 @@ def __init__( msg: Optional[str] = None, data: Optional[Any] = None, ) -> None: - super().__init__(process) + self.process = process self.done_callback = done_callback self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -342,7 +334,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: @@ -366,9 +358,9 @@ async def execute(self) -> state_machine.State: # type: ignore raise if result == NULL: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result) return cast(state_machine.State, next_state) # casting from base.State to process.State @@ -380,19 +372,12 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - def enter(self) -> None: self.in_state = True def exit(self) -> None: if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') self.in_state = False @@ -405,7 +390,6 @@ class Excepted(state_machine.State, persistence.Savable): :param exception: The exception instance :param trace_back: An optional exception traceback """ - LABEL = ProcessState.EXCEPTED EXC_VALUE = 'ex_value' @@ -424,9 +408,10 @@ def __init__( :param exception: The exception instance :param trace_back: An optional exception traceback """ - super().__init__(process) + self.process = process self.exception = exception self.traceback = trace_back + self.in_state = False def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -440,7 +425,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -463,19 +448,12 @@ def get_exc_info( self.traceback, ) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - def enter(self) -> None: self.in_state = True def exit(self) -> None: if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') self.in_state = False @@ -493,27 +471,21 @@ class Finished(state_machine.State, persistence.Savable): is_terminal = True def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - super().__init__(process) + self.process = process self.result = result self.successful = successful - - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process def enter(self) -> None: self.in_state = True def exit(self) -> None: if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') self.in_state = False @@ -538,26 +510,19 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): :param process: The associated process :param msg: Optional kill message """ - super().__init__(process) + self.process = process self.msg = msg - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process def enter(self) -> None: self.in_state = True def exit(self) -> None: if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') self.in_state = False diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 7e14163b..6bc88bac 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -867,7 +867,7 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.Finished, result=result, successful=False) + raise StateEntryFailed(process_states.Finished(self, result=result, successful=False)) self.future().set_result(self.outputs) @@ -1064,7 +1064,7 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace) + self.transition_to(process_states.Excepted(self, exception=exception, trace_back=trace)) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1127,7 +1127,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: state_machine.State) -> Any: try: - self.transition_to(process_states.Killed, msg=exception.msg) + self.transition_to(process_states.Killed(self, msg=exception.msg)) return True finally: self._killing = None @@ -1179,7 +1179,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace_back) + self.transition_to(process_states.Excepted(self, exception=exception, trace_back=trace_back)) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1207,7 +1207,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.Killed, msg=msg) + self.transition_to(process_states.Killed(self, msg)) return True @property diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index aa55202b..b6b31209 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import time +from typing import final import unittest from plumpy.base import state_machine @@ -25,10 +26,10 @@ class Playing(state_machine.State): def __init__(self, player, track): assert track is not None, 'Must provide a track name' - super().__init__(player) self.track = track self._last_time = None self._played = 0.0 + self.in_state = False def __str__(self): if self.in_state: @@ -64,7 +65,7 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' - super().__init__(player) + self.state_machine = player self.playing_state = playing_state def __str__(self): @@ -72,7 +73,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(self, track=track)) else: self.state_machine.transition_to(self.playing_state) @@ -95,11 +96,14 @@ class Stopped(state_machine.State): is_terminal = False + def __init__(self, player): + self.state_machine = player + def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(self, track=track)) def enter(self) -> None: self.in_state = True @@ -135,12 +139,12 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, playing_state=self._state) + self.transition_to(Paused(self, playing_state=self._state)) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) def stop(self): - self.transition_to(Stopped) + self.transition_to(Stopped(self)) class TestStateMachine(unittest.TestCase):