Skip to content

Commit

Permalink
fix: Correct spurious change, and fix mean/variance shapes for channe…
Browse files Browse the repository at this point in the history
…ls_first preprocessing in EfficientNetV2

- Reshaped mean and variance tensors to [1,3,1,1] for proper broadcasting in channels_first mode.
- Ensured compatibility with channels_last format while addressing broadcasting errors.
  • Loading branch information
harshaljanjani committed Jan 20, 2025
1 parent 291f4a0 commit 74e3934
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions keras/src/applications/efficientnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,17 @@ def EfficientNetV2(
num_channels = input_shape[bn_axis - 1]
if name.split("-")[-1].startswith("b") and num_channels == 3:
x = layers.Rescaling(scale=1.0 / 255)(x)
if backend.image_data_format() == "channels_first":
mean = [[[[0.485]], [[0.456]], [[0.406]]]] # shape [1,3,1,1]
variance = [
[[[0.229**2]], [[0.224**2]], [[0.225**2]]]
] # shape [1,3,1,1]
else:
mean = [0.485, 0.456, 0.406]
variance = [0.229**2, 0.224**2, 0.225**2]
x = layers.Normalization(
mean=[0.485, 0.456, 0.406],
variance=[0.229**2, 0.224**2, 0.225**2],
mean=mean,
variance=variance,
axis=bn_axis,
)(x)
else:
Expand Down

0 comments on commit 74e3934

Please sign in to comment.