-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
25 lines (22 loc) · 866 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, reduction='mean', ignore_index=-100):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none')
@torch.cuda.amp.autocast()
def forward(self, inputs, targets, mixup=None):
loss = self.loss_fn(inputs, targets)
pt = torch.exp(-loss)
F_loss = self.alpha * (1-pt)**self.gamma * loss
if mixup is not None:
F_loss = F_loss * mixup
if self.reduction == 'mean':
return torch.mean(F_loss)
elif self.reduction == 'sum':
return torch.sum(F_loss)
elif self.reduction == 'none':
return F_loss