Skip to content

Commit

Permalink
chore: change TorchAlgo._device to property
Browse files Browse the repository at this point in the history
Signed-off-by: SdgJlbl <[email protected]>
  • Loading branch information
SdgJlbl committed Aug 13, 2024
1 parent 69b61bb commit bca83e1
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions substrafl/algorithms/pytorch/torch_base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
np.random.seed(seed)
torch.manual_seed(seed)

self._device = self._get_torch_device(use_gpu=use_gpu)
self.use_gpu = use_gpu

self._model = model.to(self._device)
self._optimizer = optimizer
Expand Down Expand Up @@ -212,7 +212,8 @@ def _local_train(
if self._scheduler is not None:
self._scheduler.step()

def _get_torch_device(self, use_gpu: bool) -> torch.device:
@property
def _device(self) -> torch.device:
"""Get the torch device, CPU or GPU, depending
on availability and user input.
Expand All @@ -223,7 +224,7 @@ def _get_torch_device(self, use_gpu: bool) -> torch.device:
torch.device: Torch device
"""
device = torch.device("cpu")
if use_gpu and torch.cuda.is_available():
if self.use_gpu and torch.cuda.is_available():
device = torch.device("cuda")
return device

Expand Down

0 comments on commit bca83e1

Please sign in to comment.