-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathloss.py
51 lines (36 loc) · 1.41 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
def normalize(x):
return x / 255.0
def dice_coeff(prediction, target):
"""Calculate dice coefficient from raw prediction."""
mask = np.zeros_like(prediction)
mask[prediction >= 0.5] = 1
inter = np.sum(mask * target)
union = np.sum(mask) + np.sum(target)
epsilon = 1e-6
result = np.mean(2 * inter / (union + epsilon))
return result
class FocalLoss(nn.modules.loss._WeightedLoss):
def __init__(self, gamma=0, size_average=None, ignore_index=-100,
reduce=None, balance_param=1.0):
super(FocalLoss, self).__init__(size_average)
self.gamma = gamma
self.size_average = size_average
self.ignore_index = ignore_index
self.balance_param = balance_param
def forward(self, input, target):
# inputs and targets are assumed to be BatchxClasses
assert len(input.shape) == len(target.shape)
assert input.size(0) == target.size(0)
assert input.size(1) == target.size(1)
# compute the negative likelyhood
logpt = - F.binary_cross_entropy_with_logits(input, target)
pt = torch.exp(logpt)
# compute the loss
focal_loss = -((1 - pt) ** self.gamma) * logpt
balanced_focal_loss = self.balance_param * focal_loss
return balanced_focal_loss