Skip to content

Commit

Permalink
save_instance_state simplify to only has save interface
Browse files Browse the repository at this point in the history
For the auto_persist attributes, the fn auto_save will take care of save
the state
  • Loading branch information
unkcpz committed Dec 9, 2024
1 parent 937ad01 commit 6558ff1
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 85 deletions.
13 changes: 0 additions & 13 deletions src/plumpy/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,6 @@ def __init__(self, *args: Any, **kwargs: Any):
def ctx(self) -> Optional[AttributesDict]:
return self._context

def save_instance_state(
self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext]
) -> None:
"""Add the instance state to ``out_state``.
.. important::
The instance state will contain a pointer to the ``ctx``,
and so should be deep copied or serialised before persisting.
"""
super().save_instance_state(out_state, save_context)
if self._context is not None:
out_state[self.CONTEXT] = self._context.__dict__

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
try:
Expand Down
80 changes: 43 additions & 37 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,43 +487,9 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio
for member in self._auto_persist:
setattr(self, member, self._get_value(saved_state, member, load_context))

@super_check
def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None:
self._ensure_persist_configured()
if self._auto_persist is not None:
for member in self._auto_persist:
value = getattr(self, member)
if inspect.ismethod(value):
if value.__self__ is not self:
raise TypeError('Cannot persist methods of other classes')
Savable._set_meta_type(out_state, member, META__TYPE__METHOD)
value = value.__name__
elif isinstance(value, Savable):
Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save()
else:
value = copy.deepcopy(value)
out_state[member] = value

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = {}

if save_context is None:
save_context = LoadSaveContext()
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

utils.type_check(save_context, LoadSaveContext)

default_loader = loaders.get_object_loader()
# If the user has specified a class loader, then save it in the saved state
if save_context.loader is not None:
loader_class = default_loader.identify_object(save_context.loader.__class__)
Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
loader = save_context.loader
else:
loader = default_loader

Savable._set_class_name(out_state, loader.identify_object(self.__class__))
call_with_super_check(self.save_instance_state, out_state, save_context)
return out_state

def _ensure_persist_configured(self) -> None:
Expand Down Expand Up @@ -593,11 +559,13 @@ class SavableFuture(futures.Future, Savable):
.. note: This does not save any assigned done callbacks.
"""

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)
if self.done() and self.exception() is not None:
out_state['exception'] = self.exception()

return out_state

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Expand Down Expand Up @@ -643,3 +611,41 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadS
# typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list
for callback in self._callbacks:
self.remove_done_callback(callback) # type: ignore[arg-type]


def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = {}

if save_context is None:
save_context = LoadSaveContext()

utils.type_check(save_context, LoadSaveContext)

default_loader = loaders.get_object_loader()
# If the user has specified a class loader, then save it in the saved state
if save_context.loader is not None:
loader_class = default_loader.identify_object(save_context.loader.__class__)
Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class)
loader = save_context.loader
else:
loader = default_loader

Savable._set_class_name(out_state, loader.identify_object(obj.__class__))

obj._ensure_persist_configured()
if obj._auto_persist is not None:
for member in obj._auto_persist:
value = getattr(obj, member)
if inspect.ismethod(value):
if value.__self__ is not obj:
raise TypeError('Cannot persist methods of other classes')
Savable._set_meta_type(out_state, member, META__TYPE__METHOD)
value = value.__name__
elif isinstance(value, Savable):
Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save()
else:
value = copy.deepcopy(value)
out_state[member] = value

return out_state
45 changes: 30 additions & 15 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import copy
import inspect
import sys
import traceback
from enum import Enum
Expand All @@ -23,6 +25,7 @@
import yaml
from yaml.loader import Loader

from plumpy import loaders
from plumpy.process_comms import KillMessage, MessageType

try:
Expand All @@ -35,7 +38,7 @@
from . import exceptions, futures, persistence, utils
from .base import state_machine as st
from .lang import NULL
from .persistence import LoadSaveContext, auto_persist
from .persistence import META__OBJECT_LOADER, META__TYPE__METHOD, META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_persist, auto_save
from .utils import SAVED_STATE_TYPE

__all__ = [
Expand Down Expand Up @@ -127,10 +130,12 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any):
self.args = args
self.kwargs = kwargs

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context)
out_state[self.CONTINUE_FN] = self.continue_fn.__name__

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process
Expand Down Expand Up @@ -159,10 +164,9 @@ class ProcessState(Enum):
KILLED = 'killed'


@runtime_checkable
class Savable(Protocol):
def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ...

# @runtime_checkable
# class Savable(Protocol):
# def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ...

@final
@auto_persist('args', 'kwargs')
Expand All @@ -180,10 +184,12 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, *
self.args = args
self.kwargs = kwargs

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)
out_state[self.RUN_FN] = self.run_fn.__name__

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process
Expand Down Expand Up @@ -230,12 +236,15 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, *
self.kwargs = kwargs
self._run_handle = None

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

out_state[self.RUN_FN] = self.run_fn.__name__
if self._command is not None:
out_state[self.COMMAND] = self._command.save()

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process
Expand Down Expand Up @@ -351,11 +360,14 @@ def __init__(
self.data = data
self._waiting_future: futures.Future = futures.Future()

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

if self.done_callback is not None:
out_state[self.DONE_CALLBACK] = self.done_callback.__name__

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process
Expand Down Expand Up @@ -438,12 +450,15 @@ def __str__(self) -> str:
exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0]
return super().__str__() + f'({exception})'

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

out_state[self.EXC_VALUE] = yaml.dump(self.exception)
if self.traceback is not None:
out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback))

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)

Expand Down
15 changes: 8 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import copy
import enum
import functools
import inspect
import logging
import re
import sys
Expand All @@ -33,6 +34,8 @@
cast,
)

from plumpy import loaders

try:
from aiocontextvars import ContextVar
except ModuleNotFoundError:
Expand Down Expand Up @@ -82,7 +85,7 @@ class BundleKeys:
"""
String keys used by the process to save its state in the state bundle.
See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`.
See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.load_instance_state`.
"""

Expand Down Expand Up @@ -623,18 +626,14 @@ async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any)

# region Persistence

def save_instance_state(
self,
out_state: SAVED_STATE_TYPE,
save_context: Optional[persistence.LoadSaveContext],
) -> None:
def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE:
"""
Ask the process to save its current instance state.
:param out_state: A bundle to save the state to
:param save_context: The save context
"""
super().save_instance_state(out_state, save_context)
out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context)

# FIXME: the combined ProcessState protocol should cover the case
if isinstance(self._state, process_states.Savable):
Expand All @@ -650,6 +649,8 @@ def save_instance_state(
if self.outputs:
out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs)

return out_state

@protected
def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
"""Load the process from its saved instance state.
Expand Down
Loading

0 comments on commit 6558ff1

Please sign in to comment.