Skip to content

Commit

Permalink
Merge pull request #45 from rskmoi/patch/bert_run_on_cpu
Browse files Browse the repository at this point in the history
Add map_location on torch.load for cpu inference
  • Loading branch information
rskmoi authored May 3, 2024
2 parents 45a2086 + 2894e9b commit ac3b18d
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(self, model_path: Union[str, Path], separator: str = " ", family_fi
self.device = "cuda" if torch.cuda.is_available() else "cpu"
config = PretrainedConfig.from_json_file(CURRENT_DIR / "config.json")
model = BertForSequenceClassification(config=config)
model.load_state_dict(torch.load(model_path))
self.model = model.to(self.device)
model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
self.model = model.to(self.device).eval()

# Prepare vocabularies
with open(CURRENT_DIR / "vocab.json") as f:
Expand Down

0 comments on commit ac3b18d

Please sign in to comment.