Skip to content

Commit

Permalink
refactor llava_next with transformers new feature
Browse files Browse the repository at this point in the history
  • Loading branch information
deepindeed2022 committed Nov 21, 2024
1 parent f4ae3e8 commit 011d56a
Showing 1 changed file with 7 additions and 38 deletions.
45 changes: 7 additions & 38 deletions lmdeploy/vl/model/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def build_model(self):

@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
from transformers.models.llava_next.modeling_llava_next import \
image_size_to_num_patches
"""forward."""
processed_inputs = self.processor(images,
return_tensors='pt',
Expand All @@ -67,45 +65,16 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]:
device=self.model.device, dtype=self.model.dtype)
image_sizes = processed_inputs['image_sizes'].to(
device=self.model.device, dtype=self.model.dtype)
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.hf_config.image_grid_pinpoints,
patch_size=self.hf_config.vision_config.image_size,
) for imsize in image_sizes
]
# figure out if pixel_values is concatenated or stacked
if pixel_values.dim() == 5:
# stacking when input is
# (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of
# (num_patches, num_channels, height, width)
raise ValueError(f'pixel_values of shape {pixel_values.shape}, '
'expect to be of 4 or 5 dimensions')
image_outputs = self.model.vision_tower.forward(
pixel_values, output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.hf_config.vision_feature_layer]
if self.hf_config.vision_feature_select_strategy == 'default':
image_features = image_features[:, 1:]
elif self.hf_config.vision_feature_select_strategy == 'full':
image_features = image_features
else:
raise ValueError(
'Unexpected select feature strategy: '
f'{self.hf_config.vision_feature_select_strategy}')
image_features = self.model.multi_modal_projector(image_features)
image_features = torch.split(image_features, image_num_patches, dim=0)
image_features = self.model.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.hf_config.vision_feature_layer,
vision_feature_select_strategy=self.hf_config.
vision_feature_select_strategy)
image_features, feature_lens = self.model.pack_image_features(
image_features,
image_sizes,
self.hf_config.vision_feature_select_strategy,
image_newline=self.model.image_newline,
)
outputs = torch.split(image_features,
Expand Down

0 comments on commit 011d56a

Please sign in to comment.