Skip to content

Commit

Permalink
Waiting state directly from state_machine.state and persistence.Savable
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Nov 28, 2024
1 parent e60278f commit fb43887
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 80 deletions.
126 changes: 52 additions & 74 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,40 +430,13 @@ class Savable:
CLASS_NAME: str = 'class_name'

_auto_persist: Optional[Set[str]] = None
_persist_configured = False

@staticmethod
def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Load a `Savable` from a saved instance state. The load context is a way of passing
runtime data to the object being loaded.
:param saved_state: The saved state
:param load_context: Additional runtime state that can be passed into when loading.
The type and content (if any) is completely user defined
:return: The loaded Savable instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
assert load_context.loader is not None # required for type checking
try:
class_name = Savable._get_class_name(saved_state)
load_cls = load_context.loader.load_object(class_name)
except KeyError:
raise ValueError('Class name not found in saved state')
else:
return load_cls.recreate_from(saved_state, load_context)

@classmethod
def auto_persist(cls, *members: str) -> None:
if cls._auto_persist is None:
cls._auto_persist = set()
cls._auto_persist.update(members)

@classmethod
def persist(cls) -> None:
pass

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Expand All @@ -482,17 +455,63 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa

@super_check
def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None:
self._ensure_persist_configured()
if self._auto_persist is not None:
self.load_members(self._auto_persist, saved_state, load_context)
if self._auto_persist is None:
return None

for member in self._auto_persist:
value = saved_state[member]

typ = Savable._get_meta_type(saved_state, member)
if typ == META__TYPE__METHOD:
value = getattr(self, value)
elif typ == META__TYPE__SAVABLE:
value = Savable.load(value, load_context)

setattr(self, member, value)

@staticmethod
def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Load a `Savable` from a saved instance state. The load context is a way of passing
runtime data to the object being loaded.
:param saved_state: The saved state
:param load_context: Additional runtime state that can be passed into when loading.
The type and content (if any) is completely user defined
:return: The loaded Savable instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
assert load_context.loader is not None # required for type checking
try:
class_name = Savable._get_class_name(saved_state)
load_cls = load_context.loader.load_object(class_name)
except KeyError:
raise ValueError('Class name not found in saved state')
else:
return load_cls.recreate_from(saved_state, load_context)

@super_check
def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None:
self._ensure_persist_configured()
if self._auto_persist is not None:
self.save_members(self._auto_persist, out_state)
if self._auto_persist is None:
return None

for member in self._auto_persist:
value = getattr(self, member)
if inspect.ismethod(value):
if value.__self__ is not self:
raise TypeError('Cannot persist methods of other classes')
Savable._set_meta_type(out_state, member, META__TYPE__METHOD)
value = value.__name__
elif isinstance(value, Savable):
Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save(save_context)
else:
value = copy.deepcopy(value)
out_state[member] = value

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
"""Recursively call ``save`` on the members."""
out_state: SAVED_STATE_TYPE = {}

if save_context is None:
Expand All @@ -513,32 +532,6 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY
call_with_super_check(self.save_instance_state, out_state, save_context)
return out_state

def save_members(self, members: Iterable[str], out_state: SAVED_STATE_TYPE) -> None:
for member in members:
value = getattr(self, member)
if inspect.ismethod(value):
if value.__self__ is not self:
raise TypeError('Cannot persist methods of other classes')
Savable._set_meta_type(out_state, member, META__TYPE__METHOD)
value = value.__name__
elif isinstance(value, Savable):
Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save()
else:
value = copy.deepcopy(value)
out_state[member] = value

def load_members(
self, members: Iterable[str], saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None
) -> None:
for member in members:
setattr(self, member, self._get_value(saved_state, member, load_context))

def _ensure_persist_configured(self) -> None:
if not self._persist_configured:
self.persist()
self._persist_configured = True

# region Metadata getter/setters

@staticmethod
Expand Down Expand Up @@ -577,21 +570,6 @@ def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
except KeyError:
pass

# endregion

def _get_value(
self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext]
) -> Union[MethodType, 'Savable']:
value = saved_state[name]

typ = Savable._get_meta_type(saved_state, name)
if typ == META__TYPE__METHOD:
value = getattr(self, value)
elif typ == META__TYPE__SAVABLE:
value = Savable.load(value, load_context)

return value


@auto_persist('_state', '_result')
class SavableFuture(futures.Future, Savable):
Expand Down
24 changes: 18 additions & 6 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import sys
import traceback
from enum import Enum
Expand All @@ -18,7 +19,7 @@
from . import exceptions, futures, persistence, utils
from .base import state_machine
from .lang import NULL
from .persistence import auto_persist
from .persistence import Savable, auto_persist
from .utils import SAVED_STATE_TYPE

__all__ = [
Expand Down Expand Up @@ -264,8 +265,8 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State:
return cast(State, state) # casting from base.State to process.State


@auto_persist('msg', 'data')
class Waiting(State):
class Waiting(state_machine.State, persistence.Savable):
# class Waiting(state_machine.State):
LABEL = ProcessState.WAITING
ALLOWED = {
ProcessState.RUNNING,
Expand All @@ -278,6 +279,7 @@ class Waiting(State):
DONE_CALLBACK = 'DONE_CALLBACK'

_interruption = None
_auto_persist = {'msg', 'data', 'in_state'}

def __str__(self) -> str:
state_info = super().__str__()
Expand All @@ -288,23 +290,33 @@ def __str__(self) -> str:
def __init__(
self,
process: 'Process',
done_callback: Optional[Callable[..., Any]],
msg: Optional[str] = None,
data: Optional[Any] = None,
done_callback: Callable[..., Any] | None,
msg: str | None = None,
data: Any | None = None,
saver: Savable | None = None,
) -> None:
super().__init__(process)
self.done_callback = done_callback
self.msg = msg
self.data = data
self._waiting_future: futures.Future = futures.Future()

@property
def process(self) -> state_machine.StateMachine:
"""
:return: The process
"""
return self.state_machine

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
if self.done_callback is not None:
out_state[self.DONE_CALLBACK] = self.done_callback.__name__

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)

self.state_machine = load_context.process
callback_name = saved_state.get(self.DONE_CALLBACK, None)
if callback_name is not None:
self.done_callback = getattr(self.process, callback_name)
Expand Down

0 comments on commit fb43887

Please sign in to comment.