-
Notifications
You must be signed in to change notification settings - Fork 28
/
focal_loss.py
executable file
·58 lines (51 loc) · 2.03 KB
/
focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf-8 -*-
import tensorflow as tf
"""
Tensorflow实现何凯明的Focal Loss, 该损失函数主要用于解决分类问题中的类别不平衡
focal_loss_sigmoid: 二分类loss
focal_loss_softmax: 多分类loss
Reference Paper : Focal Loss for Dense Object Detection
"""
def focal_loss_sigmoid(labels,logits,alpha=0.25,gamma=2):
"""
Computer focal loss for binary classification
Args:
labels: A int32 tensor of shape [batch_size].
logits: A float32 tensor of shape [batch_size].
alpha: A scalar for focal loss alpha hyper-parameter. If positive samples number
> negtive samples number, alpha < 0.5 and vice versa.
gamma: A scalar for focal loss gamma hyper-parameter.
Returns:
A tensor of the same shape as `lables`
"""
y_pred=tf.nn.sigmoid(logits)
labels=tf.to_float(labels)
L=-labels*(1-alpha)*((1-y_pred)*gamma)*tf.log(y_pred)-\
(1-labels)*alpha*(y_pred**gamma)*tf.log(1-y_pred)
return L
def focal_loss_softmax(labels,logits,gamma=2):
"""
Computer focal loss for multi classification
Args:
labels: A int32 tensor of shape [batch_size].
logits: A float32 tensor of shape [batch_size,num_classes].
gamma: A scalar for focal loss gamma hyper-parameter.
Returns:
A tensor of the same shape as `lables`
"""
y_pred=tf.nn.softmax(logits,dim=-1) # [batch_size,num_classes]
labels=tf.one_hot(labels,depth=y_pred.shape[1])
L=-labels*((1-y_pred)**gamma)*tf.log(y_pred)
L=tf.reduce_sum(L,axis=1)
return L
if __name__ == '__main__':
logits=tf.random_uniform(shape=[5],minval=-1,maxval=1,dtype=tf.float32)
labels=tf.Variable([0,1,0,0,1])
loss1=focal_loss_sigmoid(labels=labels,logits=logits)
logits2=tf.random_uniform(shape=[5,4],minval=-1,maxval=1,dtype=tf.float32)
labels2=tf.Variable([1,0,2,3,1])
loss2=focal_loss_softmax(labels==labels2,logits=logits2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(loss1)
print sess.run(loss2)