diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index ba755bc5..496290fb 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -430,29 +430,6 @@ class Savable: CLASS_NAME: str = 'class_name' _auto_persist: Optional[Set[str]] = None - _persist_configured = False - - @staticmethod - def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Load a `Savable` from a saved instance state. The load context is a way of passing - runtime data to the object being loaded. - - :param saved_state: The saved state - :param load_context: Additional runtime state that can be passed into when loading. - The type and content (if any) is completely user defined - :return: The loaded Savable instance - - """ - load_context = _ensure_object_loader(load_context, saved_state) - assert load_context.loader is not None # required for type checking - try: - class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) - except KeyError: - raise ValueError('Class name not found in saved state') - else: - return load_cls.recreate_from(saved_state, load_context) @classmethod def auto_persist(cls, *members: str) -> None: @@ -460,10 +437,6 @@ def auto_persist(cls, *members: str) -> None: cls._auto_persist = set() cls._auto_persist.update(members) - @classmethod - def persist(cls) -> None: - pass - @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -482,17 +455,63 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa @super_check def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - self.load_members(self._auto_persist, saved_state, load_context) + if self._auto_persist is None: + return None + + for member in self._auto_persist: + value = saved_state[member] + + typ = Savable._get_meta_type(saved_state, member) + if typ == META__TYPE__METHOD: + value = getattr(self, value) + elif typ == META__TYPE__SAVABLE: + value = Savable.load(value, load_context) + + setattr(self, member, value) + + @staticmethod + def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Load a `Savable` from a saved instance state. The load context is a way of passing + runtime data to the object being loaded. + + :param saved_state: The saved state + :param load_context: Additional runtime state that can be passed into when loading. + The type and content (if any) is completely user defined + :return: The loaded Savable instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + assert load_context.loader is not None # required for type checking + try: + class_name = Savable._get_class_name(saved_state) + load_cls = load_context.loader.load_object(class_name) + except KeyError: + raise ValueError('Class name not found in saved state') + else: + return load_cls.recreate_from(saved_state, load_context) @super_check def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - self.save_members(self._auto_persist, out_state) + if self._auto_persist is None: + return None + + for member in self._auto_persist: + value = getattr(self, member) + if inspect.ismethod(value): + if value.__self__ is not self: + raise TypeError('Cannot persist methods of other classes') + Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, Savable): + Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + value = value.save(save_context) + else: + value = copy.deepcopy(value) + out_state[member] = value def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + """Recursively call ``save`` on the members.""" out_state: SAVED_STATE_TYPE = {} if save_context is None: @@ -513,32 +532,6 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY call_with_super_check(self.save_instance_state, out_state, save_context) return out_state - def save_members(self, members: Iterable[str], out_state: SAVED_STATE_TYPE) -> None: - for member in members: - value = getattr(self, member) - if inspect.ismethod(value): - if value.__self__ is not self: - raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - - def load_members( - self, members: Iterable[str], saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None - ) -> None: - for member in members: - setattr(self, member, self._get_value(saved_state, member, load_context)) - - def _ensure_persist_configured(self) -> None: - if not self._persist_configured: - self.persist() - self._persist_configured = True - # region Metadata getter/setters @staticmethod @@ -577,21 +570,6 @@ def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: except KeyError: pass - # endregion - - def _get_value( - self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] - ) -> Union[MethodType, 'Savable']: - value = saved_state[name] - - typ = Savable._get_meta_type(saved_state, name) - if typ == META__TYPE__METHOD: - value = getattr(self, value) - elif typ == META__TYPE__SAVABLE: - value = Savable.load(value, load_context) - - return value - @auto_persist('_state', '_result') class SavableFuture(futures.Future, Savable): diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 7ae6e9bd..8e4390e4 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from __future__ import annotations import sys import traceback from enum import Enum @@ -18,7 +19,7 @@ from . import exceptions, futures, persistence, utils from .base import state_machine from .lang import NULL -from .persistence import auto_persist +from .persistence import Savable, auto_persist from .utils import SAVED_STATE_TYPE __all__ = [ @@ -264,8 +265,8 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: return cast(State, state) # casting from base.State to process.State -@auto_persist('msg', 'data') -class Waiting(State): +class Waiting(state_machine.State, persistence.Savable): +# class Waiting(state_machine.State): LABEL = ProcessState.WAITING ALLOWED = { ProcessState.RUNNING, @@ -278,6 +279,7 @@ class Waiting(State): DONE_CALLBACK = 'DONE_CALLBACK' _interruption = None + _auto_persist = {'msg', 'data', 'in_state'} def __str__(self) -> str: state_info = super().__str__() @@ -288,9 +290,10 @@ def __str__(self) -> str: def __init__( self, process: 'Process', - done_callback: Optional[Callable[..., Any]], - msg: Optional[str] = None, - data: Optional[Any] = None, + done_callback: Callable[..., Any] | None, + msg: str | None = None, + data: Any | None = None, + saver: Savable | None = None, ) -> None: super().__init__(process) self.done_callback = done_callback @@ -298,6 +301,13 @@ def __init__( self.data = data self._waiting_future: futures.Future = futures.Future() + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) if self.done_callback is not None: @@ -305,6 +315,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + + self.state_machine = load_context.process callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: self.done_callback = getattr(self.process, callback_name)