Skip to content

Commit 5d00d18

Browse files
mtthssJAXopt authors
authored and
JAXopt authors
committed
Use optax losses as backend for jaxopt losses.
PiperOrigin-RevId: 616117244
1 parent 8da3350 commit 5d00d18

File tree

3 files changed

+16
-52
lines changed

3 files changed

+16
-52
lines changed

jaxopt/_src/loss.py

+11-51
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from typing import Callable
1818

1919
import jax
20-
from jax.nn import softplus
2120
import jax.numpy as jnp
22-
from jax.scipy.special import logsumexp
2321
from jaxopt._src.projection import projection_simplex, projection_hypercube
2422

23+
from optax import losses as optax_losses
24+
2525

2626
# Regression
2727

@@ -39,10 +39,7 @@ def huber_loss(target: float, pred: float, delta: float = 1.0) -> float:
3939
References:
4040
https://en.wikipedia.org/wiki/Huber_loss
4141
"""
42-
abs_diff = jnp.abs(target - pred)
43-
return jnp.where(abs_diff > delta,
44-
delta * (abs_diff - .5 * delta),
45-
0.5 * abs_diff ** 2)
42+
return optax_losses.huber_loss(pred, target, delta)
4643

4744
# Binary classification.
4845

@@ -56,12 +53,8 @@ def binary_logistic_loss(label: int, logit: float) -> float:
5653
Returns:
5754
loss value
5855
"""
59-
# Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1].
60-
# softplus = proba * logit - xlogx(proba) - xlogx(1 - proba),
61-
# where xlogx(proba) = proba * log(proba).
62-
# Use -log sigmoid(logit) = softplus(-logit)
63-
# and 1 - sigmoid(logit) = sigmoid(-logit).
64-
return softplus(jnp.where(label, -logit, logit))
56+
return optax_losses.sigmoid_binary_cross_entropy(
57+
jnp.asarray(logit), jnp.asarray(label))
6558

6659

6760
def binary_sparsemax_loss(label: int, logit: float) -> float:
@@ -77,33 +70,7 @@ def binary_sparsemax_loss(label: int, logit: float) -> float:
7770
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
7871
Vlad Niculae. JMLR 2020. (Sec. 4.4)
7972
"""
80-
return sparse_plus(jnp.where(label, -logit, logit))
81-
82-
83-
def sparse_plus(x: float) -> float:
84-
r"""Sparse plus function.
85-
86-
Computes the function:
87-
88-
.. math::
89-
90-
\mathrm{sparse\_plus}(x) = \begin{cases}
91-
0, & x \leq -1\\
92-
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
93-
x, & 1 \leq x
94-
\end{cases}
95-
96-
This is the twin function of the softplus activation ensuring a zero output
97-
for inputs less than -1 and a linear output for inputs greater than 1,
98-
while remaining smooth, convex, monotonic by an adequate definition between
99-
-1 and 1.
100-
101-
Args:
102-
x: input (float)
103-
Returns:
104-
sparse_plus(x) as defined above
105-
"""
106-
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))
73+
return jax.nn.sparse_plus(jnp.where(label, -logit, logit))
10774

10875

10976
def sparse_sigmoid(x: float) -> float:
@@ -144,8 +111,7 @@ def binary_hinge_loss(label: int, score: float) -> float:
144111
References:
145112
https://en.wikipedia.org/wiki/Hinge_loss
146113
"""
147-
signed_label = 2.0 * label - 1.0
148-
return jnp.maximum(0, 1 - score * signed_label)
114+
return optax_losses.hinge_loss(score, 2.0 * label - 1.0)
149115

150116

151117
def binary_perceptron_loss(label: int, score: float) -> float:
@@ -160,8 +126,7 @@ def binary_perceptron_loss(label: int, score: float) -> float:
160126
References:
161127
https://en.wikipedia.org/wiki/Perceptron
162128
"""
163-
signed_label = 2.0 * label - 1.0
164-
return jnp.maximum(0, - score * signed_label)
129+
return optax_losses.perceptron_loss(score, 2.0 * label - 1.0)
165130

166131
# Multiclass classification.
167132

@@ -175,13 +140,7 @@ def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float:
175140
Returns:
176141
loss value
177142
"""
178-
logits = jnp.asarray(logits)
179-
# Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex.
180-
# logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba))
181-
# To avoid roundoff error, subtract target inside logsumexp.
182-
# logsumexp(logits) - logits[y] = logsumexp(logits - logits[y])
183-
logits = (logits - logits[label]).at[label].set(0.0)
184-
return logsumexp(logits)
143+
return optax_losses.softmax_cross_entropy_with_integer_labels(logits, label)
185144

186145

187146
def multiclass_sparsemax_loss(label: int, scores: jnp.ndarray) -> float:
@@ -272,5 +231,6 @@ def make_fenchel_young_loss(max_fun: Callable[[jnp.ndarray], float]):
272231
"""
273232

274233
def fy_loss(y_true, scores, *args, **kwargs):
275-
return max_fun(scores, *args, **kwargs) - jnp.vdot(y_true, scores)
234+
return optax_losses.make_fenchel_young_loss(max_fun)(
235+
scores.ravel(), y_true.ravel(), *args, **kwargs)
276236
return fy_loss

jaxopt/loss.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import jax
16+
sparse_plus = jax.nn.sparse_plus
17+
1518
from jaxopt._src.loss import binary_logistic_loss
16-
from jaxopt._src.loss import binary_sparsemax_loss, sparse_plus, sparse_sigmoid
19+
from jaxopt._src.loss import binary_sparsemax_loss, sparse_sigmoid
1720
from jaxopt._src.loss import huber_loss
1821
from jaxopt._src.loss import make_fenchel_young_loss
1922
from jaxopt._src.loss import multiclass_logistic_loss

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
jax>=0.2.18
22
jaxlib>=0.1.69
33
numpy>=1.18.4
4+
optax>=0.2.2
45
scipy>=1.0.0

0 commit comments

Comments
 (0)