Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Jan 8, 2025
1 parent 0b6971c commit 1b4a7af
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
4 changes: 3 additions & 1 deletion vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,9 @@ def load(
"""

# Forcably load to CPU even if model was saves as GPU model
dictionary = _torch.load(path, map_location=lambda storage, loc: storage, weights_only=True)
dictionary = _torch.load(
path, map_location=lambda storage, loc: storage, weights_only=True
)

nsamples = dictionary["nsamples"]
alpha = dictionary["alpha"]
Expand Down
4 changes: 3 additions & 1 deletion vamb/semisupervised_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,9 @@ def load(cls, path, cuda=False, evaluate=True):
"""

# Forcably load to CPU even if model was saves as GPU model
dictionary = _torch.load(path, map_location=lambda storage, loc: storage, weights_only=False)
dictionary = _torch.load(
path, map_location=lambda storage, loc: storage, weights_only=False
)

nsamples = dictionary["nsamples"]
nlabels = dictionary["nlabels"]
Expand Down
4 changes: 3 additions & 1 deletion vamb/taxvamb_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,9 @@ def load(cls, path, nodes, table_parent, cuda=False, evaluate=True):
"""

# Forcably load to CPU even if model was saves as GPU model
dictionary = _torch.load(path, map_location=lambda storage, loc: storage, weights_only=False)
dictionary = _torch.load(
path, map_location=lambda storage, loc: storage, weights_only=False
)

nsamples = dictionary["nsamples"]
nlabels = dictionary["nlabels"]
Expand Down

0 comments on commit 1b4a7af

Please sign in to comment.