-
Notifications
You must be signed in to change notification settings - Fork 314
/
Copy patheval.py
163 lines (139 loc) · 5.14 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from dataset import ParallelTextReader
from torch.utils.data import DataLoader
from accelerate import find_executable_batch_size
from evaluate import load
from tqdm import tqdm
import torch
import json
import argparse
import numpy as np
import os
def get_dataloader(pred_path: str, gold_path: str, batch_size: int):
"""
Returns a dataloader for the given files.
"""
def collate_fn(batch):
return list(map(list, zip(*batch)))
reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path)
dataloader = DataLoader(
reader, batch_size=batch_size, collate_fn=collate_fn, num_workers=0
)
return dataloader
def eval_files(
pred_path: str,
gold_path: str,
bert_score_model: str,
starting_batch_size: int = 128,
output_path: str = None,
):
"""
Evaluates the given files.
"""
if torch.cuda.is_available():
device = "cuda:0"
print("We will use a GPU to calculate BertScore.")
else:
device = "cpu"
print(
f"We will use the CPU to calculate BertScore, this can be slow for large datasets."
)
dataloader = get_dataloader(pred_path, gold_path, starting_batch_size)
print("Loading sacrebleu...")
sacrebleu = load("sacrebleu")
print("Loading rouge...")
rouge = load("rouge")
print("Loading bleu...")
bleu = load("bleu")
print("Loading meteor...")
meteor = load("meteor")
print("Loading ter...")
ter = load("ter")
print("Loading BertScore...")
bert_score = load("bertscore")
with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar:
for predictions, references in dataloader:
sacrebleu.add_batch(predictions=predictions, references=references)
rouge.add_batch(predictions=predictions, references=references)
bleu.add_batch(predictions=predictions, references=references)
meteor.add_batch(predictions=predictions, references=references)
ter.add_batch(predictions=predictions, references=references)
bert_score.add_batch(predictions=predictions, references=references)
pbar.update(len(predictions))
result_dictionary = {"path": pred_path}
print("Computing sacrebleu")
result_dictionary["sacrebleu"] = sacrebleu.compute()
print("Computing rouge score")
result_dictionary["rouge"] = rouge.compute(
use_aggregator=True, rouge_types=["rouge1", "rouge2", "rougeL", "rougeLsum"]
)
print("Computing bleu score")
result_dictionary["bleu"] = bleu.compute()
print("Computing meteor score")
result_dictionary["meteor"] = meteor.compute()
print("Computing ter score")
result_dictionary["ter"] = ter.compute()
@find_executable_batch_size(starting_batch_size=starting_batch_size)
def inference(batch_size):
nonlocal bert_score, bert_score_model
print(f"Computing bert score with batch size {batch_size} on {device}")
results = bert_score.compute(
model_type=bert_score_model,
batch_size=batch_size,
device=device,
use_fast_tokenizer=True,
)
results["precision"] = np.average(results["precision"])
results["recall"] = np.average(results["recall"])
results["f1"] = np.average(results["f1"])
return results
result_dictionary["bert_score"] = inference()
if output_path is not None:
if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
os.makedirs(os.path.abspath(os.path.dirname(output_path)))
with open(output_path, "w") as f:
json.dump(result_dictionary, f, indent=4)
print(f"Results: {json.dumps(result_dictionary,indent=4)}")
return result_dictionary
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the translation evaluation experiments"
)
parser.add_argument(
"--pred_path",
type=str,
required=True,
help="Path to a txt file containing the predicted sentences.",
)
parser.add_argument(
"--gold_path",
type=str,
required=True,
help="Path to a txt file containing the gold sentences.",
)
parser.add_argument(
"--starting_batch_size",
type=int,
default=64,
help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.",
)
parser.add_argument(
"--output_path",
type=str,
default=None,
help="Path to a json file to save the results. If not given, the results will be printed to the console.",
)
parser.add_argument(
"--bert_score_model",
type=str,
default="microsoft/deberta-xlarge-mnli",
help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore"
"and https://github.com/Tiiiger/bert_score for more details.",
)
args = parser.parse_args()
eval_files(
pred_path=args.pred_path,
gold_path=args.gold_path,
starting_batch_size=args.starting_batch_size,
output_path=args.output_path,
bert_score_model=args.bert_score_model,
)