diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 9b8747243..75106fc2b 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -103,7 +103,7 @@ def _get_input_info( symbol = name_to_symbol[dim_name] else: symbol = Symbol() - name_to_symbol[name] = symbol + name_to_symbol[dim_name] = symbol dim = Dimension(-1) dim.set_symbol(symbol) shape[idx] = dim diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 733f5a411..56b7a1c5a 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -16,7 +16,7 @@ import os from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import openvino @@ -31,7 +31,7 @@ from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.generation.utils import GenerateOutput, GenerationMode -from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from optimum.utils.normalized_config import NormalizedConfigManager @@ -504,8 +504,8 @@ def prepare_inputs( else: position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] inputs["position_ids"] = position_ids @@ -604,6 +604,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg return model_inputs + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + **kwargs, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs + ) + + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) + return model_kwargs + def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): batch_size = logits.shape[0] if indicies.shape[0] != 1: