Skip to content

Commit

Permalink
fix(metrics): Fix BinaryAccuracy metric to handle boolean inputs (#20782
Browse files Browse the repository at this point in the history
)

* Fix BinaryAccuracy metric to handle boolean inputs

Previously, BinaryAccuracy would return incorrect results when given boolean
inputs in JAX backend, and would raise errors in TensorFlow backend. This was
because the metric expects numerical values (floats/integers) but wasn't
properly handling boolean array inputs.

Fix by casting y_true and y_pred to floatx() in MeanMetricWrapper.update_state().
This ensures consistent behavior across backends and proper handling of boolean
inputs.

* fix: Make the linter happy :)

* fix: Align update_state casting with metric's data type
  • Loading branch information
harshaljanjani authored Jan 22, 2025
1 parent 5df8fb9 commit 90568da
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras/src/metrics/reduction_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def __init__(self, fn, name=None, dtype=None, **kwargs):
self._direction = "down"

def update_state(self, y_true, y_pred, sample_weight=None):
y_true = backend.cast(y_true, self.dtype)
y_pred = backend.cast(y_pred, self.dtype)

mask = backend.get_keras_mask(y_pred)
values = self._fn(y_true, y_pred, **self._fn_kwargs)
if sample_weight is not None and mask is not None:
Expand Down
17 changes: 17 additions & 0 deletions keras/src/metrics/reduction_metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np

from keras.src import backend
from keras.src import layers
from keras.src import metrics
from keras.src import models
from keras.src import testing
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.metrics import reduction_metrics
Expand Down Expand Up @@ -174,3 +177,17 @@ def test_weighted_dynamic_shape(self):
KerasTensor((None, 5)),
)
self.assertAllEqual(result.shape, ())

def test_binary_accuracy_with_boolean_inputs(self):
inp = layers.Input(shape=(1,))
out = inp > 0.5
model = models.Model(inputs=inp, outputs=out)

x = np.random.rand(32, 1)
y = x > 0.5

res = model.predict(x)
metric = metrics.BinaryAccuracy()
metric.update_state(y, res)
result = metric.result()
assert result == 1.0

0 comments on commit 90568da

Please sign in to comment.