Skip to content

Commit

Permalink
support save/load collection data in auto_persist
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 22, 2025
1 parent fc13e79 commit 6bfc5c9
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVE
META__TYPES: str = 'types'
META__TYPE__METHOD: str = 'm'
META__TYPE__SAVABLE: str = 'S'
META__TYPE__COLLECTION: str = 'coll'


class SaveUtil:
Expand Down Expand Up @@ -516,6 +517,7 @@ def auto_save(obj: Savable, loader: loaders.ObjectLoader | None = None) -> SAVED
# Save object class name
SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__))

# FIXME: it should be an iter call to save until all resolved
if isinstance(obj, SavableWithAutoPersist):
for member in obj._auto_persist:
value = getattr(obj, member)
Expand All @@ -530,6 +532,16 @@ def auto_save(obj: Savable, loader: loaders.ObjectLoader | None = None) -> SAVED
# of lhs condition.
SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save()
elif isinstance(value, (set, list)):
SaveUtil.set_meta_type(out_state, member, META__TYPE__COLLECTION)
value = [v.save() for v in value if isinstance(v, Savable)]
value_ = []
for v in value:
if isinstance(v, Savable):
value_.append(v.save())
else:
value_.append(copy.deepcopy(v))
value = value_
else:
value = copy.deepcopy(value)
out_state[member] = value
Expand All @@ -548,6 +560,8 @@ def load_auto_persist_params(
value = getattr(obj, value)
elif typ == META__TYPE__SAVABLE:
value = load(value, load_context)
elif typ == META__TYPE__COLLECTION:
value = [load(v, load_context) for v in value]

setattr(obj, member, value)

Expand Down

0 comments on commit 6bfc5c9

Please sign in to comment.