Skip to content

Commit

Permalink
Merge pull request #1128 from tanwarsh/7-Nov-24
Browse files Browse the repository at this point in the history
fix for error while loading state dict in cifar10
  • Loading branch information
teoparvanov authored Nov 8, 2024
2 parents d4c97b3 + e407f2b commit fa3c516
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ def FedAvg(models): # NOQA: N802
state_dict[key] = np.sum(
[state[key] for state in state_dicts], axis=0
) / len(models)
new_model.load_state_dict(state_dict)
# Convert numpy arrays within the state dictionary to PyTorch tensors
state_dict_tensors = {key: torch.tensor(value, dtype=torch.float32).cpu() for key, value in state_dict.items()}

# Load the converted state dictionary into the model
new_model.load_state_dict(state_dict_tensors)
return new_model


Expand Down

0 comments on commit fa3c516

Please sign in to comment.