Skip to content

Commit

Permalink
WIP: forming Savable protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 9, 2024
1 parent 55bc734 commit 4836edc
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 17 deletions.
7 changes: 6 additions & 1 deletion src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from plumpy.utils import SAVED_STATE_TYPE

from . import persistence
from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load
from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load, auto_save

if TYPE_CHECKING:
from typing import Set, Type
Expand Down Expand Up @@ -48,6 +48,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners
Expand Down
4 changes: 1 addition & 3 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
...

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state
...

def _ensure_persist_configured(self) -> None:
if not self._persist_configured:
Expand Down
41 changes: 32 additions & 9 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import copy
import inspect
import sys
import traceback
from enum import Enum
Expand All @@ -13,19 +11,17 @@
Callable,
ClassVar,
Optional,
Protocol,
Tuple,
Type,
Union,
cast,
final,
runtime_checkable,
override,
)

import yaml
from yaml.loader import Loader

from plumpy import loaders
from plumpy.process_comms import KillMessage, MessageType
from plumpy.persistence import _ensure_object_loader

Expand All @@ -40,9 +36,6 @@
from .base import state_machine as st
from .lang import NULL
from .persistence import (
META__OBJECT_LOADER,
META__TYPE__METHOD,
META__TYPE__SAVABLE,
LoadSaveContext,
Savable,
auto_load,
Expand Down Expand Up @@ -94,8 +87,26 @@ class PauseInterruption(Interruption):


class Command(persistence.Savable):
pass

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

@auto_persist('msg')
class Kill(Command):
Expand Down Expand Up @@ -140,12 +151,14 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any):
self.args = args
self.kwargs = kwargs

@override
def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context)
out_state[self.CONTINUE_FN] = self.continue_fn.__name__

return out_state

@override
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Expand Down Expand Up @@ -605,6 +618,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

def enter(self) -> None: ...

def exit(self) -> None: ...
Expand Down Expand Up @@ -649,6 +667,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state

def enter(self) -> None: ...

def exit(self) -> None: ...
Expand Down
24 changes: 20 additions & 4 deletions tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import yaml

import plumpy
from plumpy.persistence import auto_load
from plumpy.persistence import auto_load, auto_save
from plumpy.utils import SAVED_STATE_TYPE

from . import utils

Expand All @@ -14,7 +15,7 @@ class SaveEmpty(plumpy.Savable):
pass

@classmethod
def recreate_from(cls, saved_state, load_context= None):
def recreate_from(cls, saved_state, load_context=None):
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand All @@ -28,6 +29,11 @@ def recreate_from(cls, saved_state, load_context= None):
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context=None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state


@plumpy.auto_persist('test', 'test_method')
class Save1(plumpy.Savable):
Expand All @@ -39,7 +45,7 @@ def m():
pass

@classmethod
def recreate_from(cls, saved_state, load_context= None):
def recreate_from(cls, saved_state, load_context=None):
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand All @@ -53,14 +59,19 @@ def recreate_from(cls, saved_state, load_context= None):
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context=None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state


@plumpy.auto_persist('test')
class Save(plumpy.Savable):
def __init__(self):
self.test = Save1()

@classmethod
def recreate_from(cls, saved_state, load_context= None):
def recreate_from(cls, saved_state, load_context=None):
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand All @@ -74,6 +85,11 @@ def recreate_from(cls, saved_state, load_context= None):
auto_load(obj, saved_state, load_context)
return obj

def save(self, save_context=None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

return out_state


class TestSavable(unittest.TestCase):
def test_empty_savable(self):
Expand Down

0 comments on commit 4836edc

Please sign in to comment.