From e4e0df59c908161a4c7bfdc6feb751eada58444c Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 17:07:29 +0100 Subject: [PATCH] forming the enter/exit for State protocol --- src/plumpy/base/state_machine.py | 66 ++++---------- src/plumpy/process_states.py | 148 ++++++++++++++++++------------- src/plumpy/workchains.py | 25 +++--- tests/base/test_statemachine.py | 44 +++++++-- tests/test_processes.py | 2 +- 5 files changed, 157 insertions(+), 128 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 380f4610..2164737f 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -13,15 +13,17 @@ from typing import ( Any, Callable, + ClassVar, Dict, Hashable, Iterable, List, Optional, + Protocol, Sequence, - Set, Type, Union, + runtime_checkable, ) from plumpy.futures import Future @@ -88,12 +90,12 @@ def event( if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) - if not all(issubclass(state, State) for state in from_states): # type: ignore + if not all(isinstance(state, State) for state in from_states): # type: ignore raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) - if not all(issubclass(state, State) for state in to_states): # type: ignore + if not all(isinstance(state, State) for state in to_states): # type: ignore raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: @@ -127,53 +129,20 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -class State: - LABEL: LABEL_TYPE = None - # A set containing the labels of states that can be entered - # from this one - ALLOWED: Set[LABEL_TYPE] = set() +@runtime_checkable +class State(Protocol): + LABEL: ClassVar[LABEL_TYPE] - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): - """ - :param state_machine: The process this state belongs to - """ - self.state_machine = state_machine - self.in_state: bool = False - - def __str__(self) -> str: - return str(self.LABEL) - - @property - def label(self) -> LABEL_TYPE: - """Convenience property to get the state label""" - return self.LABEL - - @super_check - def enter(self) -> None: - """Entering the state""" - - def execute(self) -> Optional['State']: + async def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ + ... - @super_check - def exit(self) -> None: - """Exiting the state""" - if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': - return self.state_machine.create_state(state_label, *args, **kwargs) - - def do_enter(self) -> None: - call_with_super_check(self.enter) - self.in_state = True + def enter(self) -> None: ... - def do_exit(self) -> None: - call_with_super_check(self.exit) - self.in_state = False + def exit(self) -> None: ... class StateEventHook(enum.Enum): @@ -250,7 +219,7 @@ def __ensure_built(cls) -> None: # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: - assert issubclass(state_cls, State) + assert isinstance(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" cls._STATES_MAP[label] = state_cls @@ -380,7 +349,8 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat # This method should be replaced by `_create_state_instance`. # aiida-core using this method for its Waiting state override. try: - return self.get_states_map()[state_label](self, *args, **kwargs) + state_cls = self.get_states_map()[state_label] + return state_cls(self, *args, **kwargs) except KeyError: raise ValueError(f'{state_label} is not a valid state') @@ -390,20 +360,20 @@ def _exit_current_state(self, next_state: State) -> None: # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: - if next_state.label != self.initial_state_label(): + if next_state.LABEL != self.initial_state_label(): raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) - self._state.do_exit() + self._state.exit() def _enter_next_state(self, next_state: State) -> None: last_state = self._state self._fire_state_event(StateEventHook.ENTERING_STATE, next_state) # Enter the new state - next_state.do_enter() + next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 91959c4d..88cab660 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,7 +5,7 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final import yaml from yaml.loader import Loader @@ -146,6 +146,7 @@ class ProcessState(Enum): KILLED: str = 'killed' +@final @auto_persist('args', 'kwargs', 'in_state') class Created(state_machine.State, persistence.Savable): LABEL = ProcessState.CREATED @@ -155,11 +156,12 @@ class Created(state_machine.State, persistence.Savable): is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs + self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -167,21 +169,24 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + async def execute(self) -> state_machine.State: + return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False +@final @auto_persist('args', 'kwargs', 'in_state') class Running(state_machine.State, persistence.Savable): LABEL = ProcessState.RUNNING @@ -204,12 +209,13 @@ class Running(state_machine.State, persistence.Savable): is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs self._run_handle = None + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -219,7 +225,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: @@ -228,7 +234,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> state_machine.State: # type: ignore + async def execute(self) -> state_machine.State: if self._command is not None: command = self._command else: @@ -242,7 +248,7 @@ async def execute(self) -> state_machine.State: # type: ignore # Let this bubble up to the caller raise except Exception: - excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) + excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) return cast(state_machine.State, excepted) else: if not isinstance(result, Command): @@ -259,28 +265,30 @@ async def execute(self) -> state_machine.State: # type: ignore def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: if isinstance(command, Kill): - state = self.create_state(ProcessState.KILLED, command.msg) + state = self.process.create_state(ProcessState.KILLED, command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) else: raise ValueError('Unrecognised command') return cast(state_machine.State, state) # casting from base.State to process.State - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False +@final @auto_persist('msg', 'data', 'in_state') class Waiting(state_machine.State, persistence.Savable): LABEL = ProcessState.WAITING @@ -311,11 +319,12 @@ def __init__( msg: Optional[str] = None, data: Optional[Any] = None, ) -> None: - super().__init__(process) + self.process = process self.done_callback = done_callback self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -324,7 +333,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: @@ -348,9 +357,9 @@ async def execute(self) -> state_machine.State: # type: ignore raise if result == NULL: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result) return cast(state_machine.State, next_state) # casting from base.State to process.State @@ -362,12 +371,14 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('in_state') @@ -397,9 +408,10 @@ def __init__( :param exception: The exception instance :param trace_back: An optional exception traceback """ - super().__init__(process) + self.process = process self.exception = exception self.traceback = trace_back + self.in_state = False def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -413,7 +425,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -436,12 +448,17 @@ def get_exc_info( self.traceback, ) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + async def execute(self) -> state_machine.State: # type: ignore + ... + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('result', 'successful', 'in_state') @@ -457,20 +474,26 @@ class Finished(state_machine.State, persistence.Savable): is_terminal = True def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - super().__init__(process) + self.process = process self.result = result self.successful = successful - - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process + + def enter(self) -> None: + self.in_state = True + + async def execute(self) -> state_machine.State: # type: ignore + ... + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('msg', 'in_state') @@ -493,19 +516,24 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): :param process: The associated process :param msg: Optional kill message """ - super().__init__(process) + self.process = process self.msg = msg - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + async def execute(self) -> state_machine.State: # type: ignore + ... def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False # endregion diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 9eafcb50..eefd57f1 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -26,6 +26,7 @@ import kiwipy from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -87,16 +88,6 @@ def __init__( resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key - def enter(self) -> None: - super().enter() - for awaitable in self._awaiting: - awaitable.add_done_callback(self._awaitable_done) - - def exit(self) -> None: - super().exit() - for awaitable in self._awaiting: - awaitable.remove_done_callback(self._awaitable_done) - def _awaitable_done(self, awaitable: asyncio.Future) -> None: key = self._awaiting.pop(awaitable) try: @@ -107,6 +98,20 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: if not self._awaiting: self._waiting_future.set_result(lang.NULL) + def enter(self) -> None: + for awaitable in self._awaiting: + awaitable.add_done_callback(self._awaitable_done) + + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False + + for awaitable in self._awaiting: + awaitable.remove_done_callback(self._awaitable_done) + class WorkChain(mixins.ContextMixin, processes.Process): """ diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index b6d7e2d3..b6100614 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- import time +from typing import final import unittest from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError # Events PLAY = 'Play' @@ -24,24 +26,16 @@ class Playing(state_machine.State): def __init__(self, player, track): assert track is not None, 'Must provide a track name' - super().__init__(player) self.track = track self._last_time = None self._played = 0.0 + self.in_state = False def __str__(self): if self.in_state: self._update_time() return f'> {self.track} ({self._played}s)' - def enter(self): - super().enter() - self._last_time = time.time() - - def exit(self): - super().exit() - self._update_time() - def play(self, track=None): # pylint: disable=no-self-use, unused-argument return False @@ -50,6 +44,17 @@ def _update_time(self): self._played += current_time - self._last_time self._last_time = current_time + def enter(self) -> None: + self._last_time = time.time() + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self._update_time() + self.in_state = False + class Paused(state_machine.State): LABEL = PAUSED @@ -73,6 +78,15 @@ def play(self, track=None): else: self.state_machine.transition_to(self.playing_state) + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + class Stopped(state_machine.State): LABEL = STOPPED @@ -83,12 +97,24 @@ class Stopped(state_machine.State): is_terminal = False + def __init__(self, player): + self.state_machine = player + def __str__(self): return '[]' def play(self, track): self.state_machine.transition_to(Playing(self.state_machine, track=track)) + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + class CdPlayer(state_machine.StateMachine): STATES = (Stopped, Playing, Paused) diff --git a/tests/test_processes.py b/tests/test_processes.py index eb5bf599..4b8cc606 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -653,7 +653,7 @@ def test_exception_during_on_entered(self): class RaisingProcess(Process): def on_entered(self, from_state): - if from_state is not None and from_state.label == ProcessState.RUNNING: + if from_state is not None and from_state.LABEL == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') super().on_entered(from_state)