-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathevaluate.py
110 lines (84 loc) · 2.75 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
__author__ = 'qiao'
'''
evaluate GeneGPT on all GeneTuring tasks and one GeneHop task (Disease gene location)
'''
import glob
import json
import os
import sys
def get_answer(answer, task):
mapper = {'Caenorhabditis elegans': 'worm',
'Homo sapiens': 'human',
'Danio rerio': 'zebrafish',
'Mus musculus': 'mouse',
'Saccharomyces cerevisiae': 'yeast',
'Rattus norvegicus': 'rat',
'Gallus gallus': 'chicken'}
if task == 'SNP location':
answer = answer.strip().split()[-1]
if 'chr' not in answer:
answer = 'chr' + answer
elif task == 'Gene location':
answer = answer.strip().split()[-1]
if 'chr' not in answer:
answer = 'chr' + answer
elif task == 'Gene disease association':
answer = answer.strip().replace('Answer: ', '')
answer = answer.split(', ')
elif task == 'Disease gene location':
answer = answer.strip().replace('Answer: ', '')
answer = answer.split(', ')
elif task == 'Protein-coding genes':
answer = answer.strip().replace('Answer: ', '')
if answer == 'Yes':
answer = 'TRUE'
elif answer == 'No':
answer = 'NA'
elif task == 'Multi-species DNA aligment':
answer = answer.strip().replace('Answer: ', '')
answer = mapper.get(answer, answer)
else:
answer = answer.strip().replace('Answer: ', '')
return answer
if __name__ == '__main__':
# all geneturing and genehop disease gene location tasks are automatically evaluated
qas = json.load(open('data/geneturing.json'))
qas['Disease gene location'] = json.load(open('data/genehop.json'))['Disease gene location']
# result dir path to evaluate
folder = sys.argv[1]
for task in glob.glob(os.path.join(folder, '*')):
print(f'\nEvaluating {task}')
preds = json.load(open(task))
task = os.path.basename(task).replace('.json', '')
if task not in qas:
print(f'{task} is not automatically evaluated.')
continue
info = qas[task]
pred_q2a = {}
for entry in preds:
pred_q2a[entry[0]] = get_answer(entry[2], task)
correct = []
for question, answer in info.items():
if task == 'Gene disease association':
answer = answer.split(', ')
answer_in = [ans in pred_q2a[question] for ans in answer]
correct.append(sum(answer_in) / len(answer_in))
elif task == 'Disease gene location':
answer_in = [ans in pred_q2a[question] for ans in answer]
correct.append(sum(answer_in) / len(answer_in))
elif task == 'Human genome DNA aligment':
pred = pred_q2a[question]
pred_chr = pred.split(':')[0]
answer_chr = answer.split(':')[0]
if pred == answer:
correct.append(1)
elif pred_chr == answer_chr:
correct.append(0.5)
else:
correct.append(0)
else:
if pred_q2a[question] == answer:
correct.append(1)
else:
correct.append(0)
print(sum(correct) / len(correct))