From 6ce23b2fcc971d8732026b70a10250772b11a3cf Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 18 Jan 2025 01:18:23 +0100 Subject: [PATCH 01/21] Explict module import list in __init__.py (#8) --- src/plumpy/__init__.py | 173 +++++++++++++++++++++++++++------ src/plumpy/events.py | 10 -- src/plumpy/exceptions.py | 10 -- src/plumpy/futures.py | 2 - src/plumpy/loaders.py | 2 - src/plumpy/message.py | 7 -- src/plumpy/mixins.py | 2 - src/plumpy/persistence.py | 12 --- src/plumpy/ports.py | 2 - src/plumpy/process_listener.py | 2 - src/plumpy/process_states.py | 18 ---- src/plumpy/processes.py | 4 +- src/plumpy/utils.py | 2 - src/plumpy/workchains.py | 6 +- 14 files changed, 147 insertions(+), 105 deletions(-) diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 864d2226..2c988cd8 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -1,42 +1,157 @@ # -*- coding: utf-8 -*- -# mypy: disable-error-code=name-defined __version__ = '0.24.0' import logging +from .base.state_machine import TransitionFailed + # interfaces from .controller import ProcessController from .coordinator import Coordinator -from .events import * -from .exceptions import * -from .futures import * -from .loaders import * -from .message import * -from .mixins import * -from .persistence import * -from .ports import * -from .process_listener import * -from .process_states import * -from .processes import * -from .rmq import * -from .utils import * -from .workchains import * +from .events import ( + PlumpyEventLoopPolicy, + get_event_loop, + new_event_loop, + reset_event_loop_policy, + run_until_complete, + set_event_loop, + set_event_loop_policy, +) +from .exceptions import ( + ClosedError, + CoordinatorConnectionError, + CoordinatorTimeoutError, + InvalidStateError, + KilledError, + PersistenceError, + UnsuccessfulResult, +) +from .futures import CancellableAction, Future, capture_exceptions, create_task +from .loaders import DefaultObjectLoader, ObjectLoader, get_object_loader, set_object_loader +from .message import MessageBuilder, ProcessLauncher, create_continue_body, create_launch_body +from .mixins import ContextMixin +from .persistence import ( + Bundle, + InMemoryPersister, + LoadSaveContext, + PersistedCheckpoint, + Persister, + PicklePersister, + Savable, + SavableFuture, + auto_persist, +) +from .ports import UNSPECIFIED, InputPort, OutputPort, Port, PortNamespace, PortValidationError +from .process_listener import ProcessListener +from .process_spec import ProcessSpec +from .process_states import ( + Continue, + Created, + Excepted, + Finished, + Interruption, + Kill, + Killed, + KillInterruption, + PauseInterruption, + ProcessState, + Running, + Stop, + Wait, + Waiting, +) +from .processes import BundleKeys, Process +from .utils import AttributesDict +from .workchains import ToContext, WorkChain, WorkChainSpec, if_, return_, while_ __all__ = ( - events.__all__ - + exceptions.__all__ - + processes.__all__ - + utils.__all__ - + futures.__all__ - + mixins.__all__ - + persistence.__all__ - + message.__all__ - + process_listener.__all__ - + workchains.__all__ - + loaders.__all__ - + ports.__all__ - + process_states.__all__ -) + ['ProcessController', 'Coordinator'] + # ports + 'UNSPECIFIED', + # utils + 'AttributesDict', + # persistence + 'Bundle', + # processes + 'BundleKeys', + # futures + 'CancellableAction', + # exceptions + 'ClosedError', + # mixins + 'ContextMixin', + # process_states/States + 'Continue', + # coordinator + 'Coordinator', + 'CoordinatorConnectionError', + 'CoordinatorTimeoutError', + 'Created', + # loaders + 'DefaultObjectLoader', + 'Excepted', + 'Finished', + 'Future', + 'InMemoryPersister', + 'InputPort', + 'Interruption', + 'InvalidStateError', + # process_states/Commands + 'Kill', + 'KillInterruption', + 'Killed', + 'KilledError', + 'LoadSaveContext', + # message + 'MessageBuilder', + 'ObjectLoader', + 'OutputPort', + 'PauseInterruption', + 'PersistedCheckpoint', + 'PersistenceError', + 'Persister', + 'PicklePersister', + # event + 'PlumpyEventLoopPolicy', + 'Port', + 'PortNamespace', + 'PortValidationError', + 'Process', + # controller + 'ProcessController', + 'ProcessLauncher', + # process_listener + 'ProcessListener', + 'ProcessSpec', + 'ProcessState', + 'Running', + 'Savable', + 'SavableFuture', + 'Stop', + # workchain + 'ToContext', + 'TransitionFailed', + 'UnsuccessfulResult', + 'Wait', + 'Waiting', + 'WorkChain', + 'WorkChainSpec', + 'auto_persist', + 'capture_exceptions', + 'create_continue_body', + 'create_launch_body', + 'create_task', + 'get_event_loop', + 'get_object_loader', + 'if_', + 'new_event_loop', + 'reset_event_loop_policy', + 'return_', + 'run_until_complete', + 'set_event_loop', + 'set_event_loop_policy', + 'set_object_loader', + 'while_', +) # Do this se we don't get the "No handlers could be found..." warnings that will be produced diff --git a/src/plumpy/events.py b/src/plumpy/events.py index 3de81987..a6e62529 100644 --- a/src/plumpy/events.py +++ b/src/plumpy/events.py @@ -5,16 +5,6 @@ import sys from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence -__all__ = [ - 'PlumpyEventLoopPolicy', - 'get_event_loop', - 'new_event_loop', - 'reset_event_loop_policy', - 'run_until_complete', - 'set_event_loop', - 'set_event_loop_policy', -] - if TYPE_CHECKING: from .processes import Process diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 5d05ea4b..b4358770 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -1,16 +1,6 @@ # -*- coding: utf-8 -*- from typing import Optional -__all__ = [ - 'ClosedError', - 'CoordinatorConnectionError', - 'CoordinatorTimeoutError', - 'InvalidStateError', - 'KilledError', - 'PersistenceError', - 'UnsuccessfulResult', -] - class KilledError(Exception): """The process was killed.""" diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 139c6069..3a59351d 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -9,8 +9,6 @@ import contextlib from typing import Any, Awaitable, Callable, Generator, Optional -__all__ = ['CancellableAction', 'Future', 'capture_exceptions', 'create_task', 'create_task'] - class InvalidFutureError(Exception): """Exception for when a future or action is in an invalid state""" diff --git a/src/plumpy/loaders.py b/src/plumpy/loaders.py index a01f9b60..bb248d6a 100644 --- a/src/plumpy/loaders.py +++ b/src/plumpy/loaders.py @@ -3,8 +3,6 @@ import importlib from typing import Any, Optional -__all__ = ['DefaultObjectLoader', 'ObjectLoader', 'get_object_loader', 'set_object_loader'] - class ObjectLoader(metaclass=abc.ABCMeta): """ diff --git a/src/plumpy/message.py b/src/plumpy/message.py index 009f1b26..04f03bd9 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -13,13 +13,6 @@ from . import loaders, persistence from .utils import PID_TYPE -__all__ = [ - 'MessageBuilder', - 'ProcessLauncher', - 'create_continue_body', - 'create_launch_body', -] - if TYPE_CHECKING: from .processes import Process diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py index 10142eb7..4b993dac 100644 --- a/src/plumpy/mixins.py +++ b/src/plumpy/mixins.py @@ -4,8 +4,6 @@ from . import persistence from .utils import SAVED_STATE_TYPE, AttributesDict -__all__ = ['ContextMixin'] - class ContextMixin(persistence.Savable): """ diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index ba755bc5..f7cbad44 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -17,18 +17,6 @@ from .base.utils import call_with_super_check, super_check from .utils import PID_TYPE, SAVED_STATE_TYPE -__all__ = [ - 'Bundle', - 'InMemoryPersister', - 'LoadSaveContext', - 'PersistedCheckpoint', - 'Persister', - 'PicklePersister', - 'Savable', - 'SavableFuture', - 'auto_persist', -] - PersistedCheckpoint = collections.namedtuple('PersistedCheckpoint', ['pid', 'tag']) if TYPE_CHECKING: diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index cfbd92d5..8522f061 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -11,8 +11,6 @@ from plumpy.utils import AttributesFrozendict, is_mutable_property, type_check -__all__ = ['UNSPECIFIED', 'InputPort', 'OutputPort', 'Port', 'PortNamespace', 'PortValidationError'] - _LOGGER = logging.getLogger(__name__) UNSPECIFIED = () diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 8e1acf94..c3ab8e5a 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -5,8 +5,6 @@ from . import persistence from .utils import SAVED_STATE_TYPE, protected -__all__ = ['ProcessListener'] - if TYPE_CHECKING: from .processes import Process diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 723292bf..a4c0788d 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -25,24 +25,6 @@ from .persistence import auto_persist from .utils import SAVED_STATE_TYPE, ensure_coroutine -__all__ = [ - 'Continue', - 'Created', - 'Excepted', - 'Finished', - 'Interruption', - # Commands - 'Kill', - 'KillInterruption', - 'Killed', - 'PauseInterruption', - 'ProcessState', - 'Running', - 'Stop', - 'Wait', - 'Waiting', -] - if TYPE_CHECKING: from .processes import Process diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 4c048d9c..c328caf0 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -47,7 +47,7 @@ from . import events, exceptions, message, persistence, ports, process_states, utils from .base import state_machine -from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event +from .base.state_machine import StateEntryFailed, StateMachine, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper from .futures import CancellableAction, capture_exceptions @@ -58,8 +58,6 @@ T = TypeVar('T') -__all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] - _LOGGER = logging.getLogger(__name__) PROCESS_STACK = ContextVar('process stack', default=[]) diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index bd1b70a7..3c37ce08 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -23,8 +23,6 @@ from . import lang from .settings import check_override, check_protected -__all__ = ['AttributesDict'] - protected = lang.protected(check=check_protected) override = lang.override(check=check_override) diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 5df20bf4..ef96b48f 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -25,11 +25,9 @@ from plumpy.coordinator import Coordinator -from . import lang, mixins, persistence, process_states, processes +from . import lang, mixins, persistence, process_spec, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE -__all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] - ToContext = dict PREDICATE_TYPE = Callable[['WorkChain'], bool] @@ -37,7 +35,7 @@ EXIT_CODE_TYPE = int -class WorkChainSpec(processes.ProcessSpec): +class WorkChainSpec(process_spec.ProcessSpec): def __init__(self) -> None: super().__init__() self._outline: Optional[Union['_Instruction', '_FunctionCall']] = None From dd17ccc1e28b0a25aa66537eef2b25d27d394c22 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 15:53:13 +0100 Subject: [PATCH 02/21] Remove the middle layer of statemachine.State + Savable abstraction --- docs/source/nitpick-exceptions | 2 +- src/plumpy/process_states.py | 111 +++++++++++++++++++++++---------- src/plumpy/processes.py | 26 ++++---- src/plumpy/workchains.py | 4 +- 4 files changed, 94 insertions(+), 49 deletions(-) diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 6aa8c345..f5265734 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator # unavailable forward references py:class plumpy.process_states.Command -py:class plumpy.process_states.State +py:class plumpy.state_machine.State py:class plumpy.base.state_machine.State py:class State py:class Process diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index a4c0788d..777933d0 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -105,6 +105,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process try: self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) except ValueError: @@ -130,25 +131,8 @@ class ProcessState(Enum): KILLED: str = 'killed' -@auto_persist('in_state') -class State(state_machine.State, persistence.Savable): - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process - - def interrupt(self, reason: Any) -> None: - pass - - -@auto_persist('args', 'kwargs') -class Created(State): +@auto_persist('args', 'kwargs', 'in_state') +class Created(state_machine.State, persistence.Savable): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} @@ -167,14 +151,23 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) def execute(self) -> state_machine.State: return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine -@auto_persist('args', 'kwargs') -class Running(State): + +@auto_persist('args', 'kwargs', 'in_state') +class Running(state_machine.State, persistence.Savable): LABEL = ProcessState.RUNNING ALLOWED = { ProcessState.RUNNING, @@ -214,6 +207,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) if self.COMMAND in saved_state: self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore @@ -221,7 +216,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> State: # type: ignore + async def execute(self) -> state_machine.State: # type: ignore if self._command is not None: command = self._command else: @@ -236,7 +231,7 @@ async def execute(self) -> State: # type: ignore raise except Exception: excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(State, excepted) + return cast(state_machine.State, excepted) else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -250,7 +245,7 @@ async def execute(self) -> State: # type: ignore next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: if isinstance(command, Kill): state = self.create_state(ProcessState.KILLED, command.msg) # elif isinstance(command, Pause): @@ -264,11 +259,18 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: else: raise ValueError('Unrecognised command') - return cast(State, state) # casting from base.State to process.State + return cast(state_machine.State, state) # casting from base.State to process.State + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine -@auto_persist('msg', 'data') -class Waiting(State): + +@auto_persist('msg', 'data', 'in_state') +class Waiting(state_machine.State, persistence.Savable): LABEL = ProcessState.WAITING ALLOWED = { ProcessState.RUNNING, @@ -308,6 +310,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: self.done_callback = getattr(self.process, callback_name) @@ -319,7 +323,7 @@ def interrupt(self, reason: Any) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> State: # type: ignore + async def execute(self) -> state_machine.State: # type: ignore try: result = await self._waiting_future except Interruption: @@ -334,7 +338,7 @@ async def execute(self) -> State: # type: ignore else: next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) - return cast(State, next_state) # casting from base.State to process.State + return cast(state_machine.State, next_state) # casting from base.State to process.State def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -344,8 +348,16 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + -class Excepted(State): +@auto_persist('in_state') +class Excepted(state_machine.State, persistence.Savable): """ Excepted state, can optionally provide exception and trace_back @@ -385,6 +397,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: @@ -406,9 +420,16 @@ def get_exc_info( self.traceback, ) + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + -@auto_persist('result', 'successful') -class Finished(State): +@auto_persist('result', 'successful', 'in_state') +class Finished(state_machine.State, persistence.Savable): """State for process is finished. :param result: The result of process @@ -422,9 +443,20 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None: self.result = result self.successful = successful + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + -@auto_persist('msg') -class Killed(State): +@auto_persist('msg', 'in_state') +class Killed(state_machine.State, persistence.Savable): """ Represents a state where a process has been killed. @@ -444,5 +476,16 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): super().__init__(process) self.msg = msg + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + # endregion diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index c328caf0..81b32c1f 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -172,7 +172,7 @@ def current(cls) -> Optional['Process']: return None @classmethod - def get_states(cls) -> Sequence[Type[process_states.State]]: + def get_states(cls) -> Sequence[Type[state_machine.State]]: """Return all allowed states of the process.""" state_classes = cls.get_state_classes() return ( @@ -181,7 +181,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -353,10 +353,10 @@ def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering( - cast(process_states.State, state) + cast(state_machine.State, state) ), state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( - cast(Optional[process_states.State], from_state) + cast(Optional[state_machine.State], from_state) ), state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } @@ -657,7 +657,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self._loop = asyncio.get_event_loop() - self._state: process_states.State = self.recreate_state(saved_state['_state']) + self._state: state_machine.State = self.recreate_state(saved_state['_state']) if 'coordinator' in load_context: self._coordinator = load_context.coordinator @@ -715,7 +715,7 @@ def log_with_pid(self, level: int, msg: str) -> None: # region Events - def on_entering(self, state: process_states.State) -> None: + def on_entering(self, state: state_machine.State) -> None: # Map these onto direct functions that the subclass can implement state_label = state.LABEL if state_label == process_states.ProcessState.CREATED: @@ -731,7 +731,7 @@ def on_entering(self, state: process_states.State) -> None: elif state_label == process_states.ProcessState.EXCEPTED: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore - def on_entered(self, from_state: Optional[process_states.State]) -> None: + def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement state_label = self._state.LABEL if state_label == process_states.ProcessState.RUNNING: @@ -1139,7 +1139,7 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: msg = MessageBuilder.pause(msg_text) return self._do_pause(state_msg=msg) - def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool: + def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: @@ -1171,7 +1171,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> Ca if isinstance(exception, process_states.KillInterruption): - def do_kill(_next_state: process_states.State) -> Any: + def do_kill(_next_state: state_machine.State) -> Any: try: new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=exception.msg) self.transition_to(new_state) @@ -1269,7 +1269,7 @@ def is_killing(self) -> bool: # endregion - def create_initial_state(self) -> process_states.State: + def create_initial_state(self) -> state_machine.State: """This method is here to override its superclass. Automatically enter the CREATED state when the process is created. @@ -1277,11 +1277,11 @@ def create_initial_state(self) -> process_states.State: :return: A Created state """ return cast( - process_states.State, + state_machine.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), ) - def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: + def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ Create a state object from a saved state @@ -1289,7 +1289,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) + return cast(state_machine.State, persistence.Savable.load(saved_state, load_context)) # endregion diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index ef96b48f..00a711b5 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -25,6 +25,8 @@ from plumpy.coordinator import Coordinator +from plumpy.base import state_machine + from . import lang, mixins, persistence, process_spec, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -115,7 +117,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map From a684daf3ed63562a7eaff7a6644fa381e262350d Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 16:06:39 +0100 Subject: [PATCH 03/21] Move is_terminal as class attribute required --- src/plumpy/base/state_machine.py | 8 ++------ src/plumpy/process_states.py | 10 ++++++++++ src/plumpy/processes.py | 4 ++-- tests/base/test_statemachine.py | 6 ++++++ 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 681858f0..217a7d51 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -133,10 +133,6 @@ class State: # from this one ALLOWED: Set[LABEL_TYPE] = set() - @classmethod - def is_terminal(cls) -> bool: - return not cls.ALLOWED - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): """ :param state_machine: The process this state belongs to @@ -165,7 +161,7 @@ def execute(self) -> Optional['State']: @super_check def exit(self) -> None: """Exiting the state""" - if self.is_terminal(): + if self.is_terminal: raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': @@ -347,7 +343,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: self._exit_current_state(new_state) self._enter_next_state(new_state) - if self._state is not None and self._state.is_terminal(): + if self._state is not None and self._state.is_terminal: call_with_super_check(self.on_terminated) except Exception: self._transitioning = False diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 777933d0..75f6c7f8 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -137,6 +137,7 @@ class Created(state_machine.State, persistence.Savable): ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' + is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: super().__init__(process) @@ -185,6 +186,7 @@ class Running(state_machine.State, persistence.Savable): _running: bool = False _run_handle = None + is_terminal = False def __init__( self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any ) -> None: @@ -284,6 +286,8 @@ class Waiting(state_machine.State, persistence.Savable): _interruption = None + is_terminal = False + def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: @@ -370,6 +374,8 @@ class Excepted(state_machine.State, persistence.Savable): EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' + is_terminal = True + def __init__( self, process: 'Process', @@ -438,6 +444,8 @@ class Finished(state_machine.State, persistence.Savable): LABEL = ProcessState.FINISHED + is_terminal = True + def __init__(self, process: 'Process', result: Any, successful: bool) -> None: super().__init__(process) self.result = result @@ -468,6 +476,8 @@ class Killed(state_machine.State, persistence.Savable): LABEL = ProcessState.KILLED + is_terminal = True + def __init__(self, process: 'Process', msg: Optional[MessageType]): """ :param process: The associated process diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 81b32c1f..4a8e029c 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -463,7 +463,7 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal() + return self._state.is_terminal def result(self) -> Any: """ @@ -536,7 +536,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal() + return self._state.is_terminal # endregion diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index ddcbb8d9..6452be51 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -20,6 +20,8 @@ class Playing(state_machine.State): ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, track): assert track is not None, 'Must provide a track name' super().__init__(player) @@ -54,6 +56,8 @@ class Paused(state_machine.State): ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) @@ -77,6 +81,8 @@ class Stopped(state_machine.State): } TRANSITIONS = {PLAY: PLAYING} + is_terminal = False + def __str__(self): return '[]' From f023e9c1e3420e3c7156705196fbcc6329860484 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 17:07:29 +0100 Subject: [PATCH 04/21] forming the enter/exit for State protocol --- src/plumpy/base/state_machine.py | 66 ++++--------- src/plumpy/process_states.py | 153 ++++++++++++++++++------------- src/plumpy/workchains.py | 25 +++-- tests/base/test_statemachine.py | 44 +++++++-- tests/test_processes.py | 2 +- 5 files changed, 158 insertions(+), 132 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 217a7d51..d224db51 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -13,15 +13,17 @@ from typing import ( Any, Callable, + ClassVar, Dict, Hashable, Iterable, List, Optional, + Protocol, Sequence, - Set, Type, Union, + runtime_checkable, ) from plumpy.futures import Future @@ -88,12 +90,12 @@ def event( if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) - if not all(issubclass(state, State) for state in from_states): # type: ignore + if not all(isinstance(state, State) for state in from_states): # type: ignore raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) - if not all(issubclass(state, State) for state in to_states): # type: ignore + if not all(isinstance(state, State) for state in to_states): # type: ignore raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: @@ -127,53 +129,20 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -class State: - LABEL: LABEL_TYPE = None - # A set containing the labels of states that can be entered - # from this one - ALLOWED: Set[LABEL_TYPE] = set() +@runtime_checkable +class State(Protocol): + LABEL: ClassVar[LABEL_TYPE] - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): - """ - :param state_machine: The process this state belongs to - """ - self.state_machine = state_machine - self.in_state: bool = False - - def __str__(self) -> str: - return str(self.LABEL) - - @property - def label(self) -> LABEL_TYPE: - """Convenience property to get the state label""" - return self.LABEL - - @super_check - def enter(self) -> None: - """Entering the state""" - - def execute(self) -> Optional['State']: + async def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ + ... - @super_check - def exit(self) -> None: - """Exiting the state""" - if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': - return self.state_machine.create_state(state_label, *args, **kwargs) - - def do_enter(self) -> None: - call_with_super_check(self.enter) - self.in_state = True + def enter(self) -> None: ... - def do_exit(self) -> None: - call_with_super_check(self.exit) - self.in_state = False + def exit(self) -> None: ... class StateEventHook(enum.Enum): @@ -250,7 +219,7 @@ def __ensure_built(cls) -> None: # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: - assert issubclass(state_cls, State) + assert isinstance(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" cls._STATES_MAP[label] = state_cls @@ -382,7 +351,8 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat # This method should be replaced by `_create_state_instance`. # aiida-core using this method for its Waiting state override. try: - return self.get_states_map()[state_label](self, *args, **kwargs) + state_cls = self.get_states_map()[state_label] + return state_cls(self, *args, **kwargs) except KeyError: raise ValueError(f'{state_label} is not a valid state') @@ -392,20 +362,20 @@ def _exit_current_state(self, next_state: State) -> None: # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: - if next_state.label != self.initial_state_label(): + if next_state.LABEL != self.initial_state_label(): raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) - self._state.do_exit() + self._state.exit() def _enter_next_state(self, next_state: State) -> None: last_state = self._state self._fire_state_event(StateEventHook.ENTERING_STATE, next_state) # Enter the new state - next_state.do_enter() + next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 75f6c7f8..a2ba5177 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -131,6 +131,7 @@ class ProcessState(Enum): KILLED: str = 'killed' +@final @auto_persist('args', 'kwargs', 'in_state') class Created(state_machine.State, persistence.Savable): LABEL = ProcessState.CREATED @@ -140,11 +141,12 @@ class Created(state_machine.State, persistence.Savable): is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs + self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -152,21 +154,24 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + async def execute(self) -> state_machine.State: + return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False +@final @auto_persist('args', 'kwargs', 'in_state') class Running(state_machine.State, persistence.Savable): LABEL = ProcessState.RUNNING @@ -187,19 +192,17 @@ class Running(state_machine.State, persistence.Savable): _run_handle = None is_terminal = False + def __init__( self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any ) -> None: - super().__init__(process) assert run_fn is not None - self.run_fn = ensure_coroutine(run_fn) - # We wrap `run_fn` to a coroutine so we can apply await on it, - # even it if it was not a coroutine in the first place. - # This allows the same usage of async and non-async function - # with the await syntax while not changing the program logic. + self.process = process + self.run_fn = run_fn self.args = args self.kwargs = kwargs self._run_handle = None + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -209,7 +212,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) if self.COMMAND in saved_state: @@ -218,7 +221,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> state_machine.State: # type: ignore + async def execute(self) -> state_machine.State: if self._command is not None: command = self._command else: @@ -232,7 +235,7 @@ async def execute(self) -> state_machine.State: # type: ignore # Let this bubble up to the caller raise except Exception: - excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) + excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) return cast(state_machine.State, excepted) else: if not isinstance(result, Command): @@ -249,28 +252,30 @@ async def execute(self) -> state_machine.State: # type: ignore def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: if isinstance(command, Kill): - state = self.create_state(ProcessState.KILLED, command.msg) + state = self.process.create_state(ProcessState.KILLED, command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) else: raise ValueError('Unrecognised command') return cast(state_machine.State, state) # casting from base.State to process.State - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False +@final @auto_persist('msg', 'data', 'in_state') class Waiting(state_machine.State, persistence.Savable): LABEL = ProcessState.WAITING @@ -301,11 +306,12 @@ def __init__( msg: Optional[str] = None, data: Optional[Any] = None, ) -> None: - super().__init__(process) + self.process = process self.done_callback = done_callback self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -314,7 +320,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: @@ -338,9 +344,9 @@ async def execute(self) -> state_machine.State: # type: ignore raise if result == NULL: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result) return cast(state_machine.State, next_state) # casting from base.State to process.State @@ -352,12 +358,14 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('in_state') @@ -387,9 +395,10 @@ def __init__( :param exception: The exception instance :param trace_back: An optional exception traceback """ - super().__init__(process) + self.process = process self.exception = exception self.traceback = trace_back + self.in_state = False def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -403,7 +412,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -426,12 +435,17 @@ def get_exc_info( self.traceback, ) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + async def execute(self) -> state_machine.State: # type: ignore + ... + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('result', 'successful', 'in_state') @@ -447,20 +461,26 @@ class Finished(state_machine.State, persistence.Savable): is_terminal = True def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - super().__init__(process) + self.process = process self.result = result self.successful = successful - - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process + + def enter(self) -> None: + self.in_state = True + + async def execute(self) -> state_machine.State: # type: ignore + ... + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('msg', 'in_state') @@ -483,19 +503,24 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): :param process: The associated process :param msg: Optional kill message """ - super().__init__(process) + self.process = process self.msg = msg - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + async def execute(self) -> state_machine.State: # type: ignore + ... def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False # endregion diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 00a711b5..f6369c28 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -26,6 +26,7 @@ from plumpy.coordinator import Coordinator from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError from . import lang, mixins, persistence, process_spec, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -85,16 +86,6 @@ def __init__( resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key - def enter(self) -> None: - super().enter() - for awaitable in self._awaiting: - awaitable.add_done_callback(self._awaitable_done) - - def exit(self) -> None: - super().exit() - for awaitable in self._awaiting: - awaitable.remove_done_callback(self._awaitable_done) - def _awaitable_done(self, awaitable: asyncio.Future) -> None: key = self._awaiting.pop(awaitable) try: @@ -105,6 +96,20 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: if not self._awaiting: self._waiting_future.set_result(lang.NULL) + def enter(self) -> None: + for awaitable in self._awaiting: + awaitable.add_done_callback(self._awaitable_done) + + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False + + for awaitable in self._awaiting: + awaitable.remove_done_callback(self._awaitable_done) + class WorkChain(mixins.ContextMixin, processes.Process): """ diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 6452be51..07a2dc80 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- import time +from typing import final import unittest from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError # Events PLAY = 'Play' @@ -24,24 +26,16 @@ class Playing(state_machine.State): def __init__(self, player, track): assert track is not None, 'Must provide a track name' - super().__init__(player) self.track = track self._last_time = None self._played = 0.0 + self.in_state = False def __str__(self): if self.in_state: self._update_time() return f'> {self.track} ({self._played}s)' - def enter(self): - super().enter() - self._last_time = time.time() - - def exit(self): - super().exit() - self._update_time() - def play(self, track=None): return False @@ -50,6 +44,17 @@ def _update_time(self): self._played += current_time - self._last_time self._last_time = current_time + def enter(self) -> None: + self._last_time = time.time() + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self._update_time() + self.in_state = False + class Paused(state_machine.State): LABEL = PAUSED @@ -73,6 +78,15 @@ def play(self, track=None): else: self.state_machine.transition_to(self.playing_state) + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + class Stopped(state_machine.State): LABEL = STOPPED @@ -83,12 +97,24 @@ class Stopped(state_machine.State): is_terminal = False + def __init__(self, player): + self.state_machine = player + def __str__(self): return '[]' def play(self, track): self.state_machine.transition_to(Playing(self.state_machine, track=track)) + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + class CdPlayer(state_machine.StateMachine): STATES = (Stopped, Playing, Paused) diff --git a/tests/test_processes.py b/tests/test_processes.py index a05d09a3..c6576b7e 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -653,7 +653,7 @@ def test_exception_during_on_entered(self): class RaisingProcess(Process): def on_entered(self, from_state): - if from_state is not None and from_state.label == ProcessState.RUNNING: + if from_state is not None and from_state.LABEL == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') super().on_entered(from_state) From 50d7f650a0d28a5827856e00e1af4c09443a15bd Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 23:20:45 +0100 Subject: [PATCH 05/21] Forming Interruptable and Proceedable protocol --- src/plumpy/base/state_machine.py | 20 +++++++++++++++----- src/plumpy/process_states.py | 12 ++---------- src/plumpy/processes.py | 22 +++++++++++++++++++++- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index d224db51..ee238a83 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -132,18 +132,28 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): LABEL: ClassVar[LABEL_TYPE] + is_terminal: ClassVar[bool] - async def execute(self) -> State | None: + def enter(self) -> None: ... + + def exit(self) -> None: ... + + +@runtime_checkable +class Interruptable(Protocol): + def interrupt(self, reason: Exception) -> None: ... + + +@runtime_checkable +class Proceedable(Protocol): + + def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ ... - def enter(self) -> None: ... - - def exit(self) -> None: ... - class StateEventHook(enum.Enum): """ diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index a2ba5177..c82c34a3 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -275,6 +275,7 @@ def exit(self) -> None: self.in_state = False + @final @auto_persist('msg', 'data', 'in_state') class Waiting(state_machine.State, persistence.Savable): @@ -329,7 +330,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.done_callback = None self._waiting_future = futures.Future() - def interrupt(self, reason: Any) -> None: + def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) @@ -435,9 +436,6 @@ def get_exc_info( self.traceback, ) - async def execute(self) -> state_machine.State: # type: ignore - ... - def enter(self) -> None: self.in_state = True @@ -473,9 +471,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def enter(self) -> None: self.in_state = True - async def execute(self) -> state_machine.State: # type: ignore - ... - def exit(self) -> None: if self.is_terminal: raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') @@ -506,9 +501,6 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): self.process = process self.msg = msg - async def execute(self) -> state_machine.State: # type: ignore - ... - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 4a8e029c..c6f45b03 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -47,7 +47,15 @@ from . import events, exceptions, message, persistence, ports, process_states, utils from .base import state_machine -from .base.state_machine import StateEntryFailed, StateMachine, event +from .base.state_machine import ( + Interruptable, + Proceedable, + StateEntryFailed, + StateMachine, + StateMachineError, + TransitionFailed, + event, +) from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper from .futures import CancellableAction, capture_exceptions @@ -1127,6 +1135,11 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: return self._pausing if self._stepping: + if not isinstance(self._state, Interruptable): + raise exceptions.InvalidStateError( + f'cannot interrupt {self._state.__class__}, method `interrupt` not implement' + ) + # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.PauseInterruption(msg_text) @@ -1139,6 +1152,10 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: msg = MessageBuilder.pause(msg_text) return self._do_pause(state_msg=msg) + @staticmethod + def _interrupt(state: Interruptable, reason: Exception) -> None: + state.interrupt(reason) + def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: @@ -1327,6 +1344,9 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused + if not isinstance(self._state, Proceedable): + raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented') + try: self._stepping = True next_state = None From a111cd88653cf9ba0f2878c1bc0abb97e6450a9b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 3 Dec 2024 00:48:53 +0100 Subject: [PATCH 06/21] Refactoring create_state as static function initialize state from label create_state refact Hashable initialized + parameters passed to Hashable Fix pre-commit errors --- src/plumpy/base/state_machine.py | 47 +++--- src/plumpy/process_states.py | 238 ++++++++++++++++--------------- src/plumpy/processes.py | 41 +++--- src/plumpy/workchains.py | 13 +- tests/base/test_statemachine.py | 15 +- 5 files changed, 174 insertions(+), 180 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index ee238a83..e3912b6f 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -34,7 +34,6 @@ _LOGGER = logging.getLogger(__name__) -LABEL_TYPE = Union[None, enum.Enum, str] EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] @@ -131,9 +130,12 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): - LABEL: ClassVar[LABEL_TYPE] + LABEL: ClassVar[Any] + ALLOWED: ClassVar[set[Any]] is_terminal: ClassVar[bool] + def __init__(self, *args: Any, **kwargs: Any): ... + def enter(self) -> None: ... def exit(self) -> None: ... @@ -146,7 +148,6 @@ def interrupt(self, reason: Exception) -> None: ... @runtime_checkable class Proceedable(Protocol): - def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. @@ -155,6 +156,14 @@ def execute(self) -> State | None: ... +def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + if state_label not in st.get_states_map(): + raise ValueError(f'{state_label} is not a valid state') + + state_cls = st.get_states_map()[state_label] + return state_cls(*args, **kwargs) + + class StateEventHook(enum.Enum): """ Hooks that can be used to register callback at various points in the state transition @@ -203,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]: raise RuntimeError('States not defined') @classmethod - def initial_state_label(cls) -> LABEL_TYPE: + def initial_state_label(cls) -> Any: cls.__ensure_built() assert cls.STATES is not None return cls.STATES[0].LABEL @classmethod - def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: + def get_state_class(cls, label: Any) -> Type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None return cls._STATES_MAP[label] @@ -253,11 +262,11 @@ def init(self) -> None: def __str__(self) -> str: return f'<{self.__class__.__name__}> ({self.state})' - def create_initial_state(self) -> State: - return self.get_state_class(self.initial_state_label())(self) + def create_initial_state(self, *args: Any, **kwargs: Any) -> State: + return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Optional[LABEL_TYPE]: + def state(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -297,6 +306,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: The arguments are passed to the state class to create state instance. (process arg does not need to pass since it will always call with 'self' as process) """ + print(f'try: {self._state} -> {new_state}') assert not self._transitioning, 'Cannot call transition_to when already transitioning state' if new_state is None: @@ -355,17 +365,6 @@ def get_debug(self) -> bool: def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: - # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic - # because the label is defined after the state and required to be know before calling this function. - # This method should be replaced by `_create_state_instance`. - # aiida-core using this method for its Waiting state override. - try: - state_cls = self.get_states_map()[state_label] - return state_cls(self, *args, **kwargs) - except KeyError: - raise ValueError(f'{state_label} is not a valid state') - def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" @@ -377,7 +376,7 @@ def _exit_current_state(self, next_state: State) -> None: return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.exit() @@ -388,11 +387,3 @@ def _enter_next_state(self, next_state: State) -> None: next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - - def _create_state_instance(self, state_cls: Hashable, **kwargs: Any) -> State: - if state_cls not in self.get_states_map(): - raise ValueError(f'{state_cls} is not a valid state') - - cls = self.get_states_map()[state_cls] - - return cls(self, **kwargs) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index c82c34a3..e5065203 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,7 +5,21 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + ClassVar, + Optional, + Protocol, + Tuple, + Type, + Union, + cast, + final, + runtime_checkable, +) import yaml from yaml.loader import Loader @@ -20,9 +34,9 @@ _HAS_TBLIB = False from . import exceptions, futures, persistence, utils -from .base import state_machine +from .base import state_machine as st from .lang import NULL -from .persistence import auto_persist +from .persistence import LoadSaveContext, auto_persist from .utils import SAVED_STATE_TYPE, ensure_coroutine if TYPE_CHECKING: @@ -123,22 +137,28 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ - CREATED: str = 'created' - RUNNING: str = 'running' - WAITING: str = 'waiting' - FINISHED: str = 'finished' - EXCEPTED: str = 'excepted' - KILLED: str = 'killed' + # FIXME: see LSP error of return a exception, the type is Literal[str] which is invariant, tricky + CREATED = 'created' + RUNNING = 'running' + WAITING = 'waiting' + FINISHED = 'finished' + EXCEPTED = 'excepted' + KILLED = 'killed' + + +@runtime_checkable +class Savable(Protocol): + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... @final -@auto_persist('args', 'kwargs', 'in_state') -class Created(state_machine.State, persistence.Savable): - LABEL = ProcessState.CREATED - ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} +@auto_persist('args', 'kwargs') +class Created(persistence.Savable): + LABEL: ClassVar = ProcessState.CREATED + ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' - is_terminal = False + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: assert run_fn is not None @@ -146,7 +166,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.run_fn = run_fn self.args = args self.kwargs = kwargs - self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -158,24 +177,21 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - async def execute(self) -> state_machine.State: - return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) - - def enter(self) -> None: - self.in_state = True + def execute(self) -> st.State: + return st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.run_fn, *self.args, **self.kwargs + ) - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... @final -@auto_persist('args', 'kwargs', 'in_state') -class Running(state_machine.State, persistence.Savable): - LABEL = ProcessState.RUNNING - ALLOWED = { +@auto_persist('args', 'kwargs') +class Running(persistence.Savable): + LABEL: ClassVar = ProcessState.RUNNING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, @@ -191,18 +207,17 @@ class Running(state_machine.State, persistence.Savable): _running: bool = False _run_handle = None - is_terminal = False + is_terminal: ClassVar[bool] = False def __init__( self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any ) -> None: assert run_fn is not None self.process = process - self.run_fn = run_fn + self.run_fn = ensure_coroutine(run_fn) self.args = args self.kwargs = kwargs self._run_handle = None - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -221,7 +236,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> state_machine.State: + async def execute(self) -> st.State: if self._command is not None: command = self._command else: @@ -235,8 +250,10 @@ async def execute(self) -> state_machine.State: # Let this bubble up to the caller raise except Exception: - excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(state_machine.State, excepted) + _, exception, traceback = sys.exc_info() + # excepted = state_cls(exception=exception, traceback=traceback) + excepted = Excepted(exception=exception, traceback=traceback) + return excepted else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -245,42 +262,52 @@ async def execute(self) -> state_machine.State: # Got passed a basic return type result = Stop(result, True) - command = result + command = cast(Stop, result) next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State: if isinstance(command, Kill): - state = self.process.create_state(ProcessState.KILLED, command.msg) + state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful) + state = st.create_state( + self.process, ProcessState.FINISHED, result=command.result, successful=command.successful + ) elif isinstance(command, Wait): - state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = st.create_state( + self.process, + ProcessState.WAITING, + process=self.process, + done_callback=command.continue_fn, + msg=command.msg, + data=command.data, + ) elif isinstance(command, Continue): - state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = st.create_state( + self.process, + ProcessState.RUNNING, + process=self.process, + run_fn=command.continue_fn, + *command.args, + **command.kwargs, + ) else: raise ValueError('Unrecognised command') - return cast(state_machine.State, state) # casting from base.State to process.State + return state - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@final -@auto_persist('msg', 'data', 'in_state') -class Waiting(state_machine.State, persistence.Savable): - LABEL = ProcessState.WAITING - ALLOWED = { +@auto_persist('msg', 'data') +class Waiting(persistence.Savable): + LABEL: ClassVar = ProcessState.WAITING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, @@ -292,7 +319,7 @@ class Waiting(state_machine.State, persistence.Savable): _interruption = None - is_terminal = False + is_terminal: ClassVar[bool] = False def __str__(self) -> str: state_info = super().__str__() @@ -312,7 +339,6 @@ def __init__( self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -334,7 +360,7 @@ def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> state_machine.State: # type: ignore + async def execute(self) -> st.State: try: result = await self._waiting_future except Interruption: @@ -345,11 +371,15 @@ async def execute(self) -> state_machine.State: # type: ignore raise if result == NULL: - next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.done_callback + ) else: - next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result + ) - return cast(state_machine.State, next_state) # casting from base.State to process.State + return next_state def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -359,47 +389,39 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('in_state') -class Excepted(state_machine.State, persistence.Savable): +@final +class Excepted(persistence.Savable): """ - Excepted state, can optionally provide exception and trace_back + Excepted state, can optionally provide exception and traceback :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - LABEL = ProcessState.EXCEPTED + LABEL: ClassVar = ProcessState.EXCEPTED + ALLOWED: ClassVar[set[str]] = set() EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' - is_terminal = True + is_terminal: ClassVar = True def __init__( self, - process: 'Process', exception: Optional[BaseException], - trace_back: Optional[TracebackType] = None, + traceback: Optional[TracebackType] = None, ): """ - :param process: The associated process :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - self.process = process self.exception = exception - self.traceback = trace_back - self.in_state = False + self.traceback = traceback def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -413,7 +435,6 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -436,50 +457,40 @@ def get_exc_info( self.traceback, ) - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('result', 'successful', 'in_state') -class Finished(state_machine.State, persistence.Savable): +@final +@auto_persist('result', 'successful') +class Finished(persistence.Savable): """State for process is finished. :param result: The result of process :param successful: Boolean for the exit code is ``0`` the process is successful. """ - LABEL = ProcessState.FINISHED + LABEL: ClassVar = ProcessState.FINISHED + ALLOWED: ClassVar[set[str]] = set() - is_terminal = True + is_terminal: ClassVar[bool] = True - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - self.process = process + def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful - self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('msg', 'in_state') -class Killed(state_machine.State, persistence.Savable): +@final +@auto_persist('msg') +class Killed(persistence.Savable): """ Represents a state where a process has been killed. @@ -489,30 +500,23 @@ class Killed(state_machine.State, persistence.Savable): :param msg: An optional message explaining the reason for the process termination. """ - LABEL = ProcessState.KILLED + LABEL: ClassVar = ProcessState.KILLED + ALLOWED: ClassVar[set[str]] = set() - is_terminal = True + is_terminal: ClassVar[bool] = True - def __init__(self, process: 'Process', msg: Optional[MessageType]): + def __init__(self, msg: Optional[MessageType]): """ - :param process: The associated process :param msg: Optional kill message """ - self.process = process self.msg = msg def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process - - def enter(self) -> None: - self.in_state = True - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... # endregion diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index c6f45b03..45c64024 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -53,7 +53,7 @@ StateEntryFailed, StateMachine, StateMachineError, - TransitionFailed, + create_state, event, ) from .base.utils import call_with_super_check, super_check @@ -189,7 +189,7 @@ def get_states(cls) -> Sequence[Type[state_machine.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: + def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -629,7 +629,9 @@ def save_instance_state( """ super().save_instance_state(out_state, save_context) - out_state['_state'] = self._state.save() + # FIXME: the combined ProcessState protocol should cover the case + if isinstance(self._state, process_states.Savable): + out_state['_state'] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -875,7 +877,7 @@ def on_finish(self, result: Any, successful: bool) -> None: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] - finished_state = state_cls(self, result=result, successful=False) + finished_state = state_cls(result=result, successful=False) raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -1108,9 +1110,8 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - new_state = self._create_state_instance( - process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace - ) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) self.transition_to(new_state) def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: @@ -1190,9 +1191,11 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> Ca def do_kill(_next_state: state_machine.State) -> Any: try: - new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=exception.msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=exception.msg) self.transition_to(new_state) return True + # FIXME: if try block except, will hit deadlock in event loop + # need to know how to debug it, and where to set a timeout. finally: self._killing = None @@ -1237,15 +1240,14 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure - :param trace_back: Optional exception traceback + :param traceback: Optional exception traceback """ - new_state = self._create_state_instance( - process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace_back - ) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback) self.transition_to(new_state) def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: @@ -1265,7 +1267,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: # Already killing return self._killing - if self._stepping: + if self._stepping and isinstance(self._state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg_text) @@ -1275,7 +1277,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: return cast(CancellableAction, self._interrupt_action) msg = MessageBuilder.kill(msg_text) - new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True @@ -1293,10 +1295,7 @@ def create_initial_state(self) -> state_machine.State: :return: A Created state """ - return cast( - state_machine.State, - self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), - ) + return self.get_state_class(process_states.ProcessState.CREATED)(self, self.run) def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ @@ -1368,7 +1367,9 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + next_state = create_state( + self, process_states.ProcessState.EXCEPTED, exception=sys.exc_info()[1], traceback=sys.exc_info()[2] + ) self._set_interrupt_action(None) if self._interrupt_action: diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index f6369c28..cf5108c8 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -11,7 +11,6 @@ Any, Callable, Dict, - Hashable, List, Mapping, MutableSequence, @@ -23,9 +22,8 @@ cast, ) -from plumpy.coordinator import Coordinator - from plumpy.base import state_machine +from plumpy.coordinator import Coordinator from plumpy.exceptions import InvalidStateError from . import lang, mixins, persistence, process_spec, process_states, processes @@ -69,6 +67,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: return self._outline +# FIXME: better use composition here @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): """Overwrite the waiting state""" @@ -78,11 +77,11 @@ def __init__( process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, + data: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: - super().__init__(process, done_callback, msg, awaiting) + super().__init__(process, done_callback, msg, data) self._awaiting: Dict[asyncio.Future, str] = {} - for awaitable, key in (awaiting or {}).items(): + for awaitable, key in (data or {}).items(): resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key @@ -122,7 +121,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: + def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 07a2dc80..f046aaa8 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -17,7 +17,7 @@ STOPPED = 'Stopped' -class Playing(state_machine.State): +class Playing: LABEL = PLAYING ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} @@ -56,7 +56,7 @@ def exit(self) -> None: self.in_state = False -class Paused(state_machine.State): +class Paused: LABEL = PAUSED ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} @@ -65,7 +65,6 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' - super().__init__(player) self._player = player self.playing_state = playing_state @@ -74,9 +73,9 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) + self._player.transition_to(Playing(player=self.state_machine, track=track)) else: - self.state_machine.transition_to(self.playing_state) + self._player.transition_to(self.playing_state) def enter(self) -> None: self.in_state = True @@ -88,7 +87,7 @@ def exit(self) -> None: self.in_state = False -class Stopped(state_machine.State): +class Stopped: LABEL = STOPPED ALLOWED = { PLAYING, @@ -98,13 +97,13 @@ class Stopped(state_machine.State): is_terminal = False def __init__(self, player): - self.state_machine = player + self._player = player def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing(self.state_machine, track=track)) + self._player.transition_to(Playing(self._player, track=track)) def enter(self) -> None: self.in_state = True From 11f7518c11680f017c92225f1de45bcd6761f89b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 4 Dec 2024 16:28:43 +0100 Subject: [PATCH 07/21] To lenthy for rethinking --- src/plumpy/persistence.py | 82 ++++++++++++++++-------------------- src/plumpy/process_states.py | 1 - 2 files changed, 37 insertions(+), 46 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index f7cbad44..4bc33158 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -420,28 +420,6 @@ class Savable: _auto_persist: Optional[Set[str]] = None _persist_configured = False - @staticmethod - def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Load a `Savable` from a saved instance state. The load context is a way of passing - runtime data to the object being loaded. - - :param saved_state: The saved state - :param load_context: Additional runtime state that can be passed into when loading. - The type and content (if any) is completely user defined - :return: The loaded Savable instance - - """ - load_context = _ensure_object_loader(load_context, saved_state) - assert load_context.loader is not None # required for type checking - try: - class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) - except KeyError: - raise ValueError('Class name not found in saved state') - else: - return load_cls.recreate_from(saved_state, load_context) - @classmethod def auto_persist(cls, *members: str) -> None: if cls._auto_persist is None: @@ -472,13 +450,48 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: self._ensure_persist_configured() if self._auto_persist is not None: - self.load_members(self._auto_persist, saved_state, load_context) + for member in self._auto_persist: + setattr(self, member, self._get_value(saved_state, member, load_context)) + + @staticmethod + def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Load a `Savable` from a saved instance state. The load context is a way of passing + runtime data to the object being loaded. + + :param saved_state: The saved state + :param load_context: Additional runtime state that can be passed into when loading. + The type and content (if any) is completely user defined + :return: The loaded Savable instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + assert load_context.loader is not None # required for type checking + try: + class_name = Savable._get_class_name(saved_state) + load_cls = load_context.loader.load_object(class_name) + except KeyError: + raise ValueError('Class name not found in saved state') + else: + return load_cls.recreate_from(saved_state, load_context) @super_check def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: self._ensure_persist_configured() if self._auto_persist is not None: - self.save_members(self._auto_persist, out_state) + for member in self._auto_persist: + value = getattr(self, member) + if inspect.ismethod(value): + if value.__self__ is not self: + raise TypeError('Cannot persist methods of other classes') + Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, Savable): + Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = {} @@ -501,27 +514,6 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY call_with_super_check(self.save_instance_state, out_state, save_context) return out_state - def save_members(self, members: Iterable[str], out_state: SAVED_STATE_TYPE) -> None: - for member in members: - value = getattr(self, member) - if inspect.ismethod(value): - if value.__self__ is not self: - raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - - def load_members( - self, members: Iterable[str], saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None - ) -> None: - for member in members: - setattr(self, member, self._get_value(saved_state, member, load_context)) - def _ensure_persist_configured(self) -> None: if not self._persist_configured: self.persist() diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index e5065203..acb83931 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -137,7 +137,6 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ - # FIXME: see LSP error of return a exception, the type is Literal[str] which is invariant, tricky CREATED = 'created' RUNNING = 'running' WAITING = 'waiting' From da45d60e20f0c67d8f7ecd1b4acdfe43462f052f Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 4 Dec 2024 23:33:13 +0100 Subject: [PATCH 08/21] Move static method load outside --- src/plumpy/persistence.py | 102 +++++++++++++++++++------------------- src/plumpy/processes.py | 2 +- 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 4bc33158..fe761e1d 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -23,8 +23,33 @@ from .processes import Process +class LoadSaveContext: + def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: + self._values = dict(**kwargs) + self.loader = loader + + def __getattr__(self, item: str) -> Any: + try: + return self._values[item] + except KeyError: + raise AttributeError(f"item '{item}' not found") + + def __iter__(self) -> Iterable[Any]: + return self._value.__iter__() + + def __contains__(self, item: Any) -> bool: + return self._values.__contains__(item) + + def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': + """Add additional information to the context by making a copy with the new values""" + extended = self._values.copy() + extended.update(kwargs) + loader = extended.pop('loader', self.loader) + return LoadSaveContext(loader=loader, **extended) + + class Bundle(dict): - def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False): + def __init__(self, savable: 'Savable', save_context: LoadSaveContext | None = None, dereference: bool = False): """ Create a bundle from a savable. Optionally keep information about the class loader that can be used to load the classes in the bundle. @@ -40,7 +65,7 @@ class loader that can be used to load the classes in the bundle. else: self.update(savable.save(save_context)) - def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable': + def unbundle(self, load_context: LoadSaveContext | None = None) -> 'Savable': """ This method loads the class of the object and calls its recreate_from method passing the positional and keyword arguments. @@ -49,7 +74,29 @@ def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable :return: An instance of the Savable """ - return Savable.load(self, load_context) + return load(self, load_context) + + +def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': + """ + Load a `Savable` from a saved instance state. The load context is a way of passing + runtime data to the object being loaded. + + :param saved_state: The saved state + :param load_context: Additional runtime state that can be passed into when loading. + The type and content (if any) is completely user defined + :return: The loaded Savable instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + assert load_context.loader is not None # required for type checking + try: + class_name = Savable._get_class_name(saved_state) + load_cls = load_context.loader.load_object(class_name) + except KeyError: + raise ValueError('Class name not found in saved state') + else: + return load_cls.recreate_from(saved_state, load_context) _BUNDLE_TAG = '!plumpy:Bundle' @@ -380,31 +427,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV return context.copyextend(loader=loader) -class LoadSaveContext: - def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: - self._values = dict(**kwargs) - self.loader = loader - - def __getattr__(self, item: str) -> Any: - try: - return self._values[item] - except KeyError: - raise AttributeError(f"item '{item}' not found") - - def __iter__(self) -> Iterable[Any]: - return self._value.__iter__() - - def __contains__(self, item: Any) -> bool: - return self._values.__contains__(item) - - def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': - """Add additional information to the context by making a copy with the new values""" - extended = self._values.copy() - extended.update(kwargs) - loader = extended.pop('loader', self.loader) - return LoadSaveContext(loader=loader, **extended) - - META: str = '!!meta' META__CLASS_NAME: str = 'class_name' META__OBJECT_LOADER: str = 'object_loader' @@ -453,28 +475,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio for member in self._auto_persist: setattr(self, member, self._get_value(saved_state, member, load_context)) - @staticmethod - def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Load a `Savable` from a saved instance state. The load context is a way of passing - runtime data to the object being loaded. - - :param saved_state: The saved state - :param load_context: Additional runtime state that can be passed into when loading. - The type and content (if any) is completely user defined - :return: The loaded Savable instance - - """ - load_context = _ensure_object_loader(load_context, saved_state) - assert load_context.loader is not None # required for type checking - try: - class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) - except KeyError: - raise ValueError('Class name not found in saved state') - else: - return load_cls.recreate_from(saved_state, load_context) - @super_check def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: self._ensure_persist_configured() @@ -568,7 +568,7 @@ def _get_value( if typ == META__TYPE__METHOD: value = getattr(self, value) elif typ == META__TYPE__SAVABLE: - value = Savable.load(value, load_context) + value = load(value, load_context) return value diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 45c64024..d088a5f8 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1305,7 +1305,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(state_machine.State, persistence.Savable.load(saved_state, load_context)) + return cast(state_machine.State, persistence.load(saved_state, load_context)) # endregion From 3993e4c79901908922a8f42478044ebb6d2e6777 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 11:01:32 +0100 Subject: [PATCH 09/21] save_instance_state simplify to only has save interface For the auto_persist attributes, the fn auto_save will take care of save the state --- src/plumpy/mixins.py | 13 ------ src/plumpy/persistence.py | 82 ++++++++++++++++++--------------- src/plumpy/process_states.py | 44 ++++++++++++------ src/plumpy/processes.py | 13 +++--- src/plumpy/workchains.py | 89 +++++++++++++++++++++++++++++++----- tests/test_processes.py | 2 +- 6 files changed, 157 insertions(+), 86 deletions(-) diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py index 4b993dac..0e3bb0c0 100644 --- a/src/plumpy/mixins.py +++ b/src/plumpy/mixins.py @@ -21,19 +21,6 @@ def __init__(self, *args: Any, **kwargs: Any): def ctx(self) -> Optional[AttributesDict]: return self._context - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - """Add the instance state to ``out_state``. - .. important:: - - The instance state will contain a pointer to the ``ctx``, - and so should be deep copied or serialised before persisting. - """ - super().save_instance_state(out_state, save_context) - if self._context is not None: - out_state[self.CONTEXT] = self._context.__dict__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) try: diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index fe761e1d..389acc27 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -92,7 +92,7 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N assert load_context.loader is not None # required for type checking try: class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) + load_cls: Savable = load_context.loader.load_object(class_name) except KeyError: raise ValueError('Class name not found in saved state') else: @@ -475,43 +475,9 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio for member in self._auto_persist: setattr(self, member, self._get_value(saved_state, member, load_context)) - @super_check - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - for member in self._auto_persist: - value = getattr(self, member) - if inspect.ismethod(value): - if value.__self__ is not self: - raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = {} - - if save_context is None: - save_context = LoadSaveContext() + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - utils.type_check(save_context, LoadSaveContext) - - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader - - Savable._set_class_name(out_state, loader.identify_object(self.__class__)) - call_with_super_check(self.save_instance_state, out_state, save_context) return out_state def _ensure_persist_configured(self) -> None: @@ -581,11 +547,13 @@ class SavableFuture(futures.Future, Savable): .. note: This does not save any assigned done callbacks. """ - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) if self.done() and self.exception() is not None: out_state['exception'] = self.exception() + return out_state + @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -631,3 +599,41 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadS # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list for callback in self._callbacks: self.remove_done_callback(callback) # type: ignore[arg-type] + + +def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = {} + + if save_context is None: + save_context = LoadSaveContext() + + utils.type_check(save_context, LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + Savable._set_class_name(out_state, loader.identify_object(obj.__class__)) + + obj._ensure_persist_configured() + if obj._auto_persist is not None: + for member in obj._auto_persist: + value = getattr(obj, member) + if inspect.ismethod(value): + if value.__self__ is not obj: + raise TypeError('Cannot persist methods of other classes') + Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, Savable): + Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value + + return out_state diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index acb83931..1b4f610c 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import copy +import inspect import sys import traceback from enum import Enum @@ -36,7 +38,7 @@ from . import exceptions, futures, persistence, utils from .base import state_machine as st from .lang import NULL -from .persistence import LoadSaveContext, auto_persist +from .persistence import META__OBJECT_LOADER, META__TYPE__METHOD, META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_persist, auto_save from .utils import SAVED_STATE_TYPE, ensure_coroutine if TYPE_CHECKING: @@ -113,10 +115,12 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process @@ -145,10 +149,9 @@ class ProcessState(Enum): KILLED = 'killed' -@runtime_checkable -class Savable(Protocol): - def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... - +# @runtime_checkable +# class Savable(Protocol): +# def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... @final @auto_persist('args', 'kwargs') @@ -166,10 +169,12 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process @@ -218,12 +223,15 @@ def __init__( self.kwargs = kwargs self._run_handle = None - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process @@ -339,11 +347,14 @@ def __init__( self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process @@ -426,12 +437,15 @@ def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] return super().__str__() + f'({exception})' - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index d088a5f8..6aba39a3 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -10,6 +10,7 @@ import copy import enum import functools +import inspect import logging import re import sys @@ -74,7 +75,7 @@ class BundleKeys: """ String keys used by the process to save its state in the state bundle. - See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`. + See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.load_instance_state`. """ @@ -616,18 +617,14 @@ async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) # region Persistence - def save_instance_state( - self, - out_state: SAVED_STATE_TYPE, - save_context: Optional[persistence.LoadSaveContext], - ) -> None: + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: """ Ask the process to save its current instance state. :param out_state: A bundle to save the state to :param save_context: The save context """ - super().save_instance_state(out_state, save_context) + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) # FIXME: the combined ProcessState protocol should cover the case if isinstance(self._state, process_states.Savable): @@ -643,6 +640,8 @@ def save_instance_state( if self.outputs: out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) + return out_state + @protected def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: """Load the process from its saved instance state. diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index cf5108c8..a1009acd 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import copy import abc import asyncio import collections @@ -28,6 +29,7 @@ from . import lang, mixins, persistence, process_spec, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE +from plumpy import loaders, utils ToContext = dict @@ -146,15 +148,69 @@ def on_create(self) -> None: super().on_create() self._stepper = self.spec().get_outline().create_stepper(self) - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + """ + Ask the process to save its current instance state. + + :param out_state: A bundle to save the state to + :param save_context: The save context + """ + out_state: SAVED_STATE_TYPE = {} + + if save_context is None: + save_context = persistence.LoadSaveContext() + + utils.type_check(save_context, persistence.LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__)) + + self._ensure_persist_configured() + if self._auto_persist is not None: + for member in self._auto_persist: + value = getattr(self, member) + if inspect.ismethod(value): + if value.__self__ is not self: + raise TypeError('Cannot persist methods of other classes') + persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, persistence.Savable): + persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value + + if isinstance(self._state, process_states.Savable): + out_state['_state'] = self._state.save() + + # Inputs/outputs + if self.raw_inputs is not None: + out_state[processes.BundleKeys.INPUTS_RAW] = self.encode_input_args(self.raw_inputs) + + if self.inputs is not None: + out_state[processes.BundleKeys.INPUTS_PARSED] = self.encode_input_args(self.inputs) + + if self.outputs: + out_state[processes.BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) # Ask the stepper to save itself if self._stepper is not None: out_state[self._STEPPER_STATE] = self._stepper.save() + if self._context is not None: + out_state[self.CONTEXT] = self._context.__dict__ + + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) @@ -250,10 +306,12 @@ def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): super().__init__(workchain) self._fn = fn - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state['_fn'] = self._fn.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._fn = getattr(self._workchain.__class__, saved_state['_fn']) @@ -323,11 +381,13 @@ def next_instruction(self) -> None: def finished(self) -> bool: return self._pos == len(self._block) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._block = load_context.block_instruction @@ -461,11 +521,13 @@ def step(self) -> Tuple[bool, Any]: def finished(self) -> bool: return self._pos == len(self._if_instruction) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._if_instruction = load_context.if_instruction @@ -555,11 +617,14 @@ def step(self) -> Tuple[bool, Any]: return False, result - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) + if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._while_instruction = load_context.while_instruction diff --git a/tests/test_processes.py b/tests/test_processes.py index c6576b7e..a634d4e5 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -700,7 +700,7 @@ def step2(self): class TestProcessSaving(unittest.TestCase): maxDiff = None - def test_running_save_instance_state(self): + def test_running_save(self): loop = asyncio.get_event_loop() nsync_comeback = SavePauseProc() From dab66ccffe1c806b321ea2e93f884ed5c55dccae Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 17:07:58 +0100 Subject: [PATCH 10/21] load_instance_state deabstract simplify - stepper de-abstract - remove ContextMixin - Stepper all using recreate_from --- src/plumpy/persistence.py | 17 +-- src/plumpy/process_states.py | 163 ++++++++++++++++++++++------ src/plumpy/processes.py | 15 +-- src/plumpy/workchains.py | 202 +++++++++++++++++++++++++++-------- 4 files changed, 308 insertions(+), 89 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 389acc27..0cb7b800 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -465,15 +465,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = _ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) - call_with_super_check(obj.load_instance_state, saved_state, load_context) + obj.load_instance_state(saved_state, load_context) return obj - @super_check def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - for member in self._auto_persist: - setattr(self, member, self._get_value(saved_state, member, load_context)) + auto_load(self, saved_state, load_context) def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = auto_save(self, save_context) @@ -594,7 +590,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa return obj def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + auto_load(self, saved_state, load_context) + if self._callbacks: # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list for callback in self._callbacks: @@ -637,3 +634,9 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S out_state[member] = value return out_state + +def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: + obj._ensure_persist_configured() + if obj._auto_persist is not None: + for member in obj._auto_persist: + setattr(obj, member, obj._get_value(saved_state, member, load_context)) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 1b4f610c..c03fd78d 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -121,14 +121,28 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + + obj.state_machine = load_context.process try: - self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) + obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN]) except ValueError: process = load_context.process - self.continue_fn = getattr(process, saved_state[self.CONTINUE_FN]) + obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN]) + return obj # endregion @@ -153,6 +167,7 @@ class ProcessState(Enum): # class Savable(Protocol): # def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... + @final @auto_persist('args', 'kwargs') class Created(persistence.Savable): @@ -175,11 +190,27 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.process = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + + auto_load(obj, saved_state, load_context) + + obj.process = load_context.process + + obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) + + return obj def execute(self) -> st.State: return st.create_state( @@ -232,13 +263,28 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.process = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + + obj.process = load_context.process - self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) - if self.COMMAND in saved_state: - self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore + obj.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) + if obj.COMMAND in saved_state: + # FIXME: typing + obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore + return obj def interrupt(self, reason: Any) -> None: pass @@ -355,16 +401,30 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.process = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + + obj.process = load_context.process - callback_name = saved_state.get(self.DONE_CALLBACK, None) + callback_name = saved_state.get(obj.DONE_CALLBACK, None) if callback_name is not None: - self.done_callback = getattr(self.process, callback_name) + obj.done_callback = getattr(obj.process, callback_name) else: - self.done_callback = None - self._waiting_future = futures.Future() + obj.done_callback = None + obj._waiting_future = futures.Future() + return obj def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception @@ -446,17 +506,30 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) - self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) + obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: - self.traceback = tblib.Traceback.from_string(saved_state[self.TRACEBACK], strict=False) + obj.traceback = tblib.Traceback.from_string(saved_state[obj.TRACEBACK], strict=False) except KeyError: - self.traceback = None + obj.traceback = None else: - self.traceback = None + obj.traceback = None + return obj def get_exc_info( self, @@ -493,8 +566,21 @@ def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj def enter(self) -> None: ... @@ -524,8 +610,21 @@ def __init__(self, msg: Optional[MessageType]): """ self.msg = msg - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj def enter(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 6aba39a3..fd8c761a 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -10,7 +10,6 @@ import copy import enum import functools -import inspect import logging import re import sys @@ -38,6 +37,7 @@ import kiwipy from plumpy.coordinator import Coordinator +from plumpy.persistence import _ensure_object_loader try: from aiocontextvars import ContextVar @@ -267,9 +267,12 @@ def recreate_from( :return: An instance of the object with its state loaded from the save state. """ - process = cast(Process, super().recreate_from(saved_state, load_context)) - call_with_super_check(process.init) - return process + load_context = _ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + proc.load_instance_state(saved_state, load_context) + + call_with_super_check(proc.init) + return proc def __init__( self, @@ -651,7 +654,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi """ # First make sure the state machine constructor is called - super().__init__() + state_machine.StateMachine.__init__(self) self._setup_event_hooks() @@ -675,7 +678,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - super().load_instance_state(saved_state, load_context) + persistence.auto_load(self, saved_state, load_context) # Inputs/outputs try: diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index a1009acd..66446619 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -16,6 +16,7 @@ Mapping, MutableSequence, Optional, + Protocol, Sequence, Tuple, Type, @@ -25,11 +26,14 @@ from plumpy.base import state_machine from plumpy.coordinator import Coordinator +from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError +from plumpy.process_listener import ProcessListener from . import lang, mixins, persistence, process_spec, process_states, processes -from .utils import PID_TYPE, SAVED_STATE_TYPE +from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict from plumpy import loaders, utils +from plumpy.persistence import _ensure_object_loader ToContext = dict @@ -101,18 +105,15 @@ def enter(self) -> None: for awaitable in self._awaiting: awaitable.add_done_callback(self._awaitable_done) - self.in_state = True - def exit(self) -> None: if self.is_terminal: raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - self.in_state = False for awaitable in self._awaiting: awaitable.remove_done_callback(self._awaitable_done) -class WorkChain(mixins.ContextMixin, processes.Process): +class WorkChain(processes.Process): """ A WorkChain is a series of instructions carried out with the ability to save state in between. @@ -120,7 +121,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _spec_class = WorkChainSpec _STEPPER_STATE = 'stepper_state' - _CONTEXT = 'CONTEXT' + CONTEXT = 'CONTEXT' @classmethod def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: @@ -137,9 +138,14 @@ def __init__( coordinator: Optional[Coordinator] = None, ) -> None: super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, coordinator=coordinator) + self._context: Optional[AttributesDict] = AttributesDict() self._stepper: Optional[Stepper] = None self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {} + @property + def ctx(self) -> Optional[AttributesDict]: + return self._context + @classmethod def spec(cls) -> WorkChainSpec: return cast(WorkChainSpec, super().spec()) @@ -212,7 +218,63 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + ######### + # FIXME: dup of Process.load_instance_state + state_machine.StateMachine.__init__(self) + + self._setup_event_hooks() + + # Runtime variables, set initial states + self._future = persistence.SavableFuture() + self._event_helper = EventHelper(ProcessListener) + self._logger = None + self._communicator = None + + if 'loop' in load_context: + self._loop = load_context.loop + else: + self._loop = asyncio.get_event_loop() + + self._state: state_machine.State = self.recreate_state(saved_state['_state']) + + if 'communicator' in load_context: + self._communicator = load_context.communicator + + if 'logger' in load_context: + self._logger = load_context.logger + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.auto_load(self, saved_state, load_context) + + # Inputs/outputs + try: + decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) + self._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + self._raw_inputs = None + + try: + decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) + self._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + self._parsed_inputs = None + + try: + decoded = self.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) + self._outputs = decoded + except KeyError: + self._outputs = {} + + # + ######### + + # context mixin + try: + self._context = AttributesDict(**saved_state[self.CONTEXT]) + except KeyError: + pass + + # end of context mixin # Recreate the stepper self._stepper = None @@ -255,15 +317,8 @@ def _do_step(self) -> Any: return return_value -class Stepper(persistence.Savable, metaclass=abc.ABCMeta): - def __init__(self, workchain: 'WorkChain') -> None: - self._workchain = workchain - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._workchain = load_context.workchain - - @abc.abstractmethod +# XXX: Stepper is also a Saver with `save` method. +class Stepper(Protocol): def step(self) -> Tuple[bool, Any]: """ Execute on step of the instructions. @@ -272,6 +327,7 @@ def step(self) -> Tuple[bool, Any]: 1. The return value from the executed step """ + ... class _Instruction(metaclass=abc.ABCMeta): @@ -301,9 +357,9 @@ def get_description(self) -> Any: """ -class _FunctionStepper(Stepper): +class _FunctionStepper(persistence.Savable): def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): - super().__init__(workchain) + self._workchain = workchain self._fn = fn def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -312,9 +368,24 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._fn = getattr(self._workchain.__class__, saved_state['_fn']) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._fn = getattr(obj._workchain.__class__, saved_state['_fn']) + + return obj def step(self) -> Tuple[bool, Any]: return True, self._fn(self._workchain) @@ -354,9 +425,9 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') -class _BlockStepper(Stepper): +class _BlockStepper(persistence.Savable): def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._block = block self._pos: int = 0 self._child_stepper: Optional[Stepper] = self._block[0].create_stepper(self._workchain) @@ -388,13 +459,28 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._block = load_context.block_instruction + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._block = load_context.block_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._block[self._pos].recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._block[obj._pos].recreate_stepper(stepper_state, obj._workchain) + + return obj def __str__(self) -> str: return str(self._pos) + ':' + str(self._child_stepper) @@ -487,9 +573,9 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') -class _IfStepper(Stepper): +class _IfStepper(persistence.Savable): def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._if_instruction = if_instruction self._pos = 0 self._child_stepper: Optional[Stepper] = None @@ -528,13 +614,27 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._if_instruction = load_context.if_instruction + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._if_instruction = load_context.if_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._if_instruction[self._pos].body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._if_instruction[obj._pos].body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._if_instruction[self._pos]) @@ -596,9 +696,9 @@ def get_description(self) -> Mapping[str, Any]: return description -class _WhileStepper(Stepper): +class _WhileStepper(persistence.Savable): def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._while_instruction = while_instruction self._child_stepper: Optional[_BlockStepper] = None @@ -625,13 +725,27 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._while_instruction = load_context.while_instruction + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._while_instruction = load_context.while_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._while_instruction.body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._while_instruction.body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._while_instruction) @@ -669,9 +783,9 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code -class _ReturnStepper(Stepper): +class _ReturnStepper(persistence.Savable): def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._return_instruction = return_instruction def step(self) -> Tuple[bool, Any]: From 4bf5e995082b8ce6f819e8bba341ba0e9830eb3a Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 22:07:15 +0100 Subject: [PATCH 11/21] ProcessListener recreate_from --- src/plumpy/process_listener.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index c3ab8e5a..e84b504d 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any, Dict, Optional from . import persistence -from .utils import SAVED_STATE_TYPE, protected +from .utils import SAVED_STATE_TYPE +from plumpy.persistence import LoadSaveContext, _ensure_object_loader if TYPE_CHECKING: from .processes import Process @@ -20,12 +21,21 @@ def __init__(self) -> None: def init(self, **kwargs: Any) -> None: self._params = kwargs - @protected - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] - ) -> None: - super().load_instance_state(saved_state, load_context) - self.init(**saved_state['_params']) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + obj.init(**saved_state['_params']) + return obj # endregion From cc80a0b116e202f493ed6784de27848e113944ef Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 22:24:38 +0100 Subject: [PATCH 12/21] Absorb all load_instance_state into recreate_from --- src/plumpy/persistence.py | 23 ++++---- src/plumpy/processes.py | 112 +++++++++++++++++--------------------- src/plumpy/workchains.py | 80 ++++++++++++++++----------- 3 files changed, 110 insertions(+), 105 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 0cb7b800..85b2d96d 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -465,12 +465,9 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = _ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) - obj.load_instance_state(saved_state, load_context) + auto_load(obj, saved_state, load_context) return obj - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: - auto_load(self, saved_state, load_context) - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = auto_save(self, save_context) @@ -587,15 +584,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa obj = cls(loop=loop) obj.cancel() - return obj - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - auto_load(self, saved_state, load_context) + # ## XXX: load_instance_state: test not cover + # auto_load(obj, saved_state, load_context) + # + # if obj._callbacks: + # # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list + # for callback in obj._callbacks: + # obj.remove_done_callback(callback) # type: ignore[arg-type] + # ## UNTILHERE XXX: - if self._callbacks: - # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list - for callback in self._callbacks: - self.remove_done_callback(callback) # type: ignore[arg-type] + return obj def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -635,6 +633,7 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S return out_state + def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: obj._ensure_persist_configured() if obj._auto_persist is not None: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index fd8c761a..f591ba1a 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -75,7 +75,7 @@ class BundleKeys: """ String keys used by the process to save its state in the state bundle. - See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.load_instance_state`. + See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.recreate_from`. """ @@ -257,10 +257,8 @@ def recreate_from( cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None, - ) -> 'Process': - """ - Recreate a process from a saved state, passing any positional and - keyword arguments on to load_instance_state + ) -> Process: + """Recreate a process from a saved state, passing any positional :param saved_state: The saved state to load from :param load_context: The load context to use @@ -269,7 +267,53 @@ def recreate_from( """ load_context = _ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) - proc.load_instance_state(saved_state, load_context) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) + + proc._setup_event_hooks() + + # Runtime variables, set initial states + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._communicator = None + + if 'loop' in load_context: + proc._loop = load_context.loop + else: + proc._loop = asyncio.get_event_loop() + + proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + + if 'communicator' in load_context: + proc._communicator = load_context.communicator + + if 'logger' in load_context: + proc._logger = load_context.logger + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.auto_load(proc, saved_state, load_context) + + # Inputs/outputs + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._raw_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._parsed_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.OUTPUTS]) + proc._outputs = decoded + except KeyError: + proc._outputs = {} call_with_super_check(proc.init) return proc @@ -645,62 +689,6 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - @protected - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - """Load the process from its saved instance state. - - :param saved_state: A bundle to load the state from - :param load_context: The load context - - """ - # First make sure the state machine constructor is called - state_machine.StateMachine.__init__(self) - - self._setup_event_hooks() - - # Runtime variables, set initial states - self._future = persistence.SavableFuture() - self._event_helper = EventHelper(ProcessListener) - self._logger = None - self._coordinator = None - - if 'loop' in load_context: - self._loop = load_context.loop - else: - self._loop = asyncio.get_event_loop() - - self._state: state_machine.State = self.recreate_state(saved_state['_state']) - - if 'coordinator' in load_context: - self._coordinator = load_context.coordinator - - if 'logger' in load_context: - self._logger = load_context.logger - - # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(self, saved_state, load_context) - - # Inputs/outputs - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) - self._raw_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._raw_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) - self._parsed_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._parsed_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.OUTPUTS]) - self._outputs = decoded - except KeyError: - self._outputs = {} - - # endregion - def add_process_listener(self, listener: ProcessListener) -> None: """Add a process listener to the process. diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 66446619..3df4c7f6 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -26,6 +26,7 @@ from plumpy.base import state_machine from plumpy.coordinator import Coordinator +from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError from plumpy.process_listener import ProcessListener @@ -217,70 +218,87 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - ######### - # FIXME: dup of Process.load_instance_state - state_machine.StateMachine.__init__(self) + @classmethod + def recreate_from( + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, + ) -> WorkChain: + """Recreate a workchain from a saved state, passing any positional + + :param saved_state: The saved state to load from + :param load_context: The load context to use + :return: An instance of the object with its state loaded from the save state. + + """ + ### FIXME: dup from process.create_from + load_context = _ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) - self._setup_event_hooks() + proc._setup_event_hooks() # Runtime variables, set initial states - self._future = persistence.SavableFuture() - self._event_helper = EventHelper(ProcessListener) - self._logger = None - self._communicator = None + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._communicator = None if 'loop' in load_context: - self._loop = load_context.loop + proc._loop = load_context.loop else: - self._loop = asyncio.get_event_loop() + proc._loop = asyncio.get_event_loop() - self._state: state_machine.State = self.recreate_state(saved_state['_state']) + proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: - self._communicator = load_context.communicator + proc._communicator = load_context.communicator if 'logger' in load_context: - self._logger = load_context.logger + proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(self, saved_state, load_context) + persistence.auto_load(proc, saved_state, load_context) # Inputs/outputs try: - decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) - self._raw_inputs = utils.AttributesFrozendict(decoded) + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) except KeyError: - self._raw_inputs = None + proc._raw_inputs = None try: - decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) - self._parsed_inputs = utils.AttributesFrozendict(decoded) + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) except KeyError: - self._parsed_inputs = None + proc._parsed_inputs = None try: - decoded = self.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) - self._outputs = decoded + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) + proc._outputs = decoded except KeyError: - self._outputs = {} - - # - ######### + proc._outputs = {} + ### UNTILHERE FIXME: dup from process.create_from # context mixin try: - self._context = AttributesDict(**saved_state[self.CONTEXT]) + proc._context = AttributesDict(**saved_state[proc.CONTEXT]) except KeyError: pass # end of context mixin # Recreate the stepper - self._stepper = None - stepper_state = saved_state.get(self._STEPPER_STATE, None) + proc._stepper = None + stepper_state = saved_state.get(proc._STEPPER_STATE, None) if stepper_state is not None: - self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) + proc._stepper = proc.spec().get_outline().recreate_stepper(stepper_state, proc) + + call_with_super_check(proc.init) + return proc def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None: """ From efb322bfe4623fa8023e325ca124032952c1ccac Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 23:01:53 +0100 Subject: [PATCH 13/21] Remove useless persist method of Savable class --- src/plumpy/persistence.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 85b2d96d..0963445e 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -448,10 +448,6 @@ def auto_persist(cls, *members: str) -> None: cls._auto_persist = set() cls._auto_persist.update(members) - @classmethod - def persist(cls) -> None: - pass - @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -463,10 +459,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - return obj + ... def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = auto_save(self, save_context) @@ -475,7 +468,6 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY def _ensure_persist_configured(self) -> None: if not self._persist_configured: - self.persist() self._persist_configured = True # region Metadata getter/setters From 94590a6d2da7109428c4ba1fcbee005e3f633655 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 23:02:09 +0100 Subject: [PATCH 14/21] Explicity recreate_from implementation --- src/plumpy/event_helper.py | 22 ++++++++++++++++-- tests/test_persistence.py | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 47ad4956..e20dae3f 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- import logging -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional + +from plumpy.utils import SAVED_STATE_TYPE from . import persistence +from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load if TYPE_CHECKING: from typing import Set, Type - from .process_listener import ProcessListener _LOGGER = logging.getLogger(__name__) @@ -30,6 +32,22 @@ def remove_listener(self, listener: 'ProcessListener') -> None: def remove_all_listeners(self) -> None: self._listeners.clear() + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Savable: + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 78724aa0..65ef3226 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,6 +5,7 @@ import yaml import plumpy +from plumpy.persistence import auto_load from . import utils @@ -12,6 +13,21 @@ class SaveEmpty(plumpy.Savable): pass + @classmethod + def recreate_from(cls, saved_state, load_context= None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + @plumpy.auto_persist('test', 'test_method') class Save1(plumpy.Savable): @@ -22,12 +38,42 @@ def __init__(self): def m(): pass + @classmethod + def recreate_from(cls, saved_state, load_context= None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + @plumpy.auto_persist('test') class Save(plumpy.Savable): def __init__(self): self.test = Save1() + @classmethod + def recreate_from(cls, saved_state, load_context= None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + class TestSavable(unittest.TestCase): def test_empty_savable(self): From 90848fa26a674d36974ebd4e58e3971be51af16b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 23:14:28 +0100 Subject: [PATCH 15/21] forming Savable protocol - remove persist_config flag of savable --- src/plumpy/event_helper.py | 12 +- src/plumpy/persistence.py | 245 ++++++++++++++++----------------- src/plumpy/process_listener.py | 16 ++- src/plumpy/process_states.py | 100 ++++++++++---- src/plumpy/processes.py | 9 +- src/plumpy/workchains.py | 77 ++++------- tests/test_persistence.py | 32 +++-- tests/test_processes.py | 4 + tests/test_workchains.py | 2 + 9 files changed, 278 insertions(+), 219 deletions(-) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index e20dae3f..abc2b24b 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -2,20 +2,21 @@ import logging from typing import TYPE_CHECKING, Any, Callable, Optional +from plumpy.persistence import LoadSaveContext, Savable, auto_load, auto_save, ensure_object_loader from plumpy.utils import SAVED_STATE_TYPE from . import persistence -from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load if TYPE_CHECKING: from typing import Set, Type + from .process_listener import ProcessListener _LOGGER = logging.getLogger(__name__) @persistence.auto_persist('_listeners', '_listener_type') -class EventHelper(persistence.Savable): +class EventHelper: def __init__(self, listener_type: 'Type[ProcessListener]'): assert listener_type is not None, 'Must provide valid listener type' @@ -43,11 +44,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 0963445e..afe82439 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -9,12 +9,24 @@ import os import pickle from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + Iterable, + List, + Optional, + Protocol, + cast, + runtime_checkable, +) import yaml from . import futures, loaders, utils -from .base.utils import call_with_super_check, super_check from .utils import PID_TYPE, SAVED_STATE_TYPE PersistedCheckpoint = collections.namedtuple('PersistedCheckpoint', ['pid', 'tag']) @@ -88,10 +100,10 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N :return: The loaded Savable instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) assert load_context.loader is not None # required for type checking try: - class_name = Savable._get_class_name(saved_state) + class_name = SaveUtil.get_class_name(saved_state) load_cls: Savable = load_context.loader.load_object(class_name) except KeyError: raise ValueError('Class name not found in saved state') @@ -380,22 +392,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -SavableClsType = TypeVar('SavableClsType', bound='type[Savable]') - - -def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: - if savable._auto_persist is None: - savable._auto_persist = set() - else: - savable._auto_persist = set(savable._auto_persist) - savable.auto_persist(*members) - return savable - - return wrapped - - -def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': +def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': """ Given a LoadSaveContext this method will ensure that it has a valid class loader using the following priorities: @@ -417,7 +414,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV # 2) Try getting from saved_state default_loader = loaders.get_object_loader() try: - loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER) + loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER) except ValueError: # 3) Fall back to default loader = default_loader @@ -436,45 +433,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV META__TYPE__SAVABLE: str = 'S' -class Savable: - CLASS_NAME: str = 'class_name' - - _auto_persist: Optional[Set[str]] = None - _persist_configured = False - - @classmethod - def auto_persist(cls, *members: str) -> None: - if cls._auto_persist is None: - cls._auto_persist = set() - cls._auto_persist.update(members) - - @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Recreate a :class:`Savable` from a saved state using an optional load context. - - :param saved_state: The saved state - :param load_context: An optional load context - - :return: The recreated instance - - """ - ... - - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - - return out_state - - def _ensure_persist_configured(self) -> None: - if not self._persist_configured: - self._persist_configured = True - - # region Metadata getter/setters - +class SaveUtil: @staticmethod def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: - user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {}) + user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {}) user_dict[name] = value @staticmethod @@ -485,47 +447,127 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: raise ValueError(f"Unknown meta key '{name}'") @staticmethod - def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: + def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: return out_state.setdefault(META, {}) @staticmethod - def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: - Savable._get_create_meta(out_state)[META__CLASS_NAME] = name + def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: + SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name @staticmethod - def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str: - return Savable._get_create_meta(saved_state)[META__CLASS_NAME] + def get_class_name(saved_state: SAVED_STATE_TYPE) -> str: + return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME] @staticmethod - def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: - type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {}) + def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: + type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {}) type_dict[name] = type_spec @staticmethod - def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: + def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: try: return saved_state[META][META__TYPES][name] except KeyError: pass - # endregion - def _get_value( - self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] - ) -> Union[MethodType, 'Savable']: - value = saved_state[name] +@runtime_checkable +class Savable(Protocol): + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + ... + + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... + + +@runtime_checkable +class SavableWithAutoPersist(Savable, Protocol): + _auto_persist: ClassVar[set[str]] = set() + + +def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = {} + + if save_context is None: + save_context = LoadSaveContext() + + utils.type_check(save_context, LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__)) + + if isinstance(obj, SavableWithAutoPersist): + for member in obj._auto_persist: + value = getattr(obj, member) + if inspect.ismethod(value): + if value.__self__ is not obj: + raise TypeError('Cannot persist methods of other classes') + SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, Savable) and not isinstance(value, type): + # persist for a savable obj, call `save` method of obj. + SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value + + return out_state + + +def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: + for member in obj._auto_persist: + setattr(obj, member, _get_value(obj, saved_state, member, load_context)) + - typ = Savable._get_meta_type(saved_state, name) - if typ == META__TYPE__METHOD: - value = getattr(self, value) - elif typ == META__TYPE__SAVABLE: - value = load(value, load_context) +def _get_value( + obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None +) -> MethodType | Savable: + value = saved_state[name] - return value + typ = SaveUtil.get_meta_type(saved_state, name) + if typ == META__TYPE__METHOD: + value = getattr(obj, value) + elif typ == META__TYPE__SAVABLE: + value = load(value, load_context) + + return value + + +def auto_persist(*members: str) -> Callable[..., Savable]: + def wrapped(savable_cls: type) -> Savable: + if not hasattr(savable_cls, '_auto_persist') or savable_cls._auto_persist is None: + savable_cls._auto_persist = set() # type: ignore[attr-defined] + else: + savable_cls._auto_persist = set(savable_cls._auto_persist) + savable_cls._auto_persist.update(members) # type: ignore[attr-defined] + # XXX: validate on `save` and `recreate_from` method?? + return cast(Savable, savable_cls) + return wrapped + + +# FIXME: move me to another module? savablefuture.py? @auto_persist('_state', '_result') -class SavableFuture(futures.Future, Savable): +class SavableFuture(futures.Future): """ A savable future. @@ -550,7 +592,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) try: loop = load_context.loop @@ -586,48 +628,3 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa # ## UNTILHERE XXX: return obj - - -def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = {} - - if save_context is None: - save_context = LoadSaveContext() - - utils.type_check(save_context, LoadSaveContext) - - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader - - Savable._set_class_name(out_state, loader.identify_object(obj.__class__)) - - obj._ensure_persist_configured() - if obj._auto_persist is not None: - for member in obj._auto_persist: - value = getattr(obj, member) - if inspect.ismethod(value): - if value.__self__ is not obj: - raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - - return out_state - - -def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - obj._ensure_persist_configured() - if obj._auto_persist is not None: - for member in obj._auto_persist: - setattr(obj, member, obj._get_value(saved_state, member, load_context)) diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index e84b504d..8e9673bb 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -2,16 +2,21 @@ import abc from typing import TYPE_CHECKING, Any, Dict, Optional +from plumpy.persistence import LoadSaveContext, auto_save, ensure_object_loader + from . import persistence from .utils import SAVED_STATE_TYPE -from plumpy.persistence import LoadSaveContext, _ensure_object_loader if TYPE_CHECKING: + from plumpy.persistence import Savable + from .processes import Process +# FIXME: test any process listener is a savable + @persistence.auto_persist('_params') -class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): +class ProcessListener(metaclass=abc.ABCMeta): # region Persistence methods def __init__(self) -> None: @@ -32,11 +37,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) obj.init(**saved_state['_params']) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + # endregion def on_process_created(self, process: 'Process') -> None: diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index c03fd78d..08a4d24c 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy -import inspect import sys import traceback from enum import Enum @@ -14,19 +12,19 @@ Callable, ClassVar, Optional, - Protocol, Tuple, Type, Union, cast, final, - runtime_checkable, + override, ) import yaml from yaml.loader import Loader from plumpy.message import MessageBuilder, MessageType +from plumpy.persistence import ensure_object_loader try: import tblib @@ -38,8 +36,32 @@ from . import exceptions, futures, persistence, utils from .base import state_machine as st from .lang import NULL -from .persistence import META__OBJECT_LOADER, META__TYPE__METHOD, META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_persist, auto_save -from .utils import SAVED_STATE_TYPE, ensure_coroutine +from .persistence import ( + LoadSaveContext, + Savable, + auto_load, + auto_persist, + auto_save, +) +from .utils import SAVED_STATE_TYPE + +__all__ = [ + 'Continue', + 'Created', + 'Excepted', + 'Finished', + 'Interruption', + # Commands + 'Kill', + 'KillInterruption', + 'Killed', + 'PauseInterruption', + 'ProcessState', + 'Running', + 'Stop', + 'Wait', + 'Waiting', +] if TYPE_CHECKING: from .processes import Process @@ -68,8 +90,26 @@ def __init__(self, msg_text: str | None): # region Commands -class Command(persistence.Savable): - pass +class Command: + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @auto_persist('msg') @@ -115,12 +155,14 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs + @override def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ return out_state + @override @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -132,7 +174,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -163,14 +205,9 @@ class ProcessState(Enum): KILLED = 'killed' -# @runtime_checkable -# class Savable(Protocol): -# def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... - - @final @auto_persist('args', 'kwargs') -class Created(persistence.Savable): +class Created: LABEL: ClassVar = ProcessState.CREATED ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} @@ -201,7 +238,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -224,7 +261,7 @@ def exit(self) -> None: ... @final @auto_persist('args', 'kwargs') -class Running(persistence.Savable): +class Running: LABEL: ClassVar = ProcessState.RUNNING ALLOWED: ClassVar = { ProcessState.RUNNING, @@ -274,7 +311,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -358,7 +395,7 @@ def exit(self) -> None: ... @auto_persist('msg', 'data') -class Waiting(persistence.Savable): +class Waiting: LABEL: ClassVar = ProcessState.WAITING ALLOWED: ClassVar = { ProcessState.RUNNING, @@ -412,7 +449,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -465,7 +502,8 @@ def exit(self) -> None: ... @final -class Excepted(persistence.Savable): +@auto_persist() +class Excepted: """ Excepted state, can optionally provide exception and traceback @@ -517,7 +555,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -550,7 +588,7 @@ def exit(self) -> None: ... @final @auto_persist('result', 'successful') -class Finished(persistence.Savable): +class Finished: """State for process is finished. :param result: The result of process @@ -577,11 +615,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... @@ -589,7 +632,7 @@ def exit(self) -> None: ... @final @auto_persist('msg') -class Killed(persistence.Savable): +class Killed: """ Represents a state where a process has been killed. @@ -621,11 +664,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index f591ba1a..09964ae7 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -37,7 +37,7 @@ import kiwipy from plumpy.coordinator import Coordinator -from plumpy.persistence import _ensure_object_loader +from plumpy.persistence import ensure_object_loader try: from aiocontextvars import ContextVar @@ -116,7 +116,7 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: '_pre_paused_status', '_event_helper', ) -class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): +class Process(StateMachine, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. @@ -265,7 +265,7 @@ def recreate_from( :return: An instance of the object with its state loaded from the save state. """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state @@ -673,8 +673,7 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA """ out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - # FIXME: the combined ProcessState protocol should cover the case - if isinstance(self._state, process_states.Savable): + if isinstance(self._state, persistence.Savable): out_state['_state'] = self._state.save() # Inputs/outputs diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 3df4c7f6..9a241b72 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy import abc import asyncio import collections @@ -9,6 +8,7 @@ import logging import re from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -24,17 +24,20 @@ cast, ) +import kiwipy + +from plumpy import utils from plumpy.base import state_machine from plumpy.coordinator import Coordinator from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError +from plumpy.persistence import LoadSaveContext, auto_persist, auto_save, ensure_object_loader, Savable from plumpy.process_listener import ProcessListener from . import lang, mixins, persistence, process_spec, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict -from plumpy import loaders, utils -from plumpy.persistence import _ensure_object_loader + ToContext = dict @@ -162,41 +165,9 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA :param out_state: A bundle to save the state to :param save_context: The save context """ - out_state: SAVED_STATE_TYPE = {} - - if save_context is None: - save_context = persistence.LoadSaveContext() + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - utils.type_check(save_context, persistence.LoadSaveContext) - - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader - - persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__)) - - self._ensure_persist_configured() - if self._auto_persist is not None: - for member in self._auto_persist: - value = getattr(self, member) - if inspect.ismethod(value): - if value.__self__ is not self: - raise TypeError('Cannot persist methods of other classes') - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, persistence.Savable): - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - - if isinstance(self._state, process_states.Savable): + if isinstance(self._state, persistence.Savable): out_state['_state'] = self._state.save() # Inputs/outputs @@ -210,7 +181,7 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA out_state[processes.BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) # Ask the stepper to save itself - if self._stepper is not None: + if self._stepper is not None and isinstance(self._stepper, Savable): out_state[self._STEPPER_STATE] = self._stepper.save() if self._context is not None: @@ -232,7 +203,7 @@ def recreate_from( """ ### FIXME: dup from process.create_from - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state @@ -375,7 +346,8 @@ def get_description(self) -> Any: """ -class _FunctionStepper(persistence.Savable): +@auto_persist() +class _FunctionStepper: def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): self._workchain = workchain self._fn = fn @@ -387,7 +359,9 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -397,7 +371,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -443,7 +417,7 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') -class _BlockStepper(persistence.Savable): +class _BlockStepper: def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: self._workchain = workchain self._block = block @@ -488,7 +462,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -591,7 +565,7 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') -class _IfStepper(persistence.Savable): +class _IfStepper: def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: self._workchain = workchain self._if_instruction = if_instruction @@ -643,7 +617,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -714,7 +688,7 @@ def get_description(self) -> Mapping[str, Any]: return description -class _WhileStepper(persistence.Savable): +class _WhileStepper: def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: self._workchain = workchain self._while_instruction = while_instruction @@ -744,7 +718,9 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -754,7 +730,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -801,7 +777,8 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code -class _ReturnStepper(persistence.Savable): +@persistence.auto_persist() +class _ReturnStepper: def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: self._workchain = workchain self._return_instruction = return_instruction diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 65ef3226..4ec4c1a5 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,16 +5,17 @@ import yaml import plumpy -from plumpy.persistence import auto_load +from plumpy.persistence import auto_load, auto_persist, auto_save +from plumpy.utils import SAVED_STATE_TYPE from . import utils -class SaveEmpty(plumpy.Savable): - pass +@auto_persist() +class SaveEmpty: @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -28,9 +29,14 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test', 'test_method') -class Save1(plumpy.Savable): +class Save1: def __init__(self): self.test = 'sup yp' self.test_method = self.m @@ -39,7 +45,7 @@ def m(): pass @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -53,14 +59,19 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test') -class Save(plumpy.Savable): +class Save: def __init__(self): self.test = Save1() @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -74,6 +85,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + class TestSavable(unittest.TestCase): def test_empty_savable(self): diff --git a/tests/test_processes.py b/tests/test_processes.py index a634d4e5..d354508f 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -14,6 +14,10 @@ from plumpy.utils import AttributesFrozendict from . import utils +# FIXME: after deabstract on savable into a protocol, test that all state are savable +# FIXME: also that any process is savable +# FIXME: any process listener is savable +# FIXME: any process control commands are savable class ForgetToCallParent(plumpy.Process): def __init__(self, forget_on): diff --git a/tests/test_workchains.py b/tests/test_workchains.py index 08c7317a..4e34d2b4 100644 --- a/tests/test_workchains.py +++ b/tests/test_workchains.py @@ -11,6 +11,8 @@ from . import utils +# FIXME: after deabstract on savable into a protocol, test that all stepper are savable +# FIXME: workchani itself is savable class Wf(WorkChain): # Keep track of which steps were completed by the workflow From d5680d72acec0b26bf9dd35e494c848ab3d5fef6 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 11 Dec 2024 00:54:27 +0100 Subject: [PATCH 16/21] Make auto_load symmetry with auto_save and state/state_label distinguish --- src/plumpy/base/state_machine.py | 10 +++- src/plumpy/event_helper.py | 3 +- src/plumpy/persistence.py | 19 ++++++- src/plumpy/process_states.py | 40 ++++++--------- src/plumpy/processes.py | 67 +++++++++++++------------ src/plumpy/workchains.py | 23 ++++----- tests/base/test_statemachine.py | 10 ++-- tests/rmq/test_process_control.py | 10 ++-- tests/test_persistence.py | 14 +++--- tests/test_processes.py | 83 ++++++++++++++++--------------- tests/utils.py | 2 +- 11 files changed, 146 insertions(+), 135 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index e3912b6f..1eae4789 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -266,7 +266,13 @@ def create_initial_state(self, *args: Any, **kwargs: Any) -> State: return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Any: + def state(self) -> State | None: + if self._state is None: + return None + return self._state + + @property + def state_label(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -314,7 +320,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: # it can happened when transit from terminal state return None - initial_state_label = self._state.LABEL if self._state is not None else None + initial_state_label = self.state_label label = None try: self._transitioning = True diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index abc2b24b..9262f856 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -45,8 +45,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index afe82439..31bbc67c 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -20,6 +20,7 @@ List, Optional, Protocol, + TypeVar, cast, runtime_checkable, ) @@ -523,6 +524,8 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S value = value.__name__ elif isinstance(value, Savable) and not isinstance(value, type): # persist for a savable obj, call `save` method of obj. + # the rhs branch is for when value is a Savable class, it is true runtime check + # of lhs condition. SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) value = value.save() else: @@ -532,11 +535,25 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S return out_state -def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: +def load_auto_persist_params( + obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None +) -> None: for member in obj._auto_persist: setattr(obj, member, _get_value(obj, saved_state, member, load_context)) +T = TypeVar('T', bound=Savable) + + +def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T: + obj = cls.__new__(cls) + + if isinstance(obj, SavableWithAutoPersist): + load_auto_persist_params(obj, saved_state, load_context) + + return obj + + def _get_value( obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None ) -> MethodType | Savable: diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 08a4d24c..49d76e46 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -42,6 +42,7 @@ auto_load, auto_persist, auto_save, + ensure_object_loader, ) from .utils import SAVED_STATE_TYPE @@ -102,8 +103,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -175,15 +176,15 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) - obj.state_machine = load_context.process try: obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN]) except ValueError: - process = load_context.process - obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN]) + if load_context is not None: + obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN]) + else: + raise return obj @@ -239,12 +240,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process - obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) return obj @@ -312,15 +309,13 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process obj.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) if obj.COMMAND in saved_state: - # FIXME: typing obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore + return obj def interrupt(self, reason: Any) -> None: @@ -450,9 +445,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process callback_name = saved_state.get(obj.DONE_CALLBACK, None) @@ -556,8 +549,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -616,8 +608,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -665,8 +656,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 09964ae7..aadd9290 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -21,6 +21,7 @@ Any, Awaitable, Callable, + ClassVar, Dict, Generator, Hashable, @@ -166,6 +167,7 @@ class Process(StateMachine, metaclass=ProcessStateMachineMeta): _cleanups: Optional[List[Callable[[], None]]] = None __called: bool = False + _auto_persist: ClassVar[set[str]] @classmethod def current(cls) -> Optional['Process']: @@ -285,7 +287,7 @@ def recreate_from( else: proc._loop = asyncio.get_event_loop() - proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + proc._state = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: proc._communicator = load_context.communicator @@ -294,7 +296,7 @@ def recreate_from( proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(proc, saved_state, load_context) + persistence.load_auto_persist_params(proc, saved_state, load_context) # Inputs/outputs try: @@ -519,7 +521,9 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal + if self.state is None: + raise exceptions.InvalidStateError('process is not in state None that is invalid') + return self.state.is_terminal def result(self) -> Any: """ @@ -529,12 +533,12 @@ def result(self) -> Any: If in any other state this will raise an InvalidStateError. :return: The result of the process """ - if isinstance(self._state, process_states.Finished): - return self._state.result - if isinstance(self._state, process_states.Killed): - raise exceptions.KilledError(self._state.msg) - if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception('process excepted')) + if isinstance(self.state, process_states.Finished): + return self.state.result + if isinstance(self.state, process_states.Killed): + raise exceptions.KilledError(self.state.msg) + if isinstance(self.state, process_states.Excepted): + raise (self.state.exception or Exception('process excepted')) raise exceptions.InvalidStateError @@ -544,7 +548,7 @@ def successful(self) -> bool: Will raise if the process is not in the FINISHED state """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError as exception: raise exceptions.InvalidStateError('process is not in the finished state') from exception @@ -555,25 +559,25 @@ def is_successful(self) -> bool: :return: boolean, True if the process is in `Finished` state with `successful` attribute set to `True` """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError: return False def killed(self) -> bool: """Return whether the process is killed.""" - return self.state == process_states.ProcessState.KILLED + return self.state_label == process_states.ProcessState.KILLED def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" - if isinstance(self._state, process_states.Killed): - return self._state.msg + if isinstance(self.state, process_states.Killed): + return self.state.msg raise exceptions.InvalidStateError('Has not been killed') def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" - if isinstance(self._state, process_states.Excepted): - return self._state.exception + if isinstance(self.state, process_states.Excepted): + return self.state.exception return None @@ -583,7 +587,7 @@ def is_excepted(self) -> bool: :return: boolean, True if the process is in ``EXCEPTED`` state. """ - return self.state == process_states.ProcessState.EXCEPTED + return self.state_label == process_states.ProcessState.EXCEPTED def done(self) -> bool: """Return True if the call was successfully killed or finished running. @@ -592,7 +596,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal + return self.has_terminated() # endregion @@ -620,7 +624,7 @@ def callback_excepted( exception: Optional[BaseException], trace: Optional[TracebackType], ) -> None: - if self.state != process_states.ProcessState.EXCEPTED: + if self.state_label != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @contextlib.contextmanager @@ -673,8 +677,8 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA """ out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if isinstance(self._state, persistence.Savable): - out_state['_state'] = self._state.save() + if isinstance(self.state, persistence.Savable): + out_state['_state'] = self.state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -732,7 +736,7 @@ def on_entering(self, state: state_machine.State) -> None: def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement - state_label = self._state.LABEL + state_label = self.state_label if state_label == process_states.ProcessState.RUNNING: call_with_super_check(self.on_running) elif state_label == process_states.ProcessState.WAITING: @@ -746,7 +750,7 @@ def on_entered(self, from_state: Optional[state_machine.State]) -> None: if self._coordinator and isinstance(self.state, enum.Enum): from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None - subject = f'state_changed.{from_label}.{self.state.value}' + subject = f'state_changed.{from_label}.{self.state_label.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._coordinator.broadcast_send(body=None, sender=self.pid, subject=subject) @@ -759,7 +763,7 @@ def on_entered(self, from_state: Optional[state_machine.State]) -> None: raise def on_exiting(self) -> None: - state = self.state + state = self.state_label if state == process_states.ProcessState.WAITING: call_with_super_check(self.on_exit_waiting) elif state == process_states.ProcessState.RUNNING: @@ -1099,7 +1103,6 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) self.transition_to(new_state) @@ -1125,9 +1128,9 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: return self._pausing if self._stepping: - if not isinstance(self._state, Interruptable): + if not isinstance(self.state, Interruptable): raise exceptions.InvalidStateError( - f'cannot interrupt {self._state.__class__}, method `interrupt` not implement' + f'cannot interrupt {self.state.__class__}, method `interrupt` not implement' ) # Ask the step function to pause by setting this flag and giving the @@ -1226,7 +1229,7 @@ def play(self) -> bool: @event(from_states=process_states.Waiting) def resume(self, *args: Any) -> None: """Start running the process again.""" - return self._state.resume(*args) # type: ignore + return self.state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: @@ -1244,7 +1247,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: Kill the process :param msg: An optional kill message """ - if self.state == process_states.ProcessState.KILLED: + if self.state_label == process_states.ProcessState.KILLED: # Already killed return True @@ -1256,7 +1259,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: # Already killing return self._killing - if self._stepping and isinstance(self._state, Interruptable): + if self._stepping and isinstance(self.state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg_text) @@ -1332,8 +1335,8 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused - if not isinstance(self._state, Proceedable): - raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented') + if not isinstance(self.state, Proceedable): + raise StateMachineError(f'cannot step from {self.state.__class__}, async method `execute` not implemented') try: self._stepping = True diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 9a241b72..0926273c 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -8,7 +8,6 @@ import logging import re from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -32,7 +31,7 @@ from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError -from plumpy.persistence import LoadSaveContext, auto_persist, auto_save, ensure_object_loader, Savable +from plumpy.persistence import LoadSaveContext, Savable, auto_persist, auto_save, ensure_object_loader from plumpy.process_listener import ProcessListener from . import lang, mixins, persistence, process_spec, process_states, processes @@ -223,7 +222,7 @@ def recreate_from( else: proc._loop = asyncio.get_event_loop() - proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + proc._state = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: proc._communicator = load_context.communicator @@ -232,7 +231,7 @@ def recreate_from( proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(proc, saved_state, load_context) + persistence.load_auto_persist_params(proc, saved_state, load_context) # Inputs/outputs try: @@ -372,8 +371,7 @@ def recreate_from( """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._fn = getattr(obj._workchain.__class__, saved_state['_fn']) @@ -446,7 +444,7 @@ def finished(self) -> bool: def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if self._child_stepper is not None: + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() return out_state @@ -463,8 +461,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._block = load_context.block_instruction stepper_state = saved_state.get(STEPPER_STATE, None) @@ -601,7 +598,7 @@ def finished(self) -> bool: def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if self._child_stepper is not None: + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() return out_state @@ -618,8 +615,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._if_instruction = load_context.if_instruction stepper_state = saved_state.get(STEPPER_STATE, None) @@ -731,8 +727,7 @@ def recreate_from( """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._while_instruction = load_context.while_instruction stepper_state = saved_state.get(STEPPER_STATE, None) diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index f046aaa8..44a084d4 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -150,22 +150,22 @@ def stop(self): class TestStateMachine(unittest.TestCase): def test_basic(self): cd_player = CdPlayer() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) cd_player.play('Eminem - The Real Slim Shady') - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) time.sleep(1.0) cd_player.pause() - self.assertEqual(cd_player.state, PAUSED) + self.assertEqual(cd_player.state_label, PAUSED) cd_player.play() - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) self.assertEqual(cd_player.play(), False) cd_player.stop() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) def test_invalid_event(self): cd_player = CdPlayer() diff --git a/tests/rmq/test_process_control.py b/tests/rmq/test_process_control.py index 79a98ba3..7c3b431c 100644 --- a/tests/rmq/test_process_control.py +++ b/tests/rmq/test_process_control.py @@ -68,7 +68,7 @@ async def test_play(self, _coordinator, async_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.WAITING + assert proc.state_label == plumpy.ProcessState.WAITING # if not close the background process will raise exception # make sure proc reach the final state @@ -85,7 +85,7 @@ async def test_kill(self, _coordinator, async_controller): # Check the outcome assert result - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_status(self, _coordinator, async_controller): @@ -173,7 +173,7 @@ async def test_play(self, _coordinator, sync_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.CREATED + assert proc.state_label == plumpy.ProcessState.CREATED @pytest.mark.asyncio async def test_kill(self, _coordinator, sync_controller): @@ -187,7 +187,7 @@ async def test_kill(self, _coordinator, sync_controller): # Check the outcome assert result # Occasionally fail - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_kill_all(self, _coordinator, sync_controller): @@ -198,7 +198,7 @@ async def test_kill_all(self, _coordinator, sync_controller): sync_controller.kill_all(msg_text='bang bang, I shot you down') await utils.wait_util(lambda: all([proc.killed() for proc in procs])) - assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) + assert all([proc.state_label == plumpy.ProcessState.KILLED for proc in procs]) @pytest.mark.asyncio async def test_status(self, _coordinator, sync_controller): diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 4ec4c1a5..7f616433 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,7 +5,7 @@ import yaml import plumpy -from plumpy.persistence import auto_load, auto_persist, auto_save +from plumpy.persistence import auto_load, auto_persist, auto_save, ensure_object_loader from plumpy.utils import SAVED_STATE_TYPE from . import utils @@ -25,8 +25,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: @@ -55,8 +55,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: @@ -81,8 +81,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: diff --git a/tests/test_processes.py b/tests/test_processes.py index d354508f..61d73054 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -19,6 +19,7 @@ # FIXME: any process listener is savable # FIXME: any process control commands are savable + class ForgetToCallParent(plumpy.Process): def __init__(self, forget_on): super().__init__() @@ -239,7 +240,7 @@ def test_execute(self): proc.execute() self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertEqual(proc.outputs, {'default': 5}) def test_run_from_class(self): @@ -277,7 +278,7 @@ def test_exception(self): proc = utils.ExceptionProcess() with self.assertRaises(RuntimeError): proc.execute() - self.assertEqual(proc.state, ProcessState.EXCEPTED) + self.assertEqual(proc.state_label, ProcessState.EXCEPTED) def test_run_kill(self): proc = utils.KillProcess() @@ -344,7 +345,7 @@ def test_wait_continue(self): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_exc_info(self): proc = utils.ExceptionProcess() @@ -368,7 +369,7 @@ def test_wait_pause_play_resume(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause() self.assertTrue(result) @@ -384,7 +385,7 @@ async def async_test(): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) @@ -405,7 +406,7 @@ def test_pause_play_status_messaging(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause(PAUSE_STATUS) self.assertTrue(result) @@ -425,7 +426,7 @@ async def async_test(): loop.run_until_complete(async_test()) self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_kill_in_run(self): class KillProcess(Process): @@ -443,7 +444,7 @@ def run(self, **kwargs): proc.execute() self.assertTrue(proc.after_kill) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused_in_run(self): class PauseProcess(Process): @@ -455,7 +456,7 @@ def run(self, **kwargs): with self.assertRaises(plumpy.KilledError): proc.execute() - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused(self): loop = asyncio.get_event_loop() @@ -479,7 +480,7 @@ async def async_test(): loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_run_multiple(self): # Create and play some processes @@ -555,7 +556,7 @@ def run(self): loop.run_forever() self.assertTrue(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_pause_play_in_process(self): """Test that we can pause and play that by playing within the process""" @@ -573,7 +574,7 @@ def run(self): proc.execute() self.assertFalse(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_process_stack(self): test_case = self @@ -784,7 +785,7 @@ def test_saving_each_step(self): proc = proc_class() saver = utils.ProcessSaver(proc) saver.capture() - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertTrue(utils.check_process_against_snapshots(loop, proc_class, saver.snapshots)) def test_restart(self): @@ -799,7 +800,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it loaded_proc.resume() @@ -822,7 +823,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it twice in succession loaded_proc.resume() @@ -864,7 +865,7 @@ async def async_test(): def test_killed(self): proc = utils.DummyProcess() proc.kill() - self.assertEqual(proc.state, plumpy.ProcessState.KILLED) + self.assertEqual(proc.state_label, plumpy.ProcessState.KILLED) self._check_round_trip(proc) def _check_round_trip(self, proc1): @@ -987,40 +988,40 @@ def run(self): self.out(namespace_nested + '.two', 2) # Run the process in default mode which should not add any outputs and therefore fail - process = DummyDynamicProcess() - process.execute() + proc = DummyDynamicProcess() + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertDictEqual(process.outputs, {}) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertDictEqual(proc.outputs, {}) # Attaching only namespaced ports should fail, because the required port is not added - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) # Attaching only the single required top-level port should be fine - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) # Attaching both the required and namespaced ports should result in a successful termination - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) - process.execute() - - self.assertIsNotNone(process.outputs) - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) + proc.execute() + + self.assertIsNotNone(proc.outputs) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) class TestProcessEvents(unittest.TestCase): diff --git a/tests/utils.py b/tests/utils.py index 3d4458f4..18082fd4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -596,7 +596,7 @@ def run_until_waiting(proc): listener = plumpy.ProcessListener() in_waiting = asyncio.Future() - if proc.state == ProcessState.WAITING: + if proc.state_label == ProcessState.WAITING: in_waiting.set_result(True) else: From 4195247147bab425310f11230c5eb4bc5a67fba6 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 18 Jan 2025 02:59:20 +0100 Subject: [PATCH 17/21] misc rebase --- src/plumpy/__init__.py | 3 --- src/plumpy/mixins.py | 29 ----------------------------- src/plumpy/process_states.py | 5 ++--- src/plumpy/processes.py | 4 ++-- src/plumpy/workchains.py | 13 +++++-------- 5 files changed, 9 insertions(+), 45 deletions(-) delete mode 100644 src/plumpy/mixins.py diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 2c988cd8..4cb50820 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -29,7 +29,6 @@ from .futures import CancellableAction, Future, capture_exceptions, create_task from .loaders import DefaultObjectLoader, ObjectLoader, get_object_loader, set_object_loader from .message import MessageBuilder, ProcessLauncher, create_continue_body, create_launch_body -from .mixins import ContextMixin from .persistence import ( Bundle, InMemoryPersister, @@ -77,8 +76,6 @@ 'CancellableAction', # exceptions 'ClosedError', - # mixins - 'ContextMixin', # process_states/States 'Continue', # coordinator diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py deleted file mode 100644 index 0e3bb0c0..00000000 --- a/src/plumpy/mixins.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Any, Optional - -from . import persistence -from .utils import SAVED_STATE_TYPE, AttributesDict - - -class ContextMixin(persistence.Savable): - """ - Add a context to a Process. The contents of the context will be saved - in the instance state unlike standard instance variables. - """ - - CONTEXT: str = '_context' - - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._context: Optional[AttributesDict] = AttributesDict() - - @property - def ctx(self) -> Optional[AttributesDict]: - return self._context - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - try: - self._context = AttributesDict(**saved_state[self.CONTEXT]) - except KeyError: - pass diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 49d76e46..704da6eb 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -42,9 +42,8 @@ auto_load, auto_persist, auto_save, - ensure_object_loader, ) -from .utils import SAVED_STATE_TYPE +from .utils import SAVED_STATE_TYPE, ensure_coroutine __all__ = [ 'Continue', @@ -312,7 +311,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process - obj.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) + obj.run_fn = ensure_coroutine(getattr(obj.process, saved_state[obj.RUN_FN])) if obj.COMMAND in saved_state: obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index aadd9290..611004ff 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -280,7 +280,7 @@ def recreate_from( proc._future = persistence.SavableFuture() proc._event_helper = EventHelper(ProcessListener) proc._logger = None - proc._communicator = None + proc._coordinator = None if 'loop' in load_context: proc._loop = load_context.loop @@ -290,7 +290,7 @@ def recreate_from( proc._state = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: - proc._communicator = load_context.communicator + proc._coordinator = load_context.coordinator if 'logger' in load_context: proc._logger = load_context.logger diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 0926273c..5c459d0e 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -23,21 +23,18 @@ cast, ) -import kiwipy - from plumpy import utils from plumpy.base import state_machine -from plumpy.coordinator import Coordinator from plumpy.base.utils import call_with_super_check +from plumpy.coordinator import Coordinator from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError from plumpy.persistence import LoadSaveContext, Savable, auto_persist, auto_save, ensure_object_loader from plumpy.process_listener import ProcessListener -from . import lang, mixins, persistence, process_spec, process_states, processes +from . import lang, persistence, process_spec, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict - ToContext = dict PREDICATE_TYPE = Callable[['WorkChain'], bool] @@ -215,7 +212,7 @@ def recreate_from( proc._future = persistence.SavableFuture() proc._event_helper = EventHelper(ProcessListener) proc._logger = None - proc._communicator = None + proc._coordinator = None if 'loop' in load_context: proc._loop = load_context.loop @@ -224,8 +221,8 @@ def recreate_from( proc._state = proc.recreate_state(saved_state['_state']) - if 'communicator' in load_context: - proc._communicator = load_context.communicator + if 'coordinator' in load_context: + proc._coordinator = load_context.coordinator if 'logger' in load_context: proc._logger = load_context.logger From 4f6a2bb3d94b11623a3cd656f7a0b2b8328d93a2 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 18 Jan 2025 03:59:05 +0100 Subject: [PATCH 18/21] debug logger when state change --- src/plumpy/base/state_machine.py | 1 - src/plumpy/processes.py | 4 ++-- tests/test_processes.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 1eae4789..814c1491 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -312,7 +312,6 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: The arguments are passed to the state class to create state instance. (process arg does not need to pass since it will always call with 'self' as process) """ - print(f'try: {self._state} -> {new_state}') assert not self._transitioning, 'Cannot call transition_to when already transitioning state' if new_state is None: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 611004ff..9c14fa86 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -748,7 +748,7 @@ def on_entered(self, from_state: Optional[state_machine.State]) -> None: elif state_label == process_states.ProcessState.KILLED: call_with_super_check(self.on_killed) - if self._coordinator and isinstance(self.state, enum.Enum): + if self._coordinator and isinstance(self.state_label, enum.Enum): from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None subject = f'state_changed.{from_label}.{self.state_label.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) @@ -1342,7 +1342,6 @@ async def step(self) -> None: self._stepping = True next_state = None try: - # XXX: debug log when need to step to next state next_state = await self._run_task(self._state.execute) except process_states.Interruption as exception: # If the interruption was caused by a call to a Process method then there should @@ -1368,6 +1367,7 @@ async def step(self) -> None: self._interrupt_action.run(next_state) else: # Everything nominal so transition to the next state + self.logger.debug(f'Process<{self.pid}>: transfer from {self._state.LABEL} to {next_state.LABEL}') self.transition_to(next_state) finally: diff --git a/tests/test_processes.py b/tests/test_processes.py index 61d73054..91d98e40 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -331,7 +331,7 @@ def test_kill(self): proc.kill(msg_text=msg_text) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg()[MESSAGE_TEXT_KEY], msg_text) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_wait_continue(self): proc = utils.WaitForSignalProcess() From 15ce0fe1e6c030393c6f000732891a5c34e09132 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 18 Jan 2025 04:03:11 +0100 Subject: [PATCH 19/21] logger for load process from context --- src/plumpy/processes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 9c14fa86..2cf20a2a 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -285,15 +285,20 @@ def recreate_from( if 'loop' in load_context: proc._loop = load_context.loop else: + _LOGGER.warning(f'cannot find `loop` store in load_context, use default event loop') proc._loop = asyncio.get_event_loop() proc._state = proc.recreate_state(saved_state['_state']) - if 'communicator' in load_context: + if 'coordinator' in load_context: proc._coordinator = load_context.coordinator + else: + _LOGGER.warning(f'cannot find `coordinator` store in load_context') if 'logger' in load_context: proc._logger = load_context.logger + else: + _LOGGER.warning(f'cannot find `logger` store in load_context') # Need to call this here as things downstream may rely on us having the runtime variable above persistence.load_auto_persist_params(proc, saved_state, load_context) From 845a7d635cd7c81ee3f4605bf55b9c735ea48487 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 22 Jan 2025 10:59:02 +0100 Subject: [PATCH 20/21] Using typing-extensions for 3.9 support of @override --- pyproject.toml | 1 + src/plumpy/base/state_machine.py | 2 ++ src/plumpy/persistence.py | 2 ++ src/plumpy/process_states.py | 2 +- src/plumpy/processes.py | 10 ++++++---- uv.lock | 2 ++ 6 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f0877103..2bc829bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ 'kiwipy[rmq]~=0.8.5', 'nest_asyncio~=1.5,>=1.5.1', 'pyyaml~=6.0', + 'typing-extensions~=4.12' ] [project.urls] diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 814c1491..a12981a0 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -275,6 +275,8 @@ def state(self) -> State | None: def state_label(self) -> Any: if self._state is None: return None + # XXX: should not use `.value` to access the printable output from LABEL + # LABEL as the ClassVar should have __str__ return self._state.LABEL def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 31bbc67c..b2da7eef 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import abc import asyncio import collections diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 704da6eb..c245d72a 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -17,10 +17,10 @@ Union, cast, final, - override, ) import yaml +from typing_extensions import override from yaml.loader import Loader from plumpy.message import MessageBuilder, MessageType diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 2cf20a2a..95f73719 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -285,7 +285,7 @@ def recreate_from( if 'loop' in load_context: proc._loop = load_context.loop else: - _LOGGER.warning(f'cannot find `loop` store in load_context, use default event loop') + _LOGGER.warning('cannot find `loop` store in load_context, use default event loop') proc._loop = asyncio.get_event_loop() proc._state = proc.recreate_state(saved_state['_state']) @@ -293,12 +293,12 @@ def recreate_from( if 'coordinator' in load_context: proc._coordinator = load_context.coordinator else: - _LOGGER.warning(f'cannot find `coordinator` store in load_context') + _LOGGER.warning('cannot find `coordinator` store in load_context') if 'logger' in load_context: proc._logger = load_context.logger else: - _LOGGER.warning(f'cannot find `logger` store in load_context') + _LOGGER.warning('cannot find `logger` store in load_context') # Need to call this here as things downstream may rely on us having the runtime variable above persistence.load_auto_persist_params(proc, saved_state, load_context) @@ -760,7 +760,9 @@ def on_entered(self, from_state: Optional[state_machine.State]) -> None: try: self._coordinator.broadcast_send(body=None, sender=self.pid, subject=subject) except exceptions.CoordinatorCommunicationError: - message = f'Process<{self.pid}>: cannot broadcast state change from {from_label} to {self.state.value}' + message = ( + f'Process<{self.pid}>: cannot broadcast state change from {from_label} to {self.state_label.value}' + ) self.logger.warning(message) self.logger.debug(message, exc_info=True) except Exception: diff --git a/uv.lock b/uv.lock index 2af8adbd..d8fc89f5 100644 --- a/uv.lock +++ b/uv.lock @@ -1862,6 +1862,7 @@ dependencies = [ { name = "kiwipy", extra = ["rmq"] }, { name = "nest-asyncio" }, { name = "pyyaml" }, + { name = "typing-extensions" }, ] [package.optional-dependencies] @@ -1913,6 +1914,7 @@ requires-dist = [ { name = "sphinx", marker = "extra == 'docs'", specifier = "~=3.2.0" }, { name = "sphinx-book-theme", marker = "extra == 'docs'", specifier = "~=0.0.39" }, { name = "types-pyyaml", marker = "extra == 'pre-commit'" }, + { name = "typing-extensions", specifier = "~=4.12" }, ] [[package]] From 76c887f03479b42e932166626316f2ef8ca7641d Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 12 Feb 2025 01:03:36 +0100 Subject: [PATCH 21/21] mis --- src/plumpy/__init__.py | 28 +++++++++------------------- src/plumpy/process_states.py | 3 +-- src/plumpy/processes.py | 3 +-- 3 files changed, 11 insertions(+), 23 deletions(-) diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 4cb50820..2ed1c5df 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -26,9 +26,9 @@ PersistenceError, UnsuccessfulResult, ) -from .futures import CancellableAction, Future, capture_exceptions, create_task +from .futures import CancellableAction, Future, capture_exceptions from .loaders import DefaultObjectLoader, ObjectLoader, get_object_loader, set_object_loader -from .message import MessageBuilder, ProcessLauncher, create_continue_body, create_launch_body +from .message import MsgContinue, MsgCreate, MsgKill, MsgLaunch, MsgPause, MsgPlay, MsgStatus, ProcessLauncher from .persistence import ( Bundle, InMemoryPersister, @@ -64,26 +64,17 @@ from .workchains import ToContext, WorkChain, WorkChainSpec, if_, return_, while_ __all__ = ( - # ports 'UNSPECIFIED', - # utils 'AttributesDict', - # persistence 'Bundle', - # processes 'BundleKeys', - # futures 'CancellableAction', - # exceptions 'ClosedError', - # process_states/States 'Continue', - # coordinator 'Coordinator', 'CoordinatorConnectionError', 'CoordinatorTimeoutError', 'Created', - # loaders 'DefaultObjectLoader', 'Excepted', 'Finished', @@ -92,14 +83,18 @@ 'InputPort', 'Interruption', 'InvalidStateError', - # process_states/Commands 'Kill', 'KillInterruption', 'Killed', 'KilledError', 'LoadSaveContext', - # message - 'MessageBuilder', + 'MsgContinue', + 'MsgCreate', + 'MsgKill', + 'MsgLaunch', + 'MsgPause', + 'MsgPlay', + 'MsgStatus', 'ObjectLoader', 'OutputPort', 'PauseInterruption', @@ -107,16 +102,13 @@ 'PersistenceError', 'Persister', 'PicklePersister', - # event 'PlumpyEventLoopPolicy', 'Port', 'PortNamespace', 'PortValidationError', 'Process', - # controller 'ProcessController', 'ProcessLauncher', - # process_listener 'ProcessListener', 'ProcessSpec', 'ProcessState', @@ -124,7 +116,6 @@ 'Savable', 'SavableFuture', 'Stop', - # workchain 'ToContext', 'TransitionFailed', 'UnsuccessfulResult', @@ -136,7 +127,6 @@ 'capture_exceptions', 'create_continue_body', 'create_launch_body', - 'create_task', 'get_event_loop', 'get_object_loader', 'if_', diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index b25ab49c..3219328e 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -23,8 +23,8 @@ from typing_extensions import override from yaml.loader import Loader -from plumpy.persistence import ensure_object_loader from plumpy.message import Message, MsgKill, MsgPause +from plumpy.persistence import ensure_object_loader try: import tblib @@ -638,7 +638,6 @@ class Killed: is_terminal: ClassVar[bool] = True def __init__(self, msg: Optional[Message]): - """ :param msg: Optional kill message """ diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 9e714ac9..e3784b21 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1149,13 +1149,12 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: msg = MsgPause.new(msg_text) return self._do_pause(state_msg=msg) - + @staticmethod def _interrupt(state: Interruptable, reason: Exception) -> None: state.interrupt(reason) def _do_pause(self, state_msg: Optional[Message], next_state: Optional[state_machine.State] = None) -> bool: - """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: