Skip to content

Commit

Permalink
accomodating substrafl api chhange
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 23, 2024
1 parent 390a466 commit d773739
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions fedeca/algorithms/torch_dp_fed_avg_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def _update_from_checkpoint(self, checkpoint: dict) -> None:

# For some reason substrafl save and load client before calling train
if "privacy_accountant_state_dict" in checkpoint:

self.accountant = RDPAccountant()
self.accountant.load_state_dict(
checkpoint.pop("privacy_accountant_state_dict")
Expand Down Expand Up @@ -427,9 +428,9 @@ def _update_from_checkpoint(self, checkpoint: dict) -> None:
self._index_generator = checkpoint.pop("index_generator")

if self._device == torch.device("cpu"):
torch.set_rng_state(checkpoint.pop("rng_state").to(self._device))
torch.set_rng_state(checkpoint.pop("torch_rng_state").to(self._device))
else:
torch.cuda.set_rng_state(checkpoint.pop("rng_state").to("cpu"))
torch.cuda.set_rng_state(checkpoint.pop("torch_rng_state").to("cpu"))

attr_names = [
"dp_max_grad_norm",
Expand Down

0 comments on commit d773739

Please sign in to comment.