diff --git a/substrafl/algorithms/pytorch/torch_base_algo.py b/substrafl/algorithms/pytorch/torch_base_algo.py index 0562a1c5..eb826174 100644 --- a/substrafl/algorithms/pytorch/torch_base_algo.py +++ b/substrafl/algorithms/pytorch/torch_base_algo.py @@ -63,7 +63,7 @@ def __init__( np.random.seed(seed) torch.manual_seed(seed) - self._device = self._get_torch_device(use_gpu=use_gpu) + self.use_gpu = use_gpu self._model = model.to(self._device) self._optimizer = optimizer @@ -212,7 +212,8 @@ def _local_train( if self._scheduler is not None: self._scheduler.step() - def _get_torch_device(self, use_gpu: bool) -> torch.device: + @property + def _device(self) -> torch.device: """Get the torch device, CPU or GPU, depending on availability and user input. @@ -223,7 +224,7 @@ def _get_torch_device(self, use_gpu: bool) -> torch.device: torch.device: Torch device """ device = torch.device("cpu") - if use_gpu and torch.cuda.is_available(): + if self.use_gpu and torch.cuda.is_available(): device = torch.device("cuda") return device