Skip to content

Commit

Permalink
accomodating new update_from_checkpoint functon
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 22, 2024
1 parent 3b3eec4 commit a73bdff
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 20 deletions.
15 changes: 6 additions & 9 deletions fedeca/algorithms/torch_dp_fed_avg_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,26 +371,23 @@ def _get_state_to_save(self) -> dict:

return checkpoint

def _update_from_checkpoint(self, path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Set self attributes using saved values.
Parameters
----------
path : Path
Path towards the checkpoint to use.
checkpoint : dict
Checkpoint to load.
Returns
-------
dict
The emptied checkpoint.
"""
# One cannot simply call checkpoint = super()._update_from_checkpoint(path)
# One cannot simply call checkpoint = super()._update_from_checkpoint(chkpt)
# because we have to change the model class if it should be changed
# (and optimizer) aka if we find a specific key in the checkpoint
assert (
path.is_file()
), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
checkpoint = torch.load(path, map_location=self._device)

# For some reason substrafl save and load client before calling train
if "privacy_accountant_state_dict" in checkpoint:
self.accountant = RDPAccountant()
Expand Down Expand Up @@ -447,4 +444,4 @@ def _update_from_checkpoint(self, path) -> dict:
for attr in attr_names:
setattr(self, attr, checkpoint.pop(attr))

return checkpoint
return
16 changes: 5 additions & 11 deletions fedeca/algorithms/torch_webdisco_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
from copy import deepcopy
from math import sqrt
from pathlib import Path
from typing import Any, List, Optional

# hello
Expand Down Expand Up @@ -547,23 +546,18 @@ def _get_state_to_save(self) -> dict:
checkpoint.update({"global_moments": self.global_moments})
return checkpoint

def _update_from_checkpoint(self, path: Path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Load the local state from the checkpoint.
Parameters
----------
path : pathlib.Path
Path where the checkpoint is saved
Returns
-------
dict
Checkpoint
checkpoint : dict
The checkpoint to load.
"""
checkpoint = super()._update_from_checkpoint(path=path)
super()._update_from_checkpoint(checkpoint=checkpoint)
self.server_state = checkpoint.pop("server_state")
self.global_moments = checkpoint.pop("global_moments")
return checkpoint
return

def summary(self):
"""Summary of the class to be exposed in the experiment summary file.
Expand Down

0 comments on commit a73bdff

Please sign in to comment.