From 90568daee498ec53123a46ebf11d11974cb66f1d Mon Sep 17 00:00:00 2001 From: Harshal Janjani <75426551+harshaljanjani@users.noreply.github.com> Date: Thu, 23 Jan 2025 00:52:38 +0530 Subject: [PATCH] fix(metrics): Fix BinaryAccuracy metric to handle boolean inputs (#20782) * Fix BinaryAccuracy metric to handle boolean inputs Previously, BinaryAccuracy would return incorrect results when given boolean inputs in JAX backend, and would raise errors in TensorFlow backend. This was because the metric expects numerical values (floats/integers) but wasn't properly handling boolean array inputs. Fix by casting y_true and y_pred to floatx() in MeanMetricWrapper.update_state(). This ensures consistent behavior across backends and proper handling of boolean inputs. * fix: Make the linter happy :) * fix: Align update_state casting with metric's data type --- keras/src/metrics/reduction_metrics.py | 3 +++ keras/src/metrics/reduction_metrics_test.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/keras/src/metrics/reduction_metrics.py b/keras/src/metrics/reduction_metrics.py index 3dde46f95835..b4c075e2f626 100644 --- a/keras/src/metrics/reduction_metrics.py +++ b/keras/src/metrics/reduction_metrics.py @@ -199,6 +199,9 @@ def __init__(self, fn, name=None, dtype=None, **kwargs): self._direction = "down" def update_state(self, y_true, y_pred, sample_weight=None): + y_true = backend.cast(y_true, self.dtype) + y_pred = backend.cast(y_pred, self.dtype) + mask = backend.get_keras_mask(y_pred) values = self._fn(y_true, y_pred, **self._fn_kwargs) if sample_weight is not None and mask is not None: diff --git a/keras/src/metrics/reduction_metrics_test.py b/keras/src/metrics/reduction_metrics_test.py index f697918ccd34..679bed081804 100644 --- a/keras/src/metrics/reduction_metrics_test.py +++ b/keras/src/metrics/reduction_metrics_test.py @@ -1,6 +1,9 @@ import numpy as np from keras.src import backend +from keras.src import layers +from keras.src import metrics +from keras.src import models from keras.src import testing from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.metrics import reduction_metrics @@ -174,3 +177,17 @@ def test_weighted_dynamic_shape(self): KerasTensor((None, 5)), ) self.assertAllEqual(result.shape, ()) + + def test_binary_accuracy_with_boolean_inputs(self): + inp = layers.Input(shape=(1,)) + out = inp > 0.5 + model = models.Model(inputs=inp, outputs=out) + + x = np.random.rand(32, 1) + y = x > 0.5 + + res = model.predict(x) + metric = metrics.BinaryAccuracy() + metric.update_state(y, res) + result = metric.result() + assert result == 1.0