Skip to content

Commit

Permalink
apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 14, 2023
1 parent 2148358 commit 7323be5
Showing 1 changed file with 18 additions and 62 deletions.
80 changes: 18 additions & 62 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,35 +627,14 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
**kwargs,
) -> Seq2SeqLMOutput:
# Encode if needed : first prediction pass
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=flattened_patches,
attention_mask=attention_mask,
)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
)

return Seq2SeqLMOutput(
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
return super().forward(
input_ids=flattened_patches,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
**kwargs,
)

def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True):
Expand Down Expand Up @@ -686,7 +665,6 @@ class OVModelForSpeechSeq2Seq(OVModelForSeq2SeqLM):
def prepare_inputs_for_generation(
self,
input_ids,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values=None,
Expand All @@ -701,7 +679,6 @@ def prepare_inputs_for_generation(
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)

return {
"input_features": input_features,
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
Expand Down Expand Up @@ -731,35 +708,14 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
**kwargs,
) -> Seq2SeqLMOutput:
# Encode if needed : first prediction pass
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_features,
attention_mask=attention_mask,
)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
)

return Seq2SeqLMOutput(
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
return super().forward(
input_ids=input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
**kwargs,
)

@classmethod
Expand All @@ -770,7 +726,7 @@ def _from_pretrained(
**kwargs,
):
if "WhisperForConditionalGeneration" in config.architectures:
return OVModelForWhisper._from_pretrained(model_id, config, **kwargs)
return _OVModelForWhisper._from_pretrained(model_id, config, **kwargs)
else:
return super()._from_pretrained(model_id, config, **kwargs)

Expand All @@ -790,7 +746,7 @@ def mro(cls):
)


class OVModelForWhisper(
class _OVModelForWhisper(
OVModelForSpeechSeq2Seq, WhisperForConditionalGeneration, metaclass=MetaClassRemoveParentsAndReorder
):
"""
Expand Down

0 comments on commit 7323be5

Please sign in to comment.