From 2894e9b3ac1dff8c6d0d369eeb973bf19a711ab3 Mon Sep 17 00:00:00 2001 From: rskmoi Date: Fri, 3 May 2024 15:26:24 +0900 Subject: [PATCH] [skip ci]Add map_location on torch.load for cpu inference --- .../beta_bert_divider/bert_name_divider_only_katakana.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/namedivider/beta_bert_divider/bert_name_divider_only_katakana.py b/namedivider/beta_bert_divider/bert_name_divider_only_katakana.py index 99b8336..9d5a3fb 100644 --- a/namedivider/beta_bert_divider/bert_name_divider_only_katakana.py +++ b/namedivider/beta_bert_divider/bert_name_divider_only_katakana.py @@ -29,8 +29,8 @@ def __init__(self, model_path: Union[str, Path], separator: str = " ", family_fi self.device = "cuda" if torch.cuda.is_available() else "cpu" config = PretrainedConfig.from_json_file(CURRENT_DIR / "config.json") model = BertForSequenceClassification(config=config) - model.load_state_dict(torch.load(model_path)) - self.model = model.to(self.device) + model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device))) + self.model = model.to(self.device).eval() # Prepare vocabularies with open(CURRENT_DIR / "vocab.json") as f: