-
Notifications
You must be signed in to change notification settings - Fork 13
/
tools.py
32 lines (27 loc) · 1.26 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from sklearn.metrics import confusion_matrix
#taken from https://gist.github.com/zachguo/10296432
def print_cm(cm, labels, hide_zeroes=False, hide_diagonal=False, hide_threshold=None):
"""pretty print for confusion matrixes"""
columnwidth = max([len(x) for x in labels] + [5]) # 5 is value length
empty_cell = " " * columnwidth
fst_empty_cell = (columnwidth-3)//2 * " " + "t/p" + (columnwidth-3)//2 * " "
if len(fst_empty_cell) < len(empty_cell):
fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell
# Print header
print(" " + fst_empty_cell, end=" ")
for label in labels:
print("%{0}s".format(columnwidth) % label, end=" ")
print()
# Print rows
for i, label1 in enumerate(labels):
print(" %{0}s".format(columnwidth) % label1, end=" ")
for j in range(len(labels)):
cell = "%{0}.1f".format(columnwidth) % cm[i, j]
if hide_zeroes:
cell = cell if float(cm[i, j]) != 0 else empty_cell
if hide_diagonal:
cell = cell if i != j else empty_cell
if hide_threshold:
cell = cell if cm[i, j] > hide_threshold else empty_cell
print(cell, end=" ")
print()