From 937ad012f0f55e2ec874d4d5dbbf34b11c35d66f Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 4 Dec 2024 23:33:13 +0100 Subject: [PATCH] Move static method load outside --- src/plumpy/persistence.py | 102 +++++++++++++++++++------------------- src/plumpy/processes.py | 2 +- 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index a1d083cb..6c3849a6 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -35,8 +35,33 @@ from .processes import Process +class LoadSaveContext: + def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: + self._values = dict(**kwargs) + self.loader = loader + + def __getattr__(self, item: str) -> Any: + try: + return self._values[item] + except KeyError: + raise AttributeError(f"item '{item}' not found") + + def __iter__(self) -> Iterable[Any]: + return self._value.__iter__() + + def __contains__(self, item: Any) -> bool: + return self._values.__contains__(item) + + def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': + """Add additional information to the context by making a copy with the new values""" + extended = self._values.copy() + extended.update(kwargs) + loader = extended.pop('loader', self.loader) + return LoadSaveContext(loader=loader, **extended) + + class Bundle(dict): - def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False): + def __init__(self, savable: 'Savable', save_context: LoadSaveContext | None = None, dereference: bool = False): """ Create a bundle from a savable. Optionally keep information about the class loader that can be used to load the classes in the bundle. @@ -52,7 +77,7 @@ class loader that can be used to load the classes in the bundle. else: self.update(savable.save(save_context)) - def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable': + def unbundle(self, load_context: LoadSaveContext | None = None) -> 'Savable': """ This method loads the class of the object and calls its recreate_from method passing the positional and keyword arguments. @@ -61,7 +86,29 @@ def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable :return: An instance of the Savable """ - return Savable.load(self, load_context) + return load(self, load_context) + + +def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': + """ + Load a `Savable` from a saved instance state. The load context is a way of passing + runtime data to the object being loaded. + + :param saved_state: The saved state + :param load_context: Additional runtime state that can be passed into when loading. + The type and content (if any) is completely user defined + :return: The loaded Savable instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + assert load_context.loader is not None # required for type checking + try: + class_name = Savable._get_class_name(saved_state) + load_cls = load_context.loader.load_object(class_name) + except KeyError: + raise ValueError('Class name not found in saved state') + else: + return load_cls.recreate_from(saved_state, load_context) _BUNDLE_TAG = '!plumpy:Bundle' @@ -392,31 +439,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV return context.copyextend(loader=loader) -class LoadSaveContext: - def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: - self._values = dict(**kwargs) - self.loader = loader - - def __getattr__(self, item: str) -> Any: - try: - return self._values[item] - except KeyError: - raise AttributeError(f"item '{item}' not found") - - def __iter__(self) -> Iterable[Any]: - return self._value.__iter__() - - def __contains__(self, item: Any) -> bool: - return self._values.__contains__(item) - - def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': - """Add additional information to the context by making a copy with the new values""" - extended = self._values.copy() - extended.update(kwargs) - loader = extended.pop('loader', self.loader) - return LoadSaveContext(loader=loader, **extended) - - META: str = '!!meta' META__CLASS_NAME: str = 'class_name' META__OBJECT_LOADER: str = 'object_loader' @@ -465,28 +487,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio for member in self._auto_persist: setattr(self, member, self._get_value(saved_state, member, load_context)) - @staticmethod - def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Load a `Savable` from a saved instance state. The load context is a way of passing - runtime data to the object being loaded. - - :param saved_state: The saved state - :param load_context: Additional runtime state that can be passed into when loading. - The type and content (if any) is completely user defined - :return: The loaded Savable instance - - """ - load_context = _ensure_object_loader(load_context, saved_state) - assert load_context.loader is not None # required for type checking - try: - class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) - except KeyError: - raise ValueError('Class name not found in saved state') - else: - return load_cls.recreate_from(saved_state, load_context) - @super_check def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: self._ensure_persist_configured() @@ -580,7 +580,7 @@ def _get_value( if typ == META__TYPE__METHOD: value = getattr(self, value) elif typ == META__TYPE__SAVABLE: - value = Savable.load(value, load_context) + value = load(value, load_context) return value diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index bae08dd4..56fdc570 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1266,7 +1266,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(state_machine.State, persistence.Savable.load(saved_state, load_context)) + return cast(state_machine.State, persistence.load(saved_state, load_context)) # endregion