Skip to content

Commit

Permalink
WIP: load_instance_state deabstract simplify
Browse files Browse the repository at this point in the history
- stepper de-abstract
- remove ContextMixin
- Stepper all using recreate_from
  • Loading branch information
unkcpz committed Dec 9, 2024
1 parent 304f3ba commit ce6beae
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 125 deletions.
2 changes: 0 additions & 2 deletions src/plumpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .exceptions import *
from .futures import *
from .loaders import *
from .mixins import *
from .persistence import *
from .ports import *
from .process_comms import *
Expand All @@ -25,7 +24,6 @@
+ processes.__all__
+ utils.__all__
+ futures.__all__
+ mixins.__all__
+ persistence.__all__
+ communications.__all__
+ process_comms.__all__
Expand Down
31 changes: 0 additions & 31 deletions src/plumpy/mixins.py

This file was deleted.

17 changes: 10 additions & 7 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,15 +477,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
call_with_super_check(obj.load_instance_state, saved_state, load_context)
obj.load_instance_state(saved_state, load_context)
return obj

@super_check
def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None:
self._ensure_persist_configured()
if self._auto_persist is not None:
for member in self._auto_persist:
setattr(self, member, self._get_value(saved_state, member, load_context))
auto_load(self, saved_state, load_context)

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
out_state: SAVED_STATE_TYPE = auto_save(self, save_context)
Expand Down Expand Up @@ -606,7 +602,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
return obj

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
auto_load(self, saved_state, load_context)

if self._callbacks:
# typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list
for callback in self._callbacks:
Expand Down Expand Up @@ -649,3 +646,9 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
out_state[member] = value

return out_state

def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
obj._ensure_persist_configured()
if obj._auto_persist is not None:
for member in obj._auto_persist:
setattr(obj, member, obj._get_value(saved_state, member, load_context))
175 changes: 142 additions & 33 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from plumpy import loaders
from plumpy.process_comms import KillMessage, MessageType
from plumpy.persistence import _ensure_object_loader

try:
import tblib
Expand All @@ -38,7 +39,16 @@
from . import exceptions, futures, persistence, utils
from .base import state_machine as st
from .lang import NULL
from .persistence import META__OBJECT_LOADER, META__TYPE__METHOD, META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_persist, auto_save
from .persistence import (
META__OBJECT_LOADER,
META__TYPE__METHOD,
META__TYPE__SAVABLE,
LoadSaveContext,
Savable,
auto_load,
auto_persist,
auto_save,
)
from .utils import SAVED_STATE_TYPE

__all__ = [
Expand Down Expand Up @@ -136,14 +146,28 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)

obj.state_machine = load_context.process
try:
self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN])
obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN])
except ValueError:
process = load_context.process
self.continue_fn = getattr(process, saved_state[self.CONTINUE_FN])
obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN])
return obj


# endregion
Expand All @@ -168,6 +192,7 @@ class ProcessState(Enum):
# class Savable(Protocol):
# def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ...


@final
@auto_persist('args', 'kwargs')
class Created(persistence.Savable):
Expand All @@ -190,11 +215,27 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
self.run_fn = getattr(self.process, saved_state[self.RUN_FN])
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)

auto_load(obj, saved_state, load_context)

obj.process = load_context.process

obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN])

return obj

def execute(self) -> st.State:
return st.create_state(
Expand Down Expand Up @@ -245,13 +286,28 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)

obj.process = load_context.process

self.run_fn = getattr(self.process, saved_state[self.RUN_FN])
if self.COMMAND in saved_state:
self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore
obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN])
if obj.COMMAND in saved_state:
# FIXME: typing
obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore
return obj

def interrupt(self, reason: Any) -> None:
pass
Expand Down Expand Up @@ -368,16 +424,30 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)

obj.process = load_context.process

callback_name = saved_state.get(self.DONE_CALLBACK, None)
callback_name = saved_state.get(obj.DONE_CALLBACK, None)
if callback_name is not None:
self.done_callback = getattr(self.process, callback_name)
obj.done_callback = getattr(obj.process, callback_name)
else:
self.done_callback = None
self._waiting_future = futures.Future()
obj.done_callback = None
obj._waiting_future = futures.Future()
return obj

def interrupt(self, reason: Exception) -> None:
# This will cause the future in execute() to raise the exception
Expand Down Expand Up @@ -459,17 +529,30 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY

return out_state

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)

self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader)
obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader)
if _HAS_TBLIB:
try:
self.traceback = tblib.Traceback.from_string(saved_state[self.TRACEBACK], strict=False)
obj.traceback = tblib.Traceback.from_string(saved_state[obj.TRACEBACK], strict=False)
except KeyError:
self.traceback = None
obj.traceback = None
else:
self.traceback = None
obj.traceback = None
return obj

def get_exc_info(
self,
Expand Down Expand Up @@ -506,8 +589,21 @@ def __init__(self, result: Any, successful: bool) -> None:
self.result = result
self.successful = successful

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj

def enter(self) -> None: ...

Expand Down Expand Up @@ -537,8 +633,21 @@ def __init__(self, msg: Optional[MessageType]):
"""
self.msg = msg

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
:param saved_state: The saved state
:param load_context: An optional load context
:return: The recreated instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
return obj

def enter(self) -> None: ...

Expand Down
16 changes: 9 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import copy
import enum
import functools
import inspect
import logging
import re
import sys
Expand All @@ -34,7 +33,7 @@
cast,
)

from plumpy import loaders
from plumpy.persistence import _ensure_object_loader

try:
from aiocontextvars import ContextVar
Expand Down Expand Up @@ -277,9 +276,12 @@ def recreate_from(
:return: An instance of the object with its state loaded from the save state.
"""
process = cast(Process, super().recreate_from(saved_state, load_context))
call_with_super_check(process.init)
return process
load_context = _ensure_object_loader(load_context, saved_state)
proc = cls.__new__(cls)
proc.load_instance_state(saved_state, load_context)

call_with_super_check(proc.init)
return proc

def __init__(
self,
Expand Down Expand Up @@ -660,7 +662,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
"""
# First make sure the state machine constructor is called
super().__init__()
state_machine.StateMachine.__init__(self)

self._setup_event_hooks()

Expand All @@ -684,7 +686,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
self._logger = load_context.logger

# Need to call this here as things downstream may rely on us having the runtime variable above
super().load_instance_state(saved_state, load_context)
persistence.auto_load(self, saved_state, load_context)

# Inputs/outputs
try:
Expand Down
Loading

0 comments on commit ce6beae

Please sign in to comment.