From 939f9d08045be0be88e3e673bc297d30cba3e728 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 1 Mar 2022 14:48:49 -0800 Subject: [PATCH] metrics: add rouge score --- genienlp/metrics.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 63636651..a186c8b3 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -29,7 +29,7 @@ import logging from collections import Counter, OrderedDict, defaultdict -from typing import Iterable, Union +from typing import List, Union import sacrebleu from datasets import load_metric @@ -84,6 +84,12 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): return max(scores_for_ground_truths) +def computeROUGE(outputs, targets, rouge_types): + targets = [target[0] for target in targets] + rouge_metric = load_metric('rouge') + return rouge_metric.compute(references=targets, predictions=outputs, rouge_types=rouge_types) + + def computeSequenceClassificationPrecision(outputs, targets): targets = [target[0] for target in targets] precision_metric = load_metric('precision') @@ -267,12 +273,12 @@ def compute_ner_f1(predictions, answers, schema='IOB2'): def compute_metrics( - predictions: Iterable[str], - answers: Union[Iterable[str], Iterable[Iterable[str]]], - requested_metrics: Iterable, + predictions: List[str], + answers: Union[List[str], List[List[str]]], + requested_metrics: List, lang: str, args, - example_ids: Iterable[str] = None, + example_ids: List[str] = None, ): """ Inputs: @@ -359,6 +365,13 @@ def compute_metrics( ner_f1 = compute_ner_f1(predictions, answers) metric_keys.append('ner_f1') metric_values.append(ner_f1) + for m in ['rouge1', 'rouge2', 'rougeL']: + if m in requested_metrics: + rouge = computeROUGE(predictions, answers, rouge_types=[m])[m] + requested_metrics.remove(m) + requested_metrics += [f'{m}_low', f'{m}_mid', f'{m}_high'] + metric_keys += [f'{m}_low', f'{m}_mid', f'{m}_high'] + metric_values += [rouge.low.fmeasure, rouge.mid.fmeasure, rouge.high.fmeasure] metric_dict = dict(zip(metric_keys, metric_values)) metric_dict = OrderedDict((key, metric_dict[key]) for key in requested_metrics)