Skip to content

Commit

Permalink
Move static method load outside
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 4, 2024
1 parent ef964ed commit 937ad01
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 52 deletions.
102 changes: 51 additions & 51 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,33 @@
from .processes import Process


class LoadSaveContext:
def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None:
self._values = dict(**kwargs)
self.loader = loader

def __getattr__(self, item: str) -> Any:
try:
return self._values[item]
except KeyError:
raise AttributeError(f"item '{item}' not found")

def __iter__(self) -> Iterable[Any]:
return self._value.__iter__()

def __contains__(self, item: Any) -> bool:
return self._values.__contains__(item)

def copyextend(self, **kwargs: Any) -> 'LoadSaveContext':
"""Add additional information to the context by making a copy with the new values"""
extended = self._values.copy()
extended.update(kwargs)
loader = extended.pop('loader', self.loader)
return LoadSaveContext(loader=loader, **extended)


class Bundle(dict):
def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False):
def __init__(self, savable: 'Savable', save_context: LoadSaveContext | None = None, dereference: bool = False):
"""
Create a bundle from a savable. Optionally keep information about the
class loader that can be used to load the classes in the bundle.
Expand All @@ -52,7 +77,7 @@ class loader that can be used to load the classes in the bundle.
else:
self.update(savable.save(save_context))

def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable':
def unbundle(self, load_context: LoadSaveContext | None = None) -> 'Savable':
"""
This method loads the class of the object and calls its recreate_from
method passing the positional and keyword arguments.
Expand All @@ -61,7 +86,29 @@ def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable
:return: An instance of the Savable
"""
return Savable.load(self, load_context)
return load(self, load_context)


def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable':
"""
Load a `Savable` from a saved instance state. The load context is a way of passing
runtime data to the object being loaded.
:param saved_state: The saved state
:param load_context: Additional runtime state that can be passed into when loading.
The type and content (if any) is completely user defined
:return: The loaded Savable instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
assert load_context.loader is not None # required for type checking
try:
class_name = Savable._get_class_name(saved_state)
load_cls = load_context.loader.load_object(class_name)
except KeyError:
raise ValueError('Class name not found in saved state')
else:
return load_cls.recreate_from(saved_state, load_context)


_BUNDLE_TAG = '!plumpy:Bundle'
Expand Down Expand Up @@ -392,31 +439,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
return context.copyextend(loader=loader)


class LoadSaveContext:
def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None:
self._values = dict(**kwargs)
self.loader = loader

def __getattr__(self, item: str) -> Any:
try:
return self._values[item]
except KeyError:
raise AttributeError(f"item '{item}' not found")

def __iter__(self) -> Iterable[Any]:
return self._value.__iter__()

def __contains__(self, item: Any) -> bool:
return self._values.__contains__(item)

def copyextend(self, **kwargs: Any) -> 'LoadSaveContext':
"""Add additional information to the context by making a copy with the new values"""
extended = self._values.copy()
extended.update(kwargs)
loader = extended.pop('loader', self.loader)
return LoadSaveContext(loader=loader, **extended)


META: str = '!!meta'
META__CLASS_NAME: str = 'class_name'
META__OBJECT_LOADER: str = 'object_loader'
Expand Down Expand Up @@ -465,28 +487,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio
for member in self._auto_persist:
setattr(self, member, self._get_value(saved_state, member, load_context))

@staticmethod
def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Load a `Savable` from a saved instance state. The load context is a way of passing
runtime data to the object being loaded.
:param saved_state: The saved state
:param load_context: Additional runtime state that can be passed into when loading.
The type and content (if any) is completely user defined
:return: The loaded Savable instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
assert load_context.loader is not None # required for type checking
try:
class_name = Savable._get_class_name(saved_state)
load_cls = load_context.loader.load_object(class_name)
except KeyError:
raise ValueError('Class name not found in saved state')
else:
return load_cls.recreate_from(saved_state, load_context)

@super_check
def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None:
self._ensure_persist_configured()
Expand Down Expand Up @@ -580,7 +580,7 @@ def _get_value(
if typ == META__TYPE__METHOD:
value = getattr(self, value)
elif typ == META__TYPE__SAVABLE:
value = Savable.load(value, load_context)
value = load(value, load_context)

return value

Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State
:return: An instance of the object with its state loaded from the save state.
"""
load_context = persistence.LoadSaveContext(process=self)
return cast(state_machine.State, persistence.Savable.load(saved_state, load_context))
return cast(state_machine.State, persistence.load(saved_state, load_context))

# endregion

Expand Down

0 comments on commit 937ad01

Please sign in to comment.