From accfc002b1b24d4deb16cdb9e2cc36eb10113991 Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Wed, 20 Nov 2024 15:33:31 +0800 Subject: [PATCH 1/2] Refactor VL modules for qwen-vl, llava and llava_next (#2773) * qwen2-vl * internvl * qwen2 * get image_tokens_per_patch for internvl2 * deepseek-vl * cogvlm * glm4v * update internvl * internvl_llava * llava * glm4v * upate internvl * cogvlm * deepseek * llava_hf * rollback llava, internvl-llava * refactor qwen * update internvl * update llava_hf * update qwen2-vl * llava_next * update llava_next * update llava * update llava * update llava --- lmdeploy/vl/model/internvl.py | 36 +++-- lmdeploy/vl/model/llava.py | 231 +++++++++++++++++++++++++++----- lmdeploy/vl/model/llava_hf.py | 8 +- lmdeploy/vl/model/llava_next.py | 141 ++++++++++++------- lmdeploy/vl/model/qwen.py | 79 +++++++++-- lmdeploy/vl/model/qwen2.py | 84 ++---------- 6 files changed, 391 insertions(+), 188 deletions(-) diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index 62a1041f5c..44feab3f63 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -192,18 +192,7 @@ def _forward(self, inputs): return outputs def preprocess(self, messages: List[Dict]) -> List[Dict]: - assert isinstance(messages, List) - assert isinstance(messages[-1]['content'], List) - content = [ - item['text'] for item in messages[-1]['content'] - if item['type'] == 'text' - ] - if len(content) > 1: - logger.warning(f'There are {len(content)} text in {content}. ' - 'Only the first one is considered') - - # get images and their corresponding preprocess parameters from - # messages, and perform preprocessing + """refers to `super.preprocess() for spec.""" outputs = [] for item in messages[-1]['content']: item_type = item['type'] @@ -233,22 +222,29 @@ def forward(self, inputs: List[Dict]) -> List[torch.Tensor]: @classmethod def proc_messages(cls, messages, chat_template, sequence_start): - # apply chat template to get the prompt + """apply chat template to get the prompt.""" prompt_messages = [] IMAGE_TOKEN = '' for message in messages: if isinstance(message['content'], str): prompt_messages.append(message) continue - n_images = [ - 1 for item in message['content'] if item['type'] == 'image' - ] - n_images = sum(n_images) + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) content = [ - item['text'] for item in message['content'] - if item['type'] == 'text' + x['text'] for x in message['content'] if x['type'] == 'text' ] - content = f'{IMAGE_TOKEN * n_images}\n' + content[0] + prompt = content[0] + if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt: + prompt = prompt.replace(f'{IMAGE_TOKEN}', + f'{IMAGE_TOKEN}') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + elif IMAGE_TOKEN not in prompt: + prompt = f'{IMAGE_TOKEN * n_images}\n' + prompt + else: + pass prompt_messages.append(dict(role='user', content=content)) prompt = chat_template.messages2prompt(prompt_messages, sequence_start) return prompt, IMAGE_TOKEN diff --git a/lmdeploy/vl/model/llava.py b/lmdeploy/vl/model/llava.py index 0b18f460cd..671288dd7f 100644 --- a/lmdeploy/vl/model/llava.py +++ b/lmdeploy/vl/model/llava.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -# Modified from -# https://github.com/haotian-liu/LLaVA.git +# Modified from https://github.com/haotian-liu/LLaVA.git +import ast +import math import warnings from contextlib import contextmanager -from typing import List, Union +from typing import Dict, List import torch -from PIL.Image import Image +from PIL import Image from transformers import AutoConfig, AutoModelForCausalLM from lmdeploy.utils import get_logger -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel from lmdeploy.vl.model.utils import disable_logging, rewrite_ctx logger = get_logger('lmdeploy') @@ -53,8 +54,166 @@ def init_llava_vision_tower(config): yield +def select_best_resolution(original_size, possible_resolutions): + """Selects the best resolution from a list of possible resolutions based on + the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ # noqa + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float('inf') + + for width, height in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int( + original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, + original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def resize_and_pad_image(image, target_resolution): + """Resize and pad an image to a target resolution while maintaining aspect + ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ # noqa + original_width, original_height = image.size + target_width, target_height = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + + return new_image + + +def divide_to_patches(image, patch_size): + """Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + + +def process_anyres_image(image, processor, grid_pinpoints): + """Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + processor: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ # noqa + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + best_resolution = select_best_resolution(image.size, possible_resolutions) + image_padded = resize_and_pad_image(image, best_resolution) + + patches = divide_to_patches(image_padded, processor.crop_size['height']) + + image_original_resize = image.resize( + (processor.size['shortest_edge'], processor.size['shortest_edge'])) + + image_patches = [image_original_resize] + patches + image_patches = [ + processor.preprocess(image_patch, + return_tensors='pt')['pixel_values'][0] + for image_patch in image_patches + ] + return torch.stack(image_patches, dim=0) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_images(images, image_processor, model_cfg): + image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None) + new_images = [] + if image_aspect_ratio == 'pad': + for image in images: + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + new_images.append(image) + elif image_aspect_ratio == 'anyres': + for image in images: + image = process_anyres_image(image, image_processor, + model_cfg.image_grid_pinpoints) + new_images.append(image) + else: + return image_processor(images, return_tensors='pt')['pixel_values'] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + @VISION_MODELS.register_module() -class LlavaVisionModel(VisonModel): +class LlavaVisionModel(LlavaHfVisionModel): """Llava visual model.""" @classmethod @@ -73,9 +232,14 @@ def match(cls, config: AutoConfig): return True return False + def build_preprocessor(self): + from transformers import CLIPImageProcessor + vision_tower_name = 'openai/clip-vit-large-patch14-336' + self.image_processor = CLIPImageProcessor.from_pretrained( + vision_tower_name) + def build_model(self): """build model & load weights.""" - # check llava install check_llava_install() self.arch = self.hf_config.architectures[0] @@ -143,38 +307,35 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor: image_features = self.mm_projector(image_features) return image_features - def preprocess( - self, - images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]: - """preprocess.""" - # TODO: gpu processor - from llava.mm_utils import process_images - images = [x.convert('RGB') for x in images] - image_processor = self.vision_tower.image_processor - outputs = process_images(images, image_processor, self.config) + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" + outputs = [] + for item in messages[-1]['content']: + if item['type'] == 'image': + image = item['image'].convert('RGB') + pixel_values = process_images([image], self.image_processor, + self.config) + outputs.append( + dict( + pixel_values=pixel_values, + image_size=image.size, + image_tokens=576, # TODO + image_token_id=0)) return outputs @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" + def forward(self, inputs: List[Dict]) -> List[torch.Tensor]: from llava.model.llava_arch import (get_anyres_image_grid_shape, unpad_image) - image_sizes = [x.size for x in images] - images = self.preprocess(images) - if isinstance(images, list): - images = [ - x.to(device=self.vision_tower.device, dtype=torch.float16) - for x in images - ] - else: - images = images.to(device=self.vision_tower.device, - dtype=torch.float16) - if type(images) is list or images.ndim == 5: - if type(images) is list: - images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] - concat_images = torch.cat([image for image in images], dim=0) - image_features = self.encode_images(concat_images) - split_sizes = [image.shape[0] for image in images] + image_sizes = [x['image_size'] for x in inputs] + pixel_values = [x['pixel_values'] for x in inputs] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + if pixel_values.ndim == 5: + split_sizes = [x.shape[0] for x in pixel_values] + pixel_values = torch.cat([x for x in pixel_values], dim=0) + image_features = self.encode_images(pixel_values) image_features = torch.split(image_features, split_sizes, dim=0) mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') @@ -238,6 +399,6 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]: raise ValueError('Unexpected mm_patch_merge_type: ' f'{self.config.mm_patch_merge_type}') else: - image_features = self.encode_images(images) + image_features = self.encode_images(pixel_values) image_features = [x for x in image_features] return image_features diff --git a/lmdeploy/vl/model/llava_hf.py b/lmdeploy/vl/model/llava_hf.py index ee24b6df59..731b3849fe 100644 --- a/lmdeploy/vl/model/llava_hf.py +++ b/lmdeploy/vl/model/llava_hf.py @@ -53,7 +53,7 @@ def build_model(self): self.model = model def preprocess(self, messages: List[Dict]) -> List[Dict]: - """refers to the spec of `super.preprocess()""" + """refers to `super.preprocess() for spec.""" outputs = [] for item in messages[-1]['content']: if item['type'] == 'image': @@ -102,10 +102,8 @@ def proc_messages(cls, messages, chat_template, sequence_start): if isinstance(message['content'], str): prompt_messages.append(message) continue - n_images = [ - 1 for item in message['content'] if item['type'] == 'image' - ] - n_images = sum(n_images) + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) content = [ item['text'] for item in message['content'] if item['type'] == 'text' diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py index 9223ebea4f..d4f993d40d 100644 --- a/lmdeploy/vl/model/llava_next.py +++ b/lmdeploy/vl/model/llava_next.py @@ -1,46 +1,59 @@ # Copyright (c) OpenMMLab. All rights reserved. - +import itertools import warnings -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoProcessor -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.utils import get_logger +from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + @VISION_MODELS.register_module() -class LlavaNextVisionModel(VisonModel): +class LlavaNextVisionModel(LlavaHfVisionModel): """Llava hf vision model.""" _arch = 'LlavaNextForConditionalGeneration' - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import get_balanced_memory, infer_auto_device_map - + def build_preprocessor(self): + processor = AutoProcessor.from_pretrained(self.model_path, + trust_remote_code=True) + if hasattr(processor, 'tokenizer'): + del processor.tokenizer + processor.prtokenizer = None + self.processor = processor.image_processor + # build the model with empty weights. The model will be used in + # `preprocess` to get the image token number + from accelerate import init_empty_weights with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter('ignore') from transformers import LlavaNextForConditionalGeneration - model = LlavaNextForConditionalGeneration._from_config( + self.model = LlavaNextForConditionalGeneration._from_config( self.hf_config) - if not self.with_llm: - del model.language_model - for key in ['language_model']: - setattr(model, key, None) - else: - self.vl_model = model + + def build_model(self): + from accelerate import load_checkpoint_and_dispatch + from accelerate.utils import get_balanced_memory, infer_auto_device_map + + if not self.with_llm: + del self.model.language_model + for key in ['language_model']: + setattr(self.model, key, None) + else: + self.vl_model = self.model no_split_module_classes = ['CLIPEncoderLayer'] max_memory = get_balanced_memory( - model, + self.model, max_memory=self.max_memory, dtype=torch.half, no_split_module_classes=no_split_module_classes) device_map = infer_auto_device_map( - model, + self.model, no_split_module_classes=no_split_module_classes, max_memory=max_memory, dtype=torch.half) @@ -55,41 +68,68 @@ def build_model(self): with disable_logging(): load_checkpoint_and_dispatch( - model=model, + model=self.model, checkpoint=self.model_path, device_map=device_map if not self.with_llm else {'': 'cpu'}, no_split_module_classes=no_split_module_classes, dtype=torch.half) - model.eval() - self.model = model - # processor - processor = AutoProcessor.from_pretrained(self.model_path, - trust_remote_code=True) - if hasattr(processor, 'tokenizer'): - del processor.tokenizer - processor.prtokenizer = None - self.processor = processor.image_processor + self.model.eval() - @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to the spec of `super.preprocess()""" from transformers.models.llava_next.modeling_llava_next import \ image_size_to_num_patches - """forward.""" - processed_inputs = self.processor(images, - return_tensors='pt', - input_data_format='channels_last') - pixel_values = processed_inputs['pixel_values'].to( - device=self.model.device, dtype=self.model.dtype) - image_sizes = processed_inputs['image_sizes'].to( - device=self.model.device, dtype=self.model.dtype) - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.hf_config.image_grid_pinpoints, - patch_size=self.hf_config.vision_config.image_size, - ) for imsize in image_sizes + outputs = [] + for item in messages[-1]['content']: + item_type = item['type'] + if item_type == 'image': + image = item['image'].convert('RGB') + result = self.processor(image, + return_tensors='pt', + input_data_format='channels_last') + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.hf_config.image_grid_pinpoints, + patch_size=self.hf_config.vision_config.image_size, + ) for imsize in result['image_sizes'] + ] + # TODO(remove hardcode 576) + hidden_size = self.hf_config.text_config.hidden_size + fake_image_features = torch.zeros( + [image_num_patches[0], 576, hidden_size]) + image_sizes = result['image_sizes'] + image_newline = torch.randn( + self.hf_config.text_config.hidden_size) + strategy = self.hf_config.vision_feature_select_strategy + _, image_tokens = self.model.pack_image_features( + [fake_image_features], + image_sizes, + vision_feature_select_strategy=strategy, + image_newline=image_newline) + result.update( + dict(image_size=image.size, + image_patches=image_num_patches, + image_tokens=image_tokens, + image_token_id=0)) + outputs.append(result) + return outputs + + @torch.no_grad() + def forward(self, inputs: List[Dict]) -> List[torch.Tensor]: + pixel_values = [ + x['pixel_values'].to(device=self.model.device, + dtype=self.model.dtype) for x in inputs + ] + pixel_values = torch.cat(pixel_values, dim=0) + image_sizes = [ + x['image_sizes'].to(device=self.model.device, + dtype=self.model.dtype) for x in inputs ] + image_sizes = torch.cat(image_sizes, dim=0) + image_num_patches = [x['num_patch'] for x in inputs] + image_num_patches = list(itertools.chain(*image_num_patches)) # figure out if pixel_values is concatenated or stacked if pixel_values.dim() == 5: # stacking when input is @@ -108,19 +148,20 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]: pixel_values, output_hidden_states=True) image_features = image_outputs.hidden_states[ self.hf_config.vision_feature_layer] - if self.hf_config.vision_feature_select_strategy == 'default': + strategy = self.hf_config.vision_feature_select_strategy + if strategy == 'default': image_features = image_features[:, 1:] - elif self.hf_config.vision_feature_select_strategy == 'full': + elif strategy == 'full': image_features = image_features else: - raise ValueError( - 'Unexpected select feature strategy: ' - f'{self.hf_config.vision_feature_select_strategy}') + raise ValueError('Unexpected select feature strategy: ' + f'{strategy}') image_features = self.model.multi_modal_projector(image_features) image_features = torch.split(image_features, image_num_patches, dim=0) image_features, feature_lens = self.model.pack_image_features( image_features, image_sizes, + vision_feature_select_strategy=strategy, image_newline=self.model.image_newline, ) outputs = torch.split(image_features, diff --git a/lmdeploy/vl/model/qwen.py b/lmdeploy/vl/model/qwen.py index 3968f27d97..49556b9a60 100644 --- a/lmdeploy/vl/model/qwen.py +++ b/lmdeploy/vl/model/qwen.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoModelForCausalLM from lmdeploy.vl.model.base import VISION_MODELS, VisonModel @@ -16,6 +15,19 @@ class QwenVisionModel(VisonModel): _arch = 'QWenLMHeadModel' + def build_preprocessor(self): + from torchvision import transforms + from torchvision.transforms import InterpolationMode + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + image_size = self.hf_config.visual['image_size'] + self.image_transform = transforms.Compose([ + transforms.Resize((image_size, image_size), + interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + def build_model(self): from accelerate import init_empty_weights with init_empty_weights(): @@ -60,13 +72,64 @@ def build_model(self): self.model = model.transformer.visual.eval() + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + outputs = [] + for item in messages[-1]['content']: + if item['type'] == 'image': + image = item['image'].convert('RGB') + pixel_values = self.image_transform(image) + outputs.append( + dict( + pixel_values=pixel_values, + image_size=image.size, + image_tokens=256, # TODO + image_token_id=0)) + return outputs + @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [self.model.image_transform(x) for x in outputs] - outputs = torch.stack(outputs, dim=0) - outputs = self.model(outputs) + def forward(self, inputs: List[Dict]) -> List[torch.Tensor]: + pixel_values = [x['pixel_values'] for x in inputs] + pixel_values = torch.stack(pixel_values, dim=0) + outputs = self.model(pixel_values) outputs = torch.split(outputs, 1, dim=0) outputs = [x.squeeze() for x in outputs] return outputs + + @classmethod + def proc_messages(cls, messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + prompt = content[0] + if IMAGE_TOKEN in prompt: + pass + else: + prompt = ''.join([ + f'Picture {str(i)}:{IMAGE_TOKEN}\n' + for i in range(n_images) + ]) + prompt + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return super().to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return super().to_turbomind_aux(messages, prompt, IMAGE_TOKEN, + tokenizer, sequence_start) diff --git a/lmdeploy/vl/model/qwen2.py b/lmdeploy/vl/model/qwen2.py index f53ffb6441..23a38bf96b 100644 --- a/lmdeploy/vl/model/qwen2.py +++ b/lmdeploy/vl/model/qwen2.py @@ -1,6 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. - -import itertools from typing import Dict, List import torch @@ -67,44 +65,7 @@ def build_model(self): self.model = model.eval() def preprocess(self, messages: List[Dict]) -> List[Dict]: - """preprocess multimodal data in the messages, of which only the last - item includes the mulitmodal data. - - Args: - message(Dict): multimodal data in a dict, which is as follows: - [ - {'role': 'user', 'content': 'user prompt'}, - {'role': 'assisant', 'content': 'AI reponse'}, - { - 'role': 'user', - 'content': [ - { - 'type': 'text', - 'text': 'string', - }, - { - 'type': 'image', - 'image': pillow.Image, - key1: value1, - ... - }, - { - 'type': 'image', - 'image': pillow.Image, - key1: value1, - ... - }, - ... - ] - } - ] - Returns: - the preprocessing results in a list. list[i] is a dict. It refers - to the preprocessing result of an image - """ - assert isinstance(messages, List) - assert isinstance(messages[-1]['content'], List) - + """refer to `super().preprocess()` for spec.""" from qwen_vl_utils import process_vision_info optional_keys = { @@ -120,9 +81,15 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: {key: x[key] for key in x.keys() if key in optional_keys}) image_inputs, _ = process_vision_info([dict(content=[item])]) - image_inputs = self.processor.image_processor( - images=image_inputs, videos=None, return_tensors='pt') - outputs.append(image_inputs) + result = self.processor.image_processor(images=image_inputs, + videos=None, + return_tensors='pt') + result.update( + dict( + image_size=image.size, + image_tokens=1, # TODO + image_token_id=0)) + outputs.append(result) return outputs @torch.no_grad() @@ -152,37 +119,14 @@ def proc_messages(cls, messages, chat_template, sequence_start): ) prompt_messages.append(dict(role='user', content=''.join(content))) prompt = chat_template.messages2prompt(prompt_messages, sequence_start) - segs = prompt.split(IMAGE_TOKEN) - # collect all preprocessing result from messages - preps = [ - message.pop('preprocess') for message in messages - if 'preprocess' in message.keys() - ] - # flatten the list - preps = list(itertools.chain(*preps)) - assert len(segs) == len(preps) + 1, ( - f'the number of {IMAGE_TOKEN} is not equal ' - f'to input images, {len(segs) - 1} vs {len(preps)}') - - return prompt, segs, preps + return prompt, IMAGE_TOKEN def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): """return to the information needed by pytorch engine.""" - prompt, segs, preps = self.proc_messages(messages, chat_template, + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start) - input_ids = [] - IMAGE_DUMMY_TOKEN_INDEX = 0 - merge_length = self.processor.image_processor.merge_size**2 - for i, seg in enumerate(segs): - if i > 0 and i <= len(preps): - preps[i - 1].update(offset=len(input_ids)) - image_grid_thw = preps[i - 1]['image_grid_thw'] - imag_tokens = image_grid_thw.prod() // merge_length - input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * imag_tokens) - token_ids = tokenizer.encode(seg, - add_bos=((i == 0) and sequence_start)) - input_ids.extend(token_ids) - return dict(prompt=prompt, input_ids=input_ids, multimodal=preps) + return super().to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): assert 0, 'TODO: support turbomind engine' From 463b5087da506030ab041467f4dfcf11872e3e28 Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Wed, 20 Nov 2024 16:01:24 +0800 Subject: [PATCH 2/2] Refactor VL modules for qwen2-vl (#2777) * qwen2-vl * internvl * qwen2 * get image_tokens_per_patch for internvl2 * deepseek-vl * cogvlm * glm4v * update internvl * internvl_llava * llava * glm4v * upate internvl * cogvlm * deepseek * llava_hf * rollback llava, internvl-llava * refactor qwen * update internvl * update llava_hf * update qwen2-vl * llava_next * update llava_next * update llava * update llava * update llava * qwen2 --- lmdeploy/vl/model/qwen2.py | 47 +++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/lmdeploy/vl/model/qwen2.py b/lmdeploy/vl/model/qwen2.py index 23a38bf96b..610b1d28e4 100644 --- a/lmdeploy/vl/model/qwen2.py +++ b/lmdeploy/vl/model/qwen2.py @@ -84,11 +84,13 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: result = self.processor.image_processor(images=image_inputs, videos=None, return_tensors='pt') + merge_length = self.processor.image_processor.merge_size**2 + image_tokens = result['image_grid_thw'].prod( + dim=1) // merge_length result.update( - dict( - image_size=image.size, - image_tokens=1, # TODO - image_token_id=0)) + dict(image_size=image.size, + image_tokens=image_tokens, + image_token_id=0)) outputs.append(result) return outputs @@ -98,26 +100,33 @@ def forward(self, inputs: List[Dict]) -> List[torch.Tensor]: @classmethod def proc_messages(cls, messages, chat_template, sequence_start): - # apply chat template to get the prompt - IMAGE_TOKEN = '<|image_pad|>' + """apply chat template to get the prompt.""" prompt_messages = [] + IMAGE_TOKEN = '' for message in messages: if isinstance(message['content'], str): prompt_messages.append(message) continue - content = [] - for item in message['content']: - item_type = item['type'] - if item_type == 'text': - content.append(item['text']) - elif item_type == 'image': - content.append( - f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>') - else: - assert 0, ( - f'unsupported type {item_type} in {message["content"]}' - ) - prompt_messages.append(dict(role='user', content=''.join(content))) + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = content[0] + if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt: + prompt = prompt.replace( + IMAGE_TOKEN, + f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>') + else: + # Qwen2-VL-2B-Instruct will concat image and user prompt + # according to their order in the content list + # we insert image token before user prompt by default. The + # user can use custom image token position if they want the + # same decorated prompt as Qwen2-VL + prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \ + n_images + prompt + prompt_messages.append(dict(role='user', content=prompt)) prompt = chat_template.messages2prompt(prompt_messages, sequence_start) return prompt, IMAGE_TOKEN