From bca83e1dc39a272e6046b45d3ace5c3e28a4c946 Mon Sep 17 00:00:00 2001 From: SdgJlbl Date: Tue, 13 Aug 2024 17:02:53 +0200 Subject: [PATCH] chore: change TorchAlgo._device to property Signed-off-by: SdgJlbl --- substrafl/algorithms/pytorch/torch_base_algo.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/substrafl/algorithms/pytorch/torch_base_algo.py b/substrafl/algorithms/pytorch/torch_base_algo.py index 0562a1c5..eb826174 100644 --- a/substrafl/algorithms/pytorch/torch_base_algo.py +++ b/substrafl/algorithms/pytorch/torch_base_algo.py @@ -63,7 +63,7 @@ def __init__( np.random.seed(seed) torch.manual_seed(seed) - self._device = self._get_torch_device(use_gpu=use_gpu) + self.use_gpu = use_gpu self._model = model.to(self._device) self._optimizer = optimizer @@ -212,7 +212,8 @@ def _local_train( if self._scheduler is not None: self._scheduler.step() - def _get_torch_device(self, use_gpu: bool) -> torch.device: + @property + def _device(self) -> torch.device: """Get the torch device, CPU or GPU, depending on availability and user input. @@ -223,7 +224,7 @@ def _get_torch_device(self, use_gpu: bool) -> torch.device: torch.device: Torch device """ device = torch.device("cpu") - if use_gpu and torch.cuda.is_available(): + if self.use_gpu and torch.cuda.is_available(): device = torch.device("cuda") return device