Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
felixgwu committed Sep 3, 2020
1 parent 9b4116c commit 188b4a4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
15 changes: 11 additions & 4 deletions bert_score/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,22 @@ def score(
print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec")

if return_hash:
return tuple([out, get_hash(model_type, num_layers, idf, rescale_with_baseline,
use_custom_baseline=use_custom_baseline)])
return tuple(
[out, get_hash(model_type, num_layers, idf, rescale_with_baseline, use_custom_baseline=use_custom_baseline)]
)

return out


def plot_example(
candidate, reference, model_type=None, num_layers=None, lang=None, rescale_with_baseline=False,
baseline_path=None, fname="",
candidate,
reference,
model_type=None,
num_layers=None,
lang=None,
rescale_with_baseline=False,
baseline_path=None,
fname="",
):
"""
BERTScore metric.
Expand Down
17 changes: 10 additions & 7 deletions bert_score/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def __init__(
self.baseline_path = baseline_path
self.use_custom_baseline = self.baseline_path is not None
if self.baseline_path is None:
self.baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_type}.tsv")
self.baseline_path = os.path.join(
os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_type}.tsv"
)

@property
def lang(self):
Expand Down Expand Up @@ -143,18 +145,19 @@ def baseline_vals(self):
pd.read_csv(self.baseline_path).iloc[self.num_layers].to_numpy()
)[1:].float()
else:
self._baseline_vals = torch.from_numpy(
pd.read_csv(self.baseline_path).to_numpy()
)[:, 1:].unsqueeze(1).float()
self._baseline_vals = (
torch.from_numpy(pd.read_csv(self.baseline_path).to_numpy())[:, 1:].unsqueeze(1).float()
)
else:
raise ValueError(
f"Baseline not Found for {self.model_type} on {self.lang} at {self.baseline_path}")
raise ValueError(f"Baseline not Found for {self.model_type} on {self.lang} at {self.baseline_path}")

return self._baseline_vals

@property
def hash(self):
return get_hash(self.model_type, self.num_layers, self.idf, self.rescale_with_baseline, self.use_custom_baseline)
return get_hash(
self.model_type, self.num_layers, self.idf, self.rescale_with_baseline, self.use_custom_baseline
)

def compute_idf(self, sents):
"""
Expand Down
6 changes: 5 additions & 1 deletion bert_score/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,16 @@ def sent_encode(tokenizer, sent):
elif isinstance(tokenizer, GPT2Tokenizer):
# for RoBERTa and GPT-2
import transformers

if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"):
return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, truncation=True)
return tokenizer.encode(
sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, truncation=True
)
else:
return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len)
else:
import transformers

if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"):
return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len, truncation=True)
else:
Expand Down

0 comments on commit 188b4a4

Please sign in to comment.