Skip to content

Commit

Permalink
fix(layers): Fix incorrect masked mean/variance in BatchNormalization…
Browse files Browse the repository at this point in the history
… layer (#20815)

* fix(layers): Fix incorrect masked mean/variance in BatchNormalization layer

Update masked moments calculation to properly account for broadcast dimensions when summing mask weights.

Added test to verify broadcast mask handling produces zero-centered outputs.

* change: skip test for OpenVINO

* fix: Fix OpenVINO compatibility in BatchNormalization layer ops

- Convert tuple reduction axes to list format for compatibility with OpenVINO's constant op

- Remove OpenVINO skip decorator after fixing axis format

* fix: Normalize reduction_axes to list during build

Avoid repeated type checks and conversions during forward pass.

* fix: Double type-casting
  • Loading branch information
harshaljanjani authored Jan 28, 2025
1 parent 3e52ce9 commit 871ad7a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
21 changes: 9 additions & 12 deletions keras/src/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,35 +318,32 @@ def _moments(self, inputs, mask):
synchronized=self.synchronized,
)

mask_weights = ops.cast(
mask,
inputs.dtype,
mask_weights = ops.cast(mask, inputs.dtype)
mask_weights_broadcasted = ops.expand_dims(mask_weights, axis=-1)
broadcasted_mask = ops.broadcast_to(
mask_weights_broadcasted, ops.shape(inputs)
)
mask_weights_broadcasted = ops.expand_dims(
mask_weights,
axis=-1,
)
weighted_inputs = mask_weights_broadcasted * inputs
weighted_inputs = broadcasted_mask * inputs

weighted_input_sum = ops.sum(
weighted_inputs,
self._reduction_axes,
keepdims=True,
)
sum_of_weights = ops.sum(
mask_weights_broadcasted,
broadcasted_mask,
self._reduction_axes,
keepdims=True,
)
mean = weighted_input_sum / (sum_of_weights + backend.config.epsilon())
mean = weighted_input_sum / (sum_of_weights + backend.epsilon())

difference = weighted_inputs - mean
squared_difference = ops.square(difference)
weighted_distsq = ops.sum(
mask_weights_broadcasted * squared_difference,
broadcasted_mask * squared_difference,
self._reduction_axes,
keepdims=True,
)
variance = weighted_distsq / (sum_of_weights + backend.config.epsilon())
variance = weighted_distsq / (sum_of_weights + backend.epsilon())

return ops.squeeze(mean), ops.squeeze(variance)
18 changes: 18 additions & 0 deletions keras/src/layers/normalization/batch_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,21 @@ def test_large_value_within_autocast_scope(self):
with backend.AutocastScope("float16"):
layer.moving_variance.assign(large_value)
self.assertAllClose(layer.moving_variance.value, large_value)

def test_masked_broadcast_normalization(self):
input_shape = (1, 2, 3, 4)
mask_shape = (1, 2, 1)
x = ops.ones(input_shape)
mask = ops.ones(mask_shape)

layer = layers.BatchNormalization(axis=-1, momentum=0.0, epsilon=1e-3)

y = layer(x, training=True, mask=mask)

mean_y = ops.mean(y, axis=[0, 1, 2])

self.assertAllClose(mean_y, ops.zeros((4,)), atol=1e-6)
self.assertAllClose(y, ops.zeros_like(y), atol=1e-6)

self.assertAllClose(layer.moving_mean, ops.ones((4,)), atol=1e-6)
self.assertAllClose(layer.moving_variance, ops.zeros((4,)), atol=1e-6)

0 comments on commit 871ad7a

Please sign in to comment.