Skip to content

Commit

Permalink
add position_ids in forward
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Oct 17, 2023
1 parent f52d7c8 commit e3f87a7
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional
torch.jit.save(self.model, os.path.join(save_directory, WEIGHTS_NAME))

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
past_key_values = past_key_values or kwargs.get("past", None)

if self.use_cache and past_key_values is not None:
Expand All @@ -156,11 +156,19 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": self.use_cache,
"position_ids": None,
"position_ids": position_ids,
"attention_mask": kwargs.get("attention_mask", None),
"token_type_ids": None,
}
Expand Down Expand Up @@ -268,6 +276,10 @@ def forward(
"attention_mask": attention_mask,
}

position_ids = kwargs.get("position_ids", None)
if position_ids is not None:
inputs.update({"position_ids": position_ids})

if self.use_cache:
if past_key_values is None:
nb_pkv = 2
Expand Down Expand Up @@ -305,6 +317,7 @@ def forward(
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values

outputs = self.model(**inputs)

if isinstance(outputs, (list, tuple)):
Expand Down

0 comments on commit e3f87a7

Please sign in to comment.