From 3de33c10e42a769da8b05aa83ce31ea1585ac158 Mon Sep 17 00:00:00 2001 From: SdgJlbl Date: Tue, 13 Aug 2024 17:02:53 +0200 Subject: [PATCH] chore: change TorchAlgo._device to property Signed-off-by: SdgJlbl --- .../algorithms/pytorch/torch_base_algo.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/substrafl/algorithms/pytorch/torch_base_algo.py b/substrafl/algorithms/pytorch/torch_base_algo.py index 5752e1c7..1b7b765a 100644 --- a/substrafl/algorithms/pytorch/torch_base_algo.py +++ b/substrafl/algorithms/pytorch/torch_base_algo.py @@ -63,8 +63,7 @@ def __init__( np.random.seed(seed) torch.manual_seed(seed) - self._device = self._get_torch_device(disable_gpu=disable_gpu) - + self.disable_gpu = disable_gpu self._model = model.to(self._device) self._optimizer = optimizer # Move the optimizer to GPU if needed @@ -212,18 +211,16 @@ def _local_train( if self._scheduler is not None: self._scheduler.step() - def _get_torch_device(self, disable_gpu: bool) -> torch.device: + @property + def _device(self) -> torch.device: """Get the torch device, CPU or GPU, depending on availability and user input. - Args: - disable_gpu (bool): whether to use GPUs if available or not. - Returns: torch.device: Torch device """ device = torch.device("cpu") - if not disable_gpu and torch.cuda.is_available(): + if not self.disable_gpu and torch.cuda.is_available(): device = torch.device("cuda") return device @@ -251,6 +248,7 @@ def _update_from_checkpoint(self, path: Path) -> dict: assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}' checkpoint = torch.load(path, map_location=self._device) self._model.load_state_dict(checkpoint.pop("model_state_dict")) + self.disable_gpu = checkpoint.pop("disable_gpu") if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint.pop("optimizer_state_dict")) @@ -307,6 +305,9 @@ def _get_state_to_save(self) -> dict: checkpoint = { "model_state_dict": self._model.state_dict(), "index_generator": self._index_generator, + "disable_gpu": self.disable_gpu, + "random_rng_state": random.getstate(), + "numpy_rng_state": np.random.get_state(), } if self._optimizer is not None: checkpoint["optimizer_state_dict"] = self._optimizer.state_dict() @@ -314,10 +315,6 @@ def _get_state_to_save(self) -> dict: if self._scheduler is not None: checkpoint["scheduler_state_dict"] = self._scheduler.state_dict() - checkpoint["random_rng_state"] = random.getstate() - - checkpoint["numpy_rng_state"] = np.random.get_state() - if self._device == torch.device("cpu"): checkpoint["torch_rng_state"] = torch.get_rng_state() else: