Skip to content

Commit

Permalink
Explicity recreate_from implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 9, 2024
1 parent 61c7fb8 commit 55bc734
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
22 changes: 20 additions & 2 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# -*- coding: utf-8 -*-
import logging
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional

from plumpy.utils import SAVED_STATE_TYPE

from . import persistence
from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load

if TYPE_CHECKING:
from typing import Set, Type

from .process_listener import ProcessListener

_LOGGER = logging.getLogger(__name__)
Expand All @@ -30,6 +32,22 @@ def remove_listener(self, listener: 'ProcessListener') -> None:
def remove_all_listeners(self) -> None:
self._listeners.clear()

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Savable:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners
Expand Down
46 changes: 46 additions & 0 deletions tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,29 @@
import yaml

import plumpy
from plumpy.persistence import auto_load

from . import utils


class SaveEmpty(plumpy.Savable):
pass

@classmethod
def recreate_from(cls, saved_state, load_context= None):
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj


@plumpy.auto_persist('test', 'test_method')
class Save1(plumpy.Savable):
Expand All @@ -22,12 +38,42 @@ def __init__(self):
def m():
pass

@classmethod
def recreate_from(cls, saved_state, load_context= None):
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj


@plumpy.auto_persist('test')
class Save(plumpy.Savable):
def __init__(self):
self.test = Save1()

@classmethod
def recreate_from(cls, saved_state, load_context= None):
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj


class TestSavable(unittest.TestCase):
def test_empty_savable(self):
Expand Down

0 comments on commit 55bc734

Please sign in to comment.