Skip to content

Commit

Permalink
fix for error while loading state dict in cifar10
Browse files Browse the repository at this point in the history
Signed-off-by: yes <[email protected]>
  • Loading branch information
tanwarsh committed Nov 7, 2024
1 parent 292289b commit e407f2b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def FedAvg(models): # NOQA: N802
[state[key] for state in state_dicts], axis=0
) / len(models)
# Convert numpy arrays within the state dictionary to PyTorch tensors
state_dict_torch = {k: torch.tensor(v, dtype=torch.float32).cpu() for k, v in state_dict.items()}
state_dict_tensors = {key: torch.tensor(value, dtype=torch.float32).cpu() for key, value in state_dict.items()}

# Load the converted state dictionary into the model
new_model.load_state_dict(state_dict_torch)
new_model.load_state_dict(state_dict_tensors)
return new_model


Expand Down

0 comments on commit e407f2b

Please sign in to comment.