Skip to content

Commit

Permalink
Move static method load outside
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 4, 2024
1 parent ef964ed commit 57a13e4
Showing 1 changed file with 50 additions and 50 deletions.
100 changes: 50 additions & 50 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,33 @@
from .processes import Process


class LoadSaveContext:
def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None:
self._values = dict(**kwargs)
self.loader = loader

def __getattr__(self, item: str) -> Any:
try:
return self._values[item]
except KeyError:
raise AttributeError(f"item '{item}' not found")

def __iter__(self) -> Iterable[Any]:
return self._value.__iter__()

def __contains__(self, item: Any) -> bool:
return self._values.__contains__(item)

def copyextend(self, **kwargs: Any) -> 'LoadSaveContext':
"""Add additional information to the context by making a copy with the new values"""
extended = self._values.copy()
extended.update(kwargs)
loader = extended.pop('loader', self.loader)
return LoadSaveContext(loader=loader, **extended)


class Bundle(dict):
def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False):
def __init__(self, savable: 'Savable', save_context: LoadSaveContext | None = None, dereference: bool = False):
"""
Create a bundle from a savable. Optionally keep information about the
class loader that can be used to load the classes in the bundle.
Expand All @@ -52,7 +77,7 @@ class loader that can be used to load the classes in the bundle.
else:
self.update(savable.save(save_context))

def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable':
def unbundle(self, load_context: LoadSaveContext | None = None) -> 'Savable':
"""
This method loads the class of the object and calls its recreate_from
method passing the positional and keyword arguments.
Expand All @@ -61,7 +86,29 @@ def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable
:return: An instance of the Savable
"""
return Savable.load(self, load_context)
return load(self, load_context)


def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable':
"""
Load a `Savable` from a saved instance state. The load context is a way of passing
runtime data to the object being loaded.
:param saved_state: The saved state
:param load_context: Additional runtime state that can be passed into when loading.
The type and content (if any) is completely user defined
:return: The loaded Savable instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
assert load_context.loader is not None # required for type checking
try:
class_name = Savable._get_class_name(saved_state)
load_cls = load_context.loader.load_object(class_name)
except KeyError:
raise ValueError('Class name not found in saved state')
else:
return load_cls.recreate_from(saved_state, load_context)


_BUNDLE_TAG = '!plumpy:Bundle'
Expand Down Expand Up @@ -392,31 +439,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV
return context.copyextend(loader=loader)


class LoadSaveContext:
def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None:
self._values = dict(**kwargs)
self.loader = loader

def __getattr__(self, item: str) -> Any:
try:
return self._values[item]
except KeyError:
raise AttributeError(f"item '{item}' not found")

def __iter__(self) -> Iterable[Any]:
return self._value.__iter__()

def __contains__(self, item: Any) -> bool:
return self._values.__contains__(item)

def copyextend(self, **kwargs: Any) -> 'LoadSaveContext':
"""Add additional information to the context by making a copy with the new values"""
extended = self._values.copy()
extended.update(kwargs)
loader = extended.pop('loader', self.loader)
return LoadSaveContext(loader=loader, **extended)


META: str = '!!meta'
META__CLASS_NAME: str = 'class_name'
META__OBJECT_LOADER: str = 'object_loader'
Expand Down Expand Up @@ -465,28 +487,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio
for member in self._auto_persist:
setattr(self, member, self._get_value(saved_state, member, load_context))

@staticmethod
def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
"""
Load a `Savable` from a saved instance state. The load context is a way of passing
runtime data to the object being loaded.
:param saved_state: The saved state
:param load_context: Additional runtime state that can be passed into when loading.
The type and content (if any) is completely user defined
:return: The loaded Savable instance
"""
load_context = _ensure_object_loader(load_context, saved_state)
assert load_context.loader is not None # required for type checking
try:
class_name = Savable._get_class_name(saved_state)
load_cls = load_context.loader.load_object(class_name)
except KeyError:
raise ValueError('Class name not found in saved state')
else:
return load_cls.recreate_from(saved_state, load_context)

@super_check
def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None:
self._ensure_persist_configured()
Expand Down

0 comments on commit 57a13e4

Please sign in to comment.