Skip to content

Commit

Permalink
support qwen
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Oct 19, 2023
1 parent 8273e7f commit 74912c0
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ def forward(
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
past_key_values = tuple([empty_tensor] * num_layers)
elif self.config.model_type == "qwen":
new_shape = [input_ids.shape[0], 0, num_key_value_heads, d_k]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
pkv = tuple(empty_tensor for _ in range(nb_pkv))
elif self.config.model_type != "bloom":
new_shape = [input_ids.shape[0], num_key_value_heads, 0, d_k]
empty_tensor = torch.empty(size=new_shape)
Expand Down

0 comments on commit 74912c0

Please sign in to comment.