Skip to content

Commit

Permalink
WIP: forming Savable protocol
Browse files Browse the repository at this point in the history
- remove persist_config flag of savable
  • Loading branch information
unkcpz committed Dec 9, 2024
1 parent 55bc734 commit 2983dc5
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 124 deletions.
10 changes: 8 additions & 2 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import logging
from typing import TYPE_CHECKING, Any, Callable, Optional

from plumpy.persistence import LoadSaveContext, Savable, auto_load, auto_save, ensure_object_loader
from plumpy.utils import SAVED_STATE_TYPE

from . import persistence
from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load

if TYPE_CHECKING:
from typing import Set, Type

from .process_listener import ProcessListener

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,11 +44,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners
Expand Down
145 changes: 69 additions & 76 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@
import os
import pickle
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generator,
Iterable,
List,
Optional,
TypeVar,
Union,
)

import yaml

from . import futures, loaders, utils
from .base.utils import call_with_super_check, super_check
from .utils import PID_TYPE, SAVED_STATE_TYPE

__all__ = [
Expand Down Expand Up @@ -100,10 +111,10 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N
:return: The loaded Savable instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
load_context = ensure_object_loader(load_context, saved_state)
assert load_context.loader is not None # required for type checking
try:
class_name = Savable._get_class_name(saved_state)
class_name = SaveUtil.get_class_name(saved_state)
load_cls: Savable = load_context.loader.load_object(class_name)
except KeyError:
raise ValueError('Class name not found in saved state')
Expand Down Expand Up @@ -396,18 +407,19 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None:


def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]:
def wrapped(savable: SavableClsType) -> SavableClsType:
if savable._auto_persist is None:
savable._auto_persist = set()
def wrapped(savable_cls: SavableClsType) -> SavableClsType:
if not hasattr(savable_cls, '_auto_persist') or savable_cls._auto_persist is None:
savable_cls._auto_persist = set()
else:
savable._auto_persist = set(savable._auto_persist)
savable.auto_persist(*members)
return savable
savable_cls._auto_persist = set(savable_cls._auto_persist)

savable_cls._auto_persist.update(members)
return savable_cls

return wrapped


def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext':
def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext':
"""
Given a LoadSaveContext this method will ensure that it has a valid class loader
using the following priorities:
Expand All @@ -429,7 +441,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
# 2) Try getting from saved_state
default_loader = loaders.get_object_loader()
try:
loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER)
loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER)
except ValueError:
# 3) Fall back to default
loader = default_loader
Expand All @@ -448,45 +460,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
META__TYPE__SAVABLE: str = 'S'


class Savable:
CLASS_NAME: str = 'class_name'

_auto_persist: Optional[Set[str]] = None
_persist_configured = False

@classmethod
def auto_persist(cls, *members: str) -> None:
if cls._auto_persist is None:
cls._auto_persist = set()
cls._auto_persist.update(members)

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
...

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

def _ensure_persist_configured(self) -> None:
if not self._persist_configured:
self._persist_configured = True

# region Metadata getter/setters

class SaveUtil:
@staticmethod
def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None:
user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {})
user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {})
user_dict[name] = value

@staticmethod
Expand All @@ -497,43 +474,47 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
raise ValueError(f"Unknown meta key '{name}'")

@staticmethod
def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]:
def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]:
return out_state.setdefault(META, {})

@staticmethod
def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None:
Savable._get_create_meta(out_state)[META__CLASS_NAME] = name
def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None:
SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name

@staticmethod
def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str:
return Savable._get_create_meta(saved_state)[META__CLASS_NAME]
def get_class_name(saved_state: SAVED_STATE_TYPE) -> str:
return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME]

@staticmethod
def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None:
type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {})
def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None:
type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {})
type_dict[name] = type_spec

@staticmethod
def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
try:
return saved_state[META][META__TYPES][name]
except KeyError:
pass

# endregion

def _get_value(
self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext]
) -> Union[MethodType, 'Savable']:
value = saved_state[name]
class Savable:
_auto_persist: ClassVar[set[str] | None] = None

typ = Savable._get_meta_type(saved_state, name)
if typ == META__TYPE__METHOD:
value = getattr(self, value)
elif typ == META__TYPE__SAVABLE:
value = load(value, load_context)
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
return value
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
...

def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ...


@auto_persist('_state', '_result')
Expand Down Expand Up @@ -562,7 +543,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
load_context = ensure_object_loader(load_context, saved_state)

try:
loop = load_context.loop
Expand Down Expand Up @@ -612,24 +593,23 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
# If the user has specified a class loader, then save it in the saved state
if save_context.loader is not None:
loader_class = default_loader.identify_object(save_context.loader.__class__)
Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
loader = save_context.loader
else:
loader = default_loader

Savable._set_class_name(out_state, loader.identify_object(obj.__class__))
SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__))

obj._ensure_persist_configured()
if obj._auto_persist is not None:
for member in obj._auto_persist:
value = getattr(obj, member)
if inspect.ismethod(value):
if value.__self__ is not obj:
raise TypeError('Cannot persist methods of other classes')
Savable._set_meta_type(out_state, member, META__TYPE__METHOD)
SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD)
value = value.__name__
elif isinstance(value, Savable):
Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE)
SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save()
else:
value = copy.deepcopy(value)
Expand All @@ -639,7 +619,20 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S


def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
obj._ensure_persist_configured()
if obj._auto_persist is not None:
for member in obj._auto_persist:
setattr(obj, member, obj._get_value(saved_state, member, load_context))
setattr(obj, member, _get_value(obj, saved_state, member, load_context))


def _get_value(
obj, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext]
) -> Union[MethodType, 'Savable']:
value = saved_state[name]

typ = SaveUtil.get_meta_type(saved_state, name)
if typ == META__TYPE__METHOD:
value = getattr(obj, value)
elif typ == META__TYPE__SAVABLE:
value = load(value, load_context)

return value
5 changes: 3 additions & 2 deletions src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import abc
from typing import TYPE_CHECKING, Any, Dict, Optional

from plumpy.persistence import LoadSaveContext, ensure_object_loader

from . import persistence
from .utils import SAVED_STATE_TYPE
from plumpy.persistence import LoadSaveContext, _ensure_object_loader

__all__ = ['ProcessListener']

Expand Down Expand Up @@ -34,7 +35,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
obj.init(**saved_state['_params'])
return obj
Expand Down
Loading

0 comments on commit 2983dc5

Please sign in to comment.