17
17
from typing import Callable
18
18
19
19
import jax
20
- from jax .nn import softplus
21
20
import jax .numpy as jnp
22
- from jax .scipy .special import logsumexp
23
21
from jaxopt ._src .projection import projection_simplex , projection_hypercube
24
22
23
+ from optax import losses as optax_losses
24
+
25
25
26
26
# Regression
27
27
@@ -39,10 +39,7 @@ def huber_loss(target: float, pred: float, delta: float = 1.0) -> float:
39
39
References:
40
40
https://en.wikipedia.org/wiki/Huber_loss
41
41
"""
42
- abs_diff = jnp .abs (target - pred )
43
- return jnp .where (abs_diff > delta ,
44
- delta * (abs_diff - .5 * delta ),
45
- 0.5 * abs_diff ** 2 )
42
+ return optax_losses .huber_loss (pred , target , delta )
46
43
47
44
# Binary classification.
48
45
@@ -56,12 +53,8 @@ def binary_logistic_loss(label: int, logit: float) -> float:
56
53
Returns:
57
54
loss value
58
55
"""
59
- # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1].
60
- # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba),
61
- # where xlogx(proba) = proba * log(proba).
62
- # Use -log sigmoid(logit) = softplus(-logit)
63
- # and 1 - sigmoid(logit) = sigmoid(-logit).
64
- return softplus (jnp .where (label , - logit , logit ))
56
+ return optax_losses .sigmoid_binary_cross_entropy (
57
+ jnp .asarray (logit ), jnp .asarray (label ))
65
58
66
59
67
60
def binary_sparsemax_loss (label : int , logit : float ) -> float :
@@ -77,33 +70,7 @@ def binary_sparsemax_loss(label: int, logit: float) -> float:
77
70
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
78
71
Vlad Niculae. JMLR 2020. (Sec. 4.4)
79
72
"""
80
- return sparse_plus (jnp .where (label , - logit , logit ))
81
-
82
-
83
- def sparse_plus (x : float ) -> float :
84
- r"""Sparse plus function.
85
-
86
- Computes the function:
87
-
88
- .. math::
89
-
90
- \mathrm{sparse\_plus}(x) = \begin{cases}
91
- 0, & x \leq -1\\
92
- \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
93
- x, & 1 \leq x
94
- \end{cases}
95
-
96
- This is the twin function of the softplus activation ensuring a zero output
97
- for inputs less than -1 and a linear output for inputs greater than 1,
98
- while remaining smooth, convex, monotonic by an adequate definition between
99
- -1 and 1.
100
-
101
- Args:
102
- x: input (float)
103
- Returns:
104
- sparse_plus(x) as defined above
105
- """
106
- return jnp .where (x <= - 1.0 , 0.0 , jnp .where (x >= 1.0 , x , (x + 1.0 )** 2 / 4 ))
73
+ return jax .nn .sparse_plus (jnp .where (label , - logit , logit ))
107
74
108
75
109
76
def sparse_sigmoid (x : float ) -> float :
@@ -144,8 +111,7 @@ def binary_hinge_loss(label: int, score: float) -> float:
144
111
References:
145
112
https://en.wikipedia.org/wiki/Hinge_loss
146
113
"""
147
- signed_label = 2.0 * label - 1.0
148
- return jnp .maximum (0 , 1 - score * signed_label )
114
+ return optax_losses .hinge_loss (score , 2.0 * label - 1.0 )
149
115
150
116
151
117
def binary_perceptron_loss (label : int , score : float ) -> float :
@@ -160,8 +126,7 @@ def binary_perceptron_loss(label: int, score: float) -> float:
160
126
References:
161
127
https://en.wikipedia.org/wiki/Perceptron
162
128
"""
163
- signed_label = 2.0 * label - 1.0
164
- return jnp .maximum (0 , - score * signed_label )
129
+ return optax_losses .perceptron_loss (score , 2.0 * label - 1.0 )
165
130
166
131
# Multiclass classification.
167
132
@@ -175,13 +140,7 @@ def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float:
175
140
Returns:
176
141
loss value
177
142
"""
178
- logits = jnp .asarray (logits )
179
- # Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex.
180
- # logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba))
181
- # To avoid roundoff error, subtract target inside logsumexp.
182
- # logsumexp(logits) - logits[y] = logsumexp(logits - logits[y])
183
- logits = (logits - logits [label ]).at [label ].set (0.0 )
184
- return logsumexp (logits )
143
+ return optax_losses .softmax_cross_entropy_with_integer_labels (logits , label )
185
144
186
145
187
146
def multiclass_sparsemax_loss (label : int , scores : jnp .ndarray ) -> float :
@@ -272,5 +231,6 @@ def make_fenchel_young_loss(max_fun: Callable[[jnp.ndarray], float]):
272
231
"""
273
232
274
233
def fy_loss (y_true , scores , * args , ** kwargs ):
275
- return max_fun (scores , * args , ** kwargs ) - jnp .vdot (y_true , scores )
234
+ return optax_losses .make_fenchel_young_loss (max_fun )(
235
+ scores .ravel (), y_true .ravel (), * args , ** kwargs )
276
236
return fy_loss
0 commit comments