diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index e20dae3f..f4472859 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -5,7 +5,7 @@ from plumpy.utils import SAVED_STATE_TYPE from . import persistence -from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load +from plumpy.persistence import Savable, LoadSaveContext, ensure_object_loader, auto_load, auto_save if TYPE_CHECKING: from typing import Set, Type @@ -43,11 +43,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 2367c759..faa084d9 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -9,12 +9,11 @@ import os import pickle from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union import yaml from . import futures, loaders, utils -from .base.utils import call_with_super_check, super_check from .utils import PID_TYPE, SAVED_STATE_TYPE __all__ = [ @@ -100,10 +99,10 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N :return: The loaded Savable instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) assert load_context.loader is not None # required for type checking try: - class_name = Savable._get_class_name(saved_state) + class_name = SaveUtil.get_class_name(saved_state) load_cls: Savable = load_context.loader.load_object(class_name) except KeyError: raise ValueError('Class name not found in saved state') @@ -396,18 +395,21 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: - if savable._auto_persist is None: - savable._auto_persist = set() + def wrapped(savable_cls: SavableClsType) -> SavableClsType: + if savable_cls._auto_persist is None: + savable_cls._auto_persist = set() else: - savable._auto_persist = set(savable._auto_persist) - savable.auto_persist(*members) - return savable + savable_cls._auto_persist = set(savable_cls._auto_persist) + + if savable_cls._auto_persist is None: + savable_cls._auto_persist = set() + savable_cls._auto_persist.update(members) + return savable_cls return wrapped -def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': +def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': """ Given a LoadSaveContext this method will ensure that it has a valid class loader using the following priorities: @@ -429,7 +431,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV # 2) Try getting from saved_state default_loader = loaders.get_object_loader() try: - loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER) + loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER) except ValueError: # 3) Fall back to default loader = default_loader @@ -448,45 +450,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV META__TYPE__SAVABLE: str = 'S' -class Savable: - CLASS_NAME: str = 'class_name' - - _auto_persist: Optional[Set[str]] = None - _persist_configured = False - - @classmethod - def auto_persist(cls, *members: str) -> None: - if cls._auto_persist is None: - cls._auto_persist = set() - cls._auto_persist.update(members) - - @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Recreate a :class:`Savable` from a saved state using an optional load context. - - :param saved_state: The saved state - :param load_context: An optional load context - - :return: The recreated instance - - """ - ... - - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - - return out_state - - def _ensure_persist_configured(self) -> None: - if not self._persist_configured: - self._persist_configured = True - - # region Metadata getter/setters - +class SaveUtil: @staticmethod def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: - user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {}) + user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {}) user_dict[name] = value @staticmethod @@ -497,43 +464,47 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: raise ValueError(f"Unknown meta key '{name}'") @staticmethod - def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: + def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: return out_state.setdefault(META, {}) @staticmethod - def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: - Savable._get_create_meta(out_state)[META__CLASS_NAME] = name + def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: + SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name @staticmethod - def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str: - return Savable._get_create_meta(saved_state)[META__CLASS_NAME] + def get_class_name(saved_state: SAVED_STATE_TYPE) -> str: + return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME] @staticmethod - def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: - type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {}) + def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: + type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {}) type_dict[name] = type_spec @staticmethod - def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: + def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: try: return saved_state[META][META__TYPES][name] except KeyError: pass - # endregion - def _get_value( - self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] - ) -> Union[MethodType, 'Savable']: - value = saved_state[name] +class Savable: + _auto_persist: ClassVar[set[str] | None] = None - typ = Savable._get_meta_type(saved_state, name) - if typ == META__TYPE__METHOD: - value = getattr(self, value) - elif typ == META__TYPE__SAVABLE: - value = load(value, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. - return value + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + ... + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: ... @auto_persist('_state', '_result') @@ -562,7 +533,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) try: loop = load_context.loop @@ -612,24 +583,23 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S # If the user has specified a class loader, then save it in the saved state if save_context.loader is not None: loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) loader = save_context.loader else: loader = default_loader - Savable._set_class_name(out_state, loader.identify_object(obj.__class__)) + SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__)) - obj._ensure_persist_configured() if obj._auto_persist is not None: for member in obj._auto_persist: value = getattr(obj, member) if inspect.ismethod(value): if value.__self__ is not obj: raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD) value = value.__name__ elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) value = value.save() else: value = copy.deepcopy(value) @@ -639,7 +609,20 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - obj._ensure_persist_configured() if obj._auto_persist is not None: for member in obj._auto_persist: - setattr(obj, member, obj._get_value(saved_state, member, load_context)) + setattr(obj, member, _get_value(obj, saved_state, member, load_context)) + + +def _get_value( + obj, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] +) -> Union[MethodType, 'Savable']: + value = saved_state[name] + + typ = SaveUtil.get_meta_type(saved_state, name) + if typ == META__TYPE__METHOD: + value = getattr(obj, value) + elif typ == META__TYPE__SAVABLE: + value = load(value, load_context) + + return value diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 166a811a..b5bac21d 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -4,7 +4,7 @@ from . import persistence from .utils import SAVED_STATE_TYPE -from plumpy.persistence import LoadSaveContext, _ensure_object_loader +from plumpy.persistence import LoadSaveContext, ensure_object_loader __all__ = ['ProcessListener'] @@ -34,7 +34,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) obj.init(**saved_state['_params']) return obj diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 1d7f2350..c5566384 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy -import inspect import sys import traceback from enum import Enum @@ -13,21 +11,19 @@ Callable, ClassVar, Optional, - Protocol, Tuple, Type, Union, cast, final, - runtime_checkable, + override, ) import yaml from yaml.loader import Loader -from plumpy import loaders from plumpy.process_comms import KillMessage, MessageType -from plumpy.persistence import _ensure_object_loader +from plumpy.persistence import ensure_object_loader try: import tblib @@ -40,9 +36,6 @@ from .base import state_machine as st from .lang import NULL from .persistence import ( - META__OBJECT_LOADER, - META__TYPE__METHOD, - META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_load, @@ -94,7 +87,25 @@ class PauseInterruption(Interruption): class Command(persistence.Savable): - pass + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @auto_persist('msg') @@ -140,12 +151,14 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs + @override def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ return out_state + @override @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -157,7 +170,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -226,7 +239,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -297,7 +310,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -435,7 +448,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -540,7 +553,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -600,11 +613,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... @@ -644,11 +662,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 96689024..4467a457 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -33,7 +33,7 @@ cast, ) -from plumpy.persistence import _ensure_object_loader +from plumpy.persistence import ensure_object_loader try: from aiocontextvars import ContextVar @@ -274,7 +274,7 @@ def recreate_from( :return: An instance of the object with its state loaded from the save state. """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index cf7ad81f..0993b0d9 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -35,7 +35,7 @@ from . import lang, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict from plumpy import loaders, utils -from plumpy.persistence import _ensure_object_loader +from plumpy.persistence import ensure_object_loader __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] @@ -176,24 +176,23 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA # If the user has specified a class loader, then save it in the saved state if save_context.loader is not None: loader_class = default_loader.identify_object(save_context.loader.__class__) - persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) + persistence.SaveUtil.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) loader = save_context.loader else: loader = default_loader - persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__)) + persistence.SaveUtil.set_class_name(out_state, loader.identify_object(self.__class__)) - self._ensure_persist_configured() if self._auto_persist is not None: for member in self._auto_persist: value = getattr(self, member) if inspect.ismethod(value): if value.__self__ is not self: raise TypeError('Cannot persist methods of other classes') - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) + persistence.SaveUtil.set_meta_type(out_state, member, persistence.META__TYPE__METHOD) value = value.__name__ elif isinstance(value, persistence.Savable): - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) + persistence.SaveUtil.set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) value = value.save() else: value = copy.deepcopy(value) @@ -235,7 +234,7 @@ def recreate_from( """ ### FIXME: dup from process.create_from - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state @@ -400,7 +399,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -491,7 +490,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -646,7 +645,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -757,7 +756,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 65ef3226..b4100391 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,7 +5,8 @@ import yaml import plumpy -from plumpy.persistence import auto_load +from plumpy.persistence import auto_load, auto_save +from plumpy.utils import SAVED_STATE_TYPE from . import utils @@ -14,7 +15,7 @@ class SaveEmpty(plumpy.Savable): pass @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -28,6 +29,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test', 'test_method') class Save1(plumpy.Savable): @@ -39,7 +45,7 @@ def m(): pass @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -53,6 +59,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test') class Save(plumpy.Savable): @@ -60,7 +71,7 @@ def __init__(self): self.test = Save1() @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -74,6 +85,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + class TestSavable(unittest.TestCase): def test_empty_savable(self):