Skip to content

Commit

Permalink
enable codegen2 back
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 6, 2024
1 parent 50012f2 commit 6dd9064
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
71 changes: 70 additions & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,68 @@ def __exit__(self, exc_type, exc_value, traceback):
layer.self_attn.forward = layer.self_attn._orig_forward


# copied from https://github.com/huggingface/optimum/blob/v1.22.0/optimum/bettertransformer/models/attention.py#L169
# for preserving backward compatibility between outdated codegen remote code and new transformers
def _codegen_wrapped_scaled_dot_product_legacy(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
from optimum.bettertransformer.models.attention import raise_on_head_mask

raise_on_head_mask(head_mask)
batch_size = query.shape[0]
mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)

dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
# first step of the decoding
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
# in this case, which is the later decoding steps, the `causal_mask`` in
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
# is [True, ..., True] so actually not causal
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)

# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)

causal_mask = torch.where(causal_mask, 0, mask_value)

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

return sdpa_result, None


class CodeGenModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
Expand All @@ -1880,10 +1942,17 @@ def __enter__(self):
# For avoiding breaking model on tracing stage, we reduce area of bettertransformer patch only for _attn.
from optimum.bettertransformer.models.attention import codegen_wrapped_scaled_dot_product

attn_fn = codegen_wrapped_scaled_dot_product
if is_torch_version(">=", "2.1.0") and is_transformers_version(">=", "4.45"):
# in transformers 4.45 removed causal_mask const buffer from model
# if it is still exists, it means legacy remote code loaded
if hasattr(self._model.transformer.h[0].attn, "causal_mask"):
attn_fn = _codegen_wrapped_scaled_dot_product_legacy

for layer in self._model.transformer.h:
if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions:
orig_self_attn_fwd = layer.attn._attn
layer.attn._attn = types.MethodType(codegen_wrapped_scaled_dot_product, layer.attn)
layer.attn._attn = types.MethodType(attn_fn, layer.attn)
layer.attn._orig_attn = orig_self_attn_fwd
patch_update_causal_mask(self._model, "4.45.0", "transformer")

Expand Down
5 changes: 1 addition & 4 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"bloom",
"chatglm",
"codegen",
"codegen2",
"gpt2",
"gpt_neo",
"gpt_neox",
Expand Down Expand Up @@ -821,10 +822,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"mistral-nemo",
)

# custom modeling defined in https://huggingface.co/katuni4ka/tiny-random-codegen2 differs from transformers after v4.45 resulting in unadapted patching
if is_transformers_version("<", "4.45.0"):
SUPPORTED_ARCHITECTURES += ("codegen2",)

GENERATION_LENGTH = 100
REMOTE_CODE_MODELS = (
"chatglm",
Expand Down

0 comments on commit 6dd9064

Please sign in to comment.