diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index f41c398d9f6..1f2d470ff6e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -288,6 +288,12 @@ def pad(img, padding, fill=0, padding_mode='constant'): 'Padding mode should be either constant, edge, reflect or symmetric' if padding_mode == 'constant': + if img.mode == 'P': + palette = img.getpalette() + image = ImageOps.expand(img, border=padding, fill=fill) + image.putpalette(palette) + return image + return ImageOps.expand(img, border=padding, fill=fill) else: if isinstance(padding, int): @@ -301,6 +307,14 @@ def pad(img, padding, fill=0, padding_mode='constant'): pad_right = padding[2] pad_bottom = padding[3] + if img.mode == 'P': + palette = img.getpalette() + img = np.asarray(img) + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + img = Image.fromarray(img) + img.putpalette(palette) + return img + img = np.asarray(img) # RGB image if len(img.shape) == 3: