From 7f2be680dc56957fc051139083d5d7fd79033b98 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Mon, 27 Jan 2025 11:20:17 +0530 Subject: [PATCH 1/5] fix(layers): Fix incorrect masked mean/variance in BatchNormalization layer Update masked moments calculation to properly account for broadcast dimensions when summing mask weights. Added test to verify broadcast mask handling produces zero-centered outputs. --- .../layers/normalization/batch_normalization.py | 17 +++++++---------- .../normalization/batch_normalization_test.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py index 5cd2e37527a7..d9bf82bff129 100644 --- a/keras/src/layers/normalization/batch_normalization.py +++ b/keras/src/layers/normalization/batch_normalization.py @@ -318,15 +318,12 @@ def _moments(self, inputs, mask): synchronized=self.synchronized, ) - mask_weights = ops.cast( - mask, - inputs.dtype, + mask_weights = ops.cast(mask, inputs.dtype) + mask_weights_broadcasted = ops.expand_dims(mask_weights, axis=-1) + broadcasted_mask = ops.broadcast_to( + mask_weights_broadcasted, ops.shape(inputs) ) - mask_weights_broadcasted = ops.expand_dims( - mask_weights, - axis=-1, - ) - weighted_inputs = mask_weights_broadcasted * inputs + weighted_inputs = broadcasted_mask * inputs weighted_input_sum = ops.sum( weighted_inputs, @@ -334,7 +331,7 @@ def _moments(self, inputs, mask): keepdims=True, ) sum_of_weights = ops.sum( - mask_weights_broadcasted, + broadcasted_mask, self._reduction_axes, keepdims=True, ) @@ -343,7 +340,7 @@ def _moments(self, inputs, mask): difference = weighted_inputs - mean squared_difference = ops.square(difference) weighted_distsq = ops.sum( - mask_weights_broadcasted * squared_difference, + broadcasted_mask * squared_difference, self._reduction_axes, keepdims=True, ) diff --git a/keras/src/layers/normalization/batch_normalization_test.py b/keras/src/layers/normalization/batch_normalization_test.py index 801fd030b0e9..92d8c09b6462 100644 --- a/keras/src/layers/normalization/batch_normalization_test.py +++ b/keras/src/layers/normalization/batch_normalization_test.py @@ -221,3 +221,20 @@ def test_large_value_within_autocast_scope(self): with backend.AutocastScope("float16"): layer.moving_variance.assign(large_value) self.assertAllClose(layer.moving_variance.value, large_value) + + def test_masked_broadcast_normalization(self): + input_shape = (1, 2, 3, 4) + mask_shape = (1, 2, 1) + x = ops.ones(input_shape) + mask = ops.ones(mask_shape) + + layer = layers.BatchNormalization(axis=-1, momentum=0.0, epsilon=1e-3) + + y = layer(x, training=True, mask=mask) + + mean_y = ops.mean(y, axis=(0, 1, 2)) + self.assertAllClose(mean_y, ops.zeros((4,)), atol=1e-6) + self.assertAllClose(y, ops.zeros_like(y), atol=1e-6) + + self.assertAllClose(layer.moving_mean, ops.ones((4,)), atol=1e-6) + self.assertAllClose(layer.moving_variance, ops.zeros((4,)), atol=1e-6) From 8d6de0ec7c56bb0ae0eb2b6092ab663d6d07303e Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Mon, 27 Jan 2025 12:36:00 +0530 Subject: [PATCH 2/5] change: skip test for OpenVINO --- keras/src/layers/normalization/batch_normalization.py | 4 ++-- keras/src/layers/normalization/batch_normalization_test.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py index d9bf82bff129..ec4c53bb86c3 100644 --- a/keras/src/layers/normalization/batch_normalization.py +++ b/keras/src/layers/normalization/batch_normalization.py @@ -335,7 +335,7 @@ def _moments(self, inputs, mask): self._reduction_axes, keepdims=True, ) - mean = weighted_input_sum / (sum_of_weights + backend.config.epsilon()) + mean = weighted_input_sum / (sum_of_weights + backend.epsilon()) difference = weighted_inputs - mean squared_difference = ops.square(difference) @@ -344,6 +344,6 @@ def _moments(self, inputs, mask): self._reduction_axes, keepdims=True, ) - variance = weighted_distsq / (sum_of_weights + backend.config.epsilon()) + variance = weighted_distsq / (sum_of_weights + backend.epsilon()) return ops.squeeze(mean), ops.squeeze(variance) diff --git a/keras/src/layers/normalization/batch_normalization_test.py b/keras/src/layers/normalization/batch_normalization_test.py index 92d8c09b6462..b191a1b83f71 100644 --- a/keras/src/layers/normalization/batch_normalization_test.py +++ b/keras/src/layers/normalization/batch_normalization_test.py @@ -222,6 +222,13 @@ def test_large_value_within_autocast_scope(self): layer.moving_variance.assign(large_value) self.assertAllClose(layer.moving_variance.value, large_value) + @pytest.mark.skipif( + backend.backend() == "openvino", + reason=""" + ops.mean() - TypeError: The necessary overload + for constant was not found + """, + ) def test_masked_broadcast_normalization(self): input_shape = (1, 2, 3, 4) mask_shape = (1, 2, 1) From bd5a3100d2c480c6b5af233865c7714d001c8359 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Mon, 27 Jan 2025 20:56:14 +0530 Subject: [PATCH 3/5] fix: Fix OpenVINO compatibility in BatchNormalization layer ops - Convert tuple reduction axes to list format for compatibility with OpenVINO's constant op - Remove OpenVINO skip decorator after fixing axis format --- .../layers/normalization/batch_normalization.py | 16 ++++++++++++---- .../normalization/batch_normalization_test.py | 10 ++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py index ec4c53bb86c3..6fa408761bee 100644 --- a/keras/src/layers/normalization/batch_normalization.py +++ b/keras/src/layers/normalization/batch_normalization.py @@ -314,7 +314,9 @@ def _moments(self, inputs, mask): if mask is None: return ops.moments( inputs, - axes=self._reduction_axes, + axes=self._reduction_axes + if isinstance(self._reduction_axes, list) + else list(self._reduction_axes), synchronized=self.synchronized, ) @@ -325,14 +327,20 @@ def _moments(self, inputs, mask): ) weighted_inputs = broadcasted_mask * inputs + reduction_axes = ( + self._reduction_axes + if isinstance(self._reduction_axes, list) + else list(self._reduction_axes) + ) + weighted_input_sum = ops.sum( weighted_inputs, - self._reduction_axes, + reduction_axes, keepdims=True, ) sum_of_weights = ops.sum( broadcasted_mask, - self._reduction_axes, + reduction_axes, keepdims=True, ) mean = weighted_input_sum / (sum_of_weights + backend.epsilon()) @@ -341,7 +349,7 @@ def _moments(self, inputs, mask): squared_difference = ops.square(difference) weighted_distsq = ops.sum( broadcasted_mask * squared_difference, - self._reduction_axes, + reduction_axes, keepdims=True, ) variance = weighted_distsq / (sum_of_weights + backend.epsilon()) diff --git a/keras/src/layers/normalization/batch_normalization_test.py b/keras/src/layers/normalization/batch_normalization_test.py index b191a1b83f71..d713670aae5c 100644 --- a/keras/src/layers/normalization/batch_normalization_test.py +++ b/keras/src/layers/normalization/batch_normalization_test.py @@ -222,13 +222,6 @@ def test_large_value_within_autocast_scope(self): layer.moving_variance.assign(large_value) self.assertAllClose(layer.moving_variance.value, large_value) - @pytest.mark.skipif( - backend.backend() == "openvino", - reason=""" - ops.mean() - TypeError: The necessary overload - for constant was not found - """, - ) def test_masked_broadcast_normalization(self): input_shape = (1, 2, 3, 4) mask_shape = (1, 2, 1) @@ -239,7 +232,8 @@ def test_masked_broadcast_normalization(self): y = layer(x, training=True, mask=mask) - mean_y = ops.mean(y, axis=(0, 1, 2)) + mean_y = ops.mean(y, axis=[0, 1, 2]) + self.assertAllClose(mean_y, ops.zeros((4,)), atol=1e-6) self.assertAllClose(y, ops.zeros_like(y), atol=1e-6) From 9bc9d0b18a4098fe8262bb626b156b7e71be0ca8 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Tue, 28 Jan 2025 07:42:06 +0530 Subject: [PATCH 4/5] fix: Normalize reduction_axes to list during build Avoid repeated type checks and conversions during forward pass. --- .../normalization/batch_normalization.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py index 6fa408761bee..fa34e0087ca0 100644 --- a/keras/src/layers/normalization/batch_normalization.py +++ b/keras/src/layers/normalization/batch_normalization.py @@ -214,7 +214,7 @@ def build(self, input_shape): reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] - self._reduction_axes = reduction_axes + self._reduction_axes = list(reduction_axes) self.built = True def compute_output_shape(self, input_shape): @@ -314,9 +314,7 @@ def _moments(self, inputs, mask): if mask is None: return ops.moments( inputs, - axes=self._reduction_axes - if isinstance(self._reduction_axes, list) - else list(self._reduction_axes), + axes=self._reduction_axes, synchronized=self.synchronized, ) @@ -327,20 +325,14 @@ def _moments(self, inputs, mask): ) weighted_inputs = broadcasted_mask * inputs - reduction_axes = ( - self._reduction_axes - if isinstance(self._reduction_axes, list) - else list(self._reduction_axes) - ) - weighted_input_sum = ops.sum( weighted_inputs, - reduction_axes, + self._reduction_axes, keepdims=True, ) sum_of_weights = ops.sum( broadcasted_mask, - reduction_axes, + self._reduction_axes, keepdims=True, ) mean = weighted_input_sum / (sum_of_weights + backend.epsilon()) @@ -349,7 +341,7 @@ def _moments(self, inputs, mask): squared_difference = ops.square(difference) weighted_distsq = ops.sum( broadcasted_mask * squared_difference, - reduction_axes, + self._reduction_axes, keepdims=True, ) variance = weighted_distsq / (sum_of_weights + backend.epsilon()) From 1db8cc023de126bde0a5132426a1604e67343a80 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Tue, 28 Jan 2025 08:53:15 +0530 Subject: [PATCH 5/5] fix: Double type-casting --- keras/src/layers/normalization/batch_normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py index fa34e0087ca0..ec4c53bb86c3 100644 --- a/keras/src/layers/normalization/batch_normalization.py +++ b/keras/src/layers/normalization/batch_normalization.py @@ -214,7 +214,7 @@ def build(self, input_shape): reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] - self._reduction_axes = list(reduction_axes) + self._reduction_axes = reduction_axes self.built = True def compute_output_shape(self, input_shape):