From 9a3b4003db60efb8bbb42bb64f67dda003f6234c Mon Sep 17 00:00:00 2001 From: "Feng, Jiqing" Date: Tue, 14 Nov 2023 19:24:34 -0800 Subject: [PATCH] fix position_ids length --- optimum/intel/generation/modeling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index f88ac6a7db..bff66f4783 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -302,7 +302,11 @@ def forward( if not self.use_cache: past_key_values_length = 0 else: - past_key_values_length = past_key_values[0][1].shape[-2] + past_key_values_length = ( + past_key_values[0].shape[-2] + if model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS + else past_key_values[0][1].shape[-2] + ) position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device ).unsqueeze(0)