Skip to content

Commit

Permalink
INCModelForCausalLM support gpt_bigcode mode
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Sep 25, 2023
1 parent e0c1fc4 commit 6f42527
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from typing import Optional, Tuple, Union

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import WEIGHTS_NAME

from huggingface_hub import hf_hub_download
from optimum.exporters import TasksManager
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager
Expand All @@ -49,10 +49,10 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_class(model.config)
if task == "text-generation" and use_cache:
onnx_config = onnx_config_class(model.config, use_past=True)
onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True)
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}
if task == "text-generation" and use_cache:
if task == "text-generation" and use_cache and model.config.model_dtype != "gpt_bigcode":
# WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect
pkv = []
for i in range(len(model_inputs["past_key_values"])):
Expand All @@ -70,6 +70,8 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals

def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
# check if the model_inputs is correct.
model(**model_inputs)
torch._C._jit_set_texpr_fuser_enabled(False)
if "past_key_values" in model_inputs.keys():
model.config.return_dict = False
Expand Down

0 comments on commit 6f42527

Please sign in to comment.