diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 6f94b5bf..5aa23401 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -9,7 +9,6 @@ from .exceptions import * from .futures import * from .loaders import * -from .mixins import * from .persistence import * from .ports import * from .process_comms import * @@ -25,7 +24,6 @@ + processes.__all__ + utils.__all__ + futures.__all__ - + mixins.__all__ + persistence.__all__ + communications.__all__ + process_comms.__all__ diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py deleted file mode 100644 index 9dfa7539..00000000 --- a/src/plumpy/mixins.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Any, Optional - -from . import persistence -from .utils import SAVED_STATE_TYPE, AttributesDict - -__all__ = ['ContextMixin'] - - -class ContextMixin(persistence.Savable): - """ - Add a context to a Process. The contents of the context will be saved - in the instance state unlike standard instance variables. - """ - - CONTEXT: str = '_context' - - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._context: Optional[AttributesDict] = AttributesDict() - - @property - def ctx(self) -> Optional[AttributesDict]: - return self._context - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - try: - self._context = AttributesDict(**saved_state[self.CONTEXT]) - except KeyError: - pass diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index ccdeef26..d33afaa1 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -477,15 +477,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = _ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) - call_with_super_check(obj.load_instance_state, saved_state, load_context) + obj.load_instance_state(saved_state, load_context) return obj - @super_check def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - for member in self._auto_persist: - setattr(self, member, self._get_value(saved_state, member, load_context)) + auto_load(self, saved_state, load_context) def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = auto_save(self, save_context) @@ -606,7 +602,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa return obj def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + auto_load(self, saved_state, load_context) + if self._callbacks: # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list for callback in self._callbacks: @@ -649,3 +646,9 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S out_state[member] = value return out_state + +def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: + obj._ensure_persist_configured() + if obj._auto_persist is not None: + for member in obj._auto_persist: + setattr(obj, member, obj._get_value(saved_state, member, load_context)) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 0f811cb6..1d7f2350 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -27,6 +27,7 @@ from plumpy import loaders from plumpy.process_comms import KillMessage, MessageType +from plumpy.persistence import _ensure_object_loader try: import tblib @@ -38,7 +39,16 @@ from . import exceptions, futures, persistence, utils from .base import state_machine as st from .lang import NULL -from .persistence import META__OBJECT_LOADER, META__TYPE__METHOD, META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_persist, auto_save +from .persistence import ( + META__OBJECT_LOADER, + META__TYPE__METHOD, + META__TYPE__SAVABLE, + LoadSaveContext, + Savable, + auto_load, + auto_persist, + auto_save, +) from .utils import SAVED_STATE_TYPE __all__ = [ @@ -136,14 +146,28 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + + obj.state_machine = load_context.process try: - self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) + obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN]) except ValueError: process = load_context.process - self.continue_fn = getattr(process, saved_state[self.CONTINUE_FN]) + obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN]) + return obj # endregion @@ -168,6 +192,7 @@ class ProcessState(Enum): # class Savable(Protocol): # def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... + @final @auto_persist('args', 'kwargs') class Created(persistence.Savable): @@ -190,11 +215,27 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.process = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + + auto_load(obj, saved_state, load_context) + + obj.process = load_context.process + + obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) + + return obj def execute(self) -> st.State: return st.create_state( @@ -245,13 +286,28 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.process = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + + obj.process = load_context.process - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - if self.COMMAND in saved_state: - self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore + obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) + if obj.COMMAND in saved_state: + # FIXME: typing + obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore + return obj def interrupt(self, reason: Any) -> None: pass @@ -368,16 +424,30 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.process = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + + obj.process = load_context.process - callback_name = saved_state.get(self.DONE_CALLBACK, None) + callback_name = saved_state.get(obj.DONE_CALLBACK, None) if callback_name is not None: - self.done_callback = getattr(self.process, callback_name) + obj.done_callback = getattr(obj.process, callback_name) else: - self.done_callback = None - self._waiting_future = futures.Future() + obj.done_callback = None + obj._waiting_future = futures.Future() + return obj def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception @@ -459,17 +529,30 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) - self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) + obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: - self.traceback = tblib.Traceback.from_string(saved_state[self.TRACEBACK], strict=False) + obj.traceback = tblib.Traceback.from_string(saved_state[obj.TRACEBACK], strict=False) except KeyError: - self.traceback = None + obj.traceback = None else: - self.traceback = None + obj.traceback = None + return obj def get_exc_info( self, @@ -506,8 +589,21 @@ def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj def enter(self) -> None: ... @@ -537,8 +633,21 @@ def __init__(self, msg: Optional[MessageType]): """ self.msg = msg - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj def enter(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 61c2ff46..f1a3f1f7 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -9,7 +9,6 @@ import copy import enum import functools -import inspect import logging import re import sys @@ -34,7 +33,7 @@ cast, ) -from plumpy import loaders +from plumpy.persistence import _ensure_object_loader try: from aiocontextvars import ContextVar @@ -277,9 +276,12 @@ def recreate_from( :return: An instance of the object with its state loaded from the save state. """ - process = cast(Process, super().recreate_from(saved_state, load_context)) - call_with_super_check(process.init) - return process + load_context = _ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + proc.load_instance_state(saved_state, load_context) + + call_with_super_check(proc.init) + return proc def __init__( self, @@ -660,7 +662,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi """ # First make sure the state machine constructor is called - super().__init__() + state_machine.StateMachine.__init__(self) self._setup_event_hooks() @@ -684,7 +686,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - super().load_instance_state(saved_state, load_context) + persistence.auto_load(self, saved_state, load_context) # Inputs/outputs try: diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 9741f7ed..cdb3b00e 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -16,6 +16,7 @@ Mapping, MutableSequence, Optional, + Protocol, Sequence, Tuple, Type, @@ -26,11 +27,14 @@ import kiwipy from plumpy.base import state_machine +from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError +from plumpy.process_listener import ProcessListener -from . import lang, mixins, persistence, process_states, processes -from .utils import PID_TYPE, SAVED_STATE_TYPE +from . import lang, persistence, process_states, processes +from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict from plumpy import loaders, utils +from plumpy.persistence import _ensure_object_loader __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] @@ -104,18 +108,15 @@ def enter(self) -> None: for awaitable in self._awaiting: awaitable.add_done_callback(self._awaitable_done) - self.in_state = True - def exit(self) -> None: if self.is_terminal: raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - self.in_state = False for awaitable in self._awaiting: awaitable.remove_done_callback(self._awaitable_done) -class WorkChain(mixins.ContextMixin, processes.Process): +class WorkChain(processes.Process): """ A WorkChain is a series of instructions carried out with the ability to save state in between. @@ -123,7 +124,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _spec_class = WorkChainSpec _STEPPER_STATE = 'stepper_state' - _CONTEXT = 'CONTEXT' + CONTEXT = 'CONTEXT' @classmethod def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: @@ -140,9 +141,14 @@ def __init__( communicator: Optional[kiwipy.Communicator] = None, ) -> None: super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) + self._context: Optional[AttributesDict] = AttributesDict() self._stepper: Optional[Stepper] = None self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {} + @property + def ctx(self) -> Optional[AttributesDict]: + return self._context + @classmethod def spec(cls) -> WorkChainSpec: return cast(WorkChainSpec, super().spec()) @@ -215,7 +221,63 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + ######### + # FIXME: dup of Process.load_instance_state + state_machine.StateMachine.__init__(self) + + self._setup_event_hooks() + + # Runtime variables, set initial states + self._future = persistence.SavableFuture() + self._event_helper = EventHelper(ProcessListener) + self._logger = None + self._communicator = None + + if 'loop' in load_context: + self._loop = load_context.loop + else: + self._loop = asyncio.get_event_loop() + + self._state: state_machine.State = self.recreate_state(saved_state['_state']) + + if 'communicator' in load_context: + self._communicator = load_context.communicator + + if 'logger' in load_context: + self._logger = load_context.logger + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.auto_load(self, saved_state, load_context) + + # Inputs/outputs + try: + decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) + self._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + self._raw_inputs = None + + try: + decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) + self._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + self._parsed_inputs = None + + try: + decoded = self.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) + self._outputs = decoded + except KeyError: + self._outputs = {} + + # + ######### + + # context mixin + try: + self._context = AttributesDict(**saved_state[self.CONTEXT]) + except KeyError: + pass + + # end of context mixin # Recreate the stepper self._stepper = None @@ -258,15 +320,8 @@ def _do_step(self) -> Any: return return_value -class Stepper(persistence.Savable, metaclass=abc.ABCMeta): - def __init__(self, workchain: 'WorkChain') -> None: - self._workchain = workchain - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._workchain = load_context.workchain - - @abc.abstractmethod +# XXX: Stepper is also a Saver with `save` method. +class Stepper(Protocol): def step(self) -> Tuple[bool, Any]: """ Execute on step of the instructions. @@ -275,6 +330,7 @@ def step(self) -> Tuple[bool, Any]: 1. The return value from the executed step """ + ... class _Instruction(metaclass=abc.ABCMeta): @@ -304,9 +360,9 @@ def get_description(self) -> Any: """ -class _FunctionStepper(Stepper): +class _FunctionStepper(persistence.Savable): def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): - super().__init__(workchain) + self._workchain = workchain self._fn = fn def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -315,9 +371,24 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._fn = getattr(self._workchain.__class__, saved_state['_fn']) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._fn = getattr(obj._workchain.__class__, saved_state['_fn']) + + return obj def step(self) -> Tuple[bool, Any]: return True, self._fn(self._workchain) @@ -357,9 +428,9 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') -class _BlockStepper(Stepper): +class _BlockStepper(persistence.Savable): def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._block = block self._pos: int = 0 self._child_stepper: Optional[Stepper] = self._block[0].create_stepper(self._workchain) @@ -391,13 +462,28 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._block = load_context.block_instruction + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._block = load_context.block_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._block[self._pos].recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._block[obj._pos].recreate_stepper(stepper_state, obj._workchain) + + return obj def __str__(self) -> str: return str(self._pos) + ':' + str(self._child_stepper) @@ -490,9 +576,9 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') -class _IfStepper(Stepper): +class _IfStepper(persistence.Savable): def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._if_instruction = if_instruction self._pos = 0 self._child_stepper: Optional[Stepper] = None @@ -531,13 +617,27 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._if_instruction = load_context.if_instruction + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._if_instruction = load_context.if_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._if_instruction[self._pos].body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._if_instruction[obj._pos].body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._if_instruction[self._pos]) @@ -599,9 +699,9 @@ def get_description(self) -> Mapping[str, Any]: return description -class _WhileStepper(Stepper): +class _WhileStepper(persistence.Savable): def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._while_instruction = while_instruction self._child_stepper: Optional[_BlockStepper] = None @@ -628,13 +728,27 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._while_instruction = load_context.while_instruction + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._while_instruction = load_context.while_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._while_instruction.body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._while_instruction.body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._while_instruction) @@ -672,9 +786,9 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code -class _ReturnStepper(Stepper): +class _ReturnStepper(persistence.Savable): def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._return_instruction = return_instruction def step(self) -> Tuple[bool, Any]: