-
Notifications
You must be signed in to change notification settings - Fork 826
/
evaluate_color.py
56 lines (40 loc) · 1.34 KB
/
evaluate_color.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import re
import json
import argparse
import sympy
from sympy.parsing.latex import parse_latex
import timeout_decorator
from tqdm import tqdm
def load_multi_line_json(f):
data = ''
all_data = []
raw_data =f.readlines()
for line in raw_data:
data = data + line
if (line.startswith('}')):
all_data.append(json.loads(data))
data = ''
return all_data
def extract_answer(pred, label):
pred = pred.split('Solution: ')[4].strip()
pred = pred.split('Problem: ')[0].strip()
pred = pred.split('\n')[0].strip()
return (pred in label)
def main(args):
with open(args.result_path, 'r') as fin:
datas = load_multi_line_json(fin)
num_correct = 0
total_problem = 0
for data in tqdm(datas):
try:
if (extract_answer(data['pred_ans'], data['real_ans']) == True):
num_correct = num_correct + 1
except:
pass
total_problem = total_problem + 1
print('Accuracy: {} ( {} / {} )'.format(round(num_correct / total_problem * 100, 2), num_correct, total_problem))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--result_path', type=str, help='The path to result')
args = parser.parse_args()
main(args)