diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 6aa8c345..f5265734 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator # unavailable forward references py:class plumpy.process_states.Command -py:class plumpy.process_states.State +py:class plumpy.state_machine.State py:class plumpy.base.state_machine.State py:class State py:class Process diff --git a/pyproject.toml b/pyproject.toml index a6307a69..f0eea9b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ 'kiwipy[rmq]~=0.8.5', 'nest_asyncio~=1.5,>=1.5.1', 'pyyaml~=6.0', + 'typing-extensions~=4.12' ] [project.urls] diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 864d2226..2ed1c5df 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -1,42 +1,144 @@ # -*- coding: utf-8 -*- -# mypy: disable-error-code=name-defined __version__ = '0.24.0' import logging +from .base.state_machine import TransitionFailed + # interfaces from .controller import ProcessController from .coordinator import Coordinator -from .events import * -from .exceptions import * -from .futures import * -from .loaders import * -from .message import * -from .mixins import * -from .persistence import * -from .ports import * -from .process_listener import * -from .process_states import * -from .processes import * -from .rmq import * -from .utils import * -from .workchains import * +from .events import ( + PlumpyEventLoopPolicy, + get_event_loop, + new_event_loop, + reset_event_loop_policy, + run_until_complete, + set_event_loop, + set_event_loop_policy, +) +from .exceptions import ( + ClosedError, + CoordinatorConnectionError, + CoordinatorTimeoutError, + InvalidStateError, + KilledError, + PersistenceError, + UnsuccessfulResult, +) +from .futures import CancellableAction, Future, capture_exceptions +from .loaders import DefaultObjectLoader, ObjectLoader, get_object_loader, set_object_loader +from .message import MsgContinue, MsgCreate, MsgKill, MsgLaunch, MsgPause, MsgPlay, MsgStatus, ProcessLauncher +from .persistence import ( + Bundle, + InMemoryPersister, + LoadSaveContext, + PersistedCheckpoint, + Persister, + PicklePersister, + Savable, + SavableFuture, + auto_persist, +) +from .ports import UNSPECIFIED, InputPort, OutputPort, Port, PortNamespace, PortValidationError +from .process_listener import ProcessListener +from .process_spec import ProcessSpec +from .process_states import ( + Continue, + Created, + Excepted, + Finished, + Interruption, + Kill, + Killed, + KillInterruption, + PauseInterruption, + ProcessState, + Running, + Stop, + Wait, + Waiting, +) +from .processes import BundleKeys, Process +from .utils import AttributesDict +from .workchains import ToContext, WorkChain, WorkChainSpec, if_, return_, while_ __all__ = ( - events.__all__ - + exceptions.__all__ - + processes.__all__ - + utils.__all__ - + futures.__all__ - + mixins.__all__ - + persistence.__all__ - + message.__all__ - + process_listener.__all__ - + workchains.__all__ - + loaders.__all__ - + ports.__all__ - + process_states.__all__ -) + ['ProcessController', 'Coordinator'] + 'UNSPECIFIED', + 'AttributesDict', + 'Bundle', + 'BundleKeys', + 'CancellableAction', + 'ClosedError', + 'Continue', + 'Coordinator', + 'CoordinatorConnectionError', + 'CoordinatorTimeoutError', + 'Created', + 'DefaultObjectLoader', + 'Excepted', + 'Finished', + 'Future', + 'InMemoryPersister', + 'InputPort', + 'Interruption', + 'InvalidStateError', + 'Kill', + 'KillInterruption', + 'Killed', + 'KilledError', + 'LoadSaveContext', + 'MsgContinue', + 'MsgCreate', + 'MsgKill', + 'MsgLaunch', + 'MsgPause', + 'MsgPlay', + 'MsgStatus', + 'ObjectLoader', + 'OutputPort', + 'PauseInterruption', + 'PersistedCheckpoint', + 'PersistenceError', + 'Persister', + 'PicklePersister', + 'PlumpyEventLoopPolicy', + 'Port', + 'PortNamespace', + 'PortValidationError', + 'Process', + 'ProcessController', + 'ProcessLauncher', + 'ProcessListener', + 'ProcessSpec', + 'ProcessState', + 'Running', + 'Savable', + 'SavableFuture', + 'Stop', + 'ToContext', + 'TransitionFailed', + 'UnsuccessfulResult', + 'Wait', + 'Waiting', + 'WorkChain', + 'WorkChainSpec', + 'auto_persist', + 'capture_exceptions', + 'create_continue_body', + 'create_launch_body', + 'get_event_loop', + 'get_object_loader', + 'if_', + 'new_event_loop', + 'reset_event_loop_policy', + 'return_', + 'run_until_complete', + 'set_event_loop', + 'set_event_loop_policy', + 'set_object_loader', + 'while_', +) # Do this se we don't get the "No handlers could be found..." warnings that will be produced diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 681858f0..a12981a0 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -13,15 +13,17 @@ from typing import ( Any, Callable, + ClassVar, Dict, Hashable, Iterable, List, Optional, + Protocol, Sequence, - Set, Type, Union, + runtime_checkable, ) from plumpy.futures import Future @@ -32,7 +34,6 @@ _LOGGER = logging.getLogger(__name__) -LABEL_TYPE = Union[None, enum.Enum, str] EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] @@ -88,12 +89,12 @@ def event( if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) - if not all(issubclass(state, State) for state in from_states): # type: ignore + if not all(isinstance(state, State) for state in from_states): # type: ignore raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) - if not all(issubclass(state, State) for state in to_states): # type: ignore + if not all(isinstance(state, State) for state in to_states): # type: ignore raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: @@ -127,57 +128,40 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -class State: - LABEL: LABEL_TYPE = None - # A set containing the labels of states that can be entered - # from this one - ALLOWED: Set[LABEL_TYPE] = set() +@runtime_checkable +class State(Protocol): + LABEL: ClassVar[Any] + ALLOWED: ClassVar[set[Any]] + is_terminal: ClassVar[bool] - @classmethod - def is_terminal(cls) -> bool: - return not cls.ALLOWED + def __init__(self, *args: Any, **kwargs: Any): ... - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): - """ - :param state_machine: The process this state belongs to - """ - self.state_machine = state_machine - self.in_state: bool = False + def enter(self) -> None: ... - def __str__(self) -> str: - return str(self.LABEL) + def exit(self) -> None: ... - @property - def label(self) -> LABEL_TYPE: - """Convenience property to get the state label""" - return self.LABEL - @super_check - def enter(self) -> None: - """Entering the state""" +@runtime_checkable +class Interruptable(Protocol): + def interrupt(self, reason: Exception) -> None: ... + - def execute(self) -> Optional['State']: +@runtime_checkable +class Proceedable(Protocol): + def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ + ... - @super_check - def exit(self) -> None: - """Exiting the state""" - if self.is_terminal(): - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': - return self.state_machine.create_state(state_label, *args, **kwargs) - def do_enter(self) -> None: - call_with_super_check(self.enter) - self.in_state = True +def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + if state_label not in st.get_states_map(): + raise ValueError(f'{state_label} is not a valid state') - def do_exit(self) -> None: - call_with_super_check(self.exit) - self.in_state = False + state_cls = st.get_states_map()[state_label] + return state_cls(*args, **kwargs) class StateEventHook(enum.Enum): @@ -228,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]: raise RuntimeError('States not defined') @classmethod - def initial_state_label(cls) -> LABEL_TYPE: + def initial_state_label(cls) -> Any: cls.__ensure_built() assert cls.STATES is not None return cls.STATES[0].LABEL @classmethod - def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: + def get_state_class(cls, label: Any) -> Type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None return cls._STATES_MAP[label] @@ -254,7 +238,7 @@ def __ensure_built(cls) -> None: # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: - assert issubclass(state_cls, State) + assert isinstance(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" cls._STATES_MAP[label] = state_cls @@ -278,13 +262,21 @@ def init(self) -> None: def __str__(self) -> str: return f'<{self.__class__.__name__}> ({self.state})' - def create_initial_state(self) -> State: - return self.get_state_class(self.initial_state_label())(self) + def create_initial_state(self, *args: Any, **kwargs: Any) -> State: + return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Optional[LABEL_TYPE]: + def state(self) -> State | None: if self._state is None: return None + return self._state + + @property + def state_label(self) -> Any: + if self._state is None: + return None + # XXX: should not use `.value` to access the printable output from LABEL + # LABEL as the ClassVar should have __str__ return self._state.LABEL def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: @@ -329,7 +321,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: # it can happened when transit from terminal state return None - initial_state_label = self._state.LABEL if self._state is not None else None + initial_state_label = self.state_label label = None try: self._transitioning = True @@ -347,7 +339,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: self._exit_current_state(new_state) self._enter_next_state(new_state) - if self._state is not None and self._state.is_terminal(): + if self._state is not None and self._state.is_terminal: call_with_super_check(self.on_terminated) except Exception: self._transitioning = False @@ -380,43 +372,25 @@ def get_debug(self) -> bool: def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: - # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic - # because the label is defined after the state and required to be know before calling this function. - # This method should be replaced by `_create_state_instance`. - # aiida-core using this method for its Waiting state override. - try: - return self.get_states_map()[state_label](self, *args, **kwargs) - except KeyError: - raise ValueError(f'{state_label} is not a valid state') - def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: - if next_state.label != self.initial_state_label(): + if next_state.LABEL != self.initial_state_label(): raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) - self._state.do_exit() + self._state.exit() def _enter_next_state(self, next_state: State) -> None: last_state = self._state self._fire_state_event(StateEventHook.ENTERING_STATE, next_state) # Enter the new state - next_state.do_enter() + next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - - def _create_state_instance(self, state_cls: Hashable, **kwargs: Any) -> State: - if state_cls not in self.get_states_map(): - raise ValueError(f'{state_cls} is not a valid state') - - cls = self.get_states_map()[state_cls] - - return cls(self, **kwargs) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 47ad4956..9262f856 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- import logging -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional + +from plumpy.persistence import LoadSaveContext, Savable, auto_load, auto_save, ensure_object_loader +from plumpy.utils import SAVED_STATE_TYPE from . import persistence @@ -13,7 +16,7 @@ @persistence.auto_persist('_listeners', '_listener_type') -class EventHelper(persistence.Savable): +class EventHelper: def __init__(self, listener_type: 'Type[ProcessListener]'): assert listener_type is not None, 'Must provide valid listener type' @@ -30,6 +33,26 @@ def remove_listener(self, listener: 'ProcessListener') -> None: def remove_all_listeners(self) -> None: self._listeners.clear() + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Savable: + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/src/plumpy/events.py b/src/plumpy/events.py index 3de81987..a6e62529 100644 --- a/src/plumpy/events.py +++ b/src/plumpy/events.py @@ -5,16 +5,6 @@ import sys from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence -__all__ = [ - 'PlumpyEventLoopPolicy', - 'get_event_loop', - 'new_event_loop', - 'reset_event_loop_policy', - 'run_until_complete', - 'set_event_loop', - 'set_event_loop_policy', -] - if TYPE_CHECKING: from .processes import Process diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 5d05ea4b..b4358770 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -1,16 +1,6 @@ # -*- coding: utf-8 -*- from typing import Optional -__all__ = [ - 'ClosedError', - 'CoordinatorConnectionError', - 'CoordinatorTimeoutError', - 'InvalidStateError', - 'KilledError', - 'PersistenceError', - 'UnsuccessfulResult', -] - class KilledError(Exception): """The process was killed.""" diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 30e3ac3f..0e9de6d0 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -9,8 +9,6 @@ import contextlib from typing import Any, Callable, Generator -__all__ = ['CancellableAction', 'Future', 'capture_exceptions'] - class InvalidFutureError(Exception): """Exception for when a future or action is in an invalid state""" diff --git a/src/plumpy/loaders.py b/src/plumpy/loaders.py index a01f9b60..bb248d6a 100644 --- a/src/plumpy/loaders.py +++ b/src/plumpy/loaders.py @@ -3,8 +3,6 @@ import importlib from typing import Any, Optional -__all__ = ['DefaultObjectLoader', 'ObjectLoader', 'get_object_loader', 'set_object_loader'] - class ObjectLoader(metaclass=abc.ABCMeta): """ diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py deleted file mode 100644 index 10142eb7..00000000 --- a/src/plumpy/mixins.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Any, Optional - -from . import persistence -from .utils import SAVED_STATE_TYPE, AttributesDict - -__all__ = ['ContextMixin'] - - -class ContextMixin(persistence.Savable): - """ - Add a context to a Process. The contents of the context will be saved - in the instance state unlike standard instance variables. - """ - - CONTEXT: str = '_context' - - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._context: Optional[AttributesDict] = AttributesDict() - - @property - def ctx(self) -> Optional[AttributesDict]: - return self._context - - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - """Add the instance state to ``out_state``. - .. important:: - - The instance state will contain a pointer to the ``ctx``, - and so should be deep copied or serialised before persisting. - """ - super().save_instance_state(out_state, save_context) - if self._context is not None: - out_state[self.CONTEXT] = self._context.__dict__ - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - try: - self._context = AttributesDict(**saved_state[self.CONTEXT]) - except KeyError: - pass diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index ba755bc5..b2da7eef 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import abc import asyncio import collections @@ -9,34 +11,60 @@ import os import pickle from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + Iterable, + List, + Optional, + Protocol, + TypeVar, + cast, + runtime_checkable, +) import yaml from . import futures, loaders, utils -from .base.utils import call_with_super_check, super_check from .utils import PID_TYPE, SAVED_STATE_TYPE -__all__ = [ - 'Bundle', - 'InMemoryPersister', - 'LoadSaveContext', - 'PersistedCheckpoint', - 'Persister', - 'PicklePersister', - 'Savable', - 'SavableFuture', - 'auto_persist', -] - PersistedCheckpoint = collections.namedtuple('PersistedCheckpoint', ['pid', 'tag']) if TYPE_CHECKING: from .processes import Process +class LoadSaveContext: + def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: + self._values = dict(**kwargs) + self.loader = loader + + def __getattr__(self, item: str) -> Any: + try: + return self._values[item] + except KeyError: + raise AttributeError(f"item '{item}' not found") + + def __iter__(self) -> Iterable[Any]: + return self._value.__iter__() + + def __contains__(self, item: Any) -> bool: + return self._values.__contains__(item) + + def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': + """Add additional information to the context by making a copy with the new values""" + extended = self._values.copy() + extended.update(kwargs) + loader = extended.pop('loader', self.loader) + return LoadSaveContext(loader=loader, **extended) + + class Bundle(dict): - def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False): + def __init__(self, savable: 'Savable', save_context: LoadSaveContext | None = None, dereference: bool = False): """ Create a bundle from a savable. Optionally keep information about the class loader that can be used to load the classes in the bundle. @@ -52,7 +80,7 @@ class loader that can be used to load the classes in the bundle. else: self.update(savable.save(save_context)) - def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable': + def unbundle(self, load_context: LoadSaveContext | None = None) -> 'Savable': """ This method loads the class of the object and calls its recreate_from method passing the positional and keyword arguments. @@ -61,7 +89,29 @@ def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable :return: An instance of the Savable """ - return Savable.load(self, load_context) + return load(self, load_context) + + +def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': + """ + Load a `Savable` from a saved instance state. The load context is a way of passing + runtime data to the object being loaded. + + :param saved_state: The saved state + :param load_context: Additional runtime state that can be passed into when loading. + The type and content (if any) is completely user defined + :return: The loaded Savable instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + assert load_context.loader is not None # required for type checking + try: + class_name = SaveUtil.get_class_name(saved_state) + load_cls: Savable = load_context.loader.load_object(class_name) + except KeyError: + raise ValueError('Class name not found in saved state') + else: + return load_cls.recreate_from(saved_state, load_context) _BUNDLE_TAG = '!plumpy:Bundle' @@ -345,22 +395,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -SavableClsType = TypeVar('SavableClsType', bound='type[Savable]') - - -def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: - if savable._auto_persist is None: - savable._auto_persist = set() - else: - savable._auto_persist = set(savable._auto_persist) - savable.auto_persist(*members) - return savable - - return wrapped - - -def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': +def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': """ Given a LoadSaveContext this method will ensure that it has a valid class loader using the following priorities: @@ -382,7 +417,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV # 2) Try getting from saved_state default_loader = loaders.get_object_loader() try: - loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER) + loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER) except ValueError: # 3) Fall back to default loader = default_loader @@ -392,31 +427,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV return context.copyextend(loader=loader) -class LoadSaveContext: - def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: - self._values = dict(**kwargs) - self.loader = loader - - def __getattr__(self, item: str) -> Any: - try: - return self._values[item] - except KeyError: - raise AttributeError(f"item '{item}' not found") - - def __iter__(self) -> Iterable[Any]: - return self._value.__iter__() - - def __contains__(self, item: Any) -> bool: - return self._values.__contains__(item) - - def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': - """Add additional information to the context by making a copy with the new values""" - extended = self._values.copy() - extended.update(kwargs) - loader = extended.pop('loader', self.loader) - return LoadSaveContext(loader=loader, **extended) - - META: str = '!!meta' META__CLASS_NAME: str = 'class_name' META__OBJECT_LOADER: str = 'object_loader' @@ -426,46 +436,48 @@ def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': META__TYPE__SAVABLE: str = 'S' -class Savable: - CLASS_NAME: str = 'class_name' +class SaveUtil: + @staticmethod + def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: + user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {}) + user_dict[name] = value - _auto_persist: Optional[Set[str]] = None - _persist_configured = False + @staticmethod + def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: + try: + return saved_state[META][name] + except KeyError: + raise ValueError(f"Unknown meta key '{name}'") @staticmethod - def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Load a `Savable` from a saved instance state. The load context is a way of passing - runtime data to the object being loaded. + def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: + return out_state.setdefault(META, {}) - :param saved_state: The saved state - :param load_context: Additional runtime state that can be passed into when loading. - The type and content (if any) is completely user defined - :return: The loaded Savable instance + @staticmethod + def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: + SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name - """ - load_context = _ensure_object_loader(load_context, saved_state) - assert load_context.loader is not None # required for type checking + @staticmethod + def get_class_name(saved_state: SAVED_STATE_TYPE) -> str: + return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME] + + @staticmethod + def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: + type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {}) + type_dict[name] = type_spec + + @staticmethod + def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: try: - class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) + return saved_state[META][META__TYPES][name] except KeyError: - raise ValueError('Class name not found in saved state') - else: - return load_cls.recreate_from(saved_state, load_context) - - @classmethod - def auto_persist(cls, *members: str) -> None: - if cls._auto_persist is None: - cls._auto_persist = set() - cls._auto_persist.update(members) + pass - @classmethod - def persist(cls) -> None: - pass +@runtime_checkable +class Savable(Protocol): @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -475,137 +487,119 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - call_with_super_check(obj.load_instance_state, saved_state, load_context) - return obj + ... - @super_check - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - self.load_members(self._auto_persist, saved_state, load_context) + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... - @super_check - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - self.save_members(self._auto_persist, out_state) - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = {} +@runtime_checkable +class SavableWithAutoPersist(Savable, Protocol): + _auto_persist: ClassVar[set[str]] = set() - if save_context is None: - save_context = LoadSaveContext() - utils.type_check(save_context, LoadSaveContext) +def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = {} - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader + if save_context is None: + save_context = LoadSaveContext() - Savable._set_class_name(out_state, loader.identify_object(self.__class__)) - call_with_super_check(self.save_instance_state, out_state, save_context) - return out_state + utils.type_check(save_context, LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__)) - def save_members(self, members: Iterable[str], out_state: SAVED_STATE_TYPE) -> None: - for member in members: - value = getattr(self, member) + if isinstance(obj, SavableWithAutoPersist): + for member in obj._auto_persist: + value = getattr(obj, member) if inspect.ismethod(value): - if value.__self__ is not self: + if value.__self__ is not obj: raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD) value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + elif isinstance(value, Savable) and not isinstance(value, type): + # persist for a savable obj, call `save` method of obj. + # the rhs branch is for when value is a Savable class, it is true runtime check + # of lhs condition. + SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) value = value.save() else: value = copy.deepcopy(value) out_state[member] = value - def load_members( - self, members: Iterable[str], saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None - ) -> None: - for member in members: - setattr(self, member, self._get_value(saved_state, member, load_context)) + return out_state - def _ensure_persist_configured(self) -> None: - if not self._persist_configured: - self.persist() - self._persist_configured = True - # region Metadata getter/setters +def load_auto_persist_params( + obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None +) -> None: + for member in obj._auto_persist: + setattr(obj, member, _get_value(obj, saved_state, member, load_context)) - @staticmethod - def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: - user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {}) - user_dict[name] = value - @staticmethod - def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: - try: - return saved_state[META][name] - except KeyError: - raise ValueError(f"Unknown meta key '{name}'") +T = TypeVar('T', bound=Savable) - @staticmethod - def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: - return out_state.setdefault(META, {}) - @staticmethod - def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: - Savable._get_create_meta(out_state)[META__CLASS_NAME] = name +def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T: + obj = cls.__new__(cls) - @staticmethod - def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str: - return Savable._get_create_meta(saved_state)[META__CLASS_NAME] + if isinstance(obj, SavableWithAutoPersist): + load_auto_persist_params(obj, saved_state, load_context) - @staticmethod - def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: - type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {}) - type_dict[name] = type_spec + return obj - @staticmethod - def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: - try: - return saved_state[META][META__TYPES][name] - except KeyError: - pass - # endregion +def _get_value( + obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None +) -> MethodType | Savable: + value = saved_state[name] + + typ = SaveUtil.get_meta_type(saved_state, name) + if typ == META__TYPE__METHOD: + value = getattr(obj, value) + elif typ == META__TYPE__SAVABLE: + value = load(value, load_context) + + return value - def _get_value( - self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] - ) -> Union[MethodType, 'Savable']: - value = saved_state[name] - typ = Savable._get_meta_type(saved_state, name) - if typ == META__TYPE__METHOD: - value = getattr(self, value) - elif typ == META__TYPE__SAVABLE: - value = Savable.load(value, load_context) +def auto_persist(*members: str) -> Callable[..., Savable]: + def wrapped(savable_cls: type) -> Savable: + if not hasattr(savable_cls, '_auto_persist') or savable_cls._auto_persist is None: + savable_cls._auto_persist = set() # type: ignore[attr-defined] + else: + savable_cls._auto_persist = set(savable_cls._auto_persist) + + savable_cls._auto_persist.update(members) # type: ignore[attr-defined] + # XXX: validate on `save` and `recreate_from` method?? + return cast(Savable, savable_cls) - return value + return wrapped +# FIXME: move me to another module? savablefuture.py? @auto_persist('_state', '_result') -class SavableFuture(futures.Future, Savable): +class SavableFuture(futures.Future): """ A savable future. .. note: This does not save any assigned done callbacks. """ - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) if self.done() and self.exception() is not None: out_state['exception'] = self.exception() + return out_state + @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -617,7 +611,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) try: loop = load_context.loop @@ -643,11 +637,13 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa obj = cls(loop=loop) obj.cancel() - return obj + # ## XXX: load_instance_state: test not cover + # auto_load(obj, saved_state, load_context) + # + # if obj._callbacks: + # # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list + # for callback in obj._callbacks: + # obj.remove_done_callback(callback) # type: ignore[arg-type] + # ## UNTILHERE XXX: - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - if self._callbacks: - # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list - for callback in self._callbacks: - self.remove_done_callback(callback) # type: ignore[arg-type] + return obj diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index cfbd92d5..8522f061 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -11,8 +11,6 @@ from plumpy.utils import AttributesFrozendict, is_mutable_property, type_check -__all__ = ['UNSPECIFIED', 'InputPort', 'OutputPort', 'Port', 'PortNamespace', 'PortValidationError'] - _LOGGER = logging.getLogger(__name__) UNSPECIFIED = () diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 8e1acf94..8e9673bb 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -2,17 +2,21 @@ import abc from typing import TYPE_CHECKING, Any, Dict, Optional -from . import persistence -from .utils import SAVED_STATE_TYPE, protected +from plumpy.persistence import LoadSaveContext, auto_save, ensure_object_loader -__all__ = ['ProcessListener'] +from . import persistence +from .utils import SAVED_STATE_TYPE if TYPE_CHECKING: + from plumpy.persistence import Savable + from .processes import Process +# FIXME: test any process listener is a savable + @persistence.auto_persist('_params') -class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): +class ProcessListener(metaclass=abc.ABCMeta): # region Persistence methods def __init__(self) -> None: @@ -22,12 +26,26 @@ def __init__(self) -> None: def init(self, **kwargs: Any) -> None: self._params = kwargs - @protected - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] - ) -> None: - super().load_instance_state(saved_state, load_context) - self.init(**saved_state['_params']) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + obj.init(**saved_state['_params']) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state # endregion diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 996886a2..3219328e 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,12 +5,26 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + ClassVar, + Optional, + Tuple, + Type, + Union, + cast, + final, +) import yaml +from typing_extensions import override from yaml.loader import Loader from plumpy.message import Message, MsgKill, MsgPause +from plumpy.persistence import ensure_object_loader try: import tblib @@ -20,9 +34,15 @@ _HAS_TBLIB = False from . import exceptions, futures, persistence, utils -from .base import state_machine +from .base import state_machine as st from .lang import NULL -from .persistence import auto_persist +from .persistence import ( + LoadSaveContext, + Savable, + auto_load, + auto_persist, + auto_save, +) from .utils import SAVED_STATE_TYPE, ensure_coroutine __all__ = [ @@ -70,8 +90,26 @@ def __init__(self, msg_text: str | None): # region Commands -class Command(persistence.Savable): - pass +class Command: + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @auto_persist('msg') @@ -117,17 +155,36 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + @override + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + return out_state + + @override + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + try: - self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) + obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN]) except ValueError: - process = load_context.process - self.continue_fn = getattr(process, saved_state[self.CONTINUE_FN]) + if load_context is not None: + obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN]) + else: + raise + return obj # endregion @@ -140,61 +197,69 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ - CREATED: str = 'created' - RUNNING: str = 'running' - WAITING: str = 'waiting' - FINISHED: str = 'finished' - EXCEPTED: str = 'excepted' - KILLED: str = 'killed' - - -@auto_persist('in_state') -class State(state_machine.State, persistence.Savable): - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process - - def interrupt(self, reason: Any) -> None: - pass + CREATED = 'created' + RUNNING = 'running' + WAITING = 'waiting' + FINISHED = 'finished' + EXCEPTED = 'excepted' + KILLED = 'killed' +@final @auto_persist('args', 'kwargs') -class Created(State): - LABEL = ProcessState.CREATED - ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} +class Created: + LABEL: ClassVar = ProcessState.CREATED + ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + obj.process = load_context.process + obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) + + return obj + + def execute(self) -> st.State: + return st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.run_fn, *self.args, **self.kwargs + ) + + def enter(self) -> None: ... - def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + def exit(self) -> None: ... +@final @auto_persist('args', 'kwargs') -class Running(State): - LABEL = ProcessState.RUNNING - ALLOWED = { +class Running: + LABEL: ClassVar = ProcessState.RUNNING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, @@ -210,36 +275,52 @@ class Running(State): _running: bool = False _run_handle = None + is_terminal: ClassVar[bool] = False + def __init__( self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any ) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = ensure_coroutine(run_fn) - # We wrap `run_fn` to a coroutine so we can apply await on it, - # even it if it was not a coroutine in the first place. - # This allows the same usage of async and non-async function - # with the await syntax while not changing the program logic. self.args = args self.kwargs = kwargs self._run_handle = None - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) - if self.COMMAND in saved_state: - self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + obj.process = load_context.process + + obj.run_fn = ensure_coroutine(getattr(obj.process, saved_state[obj.RUN_FN])) + if obj.COMMAND in saved_state: + obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore + + return obj def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> State: # type: ignore + async def execute(self) -> st.State: if self._command is not None: command = self._command else: @@ -253,8 +334,10 @@ async def execute(self) -> State: # type: ignore # Let this bubble up to the caller raise except Exception: - excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(State, excepted) + _, exception, traceback = sys.exc_info() + # excepted = state_cls(exception=exception, traceback=traceback) + excepted = Excepted(exception=exception, traceback=traceback) + return excepted else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -263,32 +346,52 @@ async def execute(self) -> State: # type: ignore # Got passed a basic return type result = Stop(result, True) - command = result + command = cast(Stop, result) next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State: if isinstance(command, Kill): - state = self.create_state(ProcessState.KILLED, command.msg) + state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = st.create_state( + self.process, ProcessState.FINISHED, result=command.result, successful=command.successful + ) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = st.create_state( + self.process, + ProcessState.WAITING, + process=self.process, + done_callback=command.continue_fn, + msg=command.msg, + data=command.data, + ) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = st.create_state( + self.process, + ProcessState.RUNNING, + process=self.process, + run_fn=command.continue_fn, + *command.args, + **command.kwargs, + ) else: raise ValueError('Unrecognised command') - return cast(State, state) # casting from base.State to process.State + return state + + def enter(self) -> None: ... + + def exit(self) -> None: ... @auto_persist('msg', 'data') -class Waiting(State): - LABEL = ProcessState.WAITING - ALLOWED = { +class Waiting: + LABEL: ClassVar = ProcessState.WAITING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, @@ -300,6 +403,8 @@ class Waiting(State): _interruption = None + is_terminal: ClassVar[bool] = False + def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: @@ -313,31 +418,48 @@ def __init__( msg: Optional[str] = None, data: Optional[Any] = None, ) -> None: - super().__init__(process) + self.process = process self.done_callback = done_callback self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - callback_name = saved_state.get(self.DONE_CALLBACK, None) + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + obj.process = load_context.process + + callback_name = saved_state.get(obj.DONE_CALLBACK, None) if callback_name is not None: - self.done_callback = getattr(self.process, callback_name) + obj.done_callback = getattr(obj.process, callback_name) else: - self.done_callback = None - self._waiting_future = futures.Future() + obj.done_callback = None + obj._waiting_future = futures.Future() + return obj - def interrupt(self, reason: Any) -> None: + def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> State: # type: ignore + async def execute(self) -> st.State: try: result = await self._waiting_future except Interruption: @@ -348,11 +470,15 @@ async def execute(self) -> State: # type: ignore raise if result == NULL: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.done_callback + ) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result + ) - return cast(State, next_state) # casting from base.State to process.State + return next_state def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -362,55 +488,77 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) + def enter(self) -> None: ... + + def exit(self) -> None: ... + -class Excepted(State): +@final +@auto_persist() +class Excepted: """ - Excepted state, can optionally provide exception and trace_back + Excepted state, can optionally provide exception and traceback :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - LABEL = ProcessState.EXCEPTED + LABEL: ClassVar = ProcessState.EXCEPTED + ALLOWED: ClassVar[set[str]] = set() EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' + is_terminal: ClassVar = True + def __init__( self, - process: 'Process', exception: Optional[BaseException], - trace_back: Optional[TracebackType] = None, + traceback: Optional[TracebackType] = None, ): """ - :param process: The associated process :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - super().__init__(process) self.exception = exception - self.traceback = trace_back + self.traceback = traceback def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] return super().__str__() + f'({exception})' - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + + obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: - self.traceback = tblib.Traceback.from_string(saved_state[self.TRACEBACK], strict=False) + obj.traceback = tblib.Traceback.from_string(saved_state[obj.TRACEBACK], strict=False) except KeyError: - self.traceback = None + obj.traceback = None else: - self.traceback = None + obj.traceback = None + return obj def get_exc_info( self, @@ -424,25 +572,57 @@ def get_exc_info( self.traceback, ) + def enter(self) -> None: ... + + def exit(self) -> None: ... + +@final @auto_persist('result', 'successful') -class Finished(State): +class Finished: """State for process is finished. :param result: The result of process :param successful: Boolean for the exit code is ``0`` the process is successful. """ - LABEL = ProcessState.FINISHED + LABEL: ClassVar = ProcessState.FINISHED + ALLOWED: ClassVar[set[str]] = set() - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - super().__init__(process) + is_terminal: ClassVar[bool] = True + + def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + + def enter(self) -> None: ... + + def exit(self) -> None: ... + +@final @auto_persist('msg') -class Killed(State): +class Killed: """ Represents a state where a process has been killed. @@ -452,15 +632,40 @@ class Killed(State): :param msg: An optional message explaining the reason for the process termination. """ - LABEL = ProcessState.KILLED + LABEL: ClassVar = ProcessState.KILLED + ALLOWED: ClassVar[set[str]] = set() - def __init__(self, process: 'Process', msg: Optional[Message]): + is_terminal: ClassVar[bool] = True + + def __init__(self, msg: Optional[Message]): """ - :param process: The associated process :param msg: Optional kill message """ - super().__init__(process) self.msg = msg + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + + def enter(self) -> None: ... + + def exit(self) -> None: ... + # endregion diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index f2b314e2..e3784b21 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -21,6 +21,7 @@ Any, Awaitable, Callable, + ClassVar, Dict, Generator, Hashable, @@ -35,6 +36,7 @@ ) from plumpy.coordinator import BroadcastFilter, Coordinator +from plumpy.persistence import ensure_object_loader try: from aiocontextvars import ContextVar @@ -45,7 +47,15 @@ from . import events, exceptions, message, persistence, ports, process_states, utils from .base import state_machine -from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event +from .base.state_machine import ( + Interruptable, + Proceedable, + StateEntryFailed, + StateMachine, + StateMachineError, + create_state, + event, +) from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper from .futures import CancellableAction, capture_exceptions @@ -56,8 +66,6 @@ T = TypeVar('T') -__all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] - _LOGGER = logging.getLogger(__name__) PROCESS_STACK = ContextVar('process stack', default=[]) @@ -66,7 +74,7 @@ class BundleKeys: """ String keys used by the process to save its state in the state bundle. - See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`. + See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.recreate_from`. """ @@ -107,7 +115,7 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: '_pre_paused_status', '_event_helper', ) -class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): +class Process(StateMachine, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. @@ -157,6 +165,7 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe _cleanups: Optional[List[Callable[[], None]]] = None __called: bool = False + _auto_persist: ClassVar[set[str]] @classmethod def current(cls) -> Optional['Process']: @@ -172,7 +181,7 @@ def current(cls) -> Optional['Process']: return None @classmethod - def get_states(cls) -> Sequence[Type[process_states.State]]: + def get_states(cls) -> Sequence[Type[state_machine.State]]: """Return all allowed states of the process.""" state_classes = cls.get_state_classes() return ( @@ -181,7 +190,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -248,19 +257,71 @@ def recreate_from( cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None, - ) -> 'Process': - """ - Recreate a process from a saved state, passing any positional and - keyword arguments on to load_instance_state + ) -> Process: + """Recreate a process from a saved state, passing any positional :param saved_state: The saved state to load from :param load_context: The load context to use :return: An instance of the object with its state loaded from the save state. """ - process = cast(Process, super().recreate_from(saved_state, load_context)) - call_with_super_check(process.init) - return process + load_context = ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) + + proc._setup_event_hooks() + + # Runtime variables, set initial states + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._coordinator = None + + if 'loop' in load_context: + proc._loop = load_context.loop + else: + _LOGGER.warning('cannot find `loop` store in load_context, use default event loop') + proc._loop = asyncio.get_event_loop() + + proc._state = proc.recreate_state(saved_state['_state']) + + if 'coordinator' in load_context: + proc._coordinator = load_context.coordinator + else: + _LOGGER.warning('cannot find `coordinator` store in load_context') + + if 'logger' in load_context: + proc._logger = load_context.logger + else: + _LOGGER.warning('cannot find `logger` store in load_context') + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.load_auto_persist_params(proc, saved_state, load_context) + + # Inputs/outputs + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._raw_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._parsed_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.OUTPUTS]) + proc._outputs = decoded + except KeyError: + proc._outputs = {} + + call_with_super_check(proc.init) + return proc def __init__( self, @@ -353,10 +414,10 @@ def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering( - cast(process_states.State, state) + cast(state_machine.State, state) ), state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( - cast(Optional[process_states.State], from_state) + cast(Optional[state_machine.State], from_state) ), state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } @@ -463,7 +524,9 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal() + if self.state is None: + raise exceptions.InvalidStateError('process is not in state None that is invalid') + return self.state.is_terminal def result(self) -> Any: """ @@ -473,12 +536,12 @@ def result(self) -> Any: If in any other state this will raise an InvalidStateError. :return: The result of the process """ - if isinstance(self._state, process_states.Finished): - return self._state.result - if isinstance(self._state, process_states.Killed): - raise exceptions.KilledError(self._state.msg) - if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception('process excepted')) + if isinstance(self.state, process_states.Finished): + return self.state.result + if isinstance(self.state, process_states.Killed): + raise exceptions.KilledError(self.state.msg) + if isinstance(self.state, process_states.Excepted): + raise (self.state.exception or Exception('process excepted')) raise exceptions.InvalidStateError @@ -488,7 +551,7 @@ def successful(self) -> bool: Will raise if the process is not in the FINISHED state """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError as exception: raise exceptions.InvalidStateError('process is not in the finished state') from exception @@ -499,25 +562,25 @@ def is_successful(self) -> bool: :return: boolean, True if the process is in `Finished` state with `successful` attribute set to `True` """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError: return False def killed(self) -> bool: """Return whether the process is killed.""" - return self.state == process_states.ProcessState.KILLED + return self.state_label == process_states.ProcessState.KILLED def killed_msg(self) -> Optional[Message]: """Return the killed message.""" - if isinstance(self._state, process_states.Killed): - return self._state.msg + if isinstance(self.state, process_states.Killed): + return self.state.msg raise exceptions.InvalidStateError('Has not been killed') def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" - if isinstance(self._state, process_states.Excepted): - return self._state.exception + if isinstance(self.state, process_states.Excepted): + return self.state.exception return None @@ -527,7 +590,7 @@ def is_excepted(self) -> bool: :return: boolean, True if the process is in ``EXCEPTED`` state. """ - return self.state == process_states.ProcessState.EXCEPTED + return self.state_label == process_states.ProcessState.EXCEPTED def done(self) -> bool: """Return True if the call was successfully killed or finished running. @@ -536,7 +599,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal() + return self.has_terminated() # endregion @@ -564,7 +627,7 @@ def callback_excepted( exception: Optional[BaseException], trace: Optional[TracebackType], ) -> None: - if self.state != process_states.ProcessState.EXCEPTED: + if self.state_label != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @contextlib.contextmanager @@ -608,20 +671,17 @@ async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) # region Persistence - def save_instance_state( - self, - out_state: SAVED_STATE_TYPE, - save_context: Optional[persistence.LoadSaveContext], - ) -> None: + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: """ Ask the process to save its current instance state. :param out_state: A bundle to save the state to :param save_context: The save context """ - super().save_instance_state(out_state, save_context) + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - out_state['_state'] = self._state.save() + if isinstance(self.state, persistence.Savable): + out_state['_state'] = self.state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -633,61 +693,7 @@ def save_instance_state( if self.outputs: out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) - @protected - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - """Load the process from its saved instance state. - - :param saved_state: A bundle to load the state from - :param load_context: The load context - - """ - # First make sure the state machine constructor is called - super().__init__() - - self._setup_event_hooks() - - # Runtime variables, set initial states - self._future = persistence.SavableFuture() - self._event_helper = EventHelper(ProcessListener) - self._logger = None - self._coordinator = None - - if 'loop' in load_context: - self._loop = load_context.loop - else: - self._loop = asyncio.get_event_loop() - - self._state: process_states.State = self.recreate_state(saved_state['_state']) - - if 'coordinator' in load_context: - self._coordinator = load_context.coordinator - - if 'logger' in load_context: - self._logger = load_context.logger - - # Need to call this here as things downstream may rely on us having the runtime variable above - super().load_instance_state(saved_state, load_context) - - # Inputs/outputs - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) - self._raw_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._raw_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) - self._parsed_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._parsed_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.OUTPUTS]) - self._outputs = decoded - except KeyError: - self._outputs = {} - - # endregion + return out_state def add_process_listener(self, listener: ProcessListener) -> None: """Add a process listener to the process. @@ -715,7 +721,7 @@ def log_with_pid(self, level: int, msg: str) -> None: # region Events - def on_entering(self, state: process_states.State) -> None: + def on_entering(self, state: state_machine.State) -> None: # Map these onto direct functions that the subclass can implement state_label = state.LABEL if state_label == process_states.ProcessState.CREATED: @@ -731,9 +737,9 @@ def on_entering(self, state: process_states.State) -> None: elif state_label == process_states.ProcessState.EXCEPTED: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore - def on_entered(self, from_state: Optional[process_states.State]) -> None: + def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement - state_label = self._state.LABEL + state_label = self.state_label if state_label == process_states.ProcessState.RUNNING: call_with_super_check(self.on_running) elif state_label == process_states.ProcessState.WAITING: @@ -745,14 +751,16 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: elif state_label == process_states.ProcessState.KILLED: call_with_super_check(self.on_killed) - if self._coordinator and isinstance(self.state, enum.Enum): + if self._coordinator and isinstance(self.state_label, enum.Enum): from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None - subject = f'state_changed.{from_label}.{self.state.value}' + subject = f'state_changed.{from_label}.{self.state_label.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._coordinator.broadcast_send(body=None, sender=self.pid, subject=subject) except exceptions.CoordinatorCommunicationError: - message = f'Process<{self.pid}>: cannot broadcast state change from {from_label} to {self.state.value}' + message = ( + f'Process<{self.pid}>: cannot broadcast state change from {from_label} to {self.state_label.value}' + ) self.logger.warning(message) self.logger.debug(message, exc_info=True) except Exception: @@ -760,7 +768,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: raise def on_exiting(self) -> None: - state = self.state + state = self.state_label if state == process_states.ProcessState.WAITING: call_with_super_check(self.on_exit_waiting) elif state == process_states.ProcessState.RUNNING: @@ -867,7 +875,7 @@ def on_finish(self, result: Any, successful: bool) -> None: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] - finished_state = state_cls(self, result=result, successful=False) + finished_state = state_cls(result=result, successful=False) raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -1100,9 +1108,7 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - new_state = self._create_state_instance( - process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace - ) + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) self.transition_to(new_state) def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: @@ -1127,6 +1133,11 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: return self._pausing if self._stepping: + if not isinstance(self.state, Interruptable): + raise exceptions.InvalidStateError( + f'cannot interrupt {self.state.__class__}, method `interrupt` not implement' + ) + # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.PauseInterruption(msg_text) @@ -1139,7 +1150,11 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: msg = MsgPause.new(msg_text) return self._do_pause(state_msg=msg) - def _do_pause(self, state_msg: Optional[Message], next_state: Optional[process_states.State] = None) -> bool: + @staticmethod + def _interrupt(state: Interruptable, reason: Exception) -> None: + state.interrupt(reason) + + def _do_pause(self, state_msg: Optional[Message], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: @@ -1171,11 +1186,13 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> Ca if isinstance(exception, process_states.KillInterruption): - def do_kill(_next_state: process_states.State) -> Any: + def do_kill(_next_state: state_machine.State) -> Any: try: - new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=exception.msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=exception.msg) self.transition_to(new_state) return True + # FIXME: if try block except, will hit deadlock in event loop + # need to know how to debug it, and where to set a timeout. finally: self._killing = None @@ -1217,18 +1234,17 @@ def play(self) -> bool: @event(from_states=process_states.Waiting) def resume(self, *args: Any) -> None: """Start running the process again.""" - return self._state.resume(*args) # type: ignore + return self.state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure - :param trace_back: Optional exception traceback + :param traceback: Optional exception traceback """ - new_state = self._create_state_instance( - process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace_back - ) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback) self.transition_to(new_state) def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: @@ -1236,7 +1252,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: Kill the process :param msg: An optional kill message """ - if self.state == process_states.ProcessState.KILLED: + if self.state_label == process_states.ProcessState.KILLED: # Already killed return True @@ -1248,7 +1264,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: # Already killing return self._killing - if self._stepping: + if self._stepping and isinstance(self.state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg_text) @@ -1258,7 +1274,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: return cast(CancellableAction, self._interrupt_action) msg = MsgKill.new(msg_text) - new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True @@ -1269,19 +1285,16 @@ def is_killing(self) -> bool: # endregion - def create_initial_state(self) -> process_states.State: + def create_initial_state(self) -> state_machine.State: """This method is here to override its superclass. Automatically enter the CREATED state when the process is created. :return: A Created state """ - return cast( - process_states.State, - self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), - ) + return self.get_state_class(process_states.ProcessState.CREATED)(self, self.run) - def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: + def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ Create a state object from a saved state @@ -1289,7 +1302,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) + return cast(state_machine.State, persistence.load(saved_state, load_context)) # endregion @@ -1327,11 +1340,13 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused + if not isinstance(self.state, Proceedable): + raise StateMachineError(f'cannot step from {self.state.__class__}, async method `execute` not implemented') + try: self._stepping = True next_state = None try: - # XXX: debug log when need to step to next state next_state = await self._run_task(self._state.execute) except process_states.Interruption as exception: # If the interruption was caused by a call to a Process method then there should @@ -1348,13 +1363,16 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + next_state = create_state( + self, process_states.ProcessState.EXCEPTED, exception=sys.exc_info()[1], traceback=sys.exc_info()[2] + ) self._set_interrupt_action(None) if self._interrupt_action: self._interrupt_action.run(next_state) else: # Everything nominal so transition to the next state + self.logger.debug(f'Process<{self.pid}>: transfer from {self._state.LABEL} to {next_state.LABEL}') self.transition_to(next_state) finally: diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index bd1b70a7..3c37ce08 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -23,8 +23,6 @@ from . import lang from .settings import check_override, check_protected -__all__ = ['AttributesDict'] - protected = lang.protected(check=check_protected) override = lang.override(check=check_override) diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 5df20bf4..5c459d0e 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -11,11 +11,11 @@ Any, Callable, Dict, - Hashable, List, Mapping, MutableSequence, Optional, + Protocol, Sequence, Tuple, Type, @@ -23,12 +23,17 @@ cast, ) +from plumpy import utils +from plumpy.base import state_machine +from plumpy.base.utils import call_with_super_check from plumpy.coordinator import Coordinator +from plumpy.event_helper import EventHelper +from plumpy.exceptions import InvalidStateError +from plumpy.persistence import LoadSaveContext, Savable, auto_persist, auto_save, ensure_object_loader +from plumpy.process_listener import ProcessListener -from . import lang, mixins, persistence, process_states, processes -from .utils import PID_TYPE, SAVED_STATE_TYPE - -__all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] +from . import lang, persistence, process_spec, process_states, processes +from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict ToContext = dict @@ -37,7 +42,7 @@ EXIT_CODE_TYPE = int -class WorkChainSpec(processes.ProcessSpec): +class WorkChainSpec(process_spec.ProcessSpec): def __init__(self) -> None: super().__init__() self._outline: Optional[Union['_Instruction', '_FunctionCall']] = None @@ -68,6 +73,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: return self._outline +# FIXME: better use composition here @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): """Overwrite the waiting state""" @@ -77,24 +83,14 @@ def __init__( process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, + data: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: - super().__init__(process, done_callback, msg, awaiting) + super().__init__(process, done_callback, msg, data) self._awaiting: Dict[asyncio.Future, str] = {} - for awaitable, key in (awaiting or {}).items(): + for awaitable, key in (data or {}).items(): resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key - def enter(self) -> None: - super().enter() - for awaitable in self._awaiting: - awaitable.add_done_callback(self._awaitable_done) - - def exit(self) -> None: - super().exit() - for awaitable in self._awaiting: - awaitable.remove_done_callback(self._awaitable_done) - def _awaitable_done(self, awaitable: asyncio.Future) -> None: key = self._awaiting.pop(awaitable) try: @@ -105,8 +101,19 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: if not self._awaiting: self._waiting_future.set_result(lang.NULL) + def enter(self) -> None: + for awaitable in self._awaiting: + awaitable.add_done_callback(self._awaitable_done) + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + for awaitable in self._awaiting: + awaitable.remove_done_callback(self._awaitable_done) + -class WorkChain(mixins.ContextMixin, processes.Process): +class WorkChain(processes.Process): """ A WorkChain is a series of instructions carried out with the ability to save state in between. @@ -114,10 +121,10 @@ class WorkChain(mixins.ContextMixin, processes.Process): _spec_class = WorkChainSpec _STEPPER_STATE = 'stepper_state' - _CONTEXT = 'CONTEXT' + CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map @@ -131,9 +138,14 @@ def __init__( coordinator: Optional[Coordinator] = None, ) -> None: super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, coordinator=coordinator) + self._context: Optional[AttributesDict] = AttributesDict() self._stepper: Optional[Stepper] = None self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {} + @property + def ctx(self) -> Optional[AttributesDict]: + return self._context + @classmethod def spec(cls) -> WorkChainSpec: return cast(WorkChainSpec, super().spec()) @@ -142,23 +154,118 @@ def on_create(self) -> None: super().on_create() self._stepper = self.spec().get_outline().create_stepper(self) - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + """ + Ask the process to save its current instance state. + + :param out_state: A bundle to save the state to + :param save_context: The save context + """ + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + if isinstance(self._state, persistence.Savable): + out_state['_state'] = self._state.save() + + # Inputs/outputs + if self.raw_inputs is not None: + out_state[processes.BundleKeys.INPUTS_RAW] = self.encode_input_args(self.raw_inputs) + + if self.inputs is not None: + out_state[processes.BundleKeys.INPUTS_PARSED] = self.encode_input_args(self.inputs) + + if self.outputs: + out_state[processes.BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) # Ask the stepper to save itself - if self._stepper is not None: + if self._stepper is not None and isinstance(self._stepper, Savable): out_state[self._STEPPER_STATE] = self._stepper.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + if self._context is not None: + out_state[self.CONTEXT] = self._context.__dict__ + + return out_state + + @classmethod + def recreate_from( + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, + ) -> WorkChain: + """Recreate a workchain from a saved state, passing any positional + + :param saved_state: The saved state to load from + :param load_context: The load context to use + :return: An instance of the object with its state loaded from the save state. + + """ + ### FIXME: dup from process.create_from + load_context = ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) + + proc._setup_event_hooks() + + # Runtime variables, set initial states + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._coordinator = None + + if 'loop' in load_context: + proc._loop = load_context.loop + else: + proc._loop = asyncio.get_event_loop() + + proc._state = proc.recreate_state(saved_state['_state']) + + if 'coordinator' in load_context: + proc._coordinator = load_context.coordinator + + if 'logger' in load_context: + proc._logger = load_context.logger + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.load_auto_persist_params(proc, saved_state, load_context) + + # Inputs/outputs + try: + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._raw_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._parsed_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) + proc._outputs = decoded + except KeyError: + proc._outputs = {} + ### UNTILHERE FIXME: dup from process.create_from + + # context mixin + try: + proc._context = AttributesDict(**saved_state[proc.CONTEXT]) + except KeyError: + pass + + # end of context mixin # Recreate the stepper - self._stepper = None - stepper_state = saved_state.get(self._STEPPER_STATE, None) + proc._stepper = None + stepper_state = saved_state.get(proc._STEPPER_STATE, None) if stepper_state is not None: - self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) + proc._stepper = proc.spec().get_outline().recreate_stepper(stepper_state, proc) + + call_with_super_check(proc.init) + return proc def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None: """ @@ -195,15 +302,8 @@ def _do_step(self) -> Any: return return_value -class Stepper(persistence.Savable, metaclass=abc.ABCMeta): - def __init__(self, workchain: 'WorkChain') -> None: - self._workchain = workchain - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._workchain = load_context.workchain - - @abc.abstractmethod +# XXX: Stepper is also a Saver with `save` method. +class Stepper(Protocol): def step(self) -> Tuple[bool, Any]: """ Execute on step of the instructions. @@ -212,6 +312,7 @@ def step(self) -> Tuple[bool, Any]: 1. The return value from the executed step """ + ... class _Instruction(metaclass=abc.ABCMeta): @@ -241,18 +342,37 @@ def get_description(self) -> Any: """ -class _FunctionStepper(Stepper): +@auto_persist() +class _FunctionStepper: def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): - super().__init__(workchain) + self._workchain = workchain self._fn = fn - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state['_fn'] = self._fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._fn = getattr(self._workchain.__class__, saved_state['_fn']) + return out_state + + @classmethod + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = persistence.auto_load(cls, saved_state, load_context) + obj._workchain = load_context.workchain + obj._fn = getattr(obj._workchain.__class__, saved_state['_fn']) + + return obj def step(self) -> Tuple[bool, Any]: return True, self._fn(self._workchain) @@ -292,9 +412,9 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') -class _BlockStepper(Stepper): +class _BlockStepper: def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._block = block self._pos: int = 0 self._child_stepper: Optional[Stepper] = self._block[0].create_stepper(self._workchain) @@ -319,18 +439,34 @@ def next_instruction(self) -> None: def finished(self) -> bool: return self._pos == len(self._block) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) - if self._child_stepper is not None: + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._block = load_context.block_instruction + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = persistence.auto_load(cls, saved_state, load_context) + obj._workchain = load_context.workchain + obj._block = load_context.block_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._block[self._pos].recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._block[obj._pos].recreate_stepper(stepper_state, obj._workchain) + + return obj def __str__(self) -> str: return str(self._pos) + ':' + str(self._child_stepper) @@ -423,9 +559,9 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') -class _IfStepper(Stepper): +class _IfStepper: def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._if_instruction = if_instruction self._pos = 0 self._child_stepper: Optional[Stepper] = None @@ -457,18 +593,33 @@ def step(self) -> Tuple[bool, Any]: def finished(self) -> bool: return self._pos == len(self._if_instruction) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) - if self._child_stepper is not None: + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._if_instruction = load_context.if_instruction + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = persistence.auto_load(cls, saved_state, load_context) + obj._workchain = load_context.workchain + obj._if_instruction = load_context.if_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._if_instruction[self._pos].body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._if_instruction[obj._pos].body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._if_instruction[self._pos]) @@ -530,9 +681,9 @@ def get_description(self) -> Mapping[str, Any]: return description -class _WhileStepper(Stepper): +class _WhileStepper: def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._while_instruction = while_instruction self._child_stepper: Optional[_BlockStepper] = None @@ -551,18 +702,36 @@ def step(self) -> Tuple[bool, Any]: return False, result - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) + if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._while_instruction = load_context.while_instruction + return out_state + + @classmethod + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = persistence.auto_load(cls, saved_state, load_context) + obj._workchain = load_context.workchain + obj._while_instruction = load_context.while_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._while_instruction.body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._while_instruction.body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._while_instruction) @@ -600,9 +769,10 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code -class _ReturnStepper(Stepper): +@persistence.auto_persist() +class _ReturnStepper: def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._return_instruction = return_instruction def step(self) -> Tuple[bool, Any]: diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index ddcbb8d9..44a084d4 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- import time +from typing import final import unittest from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError # Events PLAY = 'Play' @@ -15,31 +17,25 @@ STOPPED = 'Stopped' -class Playing(state_machine.State): +class Playing: LABEL = PLAYING ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, track): assert track is not None, 'Must provide a track name' - super().__init__(player) self.track = track self._last_time = None self._played = 0.0 + self.in_state = False def __str__(self): if self.in_state: self._update_time() return f'> {self.track} ({self._played}s)' - def enter(self): - super().enter() - self._last_time = time.time() - - def exit(self): - super().exit() - self._update_time() - def play(self, track=None): return False @@ -48,15 +44,27 @@ def _update_time(self): self._played += current_time - self._last_time self._last_time = current_time + def enter(self) -> None: + self._last_time = time.time() + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self._update_time() + self.in_state = False + -class Paused(state_machine.State): +class Paused: LABEL = PAUSED ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' - super().__init__(player) self._player = player self.playing_state = playing_state @@ -65,23 +73,46 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) + self._player.transition_to(Playing(player=self.state_machine, track=track)) else: - self.state_machine.transition_to(self.playing_state) + self._player.transition_to(self.playing_state) + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False -class Stopped(state_machine.State): + +class Stopped: LABEL = STOPPED ALLOWED = { PLAYING, } TRANSITIONS = {PLAY: PLAYING} + is_terminal = False + + def __init__(self, player): + self._player = player + def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing(self.state_machine, track=track)) + self._player.transition_to(Playing(self._player, track=track)) + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False class CdPlayer(state_machine.StateMachine): @@ -119,22 +150,22 @@ def stop(self): class TestStateMachine(unittest.TestCase): def test_basic(self): cd_player = CdPlayer() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) cd_player.play('Eminem - The Real Slim Shady') - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) time.sleep(1.0) cd_player.pause() - self.assertEqual(cd_player.state, PAUSED) + self.assertEqual(cd_player.state_label, PAUSED) cd_player.play() - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) self.assertEqual(cd_player.play(), False) cd_player.stop() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) def test_invalid_event(self): cd_player = CdPlayer() diff --git a/tests/rmq/test_process_control.py b/tests/rmq/test_process_control.py index 79a98ba3..7c3b431c 100644 --- a/tests/rmq/test_process_control.py +++ b/tests/rmq/test_process_control.py @@ -68,7 +68,7 @@ async def test_play(self, _coordinator, async_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.WAITING + assert proc.state_label == plumpy.ProcessState.WAITING # if not close the background process will raise exception # make sure proc reach the final state @@ -85,7 +85,7 @@ async def test_kill(self, _coordinator, async_controller): # Check the outcome assert result - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_status(self, _coordinator, async_controller): @@ -173,7 +173,7 @@ async def test_play(self, _coordinator, sync_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.CREATED + assert proc.state_label == plumpy.ProcessState.CREATED @pytest.mark.asyncio async def test_kill(self, _coordinator, sync_controller): @@ -187,7 +187,7 @@ async def test_kill(self, _coordinator, sync_controller): # Check the outcome assert result # Occasionally fail - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_kill_all(self, _coordinator, sync_controller): @@ -198,7 +198,7 @@ async def test_kill_all(self, _coordinator, sync_controller): sync_controller.kill_all(msg_text='bang bang, I shot you down') await utils.wait_util(lambda: all([proc.killed() for proc in procs])) - assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) + assert all([proc.state_label == plumpy.ProcessState.KILLED for proc in procs]) @pytest.mark.asyncio async def test_status(self, _coordinator, sync_controller): diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 78724aa0..7f616433 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,16 +5,38 @@ import yaml import plumpy +from plumpy.persistence import auto_load, auto_persist, auto_save, ensure_object_loader +from plumpy.utils import SAVED_STATE_TYPE from . import utils -class SaveEmpty(plumpy.Savable): - pass +@auto_persist() +class SaveEmpty: + + @classmethod + def recreate_from(cls, saved_state, load_context=None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @plumpy.auto_persist('test', 'test_method') -class Save1(plumpy.Savable): +class Save1: def __init__(self): self.test = 'sup yp' self.test_method = self.m @@ -22,12 +44,52 @@ def __init__(self): def m(): pass + @classmethod + def recreate_from(cls, saved_state, load_context=None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test') -class Save(plumpy.Savable): +class Save: def __init__(self): self.test = Save1() + @classmethod + def recreate_from(cls, saved_state, load_context=None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + class TestSavable(unittest.TestCase): def test_empty_savable(self): diff --git a/tests/test_processes.py b/tests/test_processes.py index 6a4376f9..b4bdc333 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -14,6 +14,11 @@ from plumpy.utils import AttributesFrozendict from . import utils +# FIXME: after deabstract on savable into a protocol, test that all state are savable +# FIXME: also that any process is savable +# FIXME: any process listener is savable +# FIXME: any process control commands are savable + class ForgetToCallParent(plumpy.Process): def __init__(self, forget_on): @@ -235,7 +240,7 @@ def test_execute(self): proc.execute() self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertEqual(proc.outputs, {'default': 5}) def test_run_from_class(self): @@ -273,7 +278,7 @@ def test_exception(self): proc = utils.ExceptionProcess() with self.assertRaises(RuntimeError): proc.execute() - self.assertEqual(proc.state, ProcessState.EXCEPTED) + self.assertEqual(proc.state_label, ProcessState.EXCEPTED) def test_run_kill(self): proc = utils.KillProcess() @@ -326,7 +331,7 @@ def test_kill(self): proc.kill(msg_text=msg_text) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg()[MESSAGE_TEXT_KEY], msg_text) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_wait_continue(self): proc = utils.WaitForSignalProcess() @@ -340,7 +345,7 @@ def test_wait_continue(self): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_exc_info(self): proc = utils.ExceptionProcess() @@ -364,7 +369,7 @@ def test_wait_pause_play_resume(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause() self.assertTrue(result) @@ -380,7 +385,7 @@ async def async_test(): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) @@ -401,7 +406,7 @@ def test_pause_play_status_messaging(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause(PAUSE_STATUS) self.assertTrue(result) @@ -421,7 +426,7 @@ async def async_test(): loop.run_until_complete(async_test()) self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_kill_in_run(self): class KillProcess(Process): @@ -439,7 +444,7 @@ def run(self, **kwargs): proc.execute() self.assertTrue(proc.after_kill) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused_in_run(self): class PauseProcess(Process): @@ -451,7 +456,7 @@ def run(self, **kwargs): with self.assertRaises(plumpy.KilledError): proc.execute() - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused(self): loop = asyncio.get_event_loop() @@ -475,7 +480,7 @@ async def async_test(): loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_run_multiple(self): # Create and play some processes @@ -551,7 +556,7 @@ def run(self): loop.run_forever() self.assertTrue(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_pause_play_in_process(self): """Test that we can pause and play that by playing within the process""" @@ -569,7 +574,7 @@ def run(self): proc.execute() self.assertFalse(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_process_stack(self): test_case = self @@ -653,7 +658,7 @@ def test_exception_during_on_entered(self): class RaisingProcess(Process): def on_entered(self, from_state): - if from_state is not None and from_state.label == ProcessState.RUNNING: + if from_state is not None and from_state.LABEL == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') super().on_entered(from_state) @@ -700,7 +705,7 @@ def step2(self): class TestProcessSaving(unittest.TestCase): maxDiff = None - def test_running_save_instance_state(self): + def test_running_save(self): loop = asyncio.get_event_loop() nsync_comeback = SavePauseProc() @@ -780,7 +785,7 @@ def test_saving_each_step(self): proc = proc_class() saver = utils.ProcessSaver(proc) saver.capture() - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertTrue(utils.check_process_against_snapshots(loop, proc_class, saver.snapshots)) def test_restart(self): @@ -795,7 +800,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it loaded_proc.resume() @@ -818,7 +823,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it twice in succession loaded_proc.resume() @@ -860,7 +865,7 @@ async def async_test(): def test_killed(self): proc = utils.DummyProcess() proc.kill() - self.assertEqual(proc.state, plumpy.ProcessState.KILLED) + self.assertEqual(proc.state_label, plumpy.ProcessState.KILLED) self._check_round_trip(proc) def _check_round_trip(self, proc1): @@ -983,40 +988,40 @@ def run(self): self.out(namespace_nested + '.two', 2) # Run the process in default mode which should not add any outputs and therefore fail - process = DummyDynamicProcess() - process.execute() + proc = DummyDynamicProcess() + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertDictEqual(process.outputs, {}) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertDictEqual(proc.outputs, {}) # Attaching only namespaced ports should fail, because the required port is not added - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) # Attaching only the single required top-level port should be fine - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) # Attaching both the required and namespaced ports should result in a successful termination - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) - process.execute() - - self.assertIsNotNone(process.outputs) - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) + proc.execute() + + self.assertIsNotNone(proc.outputs) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) class TestProcessEvents(unittest.TestCase): diff --git a/tests/test_workchains.py b/tests/test_workchains.py index 08c7317a..4e34d2b4 100644 --- a/tests/test_workchains.py +++ b/tests/test_workchains.py @@ -11,6 +11,8 @@ from . import utils +# FIXME: after deabstract on savable into a protocol, test that all stepper are savable +# FIXME: workchani itself is savable class Wf(WorkChain): # Keep track of which steps were completed by the workflow diff --git a/tests/utils.py b/tests/utils.py index 0adbbbeb..74fc8211 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -596,7 +596,7 @@ def run_until_waiting(proc): listener = plumpy.ProcessListener() in_waiting = asyncio.Future() - if proc.state == ProcessState.WAITING: + if proc.state_label == ProcessState.WAITING: in_waiting.set_result(True) else: diff --git a/uv.lock b/uv.lock index 6a9582ba..0df97a28 100644 --- a/uv.lock +++ b/uv.lock @@ -1542,6 +1542,7 @@ dependencies = [ { name = "kiwipy", extra = ["rmq"] }, { name = "nest-asyncio" }, { name = "pyyaml" }, + { name = "typing-extensions" }, ] [package.optional-dependencies] @@ -1593,6 +1594,7 @@ requires-dist = [ { name = "sphinx", marker = "extra == 'docs'", specifier = "~=3.2.0" }, { name = "sphinx-book-theme", marker = "extra == 'docs'", specifier = "~=0.0.39" }, { name = "types-pyyaml", marker = "extra == 'pre-commit'" }, + { name = "typing-extensions", specifier = "~=4.12" }, ] [[package]]