diff --git a/substrafl/algorithms/pytorch/torch_base_algo.py b/substrafl/algorithms/pytorch/torch_base_algo.py index 1b7b765a..d9fcdaa2 100644 --- a/substrafl/algorithms/pytorch/torch_base_algo.py +++ b/substrafl/algorithms/pytorch/torch_base_algo.py @@ -246,10 +246,12 @@ def _update_from_checkpoint(self, path: Path) -> dict: return checkpoint """ assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}' - checkpoint = torch.load(path, map_location=self._device) - self._model.load_state_dict(checkpoint.pop("model_state_dict")) + checkpoint = torch.load(path) # TO CHANGE self.disable_gpu = checkpoint.pop("disable_gpu") + self._model.load_state_dict(checkpoint.pop("model_state_dict")) + self._model.to(self._device) + if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint.pop("optimizer_state_dict"))