diff --git a/src/transformers/models/idefics3/image_processing_idefics3.py b/src/transformers/models/idefics3/image_processing_idefics3.py index 05a1a396dc72d3..f9161416656cd7 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3.py +++ b/src/transformers/models/idefics3/image_processing_idefics3.py @@ -38,6 +38,7 @@ logger = logging.get_logger(__name__) +MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum if is_vision_available(): @@ -116,7 +117,6 @@ def _resize_output_size_scale_below_upper_bound( def get_resize_output_image_size( image, resolution_max_side: int, - max_image_size: int = 1820, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> Tuple[int, int]: """ @@ -126,24 +126,18 @@ def get_resize_output_image_size( Image to resize. resolution_max_side (`int`): The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the - input aspect ratio, with a lower bound of `min_image_size`. - max_image_size (`int`, *optional*, defaults to 1820): - Maximum image resolution. If the image is larger than this size, the longest edge will be resized to this - value, with the shortest edge resized to keep the input aspect ratio, with a lower bound of `min_image_size`. + input aspect ratio. input_data_format (`ChannelDimension` or `str`): The channel dimension format of the input image. Returns: The output size of the image after resizing. """ - if resolution_max_side > max_image_size: - raise ValueError("`resolution_max_side` cannot be larger than `max_image_size`") - height, width = get_image_size(image, channel_dim=input_data_format) # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side) - # Find the output size when scaling the image to be below the max_image_size - height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=max_image_size) + # Find the output size when scaling the image to be below the MAX_IMAGE_SIZE + height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE) return height, width @@ -251,7 +245,7 @@ def convert_to_rgb( data_format = input_data_format if data_format is None else data_format mode = "P" if palette is not None else None - image = to_pil_image(image, image_mode=mode) + image = to_pil_image(image, image_mode=mode, input_data_format=input_data_format) if image.mode == "P" and palette is not None: image.putpalette(palette) @@ -404,7 +398,7 @@ def resize( image_mode = None if image.ndim == 2 or image.shape[-1] == 1: image_mode = "P" - image = to_pil_image(image, image_mode=image_mode) + image = to_pil_image(image, image_mode=image_mode, input_data_format=input_data_format) resized_image = image.resize((size[1], size[0]), resample=resample) resized_image = np.array(resized_image) @@ -754,6 +748,16 @@ def preprocess( # All transformations expect numpy arrays. images_list = [[to_numpy_array(image) for image in images] for images in images_list] + # Extra channel dimension for grayscale images + if input_data_format in [ChannelDimension.LAST, None]: + images_list = [ + [np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list + ] + elif input_data_format == ChannelDimension.FIRST: + images_list = [ + [np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list + ] + if is_scaled_image(images_list[0][0]) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" @@ -764,18 +768,6 @@ def preprocess( if input_data_format is None: input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4)) - # Extra channel dimension for grayscale images - if input_data_format == ChannelDimension.LAST: - images_list = [ - [np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list - ] - elif input_data_format == ChannelDimension.FIRST: - images_list = [ - [np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list - ] - else: - raise ValueError(f"Invalid channel dimension format {input_data_format}.") - if do_resize: images_list = [ [