diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index bbfc3db63d..5dfb7e9819 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -169,6 +169,11 @@ def _reorder_cache( """ if self.config.model_type == "bloom": return self._reorder_cache_bloom(past_key_values, beam_idx) + if self.config.model_type == "chatglm": + return tuple( + tuple(past_state.index_select(1, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache return tuple( @@ -276,6 +281,8 @@ def forward( num_attention_heads = self.normalized_config.num_key_value_heads else: num_attention_heads = self.normalized_config.num_attention_heads + if model_type in {"chatglm"} and hasattr(self.normalized_config, "multi_query_group_num"): + num_attention_heads = self.normalized_config.multi_query_group_num if model_type == "bloom": shape_key = (batch_size * num_attention_heads, d_k, 0) @@ -290,7 +297,12 @@ def forward( pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) past_key_values = tuple(pkv for _ in range(num_layers)) else: - shape = (batch_size, num_attention_heads, 0, d_k) + if self.config.model_type == "qwen": + shape = (batch_size, 0, num_attention_heads, d_k) + elif self.config.model_type == "chatglm": + shape = (0, batch_size, num_attention_heads, d_k) + else: + shape = (batch_size, num_attention_heads, 0, d_k) pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers)) @@ -298,7 +310,9 @@ def forward( if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: inputs["position_ids"] = position_ids - + if position_ids is None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + position_ids = torch.arange(input_ids.shape[1]).repeat(input_ids.shape[0], 1) + inputs["position_ids"] = position_ids outputs = self.model(**inputs) if isinstance(outputs, (list, tuple)):