diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 0834c257ea..4998cd5944 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -1872,6 +1872,68 @@ def __exit__(self, exc_type, exc_value, traceback): layer.self_attn.forward = layer.self_attn._orig_forward +# copied from https://github.com/huggingface/optimum/blob/v1.22.0/optimum/bettertransformer/models/attention.py#L169 +# for preserving backward compatibility between outdated codegen remote code and new transformers +def _codegen_wrapped_scaled_dot_product_legacy( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, +): + from optimum.bettertransformer.models.attention import raise_on_head_mask + + raise_on_head_mask(head_mask) + batch_size = query.shape[0] + mask_value = torch.finfo(value.dtype).min + mask_value = torch.full([], mask_value, dtype=value.dtype) + + if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1: + raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") + + # in codegen the query and key are always in fp32 regardless of the dtype of the model + # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226 + query = query.to(value.dtype) + key = key.to(value.dtype) + + dropout_p = self.dropout_prob_attn if self.training else 0.0 + if batch_size == 1 or self.training: + if query.shape[2] > 1: + # first step of the decoding + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True + ) + else: + # in this case, which is the later decoding steps, the `causal_mask`` in + # https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195 + # is [True, ..., True] so actually not causal + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False + ) + else: + query_length, key_length = query.size(-2), key.size(-2) + + # causal_mask is always [True, ..., True] otherwise, so executing this + # is unnecessary + if query_length > 1: + causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + + causal_mask = torch.where(causal_mask, 0, mask_value) + + # torch.Tensor.expand does no memory copy + causal_mask = causal_mask.expand(batch_size, -1, -1, -1) + + # we use torch.min to avoid having tensor(-inf) + attention_mask = torch.min(causal_mask, attention_mask) + + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False + ) + + return sdpa_result, None + + class CodeGenModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() @@ -1880,10 +1942,17 @@ def __enter__(self): # For avoiding breaking model on tracing stage, we reduce area of bettertransformer patch only for _attn. from optimum.bettertransformer.models.attention import codegen_wrapped_scaled_dot_product + attn_fn = codegen_wrapped_scaled_dot_product + if is_torch_version(">=", "2.1.0") and is_transformers_version(">=", "4.45"): + # in transformers 4.45 removed causal_mask const buffer from model + # if it is still exists, it means legacy remote code loaded + if hasattr(self._model.transformer.h[0].attn, "causal_mask"): + attn_fn = _codegen_wrapped_scaled_dot_product_legacy + for layer in self._model.transformer.h: if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions: orig_self_attn_fwd = layer.attn._attn - layer.attn._attn = types.MethodType(codegen_wrapped_scaled_dot_product, layer.attn) + layer.attn._attn = types.MethodType(attn_fn, layer.attn) layer.attn._orig_attn = orig_self_attn_fwd patch_update_causal_mask(self._model, "4.45.0", "transformer") diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 169701e4af..bc862606ba 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -773,6 +773,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "bloom", "chatglm", "codegen", + "codegen2", "gpt2", "gpt_neo", "gpt_neox", @@ -821,10 +822,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "mistral-nemo", ) - # custom modeling defined in https://huggingface.co/katuni4ka/tiny-random-codegen2 differs from transformers after v4.45 resulting in unadapted patching - if is_transformers_version("<", "4.45.0"): - SUPPORTED_ARCHITECTURES += ("codegen2",) - GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ( "chatglm",