Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ops): Fix inconsistent padding calculation in PyTorch backend ops #20774

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 95 additions & 67 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch.nn.functional as tnn

from keras.src import backend
from keras.src import tree
from keras.src.backend.common.backend_utils import (
compute_conv_transpose_padding_args_for_torch,
)
Expand Down Expand Up @@ -204,17 +203,27 @@ def sparsemax(logits, axis=-1):
def _compute_padding_length(
input_length, kernel_length, stride, dilation_rate=1
):
"""Compute padding length along one dimension."""
total_padding_length = (
dilation_rate * (kernel_length - 1) - (input_length - 1) % stride
)
left_padding = total_padding_length // 2
right_padding = (total_padding_length + 1) // 2
"""Compute padding length along one dimension with support
for asymmetric padding."""
effective_k_size = (kernel_length - 1) * dilation_rate + 1
if stride == 1:
# total padding is kernel_size - 1
total_padding = effective_k_size - 1
else:
# calc. needed padding for case with stride involved
output_size = (input_length + stride - 1) // stride
total_padding = max(
0, (output_size - 1) * stride + effective_k_size - input_length
)

# divide padding evenly, with extra pixel going at the end if needed
left_padding = total_padding // 2
right_padding = total_padding - left_padding
return (left_padding, right_padding)


def _apply_same_padding(
inputs, kernel_size, strides, operation_type, dilation_rate=1
inputs, kernel_size, strides, data_format, operation_type, dilation_rate=1
):
"""Apply same padding to the input tensor.
Expand All @@ -231,50 +240,49 @@ def _apply_same_padding(
"""
spatial_shape = inputs.shape[2:]
num_spatial_dims = len(spatial_shape)
padding = ()
padding = []

if operation_type != "pooling":
dilation_rate = standardize_tuple(
dilation_rate, num_spatial_dims, "dilation_rate"
)

for i in range(num_spatial_dims):
if operation_type == "pooling":
padding_size = _compute_padding_length(
spatial_shape[i], kernel_size[i], strides[i]
)
mode = "replicate"
else:
dilation_rate = standardize_tuple(
dilation_rate, num_spatial_dims, "dilation_rate"
)
padding_size = _compute_padding_length(
spatial_shape[i], kernel_size[i], strides[i], dilation_rate[i]
)
mode = "constant"
padding = (padding_size,) + padding
dil = 1 if operation_type == "pooling" else dilation_rate[i]
pad = _compute_padding_length(
spatial_shape[i], kernel_size[i], strides[i], dil
)
padding.append(pad)

if all([left == right for left, right in padding]):
# convert padding to torch format
if all(left == right for left, right in padding):
return inputs, [left for left, _ in padding]

flattened_padding = tuple(
value for left_and_right in padding for value in left_and_right
)
return tnn.pad(inputs, pad=flattened_padding, mode=mode), 0
# else, need to pad manually
flattened_padding = []
for pad in reversed(padding):
flattened_padding.extend(pad)

mode = "replicate" if operation_type == "pooling" else "constant"
return tnn.pad(inputs, pad=tuple(flattened_padding), mode=mode), 0


def _transpose_spatial_inputs(inputs):
num_spatial_dims = inputs.ndim - 2
"""Transpose inputs from channels_last to channels_first format."""
# Torch pooling does not support `channels_last` format, so
# we need to transpose to `channels_first` format.
if num_spatial_dims == 1:
inputs = torch.permute(inputs, (0, 2, 1))
elif num_spatial_dims == 2:
inputs = torch.permute(inputs, (0, 3, 1, 2))
elif num_spatial_dims == 3:
inputs = torch.permute(inputs, (0, 4, 1, 2, 3))
else:
raise ValueError(
"Inputs must have ndim=3, 4 or 5, "
"corresponding to 1D, 2D and 3D inputs. "
f"Received input shape: {inputs.shape}."
)
return inputs
ndim = inputs.ndim - 2
if ndim == 1: # 1D case
return torch.permute(inputs, (0, 2, 1))
elif ndim == 2: # 2D case
return torch.permute(inputs, (0, 3, 1, 2))
elif ndim == 3: # 3D case
return torch.permute(inputs, (0, 4, 1, 2, 3))
raise ValueError(
"Inputs must have ndim=3, 4 or 5, "
"corresponding to 1D, 2D and 3D inputs. "
f"Received input shape: {inputs.shape}."
)


def _transpose_spatial_outputs(outputs):
Expand Down Expand Up @@ -309,6 +317,7 @@ def max_pool(
padding="valid",
data_format=None,
):
"""Fixed max pooling implementation."""
inputs = convert_to_tensor(inputs)
num_spatial_dims = inputs.ndim - 2
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
Expand All @@ -325,7 +334,7 @@ def max_pool(
# Torch does not natively support `"same"` padding, we need to manually
# apply the right amount of padding to `inputs`.
inputs, padding = _apply_same_padding(
inputs, pool_size, strides, operation_type="pooling"
inputs, pool_size, strides, data_format, "pooling"
)
else:
padding = 0
Expand Down Expand Up @@ -370,26 +379,36 @@ def average_pool(
padding="valid",
data_format=None,
):
"""Fixed average pooling with correct padding calculation."""
inputs = convert_to_tensor(inputs)
num_spatial_dims = inputs.ndim - 2
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
if strides is None:
strides = pool_size
else:
strides = standardize_tuple(strides, num_spatial_dims, "strides")
strides = (
pool_size
if strides is None
else standardize_tuple(strides, num_spatial_dims, "strides")
)

data_format = backend.standardize_data_format(data_format)
orig_format = data_format

if data_format == "channels_last":
inputs = _transpose_spatial_inputs(inputs)

if padding == "same":
# Torch does not natively support `"same"` padding, we need to manually
# apply the right amount of padding to `inputs`.
inputs, padding = _apply_same_padding(
inputs, pool_size, strides, operation_type="pooling"
inputs,
pool_size,
strides,
"channels_first", # we're in channels_first here
"pooling",
)
else:
padding = 0

# apply pooling
if num_spatial_dims == 1:
outputs = tnn.avg_pool1d(
inputs,
Expand Down Expand Up @@ -420,8 +439,10 @@ def average_pool(
"corresponding to 1D, 2D and 3D inputs. "
f"Received input shape: {inputs.shape}."
)
if data_format == "channels_last":

if orig_format == "channels_last":
outputs = _transpose_spatial_outputs(outputs)

return outputs


Expand All @@ -433,6 +454,7 @@ def conv(
data_format=None,
dilation_rate=1,
):
"""Convolution with fixed group handling."""
inputs = convert_to_tensor(inputs)
kernel = convert_to_tensor(kernel)
num_spatial_dims = inputs.ndim - 2
Expand All @@ -441,53 +463,59 @@ def conv(
data_format = backend.standardize_data_format(data_format)
if data_format == "channels_last":
inputs = _transpose_spatial_inputs(inputs)
# Transpose kernel from keras format to torch format.

kernel = _transpose_conv_kernel(kernel)
if padding == "same" and any(d != 1 for d in tree.flatten(strides)):
# Torch does not support this case in conv2d().
# Manually pad the tensor.

# calc. groups snippet
in_channels = inputs.shape[1]
kernel_in_channels = kernel.shape[1]
if in_channels % kernel_in_channels != 0:
raise ValueError(
f"Input channels ({in_channels}) must be divisible by "
f"kernel input channels ({kernel_in_channels})"
)
groups = in_channels // kernel_in_channels

# handle padding
if padding == "same":
inputs, padding = _apply_same_padding(
inputs,
kernel.shape[2:],
strides,
operation_type="conv",
dilation_rate=dilation_rate,
)
channels = inputs.shape[1]
kernel_in_channels = kernel.shape[1]
if channels % kernel_in_channels > 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"kernel.shape[1]. Received: inputs.shape={inputs.shape}, "
f"kernel.shape={kernel.shape}"
data_format,
"conv",
dilation_rate,
)
groups = channels // kernel_in_channels
else:
padding = 0

# apply convolution
if num_spatial_dims == 1:
outputs = tnn.conv1d(
inputs,
kernel,
stride=strides,
padding=padding,
dilation=dilation_rate,
groups=groups,
padding=padding,
)
elif num_spatial_dims == 2:
outputs = tnn.conv2d(
inputs,
kernel,
stride=strides,
padding=padding,
dilation=dilation_rate,
groups=groups,
padding=padding,
)
elif num_spatial_dims == 3:
outputs = tnn.conv3d(
inputs,
kernel,
stride=strides,
padding=padding,
dilation=dilation_rate,
groups=groups,
padding=padding,
)
else:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions keras/src/layers/pooling/average_pooling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_average_pooling1d(
(2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)),
((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)),
((2, 3), (2, 2), "same", "channels_last", (3, 5, 5, 4), (3, 3, 3, 4)),
((2, 3), (3, 3), "same", "channels_first", (3, 5, 5, 4), (3, 5, 2, 2)),
)
def test_average_pooling2d(
self,
Expand Down
12 changes: 12 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,18 @@ def test_average_pool_same_padding(self):
knn.average_pool(x, 2, (2, 1), padding="same"),
np_avgpool2d(x, 2, (2, 1), padding="same", data_format=data_format),
)
# Test 2D average pooling with different pool size.
if data_format == "channels_last":
input_shape = (2, 10, 9, 3)
else:
input_shape = (2, 3, 10, 9)
x = np.arange(540, dtype=float).reshape(input_shape)
self.assertAllClose(
knn.average_pool(x, (2, 3), (3, 3), padding="same"),
np_avgpool2d(
x, (2, 3), (3, 3), padding="same", data_format=data_format
),
)

@parameterized.product(
strides=(1, 2, 3),
Expand Down
Loading