Skip to content

Commit

Permalink
Merge branch 'refactor-vl' into refactor-multimodal
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Nov 20, 2024
2 parents 5e47955 + 463b508 commit 4124c56
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 203 deletions.
36 changes: 16 additions & 20 deletions lmdeploy/vl/model/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,7 @@ def _forward(self, inputs):
return outputs

def preprocess(self, messages: List[Dict]) -> List[Dict]:
assert isinstance(messages, List)
assert isinstance(messages[-1]['content'], List)
content = [
item['text'] for item in messages[-1]['content']
if item['type'] == 'text'
]
if len(content) > 1:
logger.warning(f'There are {len(content)} text in {content}. '
'Only the first one is considered')

# get images and their corresponding preprocess parameters from
# messages, and perform preprocessing
"""refers to `super.preprocess() for spec."""
outputs = []
for item in messages[-1]['content']:
item_type = item['type']
Expand Down Expand Up @@ -233,22 +222,29 @@ def forward(self, inputs: List[Dict]) -> List[torch.Tensor]:

@classmethod
def proc_messages(cls, messages, chat_template, sequence_start):
# apply chat template to get the prompt
"""apply chat template to get the prompt."""
prompt_messages = []
IMAGE_TOKEN = '<IMAGE_TOKEN>'
for message in messages:
if isinstance(message['content'], str):
prompt_messages.append(message)
continue
n_images = [
1 for item in message['content'] if item['type'] == 'image'
]
n_images = sum(n_images)
n_images = len(
[1 for x in message['content'] if x['type'] == 'image'])
content = [
item['text'] for item in message['content']
if item['type'] == 'text'
x['text'] for x in message['content'] if x['type'] == 'text'
]
content = f'<img>{IMAGE_TOKEN * n_images}</img>\n' + content[0]
prompt = content[0]
if IMAGE_TOKEN in prompt and f'<img>{IMAGE_TOKEN}' not in prompt:
prompt = prompt.replace(f'{IMAGE_TOKEN}',
f'<img>{IMAGE_TOKEN}</img>')
prompt = prompt.replace('</img><img>', '')
prompt = prompt.replace('<img><img>', '<img>')
prompt = prompt.replace('</img></img>', '</img>')
elif IMAGE_TOKEN not in prompt:
prompt = f'<img>{IMAGE_TOKEN * n_images}</img>\n' + prompt
else:
pass
prompt_messages.append(dict(role='user', content=content))
prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
return prompt, IMAGE_TOKEN
Expand Down
231 changes: 196 additions & 35 deletions lmdeploy/vl/model/llava.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from
# https://github.com/haotian-liu/LLaVA.git
# Modified from https://github.com/haotian-liu/LLaVA.git
import ast
import math
import warnings
from contextlib import contextmanager
from typing import List, Union
from typing import Dict, List

import torch
from PIL.Image import Image
from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel
from lmdeploy.vl.model.utils import disable_logging, rewrite_ctx

logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -53,8 +54,166 @@ def init_llava_vision_tower(config):
yield


def select_best_resolution(original_size, possible_resolutions):
"""Selects the best resolution from a list of possible resolutions based on
the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
""" # noqa
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float('inf')

for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height,
original_width * original_height)
wasted_resolution = (width * height) - effective_resolution

if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)

return best_fit


def resize_and_pad_image(image, target_resolution):
"""Resize and pad an image to a target resolution while maintaining aspect
ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
""" # noqa
original_width, original_height = image.size
target_width, target_height = target_resolution

scale_w = target_width / original_width
scale_h = target_height / original_height

if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)

# Resize the image
resized_image = image.resize((new_width, new_height))

new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))

return new_image


def divide_to_patches(image, patch_size):
"""Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)

return patches


def process_anyres_image(image, processor, grid_pinpoints):
"""Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
""" # noqa
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)

patches = divide_to_patches(image_padded, processor.crop_size['height'])

image_original_resize = image.resize(
(processor.size['shortest_edge'], processor.size['shortest_edge']))

image_patches = [image_original_resize] + patches
image_patches = [
processor.preprocess(image_patch,
return_tensors='pt')['pixel_values'][0]
for image_patch in image_patches
]
return torch.stack(image_patches, dim=0)


def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result


def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None)
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(
image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
elif image_aspect_ratio == 'anyres':
for image in images:
image = process_anyres_image(image, image_processor,
model_cfg.image_grid_pinpoints)
new_images.append(image)
else:
return image_processor(images, return_tensors='pt')['pixel_values']
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images


@VISION_MODELS.register_module()
class LlavaVisionModel(VisonModel):
class LlavaVisionModel(LlavaHfVisionModel):
"""Llava visual model."""

@classmethod
Expand All @@ -73,9 +232,14 @@ def match(cls, config: AutoConfig):
return True
return False

def build_preprocessor(self):
from transformers import CLIPImageProcessor
vision_tower_name = 'openai/clip-vit-large-patch14-336'
self.image_processor = CLIPImageProcessor.from_pretrained(
vision_tower_name)

def build_model(self):
"""build model & load weights."""
# check llava install
check_llava_install()

self.arch = self.hf_config.architectures[0]
Expand Down Expand Up @@ -143,38 +307,35 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
image_features = self.mm_projector(image_features)
return image_features

def preprocess(
self,
images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]:
"""preprocess."""
# TODO: gpu processor
from llava.mm_utils import process_images
images = [x.convert('RGB') for x in images]
image_processor = self.vision_tower.image_processor
outputs = process_images(images, image_processor, self.config)
def preprocess(self, messages: List[Dict]) -> List[Dict]:
"""refer to `super().preprocess() for spec."""
outputs = []
for item in messages[-1]['content']:
if item['type'] == 'image':
image = item['image'].convert('RGB')
pixel_values = process_images([image], self.image_processor,
self.config)
outputs.append(
dict(
pixel_values=pixel_values,
image_size=image.size,
image_tokens=576, # TODO
image_token_id=0))
return outputs

@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
"""forward."""
def forward(self, inputs: List[Dict]) -> List[torch.Tensor]:
from llava.model.llava_arch import (get_anyres_image_grid_shape,
unpad_image)
image_sizes = [x.size for x in images]
images = self.preprocess(images)
if isinstance(images, list):
images = [
x.to(device=self.vision_tower.device, dtype=torch.float16)
for x in images
]
else:
images = images.to(device=self.vision_tower.device,
dtype=torch.float16)
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_sizes = [x['image_size'] for x in inputs]
pixel_values = [x['pixel_values'] for x in inputs]
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(device=self.vision_tower.device,
dtype=torch.float16)
if pixel_values.ndim == 5:
split_sizes = [x.shape[0] for x in pixel_values]
pixel_values = torch.cat([x for x in pixel_values], dim=0)
image_features = self.encode_images(pixel_values)
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type',
'flat')
Expand Down Expand Up @@ -238,6 +399,6 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]:
raise ValueError('Unexpected mm_patch_merge_type: '
f'{self.config.mm_patch_merge_type}')
else:
image_features = self.encode_images(images)
image_features = self.encode_images(pixel_values)
image_features = [x for x in image_features]
return image_features
8 changes: 3 additions & 5 deletions lmdeploy/vl/model/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def build_model(self):
self.model = model

def preprocess(self, messages: List[Dict]) -> List[Dict]:
"""refers to the spec of `super.preprocess()"""
"""refers to `super.preprocess() for spec."""
outputs = []
for item in messages[-1]['content']:
if item['type'] == 'image':
Expand Down Expand Up @@ -102,10 +102,8 @@ def proc_messages(cls, messages, chat_template, sequence_start):
if isinstance(message['content'], str):
prompt_messages.append(message)
continue
n_images = [
1 for item in message['content'] if item['type'] == 'image'
]
n_images = sum(n_images)
n_images = len(
[1 for x in message['content'] if x['type'] == 'image'])
content = [
item['text'] for item in message['content']
if item['type'] == 'text'
Expand Down
Loading

0 comments on commit 4124c56

Please sign in to comment.