diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index 6467455c62..0299f77fad 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -20,11 +20,11 @@ from typing import Optional, Tuple, Union import torch -from huggingface_hub import hf_hub_download from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import WEIGHTS_NAME +from huggingface_hub import hf_hub_download from optimum.exporters import TasksManager from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager @@ -49,10 +49,10 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) onnx_config = onnx_config_class(model.config) if task == "text-generation" and use_cache: - onnx_config = onnx_config_class(model.config, use_past=True) + onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True) dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} - if task == "text-generation" and use_cache: + if task == "text-generation" and use_cache and model.config.model_dtype != "gpt_bigcode": # WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect pkv = [] for i in range(len(model_inputs["past_key_values"])): @@ -70,6 +70,8 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): model_inputs = prepare_jit_inputs(model, task, use_cache) + # check if the model_inputs is correct. + model(**model_inputs) torch._C._jit_set_texpr_fuser_enabled(False) if "past_key_values" in model_inputs.keys(): model.config.return_dict = False