diff --git a/deel/lip/losses.py b/deel/lip/losses.py index 11043365..50cdba25 100644 --- a/deel/lip/losses.py +++ b/deel/lip/losses.py @@ -9,7 +9,12 @@ from functools import partial import numpy as np import tensorflow as tf -from tensorflow.keras.losses import categorical_crossentropy, Loss, Reduction +from tensorflow.keras.losses import ( + categorical_crossentropy, + sparse_categorical_crossentropy, + Loss, + Reduction, +) from tensorflow.keras.utils import register_keras_serializable @@ -493,3 +498,66 @@ def get_config(self): config = {"tau": self.tau.numpy()} base_config = super(TauCategoricalCrossentropy, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + +@register_keras_serializable("deel-lip", "TauSparseCategoricalCrossentropy") +class TauSparseCategoricalCrossentropy(Loss): + def __init__( + self, tau, reduction=Reduction.AUTO, name="TauSparseCategoricalCrossentropy" + ): + """ + Similar to original sparse categorical crossentropy, but with a settable + temperature parameter. + + Args: + tau (float): temperature parameter. + reduction: reduction of the loss, passed to original loss. + name (str): name of the loss + """ + self.tau = tf.Variable(tau, dtype=tf.float32) + super().__init__(name=name, reduction=reduction) + + def call(self, y_true, y_pred): + return ( + sparse_categorical_crossentropy(y_true, self.tau * y_pred, from_logits=True) + / self.tau + ) + + def get_config(self): + config = {"tau": self.tau.numpy()} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@register_keras_serializable("deel-lip", "TauBinaryCrossentropy") +class TauBinaryCrossentropy(Loss): + def __init__(self, tau, reduction=Reduction.AUTO, name="TauBinaryCrossentropy"): + """ + Similar to the original binary crossentropy, but with a settable temperature + parameter. y_pred must be a logits tensor (before sigmoid) and not + probabilities. + + Note that `y_true` and `y_pred` must be of rank 2: (batch_size, 1). `y_true` + accepts label values in (0, 1) or (-1, 1). + + Args: + tau: temperature parameter. + reduction: reduction of the loss, passed to original loss. + name: name of the loss + """ + self.tau = tf.Variable(tau, dtype=tf.float32) + super().__init__(name=name, reduction=reduction) + + def call(self, y_true, y_pred): + y_true = tf.cast(y_true > 0, y_pred.dtype) + return ( + tf.keras.losses.binary_crossentropy( + y_true, self.tau * y_pred, from_logits=True + ) + / self.tau + ) + + def get_config(self): + config = {"tau": self.tau.numpy()} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tests/test_losses.py b/tests/test_losses.py index 53c23e1a..9c17b0c3 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -14,6 +14,8 @@ MulticlassHKR, MultiMargin, TauCategoricalCrossentropy, + TauSparseCategoricalCrossentropy, + TauBinaryCrossentropy, CategoricalHinge, ) from deel.lip.utils import process_labels_for_multi_gpu @@ -218,6 +220,43 @@ def test_tau_catcrossent(self): ) check_serialization(1, taucatcrossent_loss) + def test_tau_sparse_catcrossent(self): + tau_sparse_catcrossent_loss = TauSparseCategoricalCrossentropy(1.0) + n_class = 10 + n_items = 10000 + y_true = np.random.randint(0, n_class, n_items) + y_pred = tf.random.normal((n_items, n_class)) + loss_val = tau_sparse_catcrossent_loss(y_true, y_pred).numpy() + loss_val_2 = tau_sparse_catcrossent_loss( + tf.cast(y_true, dtype=tf.int32), y_pred + ).numpy() + np.testing.assert_almost_equal( + loss_val_2, loss_val, 1, "test failed when y_true has dtype int32" + ) + check_serialization(n_class, tau_sparse_catcrossent_loss) + + def test_tau_binary_crossent(self): + loss = TauBinaryCrossentropy(2.0) + y_true = binary_tf_data([1.0, 1.0, 1.0, 0.0, 0.0, 0.0]) + y_pred = binary_tf_data([0.5, 1.5, -0.5, -0.5, -1.5, 0.5]) + + # Assert that loss value is equal to expected value + expected_loss_val = 0.279185 + loss_val = loss(y_true, y_pred).numpy() + np.testing.assert_allclose(loss_val, expected_loss_val, rtol=1e-6) + + # Assert that loss value is the same when y_true is of type int32 + loss_val_2 = loss(tf.cast(y_true, dtype=tf.int32), y_pred).numpy() + np.testing.assert_allclose(loss_val_2, expected_loss_val, rtol=1e-6) + + # Assert that loss value is the same when y_true is [-1, 1] instead of [0, 1] + y_true2 = tf.where(y_true == 1.0, 1.0, -1.0) + loss_val_3 = loss(y_true2, y_pred).numpy() + np.testing.assert_allclose(loss_val_3, expected_loss_val, rtol=1e-6) + + # Assert that loss object is correctly serialized + check_serialization(1, loss) + def test_no_reduction_binary_losses(self): """ Assert binary losses without reduction. Three losses are tested on hardcoded @@ -230,12 +269,14 @@ def test_no_reduction_binary_losses(self): KR(reduction="none"), HingeMargin(0.7 * 2.0, reduction="none"), HKR(alpha=2.5, min_margin=2.0, reduction="none"), + TauBinaryCrossentropy(tau=0.5, reduction="none"), ) expected_loss_values = ( np.array([1.0, 2.2, -0.2, 1.4, 2.6, 0.8, -0.4, 1.8]), np.array([0.2, 0, 0.8, 0, 0, 0.3, 0.9, 0]), np.array([0.25, -2.2, 2.95, -0.65, -2.6, 0.7, 3.4, -1.55]), + [1.15188, 0.91098, 1.43692, 1.06676, 0.84011, 1.19628, 1.48879, 0.98650], ) for loss, expected_loss_val in zip(losses, expected_loss_values): @@ -275,6 +316,7 @@ def test_no_reduction_multiclass_losses(self): MulticlassHKR(alpha=2.5, min_margin=1.0, reduction="none"), CategoricalHinge(1.1, reduction="none"), TauCategoricalCrossentropy(2.0, reduction="none"), + TauSparseCategoricalCrossentropy(2.0, reduction="none"), ) expected_loss_values = ( @@ -295,10 +337,25 @@ def test_no_reduction_multiclass_losses(self): 0.076357, ] ), + np.float64( + [ + 0.044275, + 0.115109, + 1.243572, + 0.084923, + 0.010887, + 2.802300, + 0.114224, + 0.076357, + ] + ), ) for loss, expected_loss_val in zip(losses, expected_loss_values): - loss_val = loss(y_true, y_pred) + if isinstance(loss, TauSparseCategoricalCrossentropy): + loss_val = loss(tf.argmax(y_true, axis=-1), y_pred) + else: + loss_val = loss(y_true, y_pred) np.testing.assert_allclose( loss_val, expected_loss_val, @@ -325,9 +382,10 @@ def test_minibatches_binary_losses(self): KR(multi_gpu=True, reduction=reduction), HingeMargin(0.7 * 2.0, reduction=reduction), HKR(alpha=2.5, min_margin=2.0, multi_gpu=True, reduction=reduction), + TauBinaryCrossentropy(tau=1.5, reduction=reduction), ) - expected_loss_values = (9.2, 2.2, 0.3) + expected_loss_values = (9.2, 2.2, 0.3, 2.19262) # Losses are tested on full batch for loss, expected_loss_val in zip(losses, expected_loss_values):