-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathevaluate.py
201 lines (168 loc) · 6.99 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# coding=utf-8
# Copyright The XTREME Benchmark Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation."""
import argparse
from seqeval.metrics import precision_score, recall_score, f1_score
import sys
from third_party.evaluate_squad import evaluate as squad_eval
from third_party.evaluate_mlqa import evaluate as mlqa_eval
def read_tag(file):
labels = []
example = []
for line in open(file, 'r'):
items = line.strip().split('\t')
if len(items) == 2:
example.append(items[1].strip())
else:
labels.append(example)
example = []
return labels
def read_label(file):
return [l.strip() for l in open(file)]
def read_squad(file):
expected_version = '1.1'
with open(file) as dataset_file:
dataset_json = json.load(dataset_file)
if 'version' in dataset_json and dataset_json['version'] != expected_version:
print('Evaluation expects v-' + expected_version,
', but got dataset with v-' + dataset_json['version'],
file=sys.stderr)
if 'data' in dataset_json:
return dataset_json['data']
else:
return dataset_json
def f1(labels, predictions, language=None):
f1 = f1_score(labels, predictions)
precision = precision_score(labels, predictions)
recall = recall_score(labels, predictions)
return {'f1': f1, 'precision': precision, 'recall': recall}
def accuracy(labels, predictions, language=None):
correct = sum([int(p == l) for p, l in zip(predictions, labels)])
accuracy = float(correct) / len(predictions)
return {'accuracy': accuracy}
def squad_em_f1(labels, predictions, language=None):
return squad_eval(labels, predictions)
def mlqa_em_f1(labels, predictions, language):
if language is None:
print('required 2-char language code for the argument `language`')
exit(0)
return mlqa_eval(labels, predictions)
GROUP2TASK = {
"classification": ["pawsx", "xnli"],
"tagging": ["udpos", "panx"],
"qa": ["xquad", "mlqa", "tydiqa"],
"retrieval": ["bucc2018", "tatoeba"],
}
TASK2LANGS = {
"pawsx": "de,en,es,fr,ja,ko,zh".split(","),
"xnli": "ar,bg,de,el,en,es,fr,hi,ru,sw,th,tr,ur,vi,zh".split(","),
"panx": "ar,he,vi,id,jv,ms,tl,eu,ml,ta,te,af,nl,en,de,el,bn,hi,mr,ur,fa,fr,it,pt,es,bg,ru,ja,ka,ko,th,sw,yo,my,zh,kk,tr,et,fi,hu".split(","),
"udpos": "af,ar,bg,de,el,en,es,et,eu,fa,fi,fr,he,hi,hu,id,it,ja,kk,ko,mr,nl,pt,ru,ta,te,th,tl,tr,ur,vi,yo,zh".split(","),
"bucc2018": "de,fr,ru,zh".split(","),
"tatoeba": "ar,he,vi,id,jv,tl,eu,ml,ta,te,af,nl,en,de,el,bn,hi,mr,ur,fa,fr,it,pt,es,bg,ru,ja,ka,ko,th,sw,zh,kk,tr,et,fi,hu".split(","),
"xquad": "en,es,de,el,ru,tr,ar,vi,th,zh,hi".split(","),
"mlqa": "en,es,de,ar,hi,vi,zh".split(","),
"tydiqa": "en,ar,bn,fi,id,ko,ru,sw,te".split(","),
}
READER_FUNCTION = {
'pawsx': read_label,
'xnli': read_label,
'panx': read_tag,
'udpos': read_tag,
'bucc2018': read_label,
'tatoeba': read_label,
'xquad': read_squad,
'mlqa': read_squad,
'tydiqa': read_squad,
}
METRIC_FUNCTION = {
'pawsx': accuracy,
'xnli': accuracy,
'panx': f1,
'udpos': f1,
'bucc2018': f1,
'tatoeba': accuracy,
'xquad': squad_em_f1,
'mlqa': mlqa_em_f1,
'tydiqa': squad_em_f1,
}
def evaluate_one_task(prediction_file, label_file, task, language=None):
"""Evalute the classification tasks by accuracy.
Args:
prediction_file (string): path to the prediction tsv file.
label_file (string): path to the grouth truth tsv file.
Return:
result (dict): a dictionary with accuracy.
Both input files contain one example per line as follows:
``[label]\t[sentence1]\t[sentence2]``
"""
predictions = READER_FUNCTION[task](prediction_file)
labels = READER_FUNCTION[task](label_file)
assert len(predictions) == len(labels), 'Number of examples in {} and {} not matched'.format(prediction_file, label_file)
result = METRIC_FUNCTION[task](labels, predictions)
return result
def evaluate(prediction_folder, label_folder):
"""Evaluate on all tasks if available.
Args:
prediction_folder (string): prediction folder that contains each task's prediction in each subfolder.
label_file (string): label folder that contains each task's ground-truth label in each subfolder.
Return:
overall_scores (dict): a dictionary with sub-group scores. key: group label.
detailed_scores (dict): a dictionary with all detailed scores. key: task label.
"""
prediction_tasks = next(os.walk(prediction_folder))[1]
label_tasks = next(os.walk(label_folder))[1]
detailed_scores = {}
for task, langs in TASK2LANGS.items():
if task in prediction_tasks and task in label_tasks:
suffix = "json" if task in GROUP2TASK["qa"] else "tsv"
# collect scores over all languages
score = defaultdict(dict)
for lg in langs:
prediction_file = os.path.join(prediction_folder, task, f"test-{lg}.{suffix}")
label_file = os.path.join(label_folder, task, f"test-{lg}.{suffix}")
score_lg = evaluate_one_task(prediction_file, label_file, task, language=lg)
for metric in score_lg.items():
score[metric][lg] = score_lg[metric]
# average over all languages
for m in score:
score[f'avg_{m}'] = sum(score[m].values()) / len(score[m])
if task in GROUP2TASK["qa"]:
score['avg_metric'] = (score['avg_exact_match'] + score['avg_f1']) / 2
elif 'avg_f1' in score:
score['avg_metric'] = score['avg_f1']
elif 'avg_accuracy' in score:
score['avg_metric'] = score['avg_accuracy']
detailed_scores[task] = score
# Display logic:
# If scores of all tasks in a sub group are available, show the score in the sub table
overall_scores = {}
for group in GROUP2TASK:
if all(task in detailed_scores for task in group):
overall_scores[group] = sum(detailed_scores[task]['avg_metric'] for task in group) / len(group)
# If scores of all tasks are available, show the overall score in the main table
all_tasks = list(TASK2LANGS.keys())
if all(task in all_tasks for task in detailed_scores):
overall_scores['all_task'] = sum(detailed_scores[task]['avg_metric'] for task in all_tasks) / len(all_tasks)
return overall_scores, detailed_scores
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--prediction_folder', default=None, type=str, required=True,
help='the predictions of one model')
parser.add_argument('--label_folder', default=None, type=str, required=True,
help='the grouth truth file')
args = parser.parse_args()
overall_scores, detailed_scores = evaluate(args.prediction_folder, args.label_folder)
print(overall_scores)