Skip to content

Commit

Permalink
fix position_ids length
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Nov 15, 2023
1 parent f187713 commit 9a3b400
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,11 @@ def forward(
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = past_key_values[0][1].shape[-2]
past_key_values_length = (
past_key_values[0].shape[-2]
if model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS
else past_key_values[0][1].shape[-2]
)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device
).unsqueeze(0)
Expand Down

0 comments on commit 9a3b400

Please sign in to comment.