diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py index 726b0f366..5065a221e 100644 --- a/lmdeploy/vl/model/llava_next.py +++ b/lmdeploy/vl/model/llava_next.py @@ -57,8 +57,6 @@ def build_model(self): @torch.no_grad() def forward(self, images: List[Image]) -> List[torch.Tensor]: - from transformers.models.llava_next.modeling_llava_next import \ - image_size_to_num_patches """forward.""" processed_inputs = self.processor(images, return_tensors='pt', @@ -67,45 +65,16 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]: device=self.model.device, dtype=self.model.dtype) image_sizes = processed_inputs['image_sizes'].to( device=self.model.device, dtype=self.model.dtype) - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.hf_config.image_grid_pinpoints, - patch_size=self.hf_config.vision_config.image_size, - ) for imsize in image_sizes - ] - # figure out if pixel_values is concatenated or stacked - if pixel_values.dim() == 5: - # stacking when input is - # (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [ - pix_val[:num_patch] - for pix_val, num_patch in zip(pixel_values, image_num_patches) - ] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of - # (num_patches, num_channels, height, width) - raise ValueError(f'pixel_values of shape {pixel_values.shape}, ' - 'expect to be of 4 or 5 dimensions') - image_outputs = self.model.vision_tower.forward( - pixel_values, output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.hf_config.vision_feature_layer] - if self.hf_config.vision_feature_select_strategy == 'default': - image_features = image_features[:, 1:] - elif self.hf_config.vision_feature_select_strategy == 'full': - image_features = image_features - else: - raise ValueError( - 'Unexpected select feature strategy: ' - f'{self.hf_config.vision_feature_select_strategy}') - image_features = self.model.multi_modal_projector(image_features) - image_features = torch.split(image_features, image_num_patches, dim=0) + image_features = self.model.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=self.hf_config.vision_feature_layer, + vision_feature_select_strategy=self.hf_config. + vision_feature_select_strategy) image_features, feature_lens = self.model.pack_image_features( image_features, image_sizes, + self.hf_config.vision_feature_select_strategy, image_newline=self.model.image_newline, ) outputs = torch.split(image_features,