From c910d62838fd3ac494b1fa9f0c806eccaf5e8771 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 22:24:38 +0100 Subject: [PATCH] Absorb all load_instance_state into recreate_from --- src/plumpy/persistence.py | 23 ++++---- src/plumpy/processes.py | 112 +++++++++++++++++--------------------- src/plumpy/workchains.py | 80 ++++++++++++++++----------- 3 files changed, 110 insertions(+), 105 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index d33afaa1..13a21c61 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -477,12 +477,9 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = _ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) - obj.load_instance_state(saved_state, load_context) + auto_load(obj, saved_state, load_context) return obj - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: - auto_load(self, saved_state, load_context) - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = auto_save(self, save_context) @@ -599,15 +596,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa obj = cls(loop=loop) obj.cancel() - return obj - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - auto_load(self, saved_state, load_context) + # ## XXX: load_instance_state: test not cover + # auto_load(obj, saved_state, load_context) + # + # if obj._callbacks: + # # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list + # for callback in obj._callbacks: + # obj.remove_done_callback(callback) # type: ignore[arg-type] + # ## UNTILHERE XXX: - if self._callbacks: - # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list - for callback in self._callbacks: - self.remove_done_callback(callback) # type: ignore[arg-type] + return obj def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -647,6 +645,7 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S return out_state + def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: obj._ensure_persist_configured() if obj._auto_persist is not None: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index f1a3f1f7..96689024 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -84,7 +84,7 @@ class BundleKeys: """ String keys used by the process to save its state in the state bundle. - See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.load_instance_state`. + See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.recreate_from`. """ @@ -266,10 +266,8 @@ def recreate_from( cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None, - ) -> 'Process': - """ - Recreate a process from a saved state, passing any positional and - keyword arguments on to load_instance_state + ) -> Process: + """Recreate a process from a saved state, passing any positional :param saved_state: The saved state to load from :param load_context: The load context to use @@ -278,7 +276,53 @@ def recreate_from( """ load_context = _ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) - proc.load_instance_state(saved_state, load_context) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) + + proc._setup_event_hooks() + + # Runtime variables, set initial states + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._communicator = None + + if 'loop' in load_context: + proc._loop = load_context.loop + else: + proc._loop = asyncio.get_event_loop() + + proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + + if 'communicator' in load_context: + proc._communicator = load_context.communicator + + if 'logger' in load_context: + proc._logger = load_context.logger + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.auto_load(proc, saved_state, load_context) + + # Inputs/outputs + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._raw_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._parsed_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.OUTPUTS]) + proc._outputs = decoded + except KeyError: + proc._outputs = {} call_with_super_check(proc.init) return proc @@ -653,62 +697,6 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - @protected - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - """Load the process from its saved instance state. - - :param saved_state: A bundle to load the state from - :param load_context: The load context - - """ - # First make sure the state machine constructor is called - state_machine.StateMachine.__init__(self) - - self._setup_event_hooks() - - # Runtime variables, set initial states - self._future = persistence.SavableFuture() - self._event_helper = EventHelper(ProcessListener) - self._logger = None - self._communicator = None - - if 'loop' in load_context: - self._loop = load_context.loop - else: - self._loop = asyncio.get_event_loop() - - self._state: state_machine.State = self.recreate_state(saved_state['_state']) - - if 'communicator' in load_context: - self._communicator = load_context.communicator - - if 'logger' in load_context: - self._logger = load_context.logger - - # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(self, saved_state, load_context) - - # Inputs/outputs - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) - self._raw_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._raw_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) - self._parsed_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._parsed_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.OUTPUTS]) - self._outputs = decoded - except KeyError: - self._outputs = {} - - # endregion - def add_process_listener(self, listener: ProcessListener) -> None: """Add a process listener to the process. diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index cdb3b00e..cf7ad81f 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -27,6 +27,7 @@ import kiwipy from plumpy.base import state_machine +from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError from plumpy.process_listener import ProcessListener @@ -220,70 +221,87 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - ######### - # FIXME: dup of Process.load_instance_state - state_machine.StateMachine.__init__(self) + @classmethod + def recreate_from( + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, + ) -> WorkChain: + """Recreate a workchain from a saved state, passing any positional + + :param saved_state: The saved state to load from + :param load_context: The load context to use + :return: An instance of the object with its state loaded from the save state. + + """ + ### FIXME: dup from process.create_from + load_context = _ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) - self._setup_event_hooks() + proc._setup_event_hooks() # Runtime variables, set initial states - self._future = persistence.SavableFuture() - self._event_helper = EventHelper(ProcessListener) - self._logger = None - self._communicator = None + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._communicator = None if 'loop' in load_context: - self._loop = load_context.loop + proc._loop = load_context.loop else: - self._loop = asyncio.get_event_loop() + proc._loop = asyncio.get_event_loop() - self._state: state_machine.State = self.recreate_state(saved_state['_state']) + proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: - self._communicator = load_context.communicator + proc._communicator = load_context.communicator if 'logger' in load_context: - self._logger = load_context.logger + proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(self, saved_state, load_context) + persistence.auto_load(proc, saved_state, load_context) # Inputs/outputs try: - decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) - self._raw_inputs = utils.AttributesFrozendict(decoded) + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) except KeyError: - self._raw_inputs = None + proc._raw_inputs = None try: - decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) - self._parsed_inputs = utils.AttributesFrozendict(decoded) + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) except KeyError: - self._parsed_inputs = None + proc._parsed_inputs = None try: - decoded = self.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) - self._outputs = decoded + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) + proc._outputs = decoded except KeyError: - self._outputs = {} - - # - ######### + proc._outputs = {} + ### UNTILHERE FIXME: dup from process.create_from # context mixin try: - self._context = AttributesDict(**saved_state[self.CONTEXT]) + proc._context = AttributesDict(**saved_state[proc.CONTEXT]) except KeyError: pass # end of context mixin # Recreate the stepper - self._stepper = None - stepper_state = saved_state.get(self._STEPPER_STATE, None) + proc._stepper = None + stepper_state = saved_state.get(proc._STEPPER_STATE, None) if stepper_state is not None: - self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) + proc._stepper = proc.spec().get_outline().recreate_stepper(stepper_state, proc) + + call_with_super_check(proc.init) + return proc def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None: """