Skip to content

Commit

Permalink
WIP: forming Savable protocol
Browse files Browse the repository at this point in the history
- remove persist_config flag of savable
  • Loading branch information
unkcpz committed Dec 9, 2024
1 parent 55bc734 commit d5ba107
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 106 deletions.
9 changes: 7 additions & 2 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from plumpy.utils import SAVED_STATE_TYPE

from . import persistence
from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load
from plumpy.persistence import Savable, LoadSaveContext, ensure_object_loader, auto_load, auto_save

if TYPE_CHECKING:
from typing import Set, Type
Expand Down Expand Up @@ -43,11 +43,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners
Expand Down
121 changes: 53 additions & 68 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N
:return: The loaded Savable instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
load_context = ensure_object_loader(load_context, saved_state)
assert load_context.loader is not None # required for type checking
try:
class_name = Savable._get_class_name(saved_state)
class_name = SaveUtil.get_class_name(saved_state)
load_cls: Savable = load_context.loader.load_object(class_name)
except KeyError:
raise ValueError('Class name not found in saved state')
Expand Down Expand Up @@ -407,7 +407,7 @@ def wrapped(savable: SavableClsType) -> SavableClsType:
return wrapped


def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext':
def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext':
"""
Given a LoadSaveContext this method will ensure that it has a valid class loader
using the following priorities:
Expand All @@ -429,7 +429,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
# 2) Try getting from saved_state
default_loader = loaders.get_object_loader()
try:
loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER)
loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER)
except ValueError:
# 3) Fall back to default
loader = default_loader
Expand All @@ -448,45 +448,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
META__TYPE__SAVABLE: str = 'S'


class Savable:
CLASS_NAME: str = 'class_name'

_auto_persist: Optional[Set[str]] = None
_persist_configured = False

@classmethod
def auto_persist(cls, *members: str) -> None:
if cls._auto_persist is None:
cls._auto_persist = set()
cls._auto_persist.update(members)

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
...

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

def _ensure_persist_configured(self) -> None:
if not self._persist_configured:
self._persist_configured = True

# region Metadata getter/setters

class SaveUtil:
@staticmethod
def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None:
user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {})
user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {})
user_dict[name] = value

@staticmethod
Expand All @@ -497,44 +462,53 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
raise ValueError(f"Unknown meta key '{name}'")

@staticmethod
def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]:
def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]:
return out_state.setdefault(META, {})

@staticmethod
def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None:
Savable._get_create_meta(out_state)[META__CLASS_NAME] = name
def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None:
SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name

@staticmethod
def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str:
return Savable._get_create_meta(saved_state)[META__CLASS_NAME]
def get_class_name(saved_state: SAVED_STATE_TYPE) -> str:
return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME]

@staticmethod
def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None:
type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {})
def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None:
type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {})
type_dict[name] = type_spec

@staticmethod
def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
try:
return saved_state[META][META__TYPES][name]
except KeyError:
pass

# endregion

def _get_value(
self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext]
) -> Union[MethodType, 'Savable']:
value = saved_state[name]
class Savable:
_auto_persist: Optional[Set[str]] = None

@classmethod
def auto_persist(cls, *members: str) -> None:
if cls._auto_persist is None:
cls._auto_persist = set()
cls._auto_persist.update(members)

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
typ = Savable._get_meta_type(saved_state, name)
if typ == META__TYPE__METHOD:
value = getattr(self, value)
elif typ == META__TYPE__SAVABLE:
value = load(value, load_context)
:return: The recreated instance
return value
"""
...

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: ...

@auto_persist('_state', '_result')
class SavableFuture(futures.Future, Savable):
Expand Down Expand Up @@ -562,7 +536,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
load_context = ensure_object_loader(load_context, saved_state)

try:
loop = load_context.loop
Expand Down Expand Up @@ -612,24 +586,23 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
# If the user has specified a class loader, then save it in the saved state
if save_context.loader is not None:
loader_class = default_loader.identify_object(save_context.loader.__class__)
Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
loader = save_context.loader
else:
loader = default_loader

Savable._set_class_name(out_state, loader.identify_object(obj.__class__))
SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__))

obj._ensure_persist_configured()
if obj._auto_persist is not None:
for member in obj._auto_persist:
value = getattr(obj, member)
if inspect.ismethod(value):
if value.__self__ is not obj:
raise TypeError('Cannot persist methods of other classes')
Savable._set_meta_type(out_state, member, META__TYPE__METHOD)
SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD)
value = value.__name__
elif isinstance(value, Savable):
Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE)
SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save()
else:
value = copy.deepcopy(value)
Expand All @@ -639,7 +612,19 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S


def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
obj._ensure_persist_configured()
if obj._auto_persist is not None:
for member in obj._auto_persist:
setattr(obj, member, obj._get_value(saved_state, member, load_context))
setattr(obj, member, _get_value(obj, saved_state, member, load_context))

def _get_value(
obj, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext]
) -> Union[MethodType, 'Savable']:
value = saved_state[name]

typ = SaveUtil.get_meta_type(saved_state, name)
if typ == META__TYPE__METHOD:
value = getattr(obj, value)
elif typ == META__TYPE__SAVABLE:
value = load(value, load_context)

return value
4 changes: 2 additions & 2 deletions src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from . import persistence
from .utils import SAVED_STATE_TYPE
from plumpy.persistence import LoadSaveContext, _ensure_object_loader
from plumpy.persistence import LoadSaveContext, ensure_object_loader

__all__ = ['ProcessListener']

Expand Down Expand Up @@ -34,7 +34,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
obj.init(**saved_state['_params'])
return obj
Expand Down
Loading

0 comments on commit d5ba107

Please sign in to comment.