Skip to content

Commit

Permalink
Merge pull request huggingface#134 from gabinguo/main
Browse files Browse the repository at this point in the history
issue#126: torch.load device issue.
  • Loading branch information
pacman100 authored Feb 25, 2023
2 parents e19ee68 + 85ad682 commit 681ce93
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def from_pretrained(cls, model, model_id, **kwargs):
f"Please check that the file {WEIGHTS_NAME} is present at {model_id}."
)

adapters_weights = torch.load(filename)
adapters_weights = torch.load(
filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# load the weights into the model
model = set_peft_model_state_dict(model, adapters_weights)
if getattr(model, "hf_device_map", None) is not None:
Expand Down

0 comments on commit 681ce93

Please sign in to comment.