-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
executable file
·82 lines (65 loc) · 2.26 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import torch
def dice_loss(inp, target):
smooth = 1e-7
iflat = inp.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
return 1 - ((2. * intersection + smooth) /
(iflat.sum() + tflat.sum() + smooth))
def MCD_loss(pm, gt, num_class):
acc_dice_loss = 0
for i in range(0,num_class):
acc_dice_loss += dice_loss(gt[:,i,:,:,:],pm[:,i,:,:,:])
acc_dice_loss/= num_class
return acc_dice_loss
# Setting up the Evaluation Metric
def dice(out, target):
smooth = 1e-7
oflat = out.view(-1)
tflat = target.view(-1)
intersection = (oflat * tflat).sum()
return (2*intersection+smooth)/(oflat.sum()+tflat.sum()+smooth)
def CE(out,target):
oflat = out.contiguous().view(-1)
tflat = target.contiguous().view(-1)
loss = torch.dot(-torch.log(oflat), tflat)/tflat.sum()
return loss
def CCE(out, target, num_class):
acc_ce_loss = 0
for i in range(num_class):
acc_ce_loss += CE(out[:,i,:,:,:],target[:,i,:,:,:])
acc_ce_loss /= num_class
return acc_ce_loss
def DCCE(out,target, n_classes):
l = MCD_loss(out,target, n_classes) + CCE(out,target,n_classes)
return l
def TV_loss(inp, target, alpha = 0.3, beta = 0.7):
smooth = 1e-7
iflat = inp.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
return 1 - (intersection + smooth)/(alpha*iflat.sum() + beta*tflat.sum() + smooth)
def MCT_loss(inp, target, num_class):
acc_tv_loss = 0
for i in range(0,num_class):
acc_tv_loss += TV_loss(inp[:,i,:,:,:],target[:,i,:,:,:])
acc_tv_loss/= num_class
return acc_tv_loss
def MSE(inp,target):
iflat = inp.contiguous().view(-1)
tflat = target.contiguous().view(-1)
num = len(iflat)
loss = (iflat - tflat)*(iflat - tflat)
loss = loss.sum()
loss = loss/num
return loss
def MSE_loss(inp,target,num_classes):
acc_mse_loss = 0
for i in range(0,num_classes):
acc_mse_loss += MSE(inp[:,i,:,:,:], target[:,i,:,:,:])
acc_mse_loss/=num_classes
return acc_mse_loss
def MCD_MSE_loss(inp,target,num_classes):
l = MCD_loss(inp,target,num_classes) + 0.1*MSE_loss(inp,target,num_classes)
return l