Skip to content

Commit

Permalink
chore: (temp) put model to device explicitely on loading
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy committed Aug 20, 2024
1 parent 3de33c1 commit 3992227
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions substrafl/algorithms/pytorch/torch_base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,12 @@ def _update_from_checkpoint(self, path: Path) -> dict:
return checkpoint
"""
assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
checkpoint = torch.load(path, map_location=self._device)
self._model.load_state_dict(checkpoint.pop("model_state_dict"))
checkpoint = torch.load(path) # TO CHANGE
self.disable_gpu = checkpoint.pop("disable_gpu")

self._model.load_state_dict(checkpoint.pop("model_state_dict"))
self._model.to(self._device)

if self._optimizer is not None:
self._optimizer.load_state_dict(checkpoint.pop("optimizer_state_dict"))

Expand Down

0 comments on commit 3992227

Please sign in to comment.