diff --git a/lmdeploy/vl/model/phi3_vision.py b/lmdeploy/vl/model/phi3_vision.py index 032b8404da..2d2076d8da 100644 --- a/lmdeploy/vl/model/phi3_vision.py +++ b/lmdeploy/vl/model/phi3_vision.py @@ -1,13 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoProcessor -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel from lmdeploy.vl.model.utils import disable_logging @@ -119,11 +118,20 @@ def _process_image_embedding(self, pixel_values: torch.Tensor, @VISION_MODELS.register_module() -class Phi3VisionModel(VisonModel): +class Phi3VisionModel(LlavaHfVisionModel): """Llava hf vision model.""" _arch = 'Phi3VForCausalLM' + def build_preprocessor(self): + processor = AutoProcessor.from_pretrained(self.model_path, + trust_remote_code=True) + if hasattr(processor, 'tokenizer'): + del processor.tokenizer + processor.prtokenizer = None + self.processor = processor.image_processor + self.processor = processor + def build_model(self): from accelerate import init_empty_weights, load_checkpoint_and_dispatch from accelerate.utils import get_balanced_memory, infer_auto_device_map @@ -173,23 +181,31 @@ def build_model(self): model.eval() self.model = model - # processor - processor = AutoProcessor.from_pretrained(self.model_path, - trust_remote_code=True) - if hasattr(processor, 'tokenizer'): - del processor.tokenizer - processor.prtokenizer = None - self.processor = processor.image_processor - self.processor = processor + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + outputs = [] + for item in messages[-1]['content']: + if item['type'] == 'image': + image = item['image'].convert('RGB') + result = self.processor.image_processor(image, + return_tensors='pt') + h = result['image_sizes'][0][0].item() // 336 + w = result['image_sizes'][0][1].item() // 336 + image_tokens = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + result.update( + dict(image_size=image.size, + image_tokens=image_tokens, + image_token_id=0)) + outputs.append(result) + return outputs @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - process_outputs = self.processor.image_processor( - images, return_tensors='pt').to(device=self.model.device, - dtype=self.model.dtype) - pixel_values = process_outputs['pixel_values'] - image_sizes = process_outputs['image_sizes'] + def forward(self, inputs: List[Dict]) -> List[torch.Tensor]: + pixel_values = [x['pixel_values'] for x in inputs] + pixel_values = torch.stack(pixel_values, dim=0) + image_sizes = [x['image_sizes'] for x in inputs] + image_sizes = torch.stack(image_sizes, dim=0) image_features = _process_image_embedding( self.model.model.vision_embed_tokens, pixel_values=pixel_values,