diff --git a/bert_score/score.py b/bert_score/score.py index 018e73e..2c2abd7 100644 --- a/bert_score/score.py +++ b/bert_score/score.py @@ -168,15 +168,22 @@ def score( print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec") if return_hash: - return tuple([out, get_hash(model_type, num_layers, idf, rescale_with_baseline, - use_custom_baseline=use_custom_baseline)]) + return tuple( + [out, get_hash(model_type, num_layers, idf, rescale_with_baseline, use_custom_baseline=use_custom_baseline)] + ) return out def plot_example( - candidate, reference, model_type=None, num_layers=None, lang=None, rescale_with_baseline=False, - baseline_path=None, fname="", + candidate, + reference, + model_type=None, + num_layers=None, + lang=None, + rescale_with_baseline=False, + baseline_path=None, + fname="", ): """ BERTScore metric. diff --git a/bert_score/scorer.py b/bert_score/scorer.py index 057976e..18e6c7f 100644 --- a/bert_score/scorer.py +++ b/bert_score/scorer.py @@ -112,7 +112,9 @@ def __init__( self.baseline_path = baseline_path self.use_custom_baseline = self.baseline_path is not None if self.baseline_path is None: - self.baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_type}.tsv") + self.baseline_path = os.path.join( + os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_type}.tsv" + ) @property def lang(self): @@ -143,18 +145,19 @@ def baseline_vals(self): pd.read_csv(self.baseline_path).iloc[self.num_layers].to_numpy() )[1:].float() else: - self._baseline_vals = torch.from_numpy( - pd.read_csv(self.baseline_path).to_numpy() - )[:, 1:].unsqueeze(1).float() + self._baseline_vals = ( + torch.from_numpy(pd.read_csv(self.baseline_path).to_numpy())[:, 1:].unsqueeze(1).float() + ) else: - raise ValueError( - f"Baseline not Found for {self.model_type} on {self.lang} at {self.baseline_path}") + raise ValueError(f"Baseline not Found for {self.model_type} on {self.lang} at {self.baseline_path}") return self._baseline_vals @property def hash(self): - return get_hash(self.model_type, self.num_layers, self.idf, self.rescale_with_baseline, self.use_custom_baseline) + return get_hash( + self.model_type, self.num_layers, self.idf, self.rescale_with_baseline, self.use_custom_baseline + ) def compute_idf(self, sents): """ diff --git a/bert_score/utils.py b/bert_score/utils.py index 632ec5a..649ebeb 100644 --- a/bert_score/utils.py +++ b/bert_score/utils.py @@ -106,12 +106,16 @@ def sent_encode(tokenizer, sent): elif isinstance(tokenizer, GPT2Tokenizer): # for RoBERTa and GPT-2 import transformers + if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): - return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, truncation=True) + return tokenizer.encode( + sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, truncation=True + ) else: return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len) else: import transformers + if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len, truncation=True) else: