-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathlosses.py
150 lines (115 loc) · 4.97 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
Define our custom loss function.
"""
import numpy as np
from tensorflow.keras import backend as K
import tensorflow as tf
import dill
def binary_focal_loss(gamma=2., alpha=.25):
"""
Binary form of focal loss.
FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
References:
https://arxiv.org/pdf/1708.02002.pdf
Usage:
model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
"""
def binary_focal_loss_fixed(y_true, y_pred):
"""
:param y_true: A tensor of the same shape as `y_pred`
:param y_pred: A tensor resulting from a sigmoid
:return: Output tensor.
"""
y_true = tf.cast(y_true, tf.float32)
# Define epsilon so that the back-propagation will not result in NaN for 0 divisor case
epsilon = K.epsilon()
# Add the epsilon to prediction value
# y_pred = y_pred + epsilon
# Clip the prediciton value
y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
# Calculate p_t
p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred)
# Calculate alpha_t
alpha_factor = K.ones_like(y_true) * alpha
alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
# Calculate cross entropy
cross_entropy = -K.log(p_t)
weight = alpha_t * K.pow((1 - p_t), gamma)
# Calculate focal loss
loss = weight * cross_entropy
# Sum the losses in mini_batch
loss = K.mean(K.sum(loss, axis=1))
return loss
return binary_focal_loss_fixed
def categorical_focal_loss(alpha, gamma=2.):
"""
Softmax version of focal loss.
When there is a skew between different categories/labels in your data set, you can try to apply this function as a
loss.
m
FL = ∑ -alpha * (1 - p_o,c)^gamma * y_o,c * log(p_o,c)
c=1
where m = number of classes, c = class and o = observation
Parameters:
alpha -- the same as weighing factor in balanced cross entropy. Alpha is used to specify the weight of different
categories/labels, the size of the array needs to be consistent with the number of classes.
gamma -- focusing parameter for modulating factor (1-p)
Default value:
gamma -- 2.0 as mentioned in the paper
alpha -- 0.25 as mentioned in the paper
References:
Official paper: https://arxiv.org/pdf/1708.02002.pdf
https://www.tensorflow.org/api_docs/python/tf/keras/backend/categorical_crossentropy
Usage:
model.compile(loss=[categorical_focal_loss(alpha=[[.25, .25, .25]], gamma=2)], metrics=["accuracy"], optimizer=adam)
"""
alpha = np.array(alpha, dtype=np.float32)
def categorical_focal_loss_fixed(y_true, y_pred):
"""
:param y_true: A tensor of the same shape as `y_pred`
:param y_pred: A tensor resulting from a softmax
:return: Output tensor.
"""
# Clip the prediction value to prevent NaN's and Inf's
epsilon = K.epsilon()
y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
# Calculate Cross Entropy
cross_entropy = -y_true * K.log(y_pred)
# Calculate Focal Loss
loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy
# Compute mean loss in mini_batch
return K.mean(K.sum(loss, axis=-1))
return categorical_focal_loss_fixed
def get_effective_class_weights(args):
'''
Determines class weighting according to the following paper
- https://arxiv.org/abs/1901.05555
'''
unique, class_frequencies = np.unique(args['y_train'], return_counts=True)
effective_num = [(1-args['reweight_beta']) / (1 - np.power(args['reweight_beta'], c_i)) for c_i in class_frequencies]
class_weights = effective_num / sum(effective_num) * args['num_classes']
print('calculated class frequencies')
class_weights = {k: v for k, v in enumerate(class_weights)}
return class_weights
def get_loss(args):
if args['reweight'] == 'effective_num':
class_weights = get_effective_class_weights(args)
else:
class_weights = {i: 1 for i in range(args['num_classes'])}
if args['loss'] == 'categorical_focal_loss':
if type(class_weights) is dict:
alpha = [class_weights[i] for i in class_weights.keys()]
else:
alpha = class_weights
loss = [categorical_focal_loss(alpha=[alpha], gamma=2)]
class_weights = {i: 1.0 / args['num_classes'] for i in range(args['num_classes'])} # class weighting already incorporated in focal loss alpha
else:
loss = args['loss']
return loss, class_weights
if __name__ == '__main__':
# Test serialization of nested functions
bin_inner = dill.loads(dill.dumps(binary_focal_loss(gamma=2., alpha=.25)))
print(bin_inner)
cat_inner = dill.loads(dill.dumps(categorical_focal_loss(gamma=2., alpha=.25)))
print(cat_inner)