Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

smol improvements to support more flexible usage #34857

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 16 additions & 24 deletions src/transformers/models/idefics3/image_processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@


logger = logging.get_logger(__name__)
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum


if is_vision_available():
Expand Down Expand Up @@ -116,7 +117,6 @@ def _resize_output_size_scale_below_upper_bound(
def get_resize_output_image_size(
image,
resolution_max_side: int,
max_image_size: int = 1820,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]:
"""
Expand All @@ -126,24 +126,18 @@ def get_resize_output_image_size(
Image to resize.
resolution_max_side (`int`):
The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
input aspect ratio, with a lower bound of `min_image_size`.
max_image_size (`int`, *optional*, defaults to 1820):
Maximum image resolution. If the image is larger than this size, the longest edge will be resized to this
value, with the shortest edge resized to keep the input aspect ratio, with a lower bound of `min_image_size`.
input aspect ratio.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
The output size of the image after resizing.
"""
if resolution_max_side > max_image_size:
raise ValueError("`resolution_max_side` cannot be larger than `max_image_size`")

height, width = get_image_size(image, channel_dim=input_data_format)

# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
# Find the output size when scaling the image to be below the max_image_size
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=max_image_size)
# Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
return height, width


Expand Down Expand Up @@ -251,7 +245,7 @@ def convert_to_rgb(
data_format = input_data_format if data_format is None else data_format

mode = "P" if palette is not None else None
image = to_pil_image(image, image_mode=mode)
image = to_pil_image(image, image_mode=mode, input_data_format=input_data_format)
if image.mode == "P" and palette is not None:
image.putpalette(palette)

Expand Down Expand Up @@ -404,7 +398,7 @@ def resize(
image_mode = None
if image.ndim == 2 or image.shape[-1] == 1:
image_mode = "P"
image = to_pil_image(image, image_mode=image_mode)
image = to_pil_image(image, image_mode=image_mode, input_data_format=input_data_format)

resized_image = image.resize((size[1], size[0]), resample=resample)
resized_image = np.array(resized_image)
Expand Down Expand Up @@ -754,6 +748,16 @@ def preprocess(
# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]

# Extra channel dimension for grayscale images
if input_data_format in [ChannelDimension.LAST, None]:
images_list = [
[np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
]
elif input_data_format == ChannelDimension.FIRST:
images_list = [
[np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
]

Comment on lines +751 to +760
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a method that converts PIL images to RGB format (see do_convert_rgb param in other image processors). It's better to use it + extend to numpy arrays if needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I know, the order of the conversion here is actually important as RGB conversion before image_splitting was hurting the model performance quite a bit D:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is why I'm not converting to RGB until later in the pipeline

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I know, the order of the conversion here is actually important as RGB conversion before image_splitting was hurting the model performance quite a bit D:

haha, that's interesting, do you have any clarification for such a behavior? 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My observation is that the order here makes the resulting images are slightly different (if you just plot the differences you see large values on edges in the images). They look the same to me, but clearly not to the model since it was trained with a different pipeline that more closely resembles this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, got it, thanks for clarifying 🤗

if is_scaled_image(images_list[0][0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
Expand All @@ -764,18 +768,6 @@ def preprocess(
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))

# Extra channel dimension for grayscale images
if input_data_format == ChannelDimension.LAST:
images_list = [
[np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
]
elif input_data_format == ChannelDimension.FIRST:
images_list = [
[np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
]
else:
raise ValueError(f"Invalid channel dimension format {input_data_format}.")

if do_resize:
images_list = [
[
Expand Down
Loading