Skip to content

Commit

Permalink
test(losses): unit tests for TauBinaryCrossentropy
Browse files Browse the repository at this point in the history
  • Loading branch information
cofri committed Aug 21, 2023
1 parent 22f0ad3 commit f2c5da2
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MultiMargin,
TauCategoricalCrossentropy,
TauSparseCategoricalCrossentropy,
TauBinaryCrossentropy,
CategoricalHinge,
)
from deel.lip.utils import process_labels_for_multi_gpu
Expand Down Expand Up @@ -234,6 +235,28 @@ def test_tau_sparse_catcrossent(self):
)
check_serialization(n_class, tau_sparse_catcrossent_loss)

def test_tau_binary_crossent(self):
loss = TauBinaryCrossentropy(2.0)
y_true = binary_tf_data([1.0, 1.0, 1.0, 0.0, 0.0, 0.0])
y_pred = binary_tf_data([0.5, 1.5, -0.5, -0.5, -1.5, 0.5])

# Assert that loss value is equal to expected value
expected_loss_val = 0.279185
loss_val = loss(y_true, y_pred).numpy()
np.testing.assert_allclose(loss_val, expected_loss_val, rtol=1e-6)

# Assert that loss value is the same when y_true is of type int32
loss_val_2 = loss(tf.cast(y_true, dtype=tf.int32), y_pred).numpy()
np.testing.assert_allclose(loss_val_2, expected_loss_val, rtol=1e-6)

# Assert that loss value is the same when y_true is [-1, 1] instead of [0, 1]
y_true2 = tf.where(y_true == 1.0, 1.0, -1.0)
loss_val_3 = loss(y_true2, y_pred).numpy()
np.testing.assert_allclose(loss_val_3, expected_loss_val, rtol=1e-6)

# Assert that loss object is correctly serialized
check_serialization(1, loss)

def test_no_reduction_binary_losses(self):
"""
Assert binary losses without reduction. Three losses are tested on hardcoded
Expand All @@ -246,12 +269,14 @@ def test_no_reduction_binary_losses(self):
KR(reduction="none"),
HingeMargin(0.7 * 2.0, reduction="none"),
HKR(alpha=2.5, min_margin=2.0, reduction="none"),
TauBinaryCrossentropy(tau=0.5, reduction="none"),
)

expected_loss_values = (
np.array([1.0, 2.2, -0.2, 1.4, 2.6, 0.8, -0.4, 1.8]),
np.array([0.2, 0, 0.8, 0, 0, 0.3, 0.9, 0]),
np.array([0.25, -2.2, 2.95, -0.65, -2.6, 0.7, 3.4, -1.55]),
[1.15188, 0.91098, 1.43692, 1.06676, 0.84011, 1.19628, 1.48879, 0.98650],
)

for loss, expected_loss_val in zip(losses, expected_loss_values):
Expand Down Expand Up @@ -323,7 +348,7 @@ def test_no_reduction_multiclass_losses(self):
0.114224,
0.076357,
]
)
),
)

for loss, expected_loss_val in zip(losses, expected_loss_values):
Expand Down Expand Up @@ -357,9 +382,10 @@ def test_minibatches_binary_losses(self):
KR(multi_gpu=True, reduction=reduction),
HingeMargin(0.7 * 2.0, reduction=reduction),
HKR(alpha=2.5, min_margin=2.0, multi_gpu=True, reduction=reduction),
TauBinaryCrossentropy(tau=1.5, reduction=reduction),
)

expected_loss_values = (9.2, 2.2, 0.3)
expected_loss_values = (9.2, 2.2, 0.3, 2.19262)

# Losses are tested on full batch
for loss, expected_loss_val in zip(losses, expected_loss_values):
Expand Down

0 comments on commit f2c5da2

Please sign in to comment.