diff --git a/examples/images/cifar10/compute_fid.py b/examples/images/cifar10/compute_fid.py index ffa66c2..7596699 100644 --- a/examples/images/cifar10/compute_fid.py +++ b/examples/images/cifar10/compute_fid.py @@ -51,7 +51,7 @@ # Load the model PATH = f"{FLAGS.input_dir}/{FLAGS.model}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt" print("path: ", PATH) -checkpoint = torch.load(PATH) +checkpoint = torch.load(PATH, map_location=device) state_dict = checkpoint["ema_model"] try: new_net.load_state_dict(state_dict)