From e596cc7e3e8709343047d645c94f38f5e9b3bcb3 Mon Sep 17 00:00:00 2001 From: eaidova Date: Fri, 20 Oct 2023 11:31:13 +0400 Subject: [PATCH] enable attention mask and fix accuracy issue for chatglm --- .../openvino/dummy_input_generators.py | 12 ++++++ optimum/exporters/openvino/model_configs.py | 1 - optimum/intel/openvino/modeling_decoder.py | 38 ++++++++++++++++--- optimum/intel/utils/modeling_utils.py | 38 +++++++++++++++++++ 4 files changed, 83 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/openvino/dummy_input_generators.py b/optimum/exporters/openvino/dummy_input_generators.py index 25439eb432..31673f45c6 100644 --- a/optimum/exporters/openvino/dummy_input_generators.py +++ b/optimum/exporters/openvino/dummy_input_generators.py @@ -14,6 +14,8 @@ from typing import Optional, Tuple +import torch + from optimum.utils import ( DEFAULT_DUMMY_SHAPES, DummyPastKeyValuesGenerator, @@ -30,6 +32,16 @@ class ChatGLN2DummyTextInputGenerator(DummyTextInputGenerator): "position_ids", } + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + input = super().generate(input_name, framework, int_dtype, float_dtype) + if input_name == "attention_mask": + input = torch.ones((input.shape[0], input.shape[1] + 1), dtype=input.dtype) + # input[0] = 0 + if input_name == "position_ids": + input = torch.range(0, input.shape[1] + 1, dtype=input.dtype).repeat(1, 1) + # input[0] = 0 + return input + class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def __init__( diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index eeec30d75e..fcefbafd58 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -59,7 +59,6 @@ class ChatGLM2OpenVINOConfig(TextDecoderOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs - common_inputs.pop("attention_mask") if not self.no_position_ids and self.task == "text-generation": common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 78b2d790bd..4f0a09ffd9 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -16,7 +16,7 @@ import os from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import openvino @@ -25,7 +25,7 @@ from openvino.runtime import Core, Tensor, Type from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from optimum.utils import NormalizedConfigManager @@ -401,9 +401,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "past_key_values": past_key_values, @@ -413,6 +412,35 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "token_type_ids": None, } + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) + + model_kwargs["is_first_forward"] = False + return model_kwargs + def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 17abf1059e..5e94e94d06 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types from typing import Tuple import torch @@ -92,6 +93,40 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, return combined_attention_mask +@torch.jit.script_if_tracing +def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor): + mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype) + if query_layer.shape[2] == key_layer.shape[2]: + tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1) + mask.masked_fill_(tmp_mask, float("-inf")) + + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask=mask) + return context_layer + + +def _core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask): + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None: + context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer) + else: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + + return context_layer + + +def _patch_chatglm_core_attention_forward(model: "PreTrainedModel"): + for block in model.transformer.encoder.layers: + block.self_attention.core_attention.forward = types.MethodType( + _core_attention_forward, block.self_attention.core_attention + ) + + def patch_decoder_attention_mask(model: "PreTrainedModel"): """ Apply patch on decoder with past model forward to resolve first inference based on model architecture @@ -108,4 +143,7 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"): model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + elif model.config.model_type == "chatglm": + _patch_chatglm_core_attention_forward(model) + return model