Skip to content

Commit

Permalink
support minicpm3 (#1029)
Browse files Browse the repository at this point in the history
* support minicpm3

* model patcher

* add test

* update readme

* Update tests/openvino/test_modeling.py

* Update tests/openvino/test_modeling.py

* Update optimum/exporters/openvino/model_patcher.py
  • Loading branch information
eaidova authored Nov 27, 2024
1 parent 860bdf8 commit 16c27ca
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Here is the list of the supported architectures :
- MT5
- Marian
- MiniCPM
- MiniCPM3
- Mistral
- Mixtral
- MobileBert
Expand Down
55 changes: 55 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
LlamaModelPatcher,
LlavaImageEmbeddingModelPatcher,
LlavaQwen2ImageEmbeddingsModelPatcher,
MiniCPM3Patcher,
MiniCPMVImageEmbeddingsModelPatcher,
MiniCPMVResamplerModelPatcher,
MistralModelPatcher,
Expand Down Expand Up @@ -192,6 +193,60 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class OVMiniCPM3DummyPastKeyValuesGenerator(MistralDummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
**kwargs,
)
self.v_head_dim = getattr(normalized_config, "v_head_dim", self.hidden_size // self.num_attention_heads)
self.k_head_dim = normalized_config.qk_nope_head_dim + normalized_config.qk_rope_head_dim

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
v_shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.v_head_dim,
)
k_shape = (self.batch_size, self.num_key_value_heads, self.sequence_length, self.k_head_dim)
return [
(
self.random_float_tensor(k_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(v_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]


@register_in_tasks_manager("minicpm3", *["text-generation", "text-generation-with-past"], library_name="transformers")
class MiniCPM3OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, OVMiniCPM3DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = OVMiniCPM3DummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
return MiniCPM3Patcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers")
class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
Expand Down
141 changes: 141 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3237,3 +3237,144 @@ def __init__(
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


def minicpm3_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
q_fp32 = q.to(dtype=torch.float32, device=q.device)
k_fp32 = k.to(dtype=torch.float32, device=k.device)
q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)

if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)

bsz, q_len, _ = hidden_states.shape

q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(hidden_states.shape[0], hidden_states.shape[1], 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)

k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

# Difference with original code, k_pe.new_empty create constant tensor in torchscript
query_states = torch.concat([q_nope, q_pe], dim=-1)
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(hidden_states.shape[0], hidden_states.shape[1], self.hidden_size)

attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value


class MiniCPM3Patcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for block in self._model.model.layers:
block.self_attn._orig_forward = block.self_attn.forward
block.self_attn.forward = types.MethodType(minicpm3_attn_forward, block.self_attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.model.layers:
block.self_attn.forward = block.self_attn._orig_forward
2 changes: 2 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"arctic",
"exaone",
"mistral-nemo",
"minicpm3",
)

GENERATION_LENGTH = 100
Expand All @@ -935,6 +936,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"glm4",
"exaone",
"decilm",
"minicpm3",
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"marian": "sshleifer/tiny-marian-en-de",
"mbart": "hf-internal-testing/tiny-random-mbart",
"minicpm": "katuni4ka/tiny-random-minicpm",
"minicpm3": "katuni4ka/tiny-random-minicpm3",
"minicpmv": "katuni4ka/tiny-random-minicpmv-2_6",
"mistral": "echarlaix/tiny-random-mistral",
"mistral-nemo": "katuni4ka/tiny-random-mistral-nemo",
Expand Down

0 comments on commit 16c27ca

Please sign in to comment.