Skip to content

Commit

Permalink
chore: change TorchAlgo._device to property
Browse files Browse the repository at this point in the history
Signed-off-by: SdgJlbl <[email protected]>
  • Loading branch information
SdgJlbl committed Aug 14, 2024
1 parent 5590e7b commit 3de33c1
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions substrafl/algorithms/pytorch/torch_base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def __init__(
np.random.seed(seed)
torch.manual_seed(seed)

self._device = self._get_torch_device(disable_gpu=disable_gpu)

self.disable_gpu = disable_gpu
self._model = model.to(self._device)
self._optimizer = optimizer
# Move the optimizer to GPU if needed
Expand Down Expand Up @@ -212,18 +211,16 @@ def _local_train(
if self._scheduler is not None:
self._scheduler.step()

def _get_torch_device(self, disable_gpu: bool) -> torch.device:
@property
def _device(self) -> torch.device:
"""Get the torch device, CPU or GPU, depending
on availability and user input.
Args:
disable_gpu (bool): whether to use GPUs if available or not.
Returns:
torch.device: Torch device
"""
device = torch.device("cpu")
if not disable_gpu and torch.cuda.is_available():
if not self.disable_gpu and torch.cuda.is_available():
device = torch.device("cuda")
return device

Expand Down Expand Up @@ -251,6 +248,7 @@ def _update_from_checkpoint(self, path: Path) -> dict:
assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
checkpoint = torch.load(path, map_location=self._device)
self._model.load_state_dict(checkpoint.pop("model_state_dict"))
self.disable_gpu = checkpoint.pop("disable_gpu")

if self._optimizer is not None:
self._optimizer.load_state_dict(checkpoint.pop("optimizer_state_dict"))
Expand Down Expand Up @@ -307,17 +305,16 @@ def _get_state_to_save(self) -> dict:
checkpoint = {
"model_state_dict": self._model.state_dict(),
"index_generator": self._index_generator,
"disable_gpu": self.disable_gpu,
"random_rng_state": random.getstate(),
"numpy_rng_state": np.random.get_state(),
}
if self._optimizer is not None:
checkpoint["optimizer_state_dict"] = self._optimizer.state_dict()

if self._scheduler is not None:
checkpoint["scheduler_state_dict"] = self._scheduler.state_dict()

checkpoint["random_rng_state"] = random.getstate()

checkpoint["numpy_rng_state"] = np.random.get_state()

if self._device == torch.device("cpu"):
checkpoint["torch_rng_state"] = torch.get_rng_state()
else:
Expand Down

0 comments on commit 3de33c1

Please sign in to comment.