diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py index 5cd2e37527a..ec4c53bb86c 100644 --- a/keras/src/layers/normalization/batch_normalization.py +++ b/keras/src/layers/normalization/batch_normalization.py @@ -318,15 +318,12 @@ def _moments(self, inputs, mask): synchronized=self.synchronized, ) - mask_weights = ops.cast( - mask, - inputs.dtype, + mask_weights = ops.cast(mask, inputs.dtype) + mask_weights_broadcasted = ops.expand_dims(mask_weights, axis=-1) + broadcasted_mask = ops.broadcast_to( + mask_weights_broadcasted, ops.shape(inputs) ) - mask_weights_broadcasted = ops.expand_dims( - mask_weights, - axis=-1, - ) - weighted_inputs = mask_weights_broadcasted * inputs + weighted_inputs = broadcasted_mask * inputs weighted_input_sum = ops.sum( weighted_inputs, @@ -334,19 +331,19 @@ def _moments(self, inputs, mask): keepdims=True, ) sum_of_weights = ops.sum( - mask_weights_broadcasted, + broadcasted_mask, self._reduction_axes, keepdims=True, ) - mean = weighted_input_sum / (sum_of_weights + backend.config.epsilon()) + mean = weighted_input_sum / (sum_of_weights + backend.epsilon()) difference = weighted_inputs - mean squared_difference = ops.square(difference) weighted_distsq = ops.sum( - mask_weights_broadcasted * squared_difference, + broadcasted_mask * squared_difference, self._reduction_axes, keepdims=True, ) - variance = weighted_distsq / (sum_of_weights + backend.config.epsilon()) + variance = weighted_distsq / (sum_of_weights + backend.epsilon()) return ops.squeeze(mean), ops.squeeze(variance) diff --git a/keras/src/layers/normalization/batch_normalization_test.py b/keras/src/layers/normalization/batch_normalization_test.py index 801fd030b0e..d713670aae5 100644 --- a/keras/src/layers/normalization/batch_normalization_test.py +++ b/keras/src/layers/normalization/batch_normalization_test.py @@ -221,3 +221,21 @@ def test_large_value_within_autocast_scope(self): with backend.AutocastScope("float16"): layer.moving_variance.assign(large_value) self.assertAllClose(layer.moving_variance.value, large_value) + + def test_masked_broadcast_normalization(self): + input_shape = (1, 2, 3, 4) + mask_shape = (1, 2, 1) + x = ops.ones(input_shape) + mask = ops.ones(mask_shape) + + layer = layers.BatchNormalization(axis=-1, momentum=0.0, epsilon=1e-3) + + y = layer(x, training=True, mask=mask) + + mean_y = ops.mean(y, axis=[0, 1, 2]) + + self.assertAllClose(mean_y, ops.zeros((4,)), atol=1e-6) + self.assertAllClose(y, ops.zeros_like(y), atol=1e-6) + + self.assertAllClose(layer.moving_mean, ops.ones((4,)), atol=1e-6) + self.assertAllClose(layer.moving_variance, ops.zeros((4,)), atol=1e-6)