diff --git a/docs/source/openvino/models.mdx b/docs/source/openvino/models.mdx index 0ebd74fbb..6b73b7bde 100644 --- a/docs/source/openvino/models.mdx +++ b/docs/source/openvino/models.mdx @@ -70,6 +70,7 @@ Here is the list of the supported architectures : - MT5 - Marian - MiniCPM +- MiniCPM3 - Mistral - Mixtral - MobileBert diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index b8310882b..d9c0165d9 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -79,6 +79,7 @@ LlamaModelPatcher, LlavaImageEmbeddingModelPatcher, LlavaQwen2ImageEmbeddingsModelPatcher, + MiniCPM3Patcher, MiniCPMVImageEmbeddingsModelPatcher, MiniCPMVResamplerModelPatcher, MistralModelPatcher, @@ -192,6 +193,60 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +class OVMiniCPM3DummyPastKeyValuesGenerator(MistralDummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + **kwargs, + ) + self.v_head_dim = getattr(normalized_config, "v_head_dim", self.hidden_size // self.num_attention_heads) + self.k_head_dim = normalized_config.qk_nope_head_dim + normalized_config.qk_rope_head_dim + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + v_shape = ( + self.batch_size, + self.num_key_value_heads, + self.sequence_length, + self.v_head_dim, + ) + k_shape = (self.batch_size, self.num_key_value_heads, self.sequence_length, self.k_head_dim) + return [ + ( + self.random_float_tensor(k_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(v_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + +@register_in_tasks_manager("minicpm3", *["text-generation", "text-generation-with-past"], library_name="transformers") +class MiniCPM3OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, OVMiniCPM3DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = OVMiniCPM3DummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + return MiniCPM3Patcher(self, model, model_kwargs=model_kwargs) + + @register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers") class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 58659e637..c71cbfe00 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -3237,3 +3237,144 @@ def __init__( def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward + + +def minicpm3_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value=None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + orig_dtype = k.dtype + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + q_fp32 = q.to(dtype=torch.float32, device=q.device) + k_fp32 = k.to(dtype=torch.float32, device=k.device) + q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin) + k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin) + return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype) + + if output_attentions: + return self._orig_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.shape + + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(hidden_states.shape[0], hidden_states.shape[1], 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + # Difference with original code, k_pe.new_empty create constant tensor in torchscript + query_states = torch.concat([q_nope, q_pe], dim=-1) + # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1) + # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(hidden_states.shape[0], hidden_states.shape[1], self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class MiniCPM3Patcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for block in self._model.model.layers: + block.self_attn._orig_forward = block.self_attn.forward + block.self_attn.forward = types.MethodType(minicpm3_attn_forward, block.self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.model.layers: + block.self_attn.forward = block.self_attn._orig_forward diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f7f677bf8..240f4f9e3 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -914,6 +914,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "arctic", "exaone", "mistral-nemo", + "minicpm3", ) GENERATION_LENGTH = 100 @@ -935,6 +936,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "glm4", "exaone", "decilm", + "minicpm3", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index b646b5b52..17d9dd1fb 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -86,6 +86,7 @@ "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "minicpm": "katuni4ka/tiny-random-minicpm", + "minicpm3": "katuni4ka/tiny-random-minicpm3", "minicpmv": "katuni4ka/tiny-random-minicpmv-2_6", "mistral": "echarlaix/tiny-random-mistral", "mistral-nemo": "katuni4ka/tiny-random-mistral-nemo",