-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluate_medqa_handmade.py
36 lines (34 loc) · 1.61 KB
/
evaluate_medqa_handmade.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
from nltk.translate.bleu_score import sentence_bleu
from defs import NQG_MEDQA_HANDMADE_PREDS_OUTPUT_PATH, MEDQA_HANDMADE_FOR_NQG_DATASET
from evaluating.model_benchmark import benchmark
if __name__ == '__main__':
distribution = {
"total": 0, "0.1": 0, "0.2": 0, "0.3": 0, "0.5": 0, "1": 0
}
preds = np.loadtxt(NQG_MEDQA_HANDMADE_PREDS_OUTPUT_PATH, delimiter='\n', dtype=str, comments=None)
targets = np.loadtxt(f"{MEDQA_HANDMADE_FOR_NQG_DATASET}/test/data.txt.target.txt", delimiter='\n', dtype=str,
comments=None)
targets = np.reshape(targets, newshape=(-1, 1))
answers = np.loadtxt(f"{MEDQA_HANDMADE_FOR_NQG_DATASET}/test/data.txt.source.txt", delimiter='\n', dtype=str,
comments=None)
benchmark(preds, targets)
# for (pred, target, answer) in zip(preds, targets, answers):
# print(f"Target: {target}")
# print(f"Prediction: {pred}")
# print(f"Answer: {answer}")
# bleu = {}
# for i in range(1, 5):
# bleu[i] = sentence_bleu([target.split(" ")], pred.split(" "), weights=list(1.0 / i for _ in range(i)))
# print(f"BLEU-{i}: {bleu[i]}")
# for threshold in distribution.keys():
# if threshold != "total":
# if bleu[1] >= float(threshold):
# distribution[threshold] += 1
# distribution["total"] += 1
# print()
#
# print(f"Total: {distribution['total']}")
# for threshold, size in distribution.items():
# if threshold != "total":
# print(f"BLEU-1 >= {threshold}: {size}")