forked from CoinCheung/pytorch-loss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdual_focal_loss.py
36 lines (30 loc) · 1.2 KB
/
dual_focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
class Dual_Focal_loss(nn.Module):
'''
This loss is proposed in this paper: https://arxiv.org/abs/1909.11932
It does not work in my projects, hope it will work well in your projects.
Hope you can correct me if there are any mistakes in the implementation.
'''
def __init__(self, ignore_lb=255, eps=1e-5, reduction='mean'):
super(Dual_Focal_loss, self).__init__()
self.ignore_lb = ignore_lb
self.eps = eps
self.reduction = reduction
self.mse = nn.MSELoss(reduction='none')
def forward(self, logits, label):
ignore = label.data.cpu() == self.ignore_lb
n_valid = (ignore == 0).sum()
label = label.clone()
label[ignore] = 0
lb_one_hot = logits.data.clone().zero_().scatter_(1, label.unsqueeze(1), 1).detach()
pred = torch.softmax(logits, dim=1)
loss = -torch.log(self.eps + 1. - self.mse(pred, lb_one_hot)).sum(dim=1)
loss[ignore] = 0
if self.reduction == 'mean':
loss = loss.sum() / n_valid
elif self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'none':
loss = loss
return loss