diff --git a/keras/src/backend/numpy/image.py b/keras/src/backend/numpy/image.py index 6dbe4b4326c..24d85a1c3c4 100644 --- a/keras/src/backend/numpy/image.py +++ b/keras/src/backend/numpy/image.py @@ -231,72 +231,137 @@ def resize( pad_width = max(width, pad_width) img_box_hstart = int(float(pad_height - height) / 2) img_box_wstart = int(float(pad_width - width) / 2) + if data_format == "channels_last": - if len(images.shape) == 4: - padded_img = ( - np.ones( - ( - batch_size, - pad_height + height, - pad_width + width, - channels, - ), - dtype=images.dtype, + if img_box_hstart > 0: + if len(images.shape) == 4: + padded_img = np.concatenate( + [ + np.ones( + (batch_size, img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + images, + np.ones( + (batch_size, img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=1, ) - * fill_value - ) - padded_img[ - :, - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - :, - ] = images - else: - padded_img = ( - np.ones( - (pad_height + height, pad_width + width, channels), - dtype=images.dtype, + else: + padded_img = np.concatenate( + [ + np.ones( + (img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + images, + np.ones( + (img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=0, ) - * fill_value - ) - padded_img[ - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - :, - ] = images - else: - if len(images.shape) == 4: - padded_img = ( - np.ones( - ( - batch_size, - channels, - pad_height + height, - pad_width + width, - ), - dtype=images.dtype, + elif img_box_wstart > 0: + if len(images.shape) == 4: + padded_img = np.concatenate( + [ + np.ones( + (batch_size, height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + images, + np.ones( + (batch_size, height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=2, + ) + else: + padded_img = np.concatenate( + [ + np.ones( + (height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + images, + np.ones( + (height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=1, ) - * fill_value - ) - padded_img[ - :, - :, - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - ] = images else: - padded_img = ( - np.ones( - (channels, pad_height + height, pad_width + width), - dtype=images.dtype, + padded_img = images + else: + if img_box_hstart > 0: + if len(images.shape) == 4: + padded_img = np.concatenate( + [ + np.ones( + (batch_size, channels, img_box_hstart, width) + ) + * fill_value, + images, + np.ones( + (batch_size, channels, img_box_hstart, width) + ) + * fill_value, + ], + axis=2, ) - * fill_value - ) - padded_img[ - :, - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - ] = images + else: + padded_img = np.concatenate( + [ + np.ones((channels, img_box_hstart, width)) + * fill_value, + images, + np.ones((channels, img_box_hstart, width)) + * fill_value, + ], + axis=1, + ) + elif img_box_wstart > 0: + if len(images.shape) == 4: + padded_img = np.concatenate( + [ + np.ones( + (batch_size, channels, height, img_box_wstart) + ) + * fill_value, + images, + np.ones( + (batch_size, channels, height, img_box_wstart) + ) + * fill_value, + ], + axis=3, + ) + else: + padded_img = np.concatenate( + [ + np.ones((channels, height, img_box_wstart)) + * fill_value, + images, + np.ones((channels, height, img_box_wstart)) + * fill_value, + ], + axis=2, + ) + else: + padded_img = images images = padded_img return np.array( diff --git a/keras/src/backend/torch/image.py b/keras/src/backend/torch/image.py index 3f5e571aa7e..d466ffeb5ef 100644 --- a/keras/src/backend/torch/image.py +++ b/keras/src/backend/torch/image.py @@ -224,38 +224,89 @@ def resize( if len(images.shape) == 4: batch_size = images.shape[0] channels = images.shape[1] - padded_img = ( - torch.ones( - ( - batch_size, - channels, - pad_height + height, - pad_width + width, - ), - dtype=images.dtype, + if img_box_hstart > 0: + padded_img = torch.cat( + [ + torch.ones( + (batch_size, channels, img_box_hstart, width), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + images, + torch.ones( + (batch_size, channels, img_box_hstart, width), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + ], + axis=2, ) - * fill_value - ) - padded_img[ - :, - :, - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - ] = images + else: + padded_img = images + + if img_box_wstart > 0: + padded_img = torch.cat( + [ + torch.ones( + (batch_size, channels, height, img_box_wstart), + dtype=images.dtype, + device=images.device, + ), + padded_img, + torch.ones( + (batch_size, channels, height, img_box_wstart), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + ], + axis=3, + ) + else: channels = images.shape[0] - padded_img = ( - torch.ones( - (channels, pad_height + height, pad_width + width), - dtype=images.dtype, + if img_box_wstart > 0: + padded_img = torch.cat( + [ + torch.ones( + (channels, img_box_hstart, width), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + images, + torch.ones( + (channels, img_box_hstart, width), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + ], + axis=1, + ) + else: + padded_img = images + if img_box_wstart > 0: + torch.cat( + [ + torch.ones( + (channels, height, img_box_wstart), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + padded_img, + torch.ones( + (channels, height, img_box_wstart), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + ], + axis=2, ) - * fill_value - ) - padded_img[ - :, - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - ] = images images = padded_img resized = torchvision.transforms.functional.resize( diff --git a/keras/src/ops/image_test.py b/keras/src/ops/image_test.py index 238cbfa8cfd..7931f0267a5 100644 --- a/keras/src/ops/image_test.py +++ b/keras/src/ops/image_test.py @@ -791,6 +791,31 @@ def test_resize_with_pad(self, fill_value): ) self.assertEqual(out.shape, (2, 3, 25, 25)) + x = np.ones((2, 3, 10, 10)) * 128 + out = kimage.resize( + x, size=(4, 4), pad_to_aspect_ratio=True, fill_value=fill_value + ) + self.assertEqual(out.shape, (2, 3, 4, 4)) + self.assertAllClose(out[:, 0, :, :], np.ones((2, 4, 4)) * 128) + + x = np.ones((2, 3, 10, 8)) * 128 + out = kimage.resize( + x, size=(4, 4), pad_to_aspect_ratio=True, fill_value=fill_value + ) + self.assertEqual(out.shape, (2, 3, 4, 4)) + self.assertAllClose( + out, + np.concatenate( + [ + np.ones((2, 3, 4, 1)) * 96.25, + np.ones((2, 3, 4, 2)) * 128.0, + np.ones((2, 3, 4, 1)) * 96.25, + ], + axis=3, + ), + atol=1.0, + ) + @parameterized.named_parameters( named_product( interpolation=["bilinear", "nearest"],