From 39922278a6be6c93dc2522970712727b9f70b52c Mon Sep 17 00:00:00 2001 From: ThibaultFy Date: Tue, 20 Aug 2024 15:12:16 +0200 Subject: [PATCH] chore: (temp) put model to device explicitely on loading Signed-off-by: ThibaultFy --- substrafl/algorithms/pytorch/torch_base_algo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/substrafl/algorithms/pytorch/torch_base_algo.py b/substrafl/algorithms/pytorch/torch_base_algo.py index 1b7b765a..d9fcdaa2 100644 --- a/substrafl/algorithms/pytorch/torch_base_algo.py +++ b/substrafl/algorithms/pytorch/torch_base_algo.py @@ -246,10 +246,12 @@ def _update_from_checkpoint(self, path: Path) -> dict: return checkpoint """ assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}' - checkpoint = torch.load(path, map_location=self._device) - self._model.load_state_dict(checkpoint.pop("model_state_dict")) + checkpoint = torch.load(path) # TO CHANGE self.disable_gpu = checkpoint.pop("disable_gpu") + self._model.load_state_dict(checkpoint.pop("model_state_dict")) + self._model.to(self._device) + if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint.pop("optimizer_state_dict"))