From 7323be50d0f4432fd2ce6a508cb8ee1c792d8ef9 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 14 Dec 2023 09:51:23 +0400 Subject: [PATCH] apply review comments --- optimum/intel/openvino/modeling_seq2seq.py | 80 +++++----------------- 1 file changed, 18 insertions(+), 62 deletions(-) diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index cd0d3bd773..9eb3ae45bd 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -627,35 +627,14 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, **kwargs, ) -> Seq2SeqLMOutput: - # Encode if needed : first prediction pass - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=flattened_patches, - attention_mask=attention_mask, - ) - - # Decode - if past_key_values is None or self.use_cache is False: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - ) - else: - decoder_outputs = self.decoder_with_past( - input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - ) - - return Seq2SeqLMOutput( - logits=decoder_outputs.logits, - past_key_values=decoder_outputs.past_key_values, + return super().forward( + input_ids=flattened_patches, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + **kwargs, ) def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True): @@ -686,7 +665,6 @@ class OVModelForSpeechSeq2Seq(OVModelForSeq2SeqLM): def prepare_inputs_for_generation( self, input_ids, - input_features: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, past_key_values=None, @@ -701,7 +679,6 @@ def prepare_inputs_for_generation( decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) return { - "input_features": input_features, "decoder_input_ids": input_ids, "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, @@ -731,35 +708,14 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, **kwargs, ) -> Seq2SeqLMOutput: - # Encode if needed : first prediction pass - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_features, - attention_mask=attention_mask, - ) - - # Decode - if past_key_values is None or self.use_cache is False: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - ) - else: - decoder_outputs = self.decoder_with_past( - input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - ) - - return Seq2SeqLMOutput( - logits=decoder_outputs.logits, - past_key_values=decoder_outputs.past_key_values, + return super().forward( + input_ids=input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + **kwargs, ) @classmethod @@ -770,7 +726,7 @@ def _from_pretrained( **kwargs, ): if "WhisperForConditionalGeneration" in config.architectures: - return OVModelForWhisper._from_pretrained(model_id, config, **kwargs) + return _OVModelForWhisper._from_pretrained(model_id, config, **kwargs) else: return super()._from_pretrained(model_id, config, **kwargs) @@ -790,7 +746,7 @@ def mro(cls): ) -class OVModelForWhisper( +class _OVModelForWhisper( OVModelForSpeechSeq2Seq, WhisperForConditionalGeneration, metaclass=MetaClassRemoveParentsAndReorder ): """