Skip to content

Commit

Permalink
support internvl-mono
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Nov 20, 2024
1 parent a2bbec5 commit 35b3459
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 30 deletions.
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,9 @@ def preprocess_input(self,
for input_mm in input_multimodals:
pixel_values = input_mm['pixel_values'].to(self.dtype)
offset = input_mm['offset']
num_pad = input_mm.get('image_tokens', self.vision_token_num)
num_pad = input_mm['image_tokens']
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/models/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,9 @@ def preprocess_input(self,
for input_mm in input_multimodals:
pixel_values = input_mm['pixel_values'].to(self.dtype)
offset = input_mm['offset']
num_pad = input_mm.get('image_tokens', self.vision_token_num)
num_pad = input_mm['image_tokens']
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
Expand Down
59 changes: 37 additions & 22 deletions lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,23 @@ def __init__(self,
self.ctx_mgr = ctx_mgr
self.select_layer = config.select_layer

llm_config = config.llm_config
self.llm_arch_name = llm_config.architectures[0]
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'

vision_config = config.vision_config
self.vision_model = InternVisionModel(vision_config,
dtype=dtype,
device=device)
if self.is_mono:
from .internvl_patch import InternVisionPatchModel
self.vision_model = InternVisionPatchModel(
vision_config,
dtype=dtype,
device=device,
)
else:
self.vision_model = InternVisionModel(vision_config,
dtype=dtype,
device=device)

llm_config = config.llm_config
self.language_model = build_model_from_hf_config(llm_config,
dtype=dtype,
device=device)
Expand All @@ -342,10 +353,7 @@ def __init__(self,
dtype=dtype,
device=device))

self.llm_arch_name = llm_config.architectures[0]

# for Mono-InternVL
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
if self.is_mono:
assert dtype != torch.float16, (
'Currently Mono-InternVL does not support FP16 due to'
Expand All @@ -370,7 +378,11 @@ def extract_feature(self, pixel_values):
"""extract vision feature."""
assert self.select_layer == -1
vit_embeds = self.vision_model(pixel_values)
vit_embeds = vit_embeds[:, 1:, :]
if self.is_mono:
if int(vit_embeds.shape[1]**0.5)**2 != vit_embeds.shape[1]:
vit_embeds = vit_embeds[:, 1:, :]
else:
vit_embeds = vit_embeds[:, 1:, :]

h = w = int(vit_embeds.shape[1]**0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
Expand All @@ -394,19 +406,13 @@ def forward(
**kwargs,
):
if inputs_embeds is None and pixel_values is not None:

# get vis idx
from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX
vis_mask = input_ids[0] == IMAGE_DUMMY_TOKEN_INDEX
vis_range = torch.arange(0,
input_ids.size(-1),
device=input_ids.device)
vis_idx = vis_range[vis_mask]

# extract feature
vit_embeds = self.extract_feature(pixel_values)
lang_embeds = self.language_model.get_input_embeddings()(input_ids)
lang_embeds[0, vis_idx] = vit_embeds.flatten(0, 1)
from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX
vis_mask = input_ids == IMAGE_DUMMY_TOKEN_INDEX
lang_embeds.masked_scatter_(vis_mask[..., None], vit_embeds)

inputs_embeds = lang_embeds

if self.is_mono:
Expand Down Expand Up @@ -443,6 +449,8 @@ def prepare_inputs_for_generation(
input_ids = context.input_ids
position_ids = context.position_ids
attn_metadata = context.attn_metadata
vision_embeddings = context.input_embeddings
vision_embedding_indexing = None

# vision inputs
pixel_values = None
Expand All @@ -460,11 +468,15 @@ def prepare_inputs_for_generation(
else:
pixel_values = None

# get inputs from context
vision_embeddings = context.input_embeddings
vision_embedding_indexing = context.input_embedding_indexing
if self.is_mono and pixel_values is not None:
vision_embedding_indexing = torch.arange(input_ids.shape[1],
device=input_ids.device)
vision_embedding_indexing = vision_embedding_indexing[input_ids[0]
== 0]

# get inputs from context
if vision_embeddings is not None and len(vision_embeddings) > 0:
vision_embedding_indexing = context.input_embedding_indexing
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds[:,
Expand All @@ -484,6 +496,7 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
pixel_values=pixel_values,
inputs_embeds=inputs_embeds,
vision_embedding_indexing=vision_embedding_indexing,
text_embedding_indexing=text_embedding_indexing,
Expand Down Expand Up @@ -558,7 +571,9 @@ def preprocess_input(self,
for input_mm in input_multimodals:
pixel_values = input_mm['pixel_values'].to(self.dtype)
offset = input_mm['offset']
num_pad = input_mm.get('image_tokens', self.vision_token_num)
num_pad = input_mm['image_tokens']
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
Expand Down
96 changes: 96 additions & 0 deletions lmdeploy/pytorch/models/internvl_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig


class InternVisionEmbeddings(nn.Module):
"""mono vision."""

def __init__(self,
config: PretrainedConfig,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size

self.class_embedding = nn.Parameter(
torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), )

self.patch_embedding = nn.Conv2d(in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
dtype=dtype,
device=device)

self.num_patches = (self.image_size // self.patch_size)**2
self.num_positions = self.num_patches + 1

self.position_embedding = nn.Parameter(
torch.empty(1,
self.num_positions,
self.embed_dim,
dtype=dtype,
device=device))

def _get_pos_embed(self, pos_embed, H, W):
target_dtype = pos_embed.dtype
pos_embed = pos_embed.float().reshape(
1, self.image_size // self.patch_size,
self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
pos_embed = F.interpolate(pos_embed,
size=(H, W),
mode='bicubic',
align_corners=False)
pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype)
return pos_embed

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(
pixel_values) # shape = [*, channel, width, height]
batch_size, _, height, width = patch_embeds.shape
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1,
-1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embedding = torch.cat([
self.position_embedding[:, :1, :],
self._get_pos_embed(self.position_embedding[:, 1:, :], height,
width)
],
dim=1)
embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings


class InternVisionPatchModel(nn.Module):
"""mono vision."""

def __init__(self,
config: PretrainedConfig,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config,
dtype=dtype,
device=device)

def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
if len(pixel_values.shape) != 4:
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')

hidden_states = self.embeddings(pixel_values)[:, 1:]
return hidden_states
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,9 @@ def preprocess_input(self,
for input_mm in input_multimodals:
pixel_values = input_mm['pixel_values'].to(self.dtype)
offset = input_mm['offset']
num_pad = input_mm.get('image_tokens', 1)
num_pad = input_mm['image_tokens']
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
Expand Down
6 changes: 1 addition & 5 deletions lmdeploy/pytorch/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,11 +1016,7 @@ def preprocess_input(self,
image_grid_thw = input_mm['image_grid_thw']
offset = input_mm['offset']
start = offset

if 'image_tokens' in input_mm:
num_pad = input_mm['image_tokens']
else:
num_pad = pixel_values.size(0) // 4
num_pad = input_mm['image_tokens'].item()

mm_data = MultiModalTensor(data=pixel_values,
start=start,
Expand Down

0 comments on commit 35b3459

Please sign in to comment.