diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py index 10142eb7..9dfa7539 100644 --- a/src/plumpy/mixins.py +++ b/src/plumpy/mixins.py @@ -23,19 +23,6 @@ def __init__(self, *args: Any, **kwargs: Any): def ctx(self) -> Optional[AttributesDict]: return self._context - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - """Add the instance state to ``out_state``. - .. important:: - - The instance state will contain a pointer to the ``ctx``, - and so should be deep copied or serialised before persisting. - """ - super().save_instance_state(out_state, save_context) - if self._context is not None: - out_state[self.CONTEXT] = self._context.__dict__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) try: diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 6c3849a6..62fdd58f 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -487,43 +487,9 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio for member in self._auto_persist: setattr(self, member, self._get_value(saved_state, member, load_context)) - @super_check - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - for member in self._auto_persist: - value = getattr(self, member) - if inspect.ismethod(value): - if value.__self__ is not self: - raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = {} - - if save_context is None: - save_context = LoadSaveContext() + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - utils.type_check(save_context, LoadSaveContext) - - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader - - Savable._set_class_name(out_state, loader.identify_object(self.__class__)) - call_with_super_check(self.save_instance_state, out_state, save_context) return out_state def _ensure_persist_configured(self) -> None: @@ -593,11 +559,13 @@ class SavableFuture(futures.Future, Savable): .. note: This does not save any assigned done callbacks. """ - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) if self.done() and self.exception() is not None: out_state['exception'] = self.exception() + return out_state + @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -643,3 +611,41 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadS # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list for callback in self._callbacks: self.remove_done_callback(callback) # type: ignore[arg-type] + + +def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = {} + + if save_context is None: + save_context = LoadSaveContext() + + utils.type_check(save_context, LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + Savable._set_class_name(out_state, loader.identify_object(obj.__class__)) + + obj._ensure_persist_configured() + if obj._auto_persist is not None: + for member in obj._auto_persist: + value = getattr(obj, member) + if inspect.ismethod(value): + if value.__self__ is not obj: + raise TypeError('Cannot persist methods of other classes') + Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, Savable): + Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value + + return out_state diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index d2743d81..0f811cb6 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import copy +import inspect import sys import traceback from enum import Enum @@ -23,6 +25,7 @@ import yaml from yaml.loader import Loader +from plumpy import loaders from plumpy.process_comms import KillMessage, MessageType try: @@ -35,7 +38,7 @@ from . import exceptions, futures, persistence, utils from .base import state_machine as st from .lang import NULL -from .persistence import LoadSaveContext, auto_persist +from .persistence import META__OBJECT_LOADER, META__TYPE__METHOD, META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_persist, auto_save from .utils import SAVED_STATE_TYPE __all__ = [ @@ -127,10 +130,12 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process @@ -159,10 +164,9 @@ class ProcessState(Enum): KILLED = 'killed' -@runtime_checkable -class Savable(Protocol): - def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... - +# @runtime_checkable +# class Savable(Protocol): +# def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... @final @auto_persist('args', 'kwargs') @@ -180,10 +184,12 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process @@ -230,12 +236,15 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.kwargs = kwargs self._run_handle = None - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process @@ -351,11 +360,14 @@ def __init__( self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process @@ -438,12 +450,15 @@ def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] return super().__str__() + f'({exception})' - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 56fdc570..61c2ff46 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -9,6 +9,7 @@ import copy import enum import functools +import inspect import logging import re import sys @@ -33,6 +34,8 @@ cast, ) +from plumpy import loaders + try: from aiocontextvars import ContextVar except ModuleNotFoundError: @@ -82,7 +85,7 @@ class BundleKeys: """ String keys used by the process to save its state in the state bundle. - See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`. + See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.load_instance_state`. """ @@ -623,18 +626,14 @@ async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) # region Persistence - def save_instance_state( - self, - out_state: SAVED_STATE_TYPE, - save_context: Optional[persistence.LoadSaveContext], - ) -> None: + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: """ Ask the process to save its current instance state. :param out_state: A bundle to save the state to :param save_context: The save context """ - super().save_instance_state(out_state, save_context) + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) # FIXME: the combined ProcessState protocol should cover the case if isinstance(self._state, process_states.Savable): @@ -650,6 +649,8 @@ def save_instance_state( if self.outputs: out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) + return out_state + @protected def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: """Load the process from its saved instance state. diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 865a5b61..9741f7ed 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import copy import abc import asyncio import collections @@ -29,6 +30,7 @@ from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE +from plumpy import loaders, utils __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] @@ -149,15 +151,69 @@ def on_create(self) -> None: super().on_create() self._stepper = self.spec().get_outline().create_stepper(self) - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + """ + Ask the process to save its current instance state. + + :param out_state: A bundle to save the state to + :param save_context: The save context + """ + out_state: SAVED_STATE_TYPE = {} + + if save_context is None: + save_context = persistence.LoadSaveContext() + + utils.type_check(save_context, persistence.LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__)) + + self._ensure_persist_configured() + if self._auto_persist is not None: + for member in self._auto_persist: + value = getattr(self, member) + if inspect.ismethod(value): + if value.__self__ is not self: + raise TypeError('Cannot persist methods of other classes') + persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, persistence.Savable): + persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value + + if isinstance(self._state, process_states.Savable): + out_state['_state'] = self._state.save() + + # Inputs/outputs + if self.raw_inputs is not None: + out_state[processes.BundleKeys.INPUTS_RAW] = self.encode_input_args(self.raw_inputs) + + if self.inputs is not None: + out_state[processes.BundleKeys.INPUTS_PARSED] = self.encode_input_args(self.inputs) + + if self.outputs: + out_state[processes.BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) # Ask the stepper to save itself if self._stepper is not None: out_state[self._STEPPER_STATE] = self._stepper.save() + if self._context is not None: + out_state[self.CONTEXT] = self._context.__dict__ + + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) @@ -253,10 +309,12 @@ def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): super().__init__(workchain) self._fn = fn - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state['_fn'] = self._fn.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._fn = getattr(self._workchain.__class__, saved_state['_fn']) @@ -326,11 +384,13 @@ def next_instruction(self) -> None: def finished(self) -> bool: return self._pos == len(self._block) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._block = load_context.block_instruction @@ -464,11 +524,13 @@ def step(self) -> Tuple[bool, Any]: def finished(self) -> bool: return self._pos == len(self._if_instruction) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._if_instruction = load_context.if_instruction @@ -558,11 +620,14 @@ def step(self) -> Tuple[bool, Any]: return False, result - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) + if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._while_instruction = load_context.while_instruction diff --git a/tests/test_processes.py b/tests/test_processes.py index 4b8cc606..7fa33bb1 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -700,7 +700,7 @@ def step2(self): class TestProcessSaving(unittest.TestCase): maxDiff = None - def test_running_save_instance_state(self): + def test_running_save(self): loop = asyncio.get_event_loop() nsync_comeback = SavePauseProc()