From 45dd2eba0748b81ba58347fc0c6eec3e92ceaaaa Mon Sep 17 00:00:00 2001 From: eaidova Date: Fri, 13 Sep 2024 16:06:09 +0400 Subject: [PATCH 1/7] support minicpm3 --- optimum/exporters/openvino/model_configs.py | 49 +++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index b8310882ba..2ec1fefde0 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -192,6 +192,55 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +class OVMiniCPM3DummyPastKeyValuesGenerator(MistralDummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + **kwargs, + ) + self.v_head_dim = getattr(normalized_config, "v_head_dim", self.hidden_size // self.num_attention_heads) + self.k_head_dim = normalized_config.qk_nope_head_dim + normalized_config.qk_rope_head_dim + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + v_shape = ( + self.batch_size, + self.num_key_value_heads, + self.sequence_length, + self.v_head_dim, + ) + k_shape = (self.batch_size, self.num_key_value_heads, self.sequence_length, self.k_head_dim) + return [ + ( + self.random_float_tensor(k_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(v_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + +@register_in_tasks_manager("minicpm3", *["text-generation", "text-generation-with-past"], library_name="transformers") +class MiniCPM3OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, OVMiniCPM3DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = OVMiniCPM3DummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + @register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers") class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 From 89e6f3ebaeb2802ecd08df59830c5b0996536065 Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 25 Nov 2024 16:49:01 +0400 Subject: [PATCH 2/7] model patcher --- optimum/exporters/openvino/model_configs.py | 6 + optimum/exporters/openvino/model_patcher.py | 145 ++++++++++++++++++++ 2 files changed, 151 insertions(+) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 2ec1fefde0..f65485f36e 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -81,6 +81,7 @@ LlavaQwen2ImageEmbeddingsModelPatcher, MiniCPMVImageEmbeddingsModelPatcher, MiniCPMVResamplerModelPatcher, + MiniCPM3Patcher, MistralModelPatcher, MixtralModelPatcher, MPTModelPatcher, @@ -240,6 +241,11 @@ class MiniCPM3OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = OVMiniCPM3DummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + return MiniCPM3Patcher(self, model, model_kwargs=model_kwargs) + @register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers") class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 58659e637b..e45d8fc918 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -3237,3 +3237,148 @@ def __init__( def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward + + +def minicpm3_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value=None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # cos = cos[position_ids].unsqueeze(unsqueeze_dim) + # sin = sin[position_ids].unsqueeze(unsqueeze_dim) + # q_embed = (q * cos) + (rotate_half(q) * sin) + # k_embed = (k * cos) + (rotate_half(k) * sin) + orig_dtype = k.dtype + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + q_fp32 = q.to(dtype=torch.float32, device=q.device) + k_fp32 = k.to(dtype=torch.float32, device=k.device) + q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin) + k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin) + return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype) + + if output_attentions: + return self._orig_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.shape + + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(hidden_states.shape[0], hidden_states.shape[1], 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + # Difference with original code, k_pe.new_empty create constant tensor in torchscript + query_states = torch.concat([q_nope, q_pe], dim=-1) + # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1) + # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(hidden_states.shape[0], hidden_states.shape[1], self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class MiniCPM3Patcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for block in self._model.model.layers: + block.self_attn._orig_forward = block.self_attn.forward + block.self_attn.forward = types.MethodType(minicpm3_attn_forward, block.self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.model.layers: + block.self_attn.forward = block.self_attn._orig_forward From 21850e64d14f4e3b51d9733f4ce993b01abe3fda Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 25 Nov 2024 17:36:08 +0400 Subject: [PATCH 3/7] add test --- tests/openvino/test_modeling.py | 6 ++++-- tests/openvino/utils_tests.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f7f677bf8c..ae50976db8 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -914,6 +914,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "arctic", "exaone", "mistral-nemo", + "minicpm3", ) GENERATION_LENGTH = 100 @@ -935,6 +936,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "glm4", "exaone", "decilm", + "minicpm3", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -1182,8 +1184,8 @@ def test_default_filling_attention_mask_and_position_ids(self): gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) - @pytest.mark.run_slow - @slow + # @pytest.mark.run_slow + # @slow def test_beam_search(self, model_arch): model_kwargs = {} model_id = MODEL_NAMES[model_arch] diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index b646b5b52a..17d9dd1fbe 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -86,6 +86,7 @@ "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "minicpm": "katuni4ka/tiny-random-minicpm", + "minicpm3": "katuni4ka/tiny-random-minicpm3", "minicpmv": "katuni4ka/tiny-random-minicpmv-2_6", "mistral": "echarlaix/tiny-random-mistral", "mistral-nemo": "katuni4ka/tiny-random-mistral-nemo", From 9eeedf4b83c16e2b2d1fa2e0dfc7b9b007b97aac Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 25 Nov 2024 17:36:56 +0400 Subject: [PATCH 4/7] update readme --- docs/source/openvino/models.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/openvino/models.mdx b/docs/source/openvino/models.mdx index 0ebd74fbbf..6b73b7bdeb 100644 --- a/docs/source/openvino/models.mdx +++ b/docs/source/openvino/models.mdx @@ -70,6 +70,7 @@ Here is the list of the supported architectures : - MT5 - Marian - MiniCPM +- MiniCPM3 - Mistral - Mixtral - MobileBert From d1d3d048eff382eb28a5075048d90adab6d3f5bb Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Mon, 25 Nov 2024 17:38:20 +0400 Subject: [PATCH 5/7] Update tests/openvino/test_modeling.py --- tests/openvino/test_modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index ae50976db8..240f4f9e3f 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1184,8 +1184,8 @@ def test_default_filling_attention_mask_and_position_ids(self): gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) - # @pytest.mark.run_slow - # @slow + @pytest.mark.run_slow + @slow def test_beam_search(self, model_arch): model_kwargs = {} model_id = MODEL_NAMES[model_arch] From 3a724ee7844efb77fbbd091ce847a1115bad9ac9 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Mon, 25 Nov 2024 17:38:20 +0400 Subject: [PATCH 6/7] Update tests/openvino/test_modeling.py --- optimum/exporters/openvino/model_configs.py | 2 +- tests/openvino/test_modeling.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index f65485f36e..d9c0165d98 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -79,9 +79,9 @@ LlamaModelPatcher, LlavaImageEmbeddingModelPatcher, LlavaQwen2ImageEmbeddingsModelPatcher, + MiniCPM3Patcher, MiniCPMVImageEmbeddingsModelPatcher, MiniCPMVResamplerModelPatcher, - MiniCPM3Patcher, MistralModelPatcher, MixtralModelPatcher, MPTModelPatcher, diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index ae50976db8..240f4f9e3f 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1184,8 +1184,8 @@ def test_default_filling_attention_mask_and_position_ids(self): gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) - # @pytest.mark.run_slow - # @slow + @pytest.mark.run_slow + @slow def test_beam_search(self, model_arch): model_kwargs = {} model_id = MODEL_NAMES[model_arch] From d8acf30ae1999eab737ef604d69e3febd90b0d48 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Mon, 25 Nov 2024 19:07:48 +0400 Subject: [PATCH 7/7] Update optimum/exporters/openvino/model_patcher.py --- optimum/exporters/openvino/model_patcher.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index e45d8fc918..c71cbfe003 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -3274,10 +3274,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - # cos = cos[position_ids].unsqueeze(unsqueeze_dim) - # sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # q_embed = (q * cos) + (rotate_half(q) * sin) - # k_embed = (k * cos) + (rotate_half(k) * sin) orig_dtype = k.dtype cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]