-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmetrics.py
145 lines (112 loc) · 4.39 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import numpy as np
from hausdorff import hausdorff_distance
from medpy.metric.binary import hd, dc
def dice(pred, target):
pred = pred.contiguous()
target = target.contiguous()
smooth = 0.00001
# intersection = (pred * target).sum(dim=2).sum(dim=2)
pred_flat = pred.view(1, -1)
target_flat = target.view(1, -1)
intersection = (pred_flat * target_flat).sum().item()
# loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
dice = (2 * intersection + smooth) / (pred_flat.sum().item() + target_flat.sum().item() + smooth)
return dice
def dice3D(img_gt, img_pred, voxel_size):
"""
Function to compute the metrics between two segmentation maps given as input.
Parameters
----------
img_gt: np.array
Array of the ground truth segmentation map.
img_pred: np.array
Array of the predicted segmentation map.
voxel_size: list, tuple or np.array
The size of a voxel of the images used to compute the volumes.
Return
------
A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml),
Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)]
"""
if img_gt.ndim != img_pred.ndim:
raise ValueError("The arrays 'img_gt' and 'img_pred' should have the "
"same dimension, {} against {}".format(img_gt.ndim,
img_pred.ndim))
res = []
# Loop on each classes of the input images
for c in [3, 1, 2]:
# Copy the gt image to not alterate the input
gt_c_i = np.copy(img_gt)
gt_c_i[gt_c_i != c] = 0
# Copy the pred image to not alterate the input
pred_c_i = np.copy(img_pred)
pred_c_i[pred_c_i != c] = 0
# Clip the value to compute the volumes
gt_c_i = np.clip(gt_c_i, 0, 1)
pred_c_i = np.clip(pred_c_i, 0, 1)
# Compute the Dice
dice = dc(gt_c_i, pred_c_i)
# Compute volume
# volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.
# volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.
# res += [dice, volpred, volpred-volgt]
res += [dice]
return res
def hd_3D(img_pred, img_gt, labels=[3, 1, 2]):
res = []
for c in labels:
gt_c_i = np.copy(img_gt)
gt_c_i[gt_c_i != c] = 0
pred_c_i = np.copy(img_pred)
pred_c_i[pred_c_i != c] = 0
gt_c_i = np.clip(gt_c_i, 0, 1)
pred_c_i = np.clip(pred_c_i, 0, 1)
if np.sum(pred_c_i) == 0 or np.sum(gt_c_i) == 0:
hausdorff = 0
else:
hausdorff = hd(pred_c_i, gt_c_i)
res += [hausdorff]
return res
def cal_hausdorff_distance(pred,target):
pred = np.array(pred.contiguous())
target = np.array(target.contiguous())
result = hausdorff_distance(pred,target,distance="euclidean")
return result
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [N, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [N, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1)
# result = result.scatter_(1, input.cpu(), 1)
return result
def match_pred_gt(pred, gt):
""" pred: (1, C, H, W)
gt: (1, C, H, W)
"""
gt_labels = torch.unique(gt, sorted=True)[1:]
pred_labels = torch.unique(pred, sorted=True)[1:]
if len(gt_labels) != 0 and len(pred_labels) != 0:
dice_Matrix = torch.zeros((len(pred_labels), len(gt_labels)))
for i, pl in enumerate(pred_labels):
pred_i = torch.tensor(pred==pl, dtype=torch.float)
for j, gl in enumerate(gt_labels):
dice_Matrix[i, j] = dice(make_one_hot(pred_i, 2)[0], make_one_hot(gt==gl, 2)[0])
# max_axis0 = np.max(dice_Matrix, axis=0)
max_arg0 = np.argmax(dice_Matrix, axis=0)
else:
return torch.zeros_like(pred)
pred_match = torch.zeros_like(pred)
for i, arg in enumerate(max_arg0):
pred_match[pred==pred_labels[arg]] = i + 1
return pred_match
if __name__ == "__main__":
npy_path = "/home/fcheng/Cardia/source_code/logs/logs_df_50000/eval_pp_test/200.npy"
pred_df, gt_df = np.load(npy_p)