Skip to content

Commit

Permalink
Add starcode past-kv shape for TSModelForCausal class (#371)
Browse files Browse the repository at this point in the history
* add starcode past-kv shape for TSModelForCausal class

Signed-off-by: changwangss <[email protected]>

* improve code style and past-kv shape

Signed-off-by: changwangss <[email protected]>

* fix style

Signed-off-by: changwangss <[email protected]>

* support gpt_bigcode

Signed-off-by: changwangss <[email protected]>

* add gptbigcode to ipex test

Signed-off-by: changwangss <[email protected]>

* fix style

Signed-off-by: changwangss <[email protected]>

* fix style

Signed-off-by: changwangss <[email protected]>

---------

Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss authored Sep 26, 2023
1 parent 985d0d1 commit fc71567
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
19 changes: 13 additions & 6 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_class(model.config)
if task == "text-generation" and use_cache:
onnx_config = onnx_config_class(model.config, use_past=True)
onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True)
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}
if task == "text-generation" and use_cache:
if task == "text-generation" and use_cache and model.config.model_type != "gpt_bigcode":
# WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect
pkv = []
for i in range(len(model_inputs["past_key_values"])):
Expand All @@ -70,6 +70,8 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals

def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
# check if the model_inputs is correct.
model(**model_inputs)
torch._C._jit_set_texpr_fuser_enabled(False)
if "past_key_values" in model_inputs.keys():
model.config.return_dict = False
Expand Down Expand Up @@ -273,13 +275,17 @@ def forward(
num_attention_heads = self.normalized_config.num_attention_heads
hidden_size = self.normalized_config.hidden_size
d_k = hidden_size // num_attention_heads

if self.config.model_type != "bloom":
if self.config.model_type == "gpt_bigcode":
new_shape = [input_ids.shape[0], 0, d_k * 2]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
past_key_values = tuple([empty_tensor] * num_layers)
elif self.config.model_type != "bloom":
new_shape = [input_ids.shape[0], num_attention_heads, 0, d_k]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
past_key_values = tuple(tuple(empty_tensor for _ in range(nb_pkv)) for _ in range(num_layers))
pkv = tuple(empty_tensor for _ in range(nb_pkv))
else:
pkv = ()
Expand All @@ -292,7 +298,8 @@ def forward(
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
pkv = pkv + (empty_tensor,)
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))
if past_key_values is None:
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values
outputs = self.model(**inputs)
Expand Down
7 changes: 2 additions & 5 deletions tests/ipex/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
}

_CLASSIFICATION_TASK_TO_AUTOMODELS = {
Expand All @@ -55,11 +56,7 @@ class IPEXIntegrationTest(unittest.TestCase):
"roberta",
)

TEXT_GENERATION_SUPPORTED_ARCHITECTURES = (
"gptj",
"gpt2",
"gpt_neo",
)
TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("gptj", "gpt2", "gpt_neo", "gpt_bigcode")

QA_SUPPORTED_ARCHITECTURES = (
"bert",
Expand Down

0 comments on commit fc71567

Please sign in to comment.