diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index b4c41e0be1..6a080409ed 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -145,7 +145,7 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional torch.jit.save(self.model, os.path.join(save_directory, WEIGHTS_NAME)) # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): past_key_values = past_key_values or kwargs.get("past", None) if self.use_cache and past_key_values is not None: @@ -156,11 +156,19 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": self.use_cache, - "position_ids": None, + "position_ids": position_ids, "attention_mask": kwargs.get("attention_mask", None), "token_type_ids": None, } @@ -268,6 +276,10 @@ def forward( "attention_mask": attention_mask, } + position_ids = kwargs.get("position_ids", None) + if position_ids is not None: + inputs.update({"position_ids": position_ids}) + if self.use_cache: if past_key_values is None: nb_pkv = 2 @@ -305,6 +317,7 @@ def forward( past_key_values = tuple(tuple(pkv) for _ in range(num_layers)) inputs["past_key_values"] = past_key_values + outputs = self.model(**inputs) if isinstance(outputs, (list, tuple)):