forked from zhan-xiong/segrnn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate.py
51 lines (46 loc) · 1.44 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from config import *
def count_correct_labels(predicted, gold):
correct_count = 0
predicted_set = set()
chars = 0
for tag, l in predicted:
label = (tag, chars, chars + l)
predicted_set.add(label)
chars += l
chars = 0
for tag, l in gold:
label = (tag, chars, chars + l)
if label in predicted_set:
correct_count += 1
chars += l
return correct_count
def eval_f1(seg_rnn, pairs, write_to_file=True):
gold_segs = 0
predicted_segs = 0
correct_segs = 0
for idx, (datum, (gold_label, sentence)) in enumerate(pairs):
if idx % 25 == 0:
print("eval ", idx)
predicted_label = seg_rnn.infer(datum.reshape(len(sentence), 1, EMBEDDING_DIM))
predicted_segs += len(predicted_label)
gold_segs += len(gold_label)
correct_segs += count_correct_labels(predicted_label, gold_label)
if predicted_segs > 0:
precision = correct_segs / predicted_segs
else:
precision = 0.0
print("Precision: ", precision)
if gold_segs > 0:
recall = correct_segs / gold_segs
else:
recall = 0.0
print("Recall: ", recall)
if precision > 0 and recall > 0:
f1 = 2.0 / (1.0 / precision + 1.0 / recall)
else:
f1 = 0.0
print("F1: " , f1)
if write_to_file:
f = open("eval_scores.txt", "a+")
f.write("%f %f %f\n" % (precision, recall, f1))
f.close()