-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFocalLoss.py
executable file
·87 lines (77 loc) · 4.41 KB
/
FocalLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# # -*- coding: utf-8 -*-
# # @Author : LG
# from torch import nn
# import torch
# from torch.nn import functional as F
# class focal_loss(nn.Module):
# def __init__(self, alpha=0.25, gamma=2, num_classes = 2, size_average=True):
# """
# focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
# 步骤详细的实现了 focal_loss损失函数.
# :param alpha: 阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
# :param gamma: 伽马γ,难易样本调节参数. retainnet中设置为2
# :param num_classes: 类别数量
# :param size_average: 损失计算方式,默认取均值
# """
# super(focal_loss,self).__init__()
# self.size_average = size_average
# if isinstance(alpha,list):
# assert len(alpha)==num_classes # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
# print("Focal_loss alpha = {}, 将对每一类权重进行精细化赋值".format(alpha))
# self.alpha = torch.Tensor(alpha)
# else:
# assert alpha<1 #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
# print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
# self.alpha = torch.zeros(num_classes)
# self.alpha[0] += alpha
# self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
# self.gamma = gamma
# def forward(self, preds, labels):
# """
# focal_loss损失计算
# :param preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数
# :param labels: 实际类别. size:[B,N] or [B]
# :return:
# """
# # # assert preds.dim()==2 and labels.dim()==1
# # preds = preds.view(-1,preds.size(-1))
# # self.alpha = self.alpha.to(preds.device)
# # preds_softmax = F.softmax(preds, dim=1) # 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然你也可以使用log_softmax,然后进行exp操作)
# # preds_logsoft = torch.log(preds_softmax)
# # preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # 这部分实现nll_loss ( crossempty = log_softmax + nll )
# # preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
# # self.alpha = self.alpha.gather(0,labels.view(-1))
# # loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
# # loss = torch.mul(self.alpha, loss.t())
# # assert preds.dim()==2 and labels.dim()==1
# loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
# loss = torch.mul(self.alpha, loss.t())
# if self.size_average:
# loss = loss.mean()
# else:
# loss = loss.sum()
# return loss
from torch import nn
import torch
from torch.nn import functional as F
class focal_loss(nn.Module):
def __init__(self, alpha=1, gamma=3, logits=False, reduce=False):
super(focal_loss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
full_pos=torch.full_like(targets,10)
full_neg=torch.full_like(targets,0.1)
weights=torch.where(targets==1,full_pos,full_neg) # 对正负样本的权重进行赋值
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, size_average=False, pos_weight=weights)
else:
BCE_loss = F.binary_cross_entropy(inputs, targets, size_average=False, reduce=False)
pt = torch.exp(-BCE_loss) # 将数据近似的恢复成sigmoid的概率值
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss