diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 47ad4956..e20dae3f 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- import logging -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional + +from plumpy.utils import SAVED_STATE_TYPE from . import persistence +from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load if TYPE_CHECKING: from typing import Set, Type - from .process_listener import ProcessListener _LOGGER = logging.getLogger(__name__) @@ -30,6 +32,22 @@ def remove_listener(self, listener: 'ProcessListener') -> None: def remove_all_listeners(self) -> None: self._listeners.clear() + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Savable: + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 78724aa0..65ef3226 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,6 +5,7 @@ import yaml import plumpy +from plumpy.persistence import auto_load from . import utils @@ -12,6 +13,21 @@ class SaveEmpty(plumpy.Savable): pass + @classmethod + def recreate_from(cls, saved_state, load_context= None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + @plumpy.auto_persist('test', 'test_method') class Save1(plumpy.Savable): @@ -22,12 +38,42 @@ def __init__(self): def m(): pass + @classmethod + def recreate_from(cls, saved_state, load_context= None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + @plumpy.auto_persist('test') class Save(plumpy.Savable): def __init__(self): self.test = Save1() + @classmethod + def recreate_from(cls, saved_state, load_context= None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + class TestSavable(unittest.TestCase): def test_empty_savable(self):