diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index e20dae3f..b0f440ab 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -2,13 +2,14 @@ import logging from typing import TYPE_CHECKING, Any, Callable, Optional +from plumpy.persistence import LoadSaveContext, Savable, auto_load, auto_save, ensure_object_loader from plumpy.utils import SAVED_STATE_TYPE from . import persistence -from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load if TYPE_CHECKING: from typing import Set, Type + from .process_listener import ProcessListener _LOGGER = logging.getLogger(__name__) @@ -43,11 +44,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 2367c759..a53948ad 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -9,12 +9,23 @@ import os import pickle from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + Iterable, + List, + Optional, + TypeVar, + Union, +) import yaml from . import futures, loaders, utils -from .base.utils import call_with_super_check, super_check from .utils import PID_TYPE, SAVED_STATE_TYPE __all__ = [ @@ -100,10 +111,10 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N :return: The loaded Savable instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) assert load_context.loader is not None # required for type checking try: - class_name = Savable._get_class_name(saved_state) + class_name = SaveUtil.get_class_name(saved_state) load_cls: Savable = load_context.loader.load_object(class_name) except KeyError: raise ValueError('Class name not found in saved state') @@ -396,18 +407,19 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: - if savable._auto_persist is None: - savable._auto_persist = set() + def wrapped(savable_cls: SavableClsType) -> SavableClsType: + if not hasattr(savable_cls, '_auto_persist') or savable_cls._auto_persist is None: + savable_cls._auto_persist = set() else: - savable._auto_persist = set(savable._auto_persist) - savable.auto_persist(*members) - return savable + savable_cls._auto_persist = set(savable_cls._auto_persist) + + savable_cls._auto_persist.update(members) + return savable_cls return wrapped -def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': +def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': """ Given a LoadSaveContext this method will ensure that it has a valid class loader using the following priorities: @@ -429,7 +441,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV # 2) Try getting from saved_state default_loader = loaders.get_object_loader() try: - loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER) + loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER) except ValueError: # 3) Fall back to default loader = default_loader @@ -448,45 +460,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV META__TYPE__SAVABLE: str = 'S' -class Savable: - CLASS_NAME: str = 'class_name' - - _auto_persist: Optional[Set[str]] = None - _persist_configured = False - - @classmethod - def auto_persist(cls, *members: str) -> None: - if cls._auto_persist is None: - cls._auto_persist = set() - cls._auto_persist.update(members) - - @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Recreate a :class:`Savable` from a saved state using an optional load context. - - :param saved_state: The saved state - :param load_context: An optional load context - - :return: The recreated instance - - """ - ... - - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - - return out_state - - def _ensure_persist_configured(self) -> None: - if not self._persist_configured: - self._persist_configured = True - - # region Metadata getter/setters - +class SaveUtil: @staticmethod def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: - user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {}) + user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {}) user_dict[name] = value @staticmethod @@ -497,43 +474,47 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: raise ValueError(f"Unknown meta key '{name}'") @staticmethod - def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: + def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: return out_state.setdefault(META, {}) @staticmethod - def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: - Savable._get_create_meta(out_state)[META__CLASS_NAME] = name + def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: + SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name @staticmethod - def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str: - return Savable._get_create_meta(saved_state)[META__CLASS_NAME] + def get_class_name(saved_state: SAVED_STATE_TYPE) -> str: + return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME] @staticmethod - def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: - type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {}) + def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: + type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {}) type_dict[name] = type_spec @staticmethod - def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: + def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: try: return saved_state[META][META__TYPES][name] except KeyError: pass - # endregion - def _get_value( - self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] - ) -> Union[MethodType, 'Savable']: - value = saved_state[name] +class Savable: + _auto_persist: ClassVar[set[str] | None] = None - typ = Savable._get_meta_type(saved_state, name) - if typ == META__TYPE__METHOD: - value = getattr(self, value) - elif typ == META__TYPE__SAVABLE: - value = load(value, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. - return value + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + ... + + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... @auto_persist('_state', '_result') @@ -562,7 +543,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) try: loop = load_context.loop @@ -612,24 +593,23 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S # If the user has specified a class loader, then save it in the saved state if save_context.loader is not None: loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) loader = save_context.loader else: loader = default_loader - Savable._set_class_name(out_state, loader.identify_object(obj.__class__)) + SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__)) - obj._ensure_persist_configured() if obj._auto_persist is not None: for member in obj._auto_persist: value = getattr(obj, member) if inspect.ismethod(value): if value.__self__ is not obj: raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD) value = value.__name__ elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) value = value.save() else: value = copy.deepcopy(value) @@ -639,7 +619,20 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - obj._ensure_persist_configured() if obj._auto_persist is not None: for member in obj._auto_persist: - setattr(obj, member, obj._get_value(saved_state, member, load_context)) + setattr(obj, member, _get_value(obj, saved_state, member, load_context)) + + +def _get_value( + obj, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] +) -> Union[MethodType, 'Savable']: + value = saved_state[name] + + typ = SaveUtil.get_meta_type(saved_state, name) + if typ == META__TYPE__METHOD: + value = getattr(obj, value) + elif typ == META__TYPE__SAVABLE: + value = load(value, load_context) + + return value diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 166a811a..d72dd16d 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -2,9 +2,10 @@ import abc from typing import TYPE_CHECKING, Any, Dict, Optional +from plumpy.persistence import LoadSaveContext, ensure_object_loader + from . import persistence from .utils import SAVED_STATE_TYPE -from plumpy.persistence import LoadSaveContext, _ensure_object_loader __all__ = ['ProcessListener'] @@ -34,7 +35,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) obj.init(**saved_state['_params']) return obj diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 1d7f2350..7f88d392 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy -import inspect import sys import traceback from enum import Enum @@ -13,21 +11,19 @@ Callable, ClassVar, Optional, - Protocol, Tuple, Type, Union, cast, final, - runtime_checkable, + override, ) import yaml from yaml.loader import Loader -from plumpy import loaders +from plumpy.persistence import ensure_object_loader from plumpy.process_comms import KillMessage, MessageType -from plumpy.persistence import _ensure_object_loader try: import tblib @@ -40,9 +36,6 @@ from .base import state_machine as st from .lang import NULL from .persistence import ( - META__OBJECT_LOADER, - META__TYPE__METHOD, - META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_load, @@ -94,7 +87,25 @@ class PauseInterruption(Interruption): class Command(persistence.Savable): - pass + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @auto_persist('msg') @@ -140,12 +151,14 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs + @override def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ return out_state + @override @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -157,7 +170,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -188,11 +201,6 @@ class ProcessState(Enum): KILLED = 'killed' -# @runtime_checkable -# class Savable(Protocol): -# def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... - - @final @auto_persist('args', 'kwargs') class Created(persistence.Savable): @@ -226,7 +234,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -297,7 +305,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -435,7 +443,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -488,6 +496,7 @@ def exit(self) -> None: ... @final +@auto_persist() class Excepted(persistence.Savable): """ Excepted state, can optionally provide exception and traceback @@ -540,7 +549,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -600,11 +609,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... @@ -644,11 +658,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 96689024..201f62b2 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -33,7 +33,7 @@ cast, ) -from plumpy.persistence import _ensure_object_loader +from plumpy.persistence import ensure_object_loader try: from aiocontextvars import ContextVar @@ -274,7 +274,7 @@ def recreate_from( :return: An instance of the object with its state loaded from the save state. """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state @@ -681,8 +681,7 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA """ out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - # FIXME: the combined ProcessState protocol should cover the case - if isinstance(self._state, process_states.Savable): + if isinstance(self._state, persistence.Savable): out_state['_state'] = self._state.save() # Inputs/outputs diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index cf7ad81f..28a0b0f0 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy import abc import asyncio import collections +import copy import inspect import logging import re @@ -26,16 +26,16 @@ import kiwipy +from plumpy import loaders, utils from plumpy.base import state_machine from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError +from plumpy.persistence import ensure_object_loader from plumpy.process_listener import ProcessListener from . import lang, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict -from plumpy import loaders, utils -from plumpy.persistence import _ensure_object_loader __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] @@ -176,30 +176,29 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA # If the user has specified a class loader, then save it in the saved state if save_context.loader is not None: loader_class = default_loader.identify_object(save_context.loader.__class__) - persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) + persistence.SaveUtil.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) loader = save_context.loader else: loader = default_loader - persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__)) + persistence.SaveUtil.set_class_name(out_state, loader.identify_object(self.__class__)) - self._ensure_persist_configured() if self._auto_persist is not None: for member in self._auto_persist: value = getattr(self, member) if inspect.ismethod(value): if value.__self__ is not self: raise TypeError('Cannot persist methods of other classes') - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) + persistence.SaveUtil.set_meta_type(out_state, member, persistence.META__TYPE__METHOD) value = value.__name__ elif isinstance(value, persistence.Savable): - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) + persistence.SaveUtil.set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) value = value.save() else: value = copy.deepcopy(value) out_state[member] = value - if isinstance(self._state, process_states.Savable): + if isinstance(self._state, persistence.Savable): out_state['_state'] = self._state.save() # Inputs/outputs @@ -235,7 +234,7 @@ def recreate_from( """ ### FIXME: dup from process.create_from - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state @@ -378,6 +377,7 @@ def get_description(self) -> Any: """ +@persistence.auto_persist() class _FunctionStepper(persistence.Savable): def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): self._workchain = workchain @@ -400,7 +400,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -491,7 +491,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -646,7 +646,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -717,6 +717,7 @@ def get_description(self) -> Mapping[str, Any]: return description +@persistence.auto_persist() class _WhileStepper(persistence.Savable): def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: self._workchain = workchain @@ -757,7 +758,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -804,6 +805,7 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code +@persistence.auto_persist() class _ReturnStepper(persistence.Savable): def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: self._workchain = workchain diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 65ef3226..0ca45cc1 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,16 +5,18 @@ import yaml import plumpy -from plumpy.persistence import auto_load +from plumpy.persistence import auto_load, auto_persist, auto_save +from plumpy.utils import SAVED_STATE_TYPE from . import utils +@auto_persist() class SaveEmpty(plumpy.Savable): pass @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -28,6 +30,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test', 'test_method') class Save1(plumpy.Savable): @@ -39,7 +46,7 @@ def m(): pass @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -53,6 +60,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test') class Save(plumpy.Savable): @@ -60,7 +72,7 @@ def __init__(self): self.test = Save1() @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -74,6 +86,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + class TestSavable(unittest.TestCase): def test_empty_savable(self):