Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding BERT score [WIP] #76

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
wip
  • Loading branch information
astariul committed Jul 6, 2019
commit c541b38f28516d90eefdbe2c858526d30c96f1cb
6 changes: 6 additions & 0 deletions nlgeval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from nlgeval.pycocoevalcap.cider.cider import Cider
from nlgeval.pycocoevalcap.meteor.meteor import Meteor
from nlgeval.pycocoevalcap.rouge.rouge import Rouge
from nlgeval.others.bert_scorer import BertScore


# str/unicode stripping in Python 2 and 3 instead of `str.strip`.
Expand All @@ -34,6 +35,7 @@ def compute_metrics(hypothesis, references, no_overlap=False, no_skipthoughts=Fa
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
(Meteor(), "METEOR"),
(Rouge(), "ROUGE_L"),
(BertScore(), "BERT_score"),
(Cider(), "CIDEr")
]
for scorer, method in scorers:
Expand Down Expand Up @@ -99,6 +101,7 @@ def compute_individual_metrics(ref, hyp, no_overlap=False, no_skipthoughts=False
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
(Meteor(), "METEOR"),
(Rouge(), "ROUGE_L"),
(BertScore(), "BERT_score"),
(Cider(), "CIDEr")
]
for scorer, method in scorers:
Expand Down Expand Up @@ -152,6 +155,7 @@ class NLGEval(object):
'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4',
'METEOR',
'ROUGE_L',
'BERT_score'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a comma missing at the end of the line.

'CIDEr',

# Skip-thought
Expand Down Expand Up @@ -212,6 +216,8 @@ def load_scorers(self):
self.scorers.append((Meteor(), "METEOR"))
if 'ROUGE_L' not in self.metrics_to_omit:
self.scorers.append((Rouge(), "ROUGE_L"))
if 'BERT_score' not in self.metrics_to_omit:
self.scorers.append((BertScore(), "BERT_score"))
if 'CIDEr' not in self.metrics_to_omit:
self.scorers.append((Cider(), "CIDEr"))

Expand Down
47 changes: 47 additions & 0 deletions nlgeval/others/bert_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python
#
# File Name : bert_scorer.py
#
# Description : Computes BERT score as described by Tianyi Zhang et all (2019)
#
# Creation Date : 2019-07-06
# Author : REMOND Nicolas

from bert_score import score

class BertScore():
'''
Class for computing BERT score for a set of candidate sentences
'''

def __init__(self, score_type='f_score'):
# Score type to be returned
if score_type not in ['f_score', 'recall', 'precision']:
raise ValueError("Score type must be either 'f_score', 'precision', or 'recall'. Given : {}".format(score_type))
self.score_type = score_type

def compute_score(self, gts, res):
"""
Computes BERT score given a set of reference and candidate sentences for the dataset
:param res: dict : candidate / test sentences.
:param gts: dict : references.
:returns: average_score: float (mean BERT score computed by averaging scores for all the images), individual scores
"""
assert(gts.keys() == res.keys())
imgIds = gts.keys()

hyp = [res[id][0] for id in imgIds]
ref = [gts[id][0] for id in imgIds] # Take only the first reference
# Because Bert Score support only 1
assert len(hyp) == len(ref)

P, R, F1 = score(hyp, ref, bert="bert-base-uncased", no_idf=(len(ref) == 1))

if self.score_type == 'recall':
s = R
elif self.score_type == 'precision':
s = P
elif self.score_type == 'f_score':
s = F1

return s.mean().item(), s.tolist()
11 changes: 11 additions & 0 deletions nlgeval/tests/test_nlgeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,14 @@ def test_compute_metrics(self):
self.assertAlmostEqual(0.568696, scores['VectorExtremaCosineSimilarity'], places=5)
self.assertAlmostEqual(0.784205, scores['GreedyMatchingScore'], places=5)
self.assertEqual(11, len(scores))

def test_bert_score(self):
n = NLGEval(metrics_to_omit=['Bleu_1', 'Bleu_2', 'Bleu_3', 'ROUGE_L', 'METEOR', 'EmbeddingAverageCosineSimilairty', 'CIDEr', 'SkipThoughtCS', 'VectorExtremaCosineSimilarity', 'GreedyMatchingScore'])

# Individual Metrics
scores = n.compute_individual_metrics(ref=["Until you start talking to Katrin Bahr."],
hyp="Until you talk to Katrin Bahr.")
self.assertAlmostEqual(0.9345, scores['BERT_score'], places=5)

if __name__ == "__main__":
unittest.main()