Skip to content

Commit

Permalink
Absorb all load_instance_state into recreate_from
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 9, 2024
1 parent 484ae87 commit c910d62
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 105 deletions.
23 changes: 11 additions & 12 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,9 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
obj.load_instance_state(saved_state, load_context)
auto_load(obj, saved_state, load_context)
return obj

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None:
auto_load(self, saved_state, load_context)

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)

Expand Down Expand Up @@ -599,15 +596,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
obj = cls(loop=loop)
obj.cancel()

return obj

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
auto_load(self, saved_state, load_context)
# ## XXX: load_instance_state: test not cover
# auto_load(obj, saved_state, load_context)
#
# if obj._callbacks:
# # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list
# for callback in obj._callbacks:
# obj.remove_done_callback(callback) # type: ignore[arg-type]
# ## UNTILHERE XXX:

if self._callbacks:
# typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list
for callback in self._callbacks:
self.remove_done_callback(callback) # type: ignore[arg-type]
return obj


def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
Expand Down Expand Up @@ -647,6 +645,7 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S

return out_state


def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
obj._ensure_persist_configured()
if obj._auto_persist is not None:
Expand Down
112 changes: 50 additions & 62 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class BundleKeys:
"""
String keys used by the process to save its state in the state bundle.
See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.load_instance_state`.
See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.recreate_from`.
"""

Expand Down Expand Up @@ -266,10 +266,8 @@ def recreate_from(
cls,
saved_state: SAVED_STATE_TYPE,
load_context: Optional[persistence.LoadSaveContext] = None,
) -> 'Process':
"""
Recreate a process from a saved state, passing any positional and
keyword arguments on to load_instance_state
) -> Process:
"""Recreate a process from a saved state, passing any positional
:param saved_state: The saved state to load from
:param load_context: The load context to use
Expand All @@ -278,7 +276,53 @@ def recreate_from(
"""
load_context = _ensure_object_loader(load_context, saved_state)
proc = cls.__new__(cls)
proc.load_instance_state(saved_state, load_context)

# XXX: load_instance_state
# First make sure the state machine constructor is called
state_machine.StateMachine.__init__(proc)

proc._setup_event_hooks()

# Runtime variables, set initial states
proc._future = persistence.SavableFuture()
proc._event_helper = EventHelper(ProcessListener)
proc._logger = None
proc._communicator = None

if 'loop' in load_context:
proc._loop = load_context.loop
else:
proc._loop = asyncio.get_event_loop()

proc._state: state_machine.State = proc.recreate_state(saved_state['_state'])

if 'communicator' in load_context:
proc._communicator = load_context.communicator

if 'logger' in load_context:
proc._logger = load_context.logger

# Need to call this here as things downstream may rely on us having the runtime variable above
persistence.auto_load(proc, saved_state, load_context)

# Inputs/outputs
try:
decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_RAW])
proc._raw_inputs = utils.AttributesFrozendict(decoded)
except KeyError:
proc._raw_inputs = None

try:
decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED])
proc._parsed_inputs = utils.AttributesFrozendict(decoded)
except KeyError:
proc._parsed_inputs = None

try:
decoded = proc.decode_input_args(saved_state[BundleKeys.OUTPUTS])
proc._outputs = decoded
except KeyError:
proc._outputs = {}

call_with_super_check(proc.init)
return proc
Expand Down Expand Up @@ -653,62 +697,6 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA

return out_state

@protected
def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
"""Load the process from its saved instance state.
:param saved_state: A bundle to load the state from
:param load_context: The load context
"""
# First make sure the state machine constructor is called
state_machine.StateMachine.__init__(self)

self._setup_event_hooks()

# Runtime variables, set initial states
self._future = persistence.SavableFuture()
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None

if 'loop' in load_context:
self._loop = load_context.loop
else:
self._loop = asyncio.get_event_loop()

self._state: state_machine.State = self.recreate_state(saved_state['_state'])

if 'communicator' in load_context:
self._communicator = load_context.communicator

if 'logger' in load_context:
self._logger = load_context.logger

# Need to call this here as things downstream may rely on us having the runtime variable above
persistence.auto_load(self, saved_state, load_context)

# Inputs/outputs
try:
decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_RAW])
self._raw_inputs = utils.AttributesFrozendict(decoded)
except KeyError:
self._raw_inputs = None

try:
decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED])
self._parsed_inputs = utils.AttributesFrozendict(decoded)
except KeyError:
self._parsed_inputs = None

try:
decoded = self.decode_input_args(saved_state[BundleKeys.OUTPUTS])
self._outputs = decoded
except KeyError:
self._outputs = {}

# endregion

def add_process_listener(self, listener: ProcessListener) -> None:
"""Add a process listener to the process.
Expand Down
80 changes: 49 additions & 31 deletions src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import kiwipy

from plumpy.base import state_machine
from plumpy.base.utils import call_with_super_check
from plumpy.event_helper import EventHelper
from plumpy.exceptions import InvalidStateError
from plumpy.process_listener import ProcessListener
Expand Down Expand Up @@ -220,70 +221,87 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
#########
# FIXME: dup of Process.load_instance_state
state_machine.StateMachine.__init__(self)
@classmethod
def recreate_from(
cls,
saved_state: SAVED_STATE_TYPE,
load_context: Optional[persistence.LoadSaveContext] = None,
) -> WorkChain:
"""Recreate a workchain from a saved state, passing any positional
:param saved_state: The saved state to load from
:param load_context: The load context to use
:return: An instance of the object with its state loaded from the save state.
"""
### FIXME: dup from process.create_from
load_context = _ensure_object_loader(load_context, saved_state)
proc = cls.__new__(cls)

# XXX: load_instance_state
# First make sure the state machine constructor is called
state_machine.StateMachine.__init__(proc)

self._setup_event_hooks()
proc._setup_event_hooks()

# Runtime variables, set initial states
self._future = persistence.SavableFuture()
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None
proc._future = persistence.SavableFuture()
proc._event_helper = EventHelper(ProcessListener)
proc._logger = None
proc._communicator = None

if 'loop' in load_context:
self._loop = load_context.loop
proc._loop = load_context.loop
else:
self._loop = asyncio.get_event_loop()
proc._loop = asyncio.get_event_loop()

self._state: state_machine.State = self.recreate_state(saved_state['_state'])
proc._state: state_machine.State = proc.recreate_state(saved_state['_state'])

if 'communicator' in load_context:
self._communicator = load_context.communicator
proc._communicator = load_context.communicator

if 'logger' in load_context:
self._logger = load_context.logger
proc._logger = load_context.logger

# Need to call this here as things downstream may rely on us having the runtime variable above
persistence.auto_load(self, saved_state, load_context)
persistence.auto_load(proc, saved_state, load_context)

# Inputs/outputs
try:
decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW])
self._raw_inputs = utils.AttributesFrozendict(decoded)
decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW])
proc._raw_inputs = utils.AttributesFrozendict(decoded)
except KeyError:
self._raw_inputs = None
proc._raw_inputs = None

try:
decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED])
self._parsed_inputs = utils.AttributesFrozendict(decoded)
decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED])
proc._parsed_inputs = utils.AttributesFrozendict(decoded)
except KeyError:
self._parsed_inputs = None
proc._parsed_inputs = None

try:
decoded = self.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS])
self._outputs = decoded
decoded = proc.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS])
proc._outputs = decoded
except KeyError:
self._outputs = {}

#
#########
proc._outputs = {}
### UNTILHERE FIXME: dup from process.create_from

# context mixin
try:
self._context = AttributesDict(**saved_state[self.CONTEXT])
proc._context = AttributesDict(**saved_state[proc.CONTEXT])
except KeyError:
pass

# end of context mixin

# Recreate the stepper
self._stepper = None
stepper_state = saved_state.get(self._STEPPER_STATE, None)
proc._stepper = None
stepper_state = saved_state.get(proc._STEPPER_STATE, None)
if stepper_state is not None:
self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self)
proc._stepper = proc.spec().get_outline().recreate_stepper(stepper_state, proc)

call_with_super_check(proc.init)
return proc

def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None:
"""
Expand Down

0 comments on commit c910d62

Please sign in to comment.