Skip to content

Commit

Permalink
metrics: add rouge score
Browse files Browse the repository at this point in the history
  • Loading branch information
Mehrad0711 committed Mar 2, 2022
1 parent c4ac2cd commit 939f9d0
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions genienlp/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import logging
from collections import Counter, OrderedDict, defaultdict
from typing import Iterable, Union
from typing import List, Union

import sacrebleu
from datasets import load_metric
Expand Down Expand Up @@ -84,6 +84,12 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)


def computeROUGE(outputs, targets, rouge_types):
targets = [target[0] for target in targets]
rouge_metric = load_metric('rouge')
return rouge_metric.compute(references=targets, predictions=outputs, rouge_types=rouge_types)


def computeSequenceClassificationPrecision(outputs, targets):
targets = [target[0] for target in targets]
precision_metric = load_metric('precision')
Expand Down Expand Up @@ -267,12 +273,12 @@ def compute_ner_f1(predictions, answers, schema='IOB2'):


def compute_metrics(
predictions: Iterable[str],
answers: Union[Iterable[str], Iterable[Iterable[str]]],
requested_metrics: Iterable,
predictions: List[str],
answers: Union[List[str], List[List[str]]],
requested_metrics: List,
lang: str,
args,
example_ids: Iterable[str] = None,
example_ids: List[str] = None,
):
"""
Inputs:
Expand Down Expand Up @@ -359,6 +365,13 @@ def compute_metrics(
ner_f1 = compute_ner_f1(predictions, answers)
metric_keys.append('ner_f1')
metric_values.append(ner_f1)
for m in ['rouge1', 'rouge2', 'rougeL']:
if m in requested_metrics:
rouge = computeROUGE(predictions, answers, rouge_types=[m])[m]
requested_metrics.remove(m)
requested_metrics += [f'{m}_low', f'{m}_mid', f'{m}_high']
metric_keys += [f'{m}_low', f'{m}_mid', f'{m}_high']
metric_values += [rouge.low.fmeasure, rouge.mid.fmeasure, rouge.high.fmeasure]

metric_dict = dict(zip(metric_keys, metric_values))
metric_dict = OrderedDict((key, metric_dict[key]) for key in requested_metrics)
Expand Down

0 comments on commit 939f9d0

Please sign in to comment.