From f52960bdb42400a0dbd0305b4007e3a0a34cd857 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 23 Oct 2023 21:32:40 +0200 Subject: [PATCH] fix input names --- optimum/intel/openvino/modeling_base.py | 8 +++++++- optimum/intel/utils/modeling_utils.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 321f50f570..9384477eb9 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -79,7 +79,13 @@ def __init__( height = -1 if self.export_feature == "image-classification" else None width = -1 if self.export_feature == "image-classification" else None model = self._reshape(model, -1, -1, height, width) - self.input_names = {key.get_any_name(): idx for idx, key in enumerate(model.inputs)} + + input_names = {} + for idx, key in enumerate(model.inputs): + names = tuple(key.get_names()) + input_names[next((name for name in names if "/" not in name), names[0])] = idx + self.input_names = input_names + self.model = model self.request = None if enable_compilation: diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 6b6cd30999..b56e5e4f2d 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -18,6 +18,9 @@ from transformers.modeling_utils import PreTrainedModel +# from ...utils.modeling_utils import _prepare_decoder_sliding_window_attention_mask + + MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} @@ -109,6 +112,8 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"): model.transformer._prepare_attn_mask = _prepare_attn_mask elif model.config.model_type == "llama": model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + # elif model.config.model_type == "mistral": + # model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask return model