Skip to content

Commit

Permalink
WIP: forming Savable protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 9, 2024
1 parent 55bc734 commit 3cc21c3
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 76 deletions.
7 changes: 6 additions & 1 deletion src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from plumpy.utils import SAVED_STATE_TYPE

from . import persistence
from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load
from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load, auto_save

if TYPE_CHECKING:
from typing import Set, Type
Expand Down Expand Up @@ -48,6 +48,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners
Expand Down
110 changes: 52 additions & 58 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N
load_context = _ensure_object_loader(load_context, saved_state)
assert load_context.loader is not None # required for type checking
try:
class_name = Savable._get_class_name(saved_state)
class_name = SaveUtil._get_class_name(saved_state)
load_cls: Savable = load_context.loader.load_object(class_name)
except KeyError:
raise ValueError('Class name not found in saved state')
Expand Down Expand Up @@ -429,7 +429,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
# 2) Try getting from saved_state
default_loader = loaders.get_object_loader()
try:
loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER)
loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER)
except ValueError:
# 3) Fall back to default
loader = default_loader
Expand All @@ -448,45 +448,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
META__TYPE__SAVABLE: str = 'S'


class Savable:
CLASS_NAME: str = 'class_name'

_auto_persist: Optional[Set[str]] = None
_persist_configured = False

@classmethod
def auto_persist(cls, *members: str) -> None:
if cls._auto_persist is None:
cls._auto_persist = set()
cls._auto_persist.update(members)

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
...

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

def _ensure_persist_configured(self) -> None:
if not self._persist_configured:
self._persist_configured = True

# region Metadata getter/setters

class SaveUtil:
@staticmethod
def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None:
user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {})
user_dict = SaveUtil._get_create_meta(out_state).setdefault(META__USER, {})
user_dict[name] = value

@staticmethod
Expand All @@ -502,15 +467,15 @@ def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]:

@staticmethod
def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None:
Savable._get_create_meta(out_state)[META__CLASS_NAME] = name
SaveUtil._get_create_meta(out_state)[META__CLASS_NAME] = name

@staticmethod
def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str:
return Savable._get_create_meta(saved_state)[META__CLASS_NAME]
return SaveUtil._get_create_meta(saved_state)[META__CLASS_NAME]

@staticmethod
def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None:
type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {})
type_dict = SaveUtil._get_create_meta(out_state).setdefault(META__TYPES, {})
type_dict[name] = type_spec

@staticmethod
Expand All @@ -520,21 +485,37 @@ def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
except KeyError:
pass

# endregion

def _get_value(
self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext]
) -> Union[MethodType, 'Savable']:
value = saved_state[name]
class Savable:
CLASS_NAME: str = 'class_name'

_auto_persist: Optional[Set[str]] = None
_persist_configured = False

@classmethod
def auto_persist(cls, *members: str) -> None:
if cls._auto_persist is None:
cls._auto_persist = set()
cls._auto_persist.update(members)

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
typ = Savable._get_meta_type(saved_state, name)
if typ == META__TYPE__METHOD:
value = getattr(self, value)
elif typ == META__TYPE__SAVABLE:
value = load(value, load_context)
"""
...

return value
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: ...

def _ensure_persist_configured(self) -> None:
if not self._persist_configured:
self._persist_configured = True

@auto_persist('_state', '_result')
class SavableFuture(futures.Future, Savable):
Expand Down Expand Up @@ -612,12 +593,12 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
# If the user has specified a class loader, then save it in the saved state
if save_context.loader is not None:
loader_class = default_loader.identify_object(save_context.loader.__class__)
Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
loader = save_context.loader
else:
loader = default_loader

Savable._set_class_name(out_state, loader.identify_object(obj.__class__))
SaveUtil._set_class_name(out_state, loader.identify_object(obj.__class__))

obj._ensure_persist_configured()
if obj._auto_persist is not None:
Expand All @@ -626,10 +607,10 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
if inspect.ismethod(value):
if value.__self__ is not obj:
raise TypeError('Cannot persist methods of other classes')
Savable._set_meta_type(out_state, member, META__TYPE__METHOD)
SaveUtil._set_meta_type(out_state, member, META__TYPE__METHOD)
value = value.__name__
elif isinstance(value, Savable):
Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE)
SaveUtil._set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save()
else:
value = copy.deepcopy(value)
Expand All @@ -642,4 +623,17 @@ def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSav
obj._ensure_persist_configured()
if obj._auto_persist is not None:
for member in obj._auto_persist:
setattr(obj, member, obj._get_value(saved_state, member, load_context))
setattr(obj, member, _get_value(obj, saved_state, member, load_context))

def _get_value(
obj, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext]
) -> Union[MethodType, 'Savable']:
value = saved_state[name]

typ = SaveUtil._get_meta_type(saved_state, name)
if typ == META__TYPE__METHOD:
value = getattr(obj, value)
elif typ == META__TYPE__SAVABLE:
value = load(value, load_context)

return value
41 changes: 32 additions & 9 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import copy
import inspect
import sys
import traceback
from enum import Enum
Expand All @@ -13,19 +11,17 @@
Callable,
ClassVar,
Optional,
Protocol,
Tuple,
Type,
Union,
cast,
final,
runtime_checkable,
override,
)

import yaml
from yaml.loader import Loader

from plumpy import loaders
from plumpy.process_comms import KillMessage, MessageType
from plumpy.persistence import _ensure_object_loader

Expand All @@ -40,9 +36,6 @@
from .base import state_machine as st
from .lang import NULL
from .persistence import (
META__OBJECT_LOADER,
META__TYPE__METHOD,
META__TYPE__SAVABLE,
LoadSaveContext,
Savable,
auto_load,
Expand Down Expand Up @@ -94,8 +87,26 @@ class PauseInterruption(Interruption):


class Command(persistence.Savable):
pass

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

@auto_persist('msg')
class Kill(Command):
Expand Down Expand Up @@ -140,12 +151,14 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any):
self.args = args
self.kwargs = kwargs

@override
def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context)
out_state[self.CONTINUE_FN] = self.continue_fn.__name__

return out_state

@override
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Expand Down Expand Up @@ -605,6 +618,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

def enter(self) -> None: ...

def exit(self) -> None: ...
Expand Down Expand Up @@ -649,6 +667,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

def enter(self) -> None: ...

def exit(self) -> None: ...
Expand Down
8 changes: 4 additions & 4 deletions src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,12 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA
# If the user has specified a class loader, then save it in the saved state
if save_context.loader is not None:
loader_class = default_loader.identify_object(save_context.loader.__class__)
persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class)
persistence.SaveUtil.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class)
loader = save_context.loader
else:
loader = default_loader

persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__))
persistence.SaveUtil._set_class_name(out_state, loader.identify_object(self.__class__))

self._ensure_persist_configured()
if self._auto_persist is not None:
Expand All @@ -190,10 +190,10 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA
if inspect.ismethod(value):
if value.__self__ is not self:
raise TypeError('Cannot persist methods of other classes')
persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD)
persistence.SaveUtil._set_meta_type(out_state, member, persistence.META__TYPE__METHOD)
value = value.__name__
elif isinstance(value, persistence.Savable):
persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE)
persistence.SaveUtil._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE)
value = value.save()
else:
value = copy.deepcopy(value)
Expand Down
Loading

0 comments on commit 3cc21c3

Please sign in to comment.