Skip to content

Commit

Permalink
Make auto_load symmetry with auto_save and state/state_label distinguish
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 11, 2024
1 parent 9b9a5b7 commit 3e6a2dd
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 145 deletions.
10 changes: 8 additions & 2 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,13 @@ def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)

@property
def state(self) -> Any:
def state(self) -> State | None:
if self._state is None:
return None
return self._state

@property
def state_label(self) -> Any:
if self._state is None:
return None
return self._state.LABEL
Expand Down Expand Up @@ -312,7 +318,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
if new_state is None:
return None

initial_state_label = self._state.LABEL if self._state is not None else None
initial_state_label = self.state_label
label = None
try:
self._transitioning = True
Expand Down
3 changes: 1 addition & 2 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
Expand Down
19 changes: 18 additions & 1 deletion src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
List,
Optional,
Protocol,
TypeVar,
cast,
runtime_checkable,
)
Expand Down Expand Up @@ -535,6 +536,8 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
value = value.__name__
elif isinstance(value, Savable) and not isinstance(value, type):
# persist for a savable obj, call `save` method of obj.
# the rhs branch is for when value is a Savable class, it is true runtime check
# of lhs condition.
SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save()
else:
Expand All @@ -544,11 +547,25 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
return out_state


def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
def load_auto_persist_params(
obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None
) -> None:
for member in obj._auto_persist:
setattr(obj, member, _get_value(obj, saved_state, member, load_context))


T = TypeVar('T', bound=Savable)


def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T:
obj = cls.__new__(cls)

if isinstance(obj, SavableWithAutoPersist):
load_auto_persist_params(obj, saved_state, load_context)

return obj


def _get_value(
obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None
) -> MethodType | Savable:
Expand Down
42 changes: 15 additions & 27 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import yaml
from yaml.loader import Loader

from plumpy.persistence import ensure_object_loader
from plumpy.process_comms import KillMessage, MessageType

try:
Expand All @@ -41,6 +40,7 @@
auto_load,
auto_persist,
auto_save,
ensure_object_loader,
)
from .utils import SAVED_STATE_TYPE

Expand Down Expand Up @@ -98,8 +98,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
:return: The recreated instance
"""
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
load_context = ensure_object_loader(load_context, saved_state)
obj = auto_load(cls, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
Expand Down Expand Up @@ -171,15 +171,15 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)

obj.state_machine = load_context.process
try:
obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN])
except ValueError:
process = load_context.process
obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN])
if load_context is not None:
obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN])
else:
raise
return obj


Expand Down Expand Up @@ -235,12 +235,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)

auto_load(obj, saved_state, load_context)

obj = auto_load(cls, saved_state, load_context)
obj.process = load_context.process

obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN])

return obj
Expand Down Expand Up @@ -306,15 +302,12 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)

obj = auto_load(cls, saved_state, load_context)
obj.process = load_context.process

obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN])
if obj.COMMAND in saved_state:
# FIXME: typing
obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore

return obj

def interrupt(self, reason: Any) -> None:
Expand Down Expand Up @@ -444,9 +437,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)

obj = auto_load(cls, saved_state, load_context)
obj.process = load_context.process

callback_name = saved_state.get(obj.DONE_CALLBACK, None)
Expand Down Expand Up @@ -550,8 +541,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)

obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader)
if _HAS_TBLIB:
Expand Down Expand Up @@ -610,8 +600,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
Expand Down Expand Up @@ -659,8 +648,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
Expand Down
Loading

0 comments on commit 3e6a2dd

Please sign in to comment.