Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TSModelForCausalLM support model_type qwen/chatglm #458

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def _reorder_cache(
"""
if self.config.model_type == "bloom":
return self._reorder_cache_bloom(past_key_values, beam_idx)
if self.config.model_type == "chatglm":
return tuple(
tuple(past_state.index_select(1, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)

# from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
return tuple(
Expand Down Expand Up @@ -276,6 +281,8 @@ def forward(
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
if model_type in {"chatglm"} and hasattr(self.normalized_config, "multi_query_group_num"):
num_attention_heads = self.normalized_config.multi_query_group_num

if model_type == "bloom":
shape_key = (batch_size * num_attention_heads, d_k, 0)
Expand All @@ -290,15 +297,22 @@ def forward(
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(pkv for _ in range(num_layers))
else:
shape = (batch_size, num_attention_heads, 0, d_k)
if self.config.model_type == "qwen":
shape = (batch_size, 0, num_attention_heads, d_k)
elif self.config.model_type == "chatglm":
shape = (0, batch_size, num_attention_heads, d_k)
else:
shape = (batch_size, num_attention_heads, 0, d_k)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values

if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
inputs["position_ids"] = position_ids

if position_ids is None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
position_ids = torch.arange(input_ids.shape[1]).repeat(input_ids.shape[0], 1)
inputs["position_ids"] = position_ids
outputs = self.model(**inputs)

if isinstance(outputs, (list, tuple)):
Expand Down