Skip to content

Commit

Permalink
fix and optimize minicpm v 2 (#11799)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Aug 14, 2024
1 parent d8d887e commit 9a93808
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 12 deletions.
11 changes: 10 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,11 @@ def safe_bmm_fwd(*args, **kwargs):
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
model.generate = MethodType(minicpmv_generate, model)

if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
# MiniCPM-V 2
model.llm.config.model_type = "minicpm"
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
model.llm.config.model_type = "minicpmv"
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
# MiniCPM-V 2.6
model.llm.config.model_type = "qwen2"
Expand All @@ -1739,7 +1744,11 @@ def safe_bmm_fwd(*args, **kwargs):

vpm_modeling_module_name = model.vpm.__class__.__module__
vpm_module = importlib.import_module(vpm_modeling_module_name)
if model.vpm.config.model_type == "siglip":
if not hasattr(model.vpm, "config"):
# MiniCPM-V 2
from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
elif model.vpm.config.model_type == "siglip":
# MiniCPM-V 2.6
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
Expand Down
46 changes: 35 additions & 11 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@
#


import math
import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from transformers import AutoProcessor
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor


# MiniCPM-V-2_5 and MiniCPM-V-2_6
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "Idefics2VisionAttention")


# MiniCPM-V-2_5 and MiniCPM-V-2_6
def siglip_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -58,17 +61,7 @@ def siglip_attention_forward(
return attn_output, attn_weights


def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu":
import xe_addons
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
else:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores


# MiniCPM-V-2_5
def minicpmv_chat_wrapper(origin_chat):
def minicpmv_chat(
self,
Expand Down Expand Up @@ -106,6 +99,37 @@ def minicpmv_chat(
return minicpmv_chat


# MiniCPM-V-2
def minicpmv_get_vision_embedding(self, pixel_values):
res = []
dtype = self.dtype

def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
H, W = pixel_value.shape[-2:]
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))

if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
return resampler(vision_embedding, target_size)

for pixel_value in pixel_values:
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
res.append(result)
return torch.vstack(res)


def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu":
import xe_addons
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
else:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores


def minicpmv_generate_wrapper(origin_generate):
def generate(
*inputs,
Expand Down

0 comments on commit 9a93808

Please sign in to comment.