Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add starcode past-kv shape for TSModelForCausal class #371

Merged
merged 7 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_class(model.config)
if task == "text-generation" and use_cache:
onnx_config = onnx_config_class(model.config, use_past=True)
onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True)
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}
if task == "text-generation" and use_cache:
if task == "text-generation" and use_cache and model.config.model_type != "gpt_bigcode":
# WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect
pkv = []
for i in range(len(model_inputs["past_key_values"])):
Expand All @@ -70,6 +70,8 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals

def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
# check if the model_inputs is correct.
model(**model_inputs)
torch._C._jit_set_texpr_fuser_enabled(False)
if "past_key_values" in model_inputs.keys():
model.config.return_dict = False
Expand Down Expand Up @@ -273,13 +275,17 @@ def forward(
num_attention_heads = self.normalized_config.num_attention_heads
hidden_size = self.normalized_config.hidden_size
d_k = hidden_size // num_attention_heads

if self.config.model_type != "bloom":
if self.config.model_type == "gpt_bigcode":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test to verify inference is behaving as expected ? (using INCModelForCausalLM can be enough for now)

new_shape = [input_ids.shape[0], 0, d_k * 2]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
past_key_values = tuple([empty_tensor] * num_layers)
elif self.config.model_type != "bloom":
new_shape = [input_ids.shape[0], num_attention_heads, 0, d_k]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
past_key_values = tuple(tuple(empty_tensor for _ in range(nb_pkv)) for _ in range(num_layers))
pkv = tuple(empty_tensor for _ in range(nb_pkv))
else:
pkv = ()
Expand All @@ -292,7 +298,8 @@ def forward(
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
pkv = pkv + (empty_tensor,)
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))
if past_key_values is None:
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values
outputs = self.model(**inputs)
Expand Down
7 changes: 2 additions & 5 deletions tests/ipex/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
}

_CLASSIFICATION_TASK_TO_AUTOMODELS = {
Expand All @@ -55,11 +56,7 @@ class IPEXIntegrationTest(unittest.TestCase):
"roberta",
)

TEXT_GENERATION_SUPPORTED_ARCHITECTURES = (
"gptj",
"gpt2",
"gpt_neo",
)
TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("gptj", "gpt2", "gpt_neo", "gpt_bigcode")

QA_SUPPORTED_ARCHITECTURES = (
"bert",
Expand Down
Loading