From b05617797ce920959c0326ea942589fe65f3e9f3 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Thu, 14 Dec 2023 17:54:19 +0400 Subject: [PATCH] Add support whisper for openvino (#470) * add support whisper for openvino * add test * fix tests * restrict transformers version for now... * allow to run on GPU * apply review comments * fix compatibility with transformers 4.36 * fix generate * apply comments * fix pix2struct --- optimum/exporters/openvino/__main__.py | 18 +- optimum/intel/__init__.py | 2 + optimum/intel/openvino/__init__.py | 2 +- .../intel/openvino/modeling_base_seq2seq.py | 2 - optimum/intel/openvino/modeling_seq2seq.py | 624 +++++++++++++++++- optimum/intel/openvino/trainer.py | 24 +- optimum/intel/utils/dummy_openvino_objects.py | 11 + tests/openvino/test_modeling.py | 65 ++ tests/openvino/utils_tests.py | 2 + 9 files changed, 705 insertions(+), 45 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index cb011706c8..9be180621c 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -26,10 +26,19 @@ from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors -from ...intel.utils.import_utils import is_nncf_available +from ...intel.utils.import_utils import is_nncf_available, is_optimum_version, is_transformers_version from .convert import export_models +if is_optimum_version(">=", "1.16.0"): + from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED +else: + # Copied from https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/constants.py + SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED = [ + "bart", + "whisper", + ] + OV_XML_FILE_NAME = "openvino_model.xml" _MAX_UNCOMPRESSED_SIZE = 1e9 @@ -140,10 +149,12 @@ def main_export( do_gptq_patching = False try: config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) + model_type = config.model_type.replace("_", "-") config_dict = config.to_dict() quantization_config = config_dict.get("quantization_config", None) do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" except Exception: + model_type = None pass if do_gptq_patching: @@ -192,6 +203,10 @@ class StoreAttr(object): f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) + loading_kwargs = {} + if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: + loading_kwargs["attn_implementation"] = "eager" + model = TasksManager.get_model_from_task( task, model_name_or_path, @@ -204,6 +219,7 @@ class StoreAttr(object): trust_remote_code=trust_remote_code, framework=framework, device=device, + **loading_kwargs, ) custom_architecture = False diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index 5fe65dcd41..570a451bd8 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -99,6 +99,7 @@ "OVModelForPix2Struct", "OVModelForQuestionAnswering", "OVModelForSeq2SeqLM", + "OVModelForSpeechSeq2Seq", "OVModelForSequenceClassification", "OVModelForTokenClassification", ] @@ -195,6 +196,7 @@ OVModelForQuestionAnswering, OVModelForSeq2SeqLM, OVModelForSequenceClassification, + OVModelForSpeechSeq2Seq, OVModelForTokenClassification, ) diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index 7ed550ceb0..6999c6b48f 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -46,7 +46,7 @@ OVModelForTokenClassification, ) from .modeling_decoder import OVModelForCausalLM -from .modeling_seq2seq import OVModelForPix2Struct, OVModelForSeq2SeqLM +from .modeling_seq2seq import OVModelForPix2Struct, OVModelForSeq2SeqLM, OVModelForSpeechSeq2Seq if is_diffusers_available(): diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 527adc4347..e3dd1d7aa0 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -68,8 +68,6 @@ def __init__( self.ov_config = ov_config if ov_config is not None else {} self.preprocessors = kwargs.get("preprocessors", []) - if "GPU" in self._device: - raise ValueError("Support of dynamic shapes for GPU devices is not yet available.") if self.is_dynamic: encoder = self._reshape(encoder, -1, -1, is_decoder=False) decoder = self._reshape(decoder, -1, -1) diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 6b759054d0..d43dbf3427 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -12,19 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import logging from pathlib import Path from tempfile import gettempdir -from typing import Dict, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union import numpy as np import openvino import torch import transformers from openvino.runtime import Core -from transformers import AutoConfig, AutoModelForSeq2SeqLM, Pix2StructForConditionalGeneration +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoModelForSpeechSeq2Seq, + Pix2StructForConditionalGeneration, + WhisperForConditionalGeneration, +) from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward +from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE from ..utils.import_utils import is_transformers_version from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM @@ -35,6 +44,9 @@ else: from transformers.generation import GenerationMixin +if TYPE_CHECKING: + from transformers import PretrainedConfig + core = Core() logger = logging.getLogger(__name__) @@ -176,6 +188,56 @@ ``` """ +SPEECH_SEQ2SEQ_MODEL_DOCSTRING = r""" + Args: + input_features (`torch.FloatTensor`): + Mel features extracted from the raw speech waveform. + `(batch_size, feature_size, encoder_sequence_length)`. + decoder_input_ids (`torch.LongTensor`): + Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`. + encoder_outputs (`torch.FloatTensor`): + The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)` + Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding. + The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. +""" + +AUTOMATIC_SPEECH_RECOGNITION_EXAMPLE = r""" + Example of text generation: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.intel.openvino import {model_class} + >>> from datasets import load_dataset + + >>> processor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor.feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + + >>> gen_tokens = model.generate(inputs=inputs.input_features) + >>> outputs = processor.tokenizer.batch_decode(gen_tokens) + ``` + + Example using `transformers.pipeline`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.intel.openvino import {model_class} + >>> from datasets import load_dataset + + >>> processor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> speech_recognition = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> pred = speech_recognition(ds[0]["audio"]["array"]) + ``` +""" + @add_start_docstrings( """ @@ -262,6 +324,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.FloatTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, **kwargs, @@ -276,6 +339,7 @@ def forward( input_ids=decoder_input_ids, encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, ) else: decoder_outputs = self.decoder_with_past( @@ -283,6 +347,7 @@ def forward( past_key_values=past_key_values, encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, ) return Seq2SeqLMOutput(logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values) @@ -394,9 +459,9 @@ def forward( inputs["attention_mask"] = attention_mask # Run inference - last_hidden_state = torch.from_numpy(self.request(inputs, shared_memory=True)["last_hidden_state"]).to( - self.device - ) + last_hidden_state = torch.from_numpy( + self.request(inputs, share_inputs=True, share_outputs=True)["last_hidden_state"] + ).to(self.device) return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -475,7 +540,7 @@ def forward( if "decoder_attention_mask" in self.input_names and decoder_attention_mask is not None: inputs["decoder_attention_mask"] = decoder_attention_mask # Run inference - self.request.start_async(inputs, shared_memory=True) + self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) @@ -567,35 +632,14 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, **kwargs, ) -> Seq2SeqLMOutput: - # Encode if needed : first prediction pass - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=flattened_patches, - attention_mask=attention_mask, - ) - - # Decode - if past_key_values is None or self.use_cache is False: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - ) - else: - decoder_outputs = self.decoder_with_past( - input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - ) - - return Seq2SeqLMOutput( - logits=decoder_outputs.logits, - past_key_values=decoder_outputs.past_key_values, + return super().forward( + input_ids=flattened_patches, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + **kwargs, ) def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True): @@ -610,3 +654,513 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng shapes[inputs][1] = -1 model.reshape(shapes) return model + + +@add_start_docstrings( + """ + Speech Sequence-to-sequence model with a language modeling head for OpenVINO inference. This class officially supports whisper, speech_to_text. + """, + INPUTS_DOCSTRING, +) +class OVModelForSpeechSeq2Seq(OVModelForSeq2SeqLM): + auto_model_class = AutoModelForSpeechSeq2Seq + main_input_name = "input_features" + export_feature = "automatic-speech-recognition" + + def prepare_inputs_for_generation( + self, + input_ids, + attention_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ) -> Dict: + if decoder_attention_mask is None: + decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) + + return { + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + @add_start_docstrings_to_model_forward( + SPEECH_SEQ2SEQ_MODEL_DOCSTRING + + AUTOMATIC_SPEECH_RECOGNITION_EXAMPLE.format( + processor_class=_PROCESSOR_FOR_DOC, + model_class="OVModelForSpeechSeq2Seq", + checkpoint="openai/whisper-tiny", + ) + ) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + **kwargs, + ) -> Seq2SeqLMOutput: + return super().forward( + input_ids=input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + **kwargs, + ) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + **kwargs, + ): + if "WhisperForConditionalGeneration" in config.architectures: + return _OVModelForWhisper._from_pretrained(model_id, config, **kwargs) + else: + return super()._from_pretrained(model_id, config, **kwargs) + + +class _OVModelForWhisper(OVModelForSpeechSeq2Seq): + """ + Whisper implements its own generate() method. + """ + + auto_model_class = WhisperForConditionalGeneration + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + **kwargs, + ): + return super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs) + + # Adapted from transformers.models.whisper.modeling_whisper + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_timestamps=None, + task=None, + language=None, + is_multilingual=None, + prompt_ids: Optional[torch.Tensor] = None, + num_segment_frames: Optional[int] = None, + return_token_timestamps: Optional[bool] = None, + return_segments: bool = False, + attention_mask: Optional[torch.Tensor] = None, + time_precision: int = 0.02, + return_dict_in_generate: Optional[bool] = None, + **kwargs, + ): + if "inputs" in kwargs: + input_features = kwargs.pop("inputs") + logging.warn( + "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", + FutureWarning, + ) + + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + + input_stride = ( + 1 * 2 + ) # NOTE: replaced from `self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]` + if num_segment_frames is None: + num_segment_frames = input_stride * self.config.max_source_positions + + # 1. Check whether we're in shortform or longform mode + if input_features is not None: + total_input_frames = input_features.shape[-1] + elif "encoder_outputs" in kwargs: + encoder_outputs_shape = ( + kwargs["encoder_outputs"][0].shape + if isinstance(kwargs["encoder_outputs"], BaseModelOutput) + else kwargs["encoder_outputs"].shape + ) + total_input_frames = encoder_outputs_shape[1] * input_stride + else: + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") + + is_shortform = total_input_frames <= num_segment_frames + + # 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not + if return_timestamps is True: + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError( + "You are trying to return timestamps, but the generation config is not properly set. " + "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " + "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + ) + generation_config.return_timestamps = return_timestamps + elif not is_shortform: + if return_timestamps is False: + raise ValueError( + "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " + "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." + ) + + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError( + "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " + "requires the generation config to have `no_timestamps_token_id` correctly. " + "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " + "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + "or make sure to pass no more than 3000 mel input features." + ) + + logger.info("Setting `return_timestamps=True` for long-form generation.") + generation_config.return_timestamps = True + else: + generation_config.return_timestamps = False + + # 3. Make sure to correctly set language-related parameters + if is_multilingual is not None: + if not hasattr(generation_config, "is_multilingual"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `is_multilingual` argument " + "to `generate`. Please update the generation config as per the instructions " + "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + generation_config.is_multilingual = is_multilingual + + if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual: + if task is not None or language is not None: + raise ValueError( + "Cannot specify `task` or `language` for an English-only model. If the model is intended to be " + "multilingual, pass `is_multilingual=True` to generate, or update the generation config." + ) + + if language is not None: + if not hasattr(generation_config, "lang_to_id"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `language` argument " + "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, " + "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + language = language.lower() + generation_config.language = language + if task is not None: + if not hasattr(generation_config, "task_to_id"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `task` argument " + "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, " + "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + generation_config.task = task + + # 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps` + forced_decoder_ids = None + # Legacy code for backward compatibility + if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: + forced_decoder_ids = self.config.forced_decoder_ids + elif ( + hasattr(self.generation_config, "forced_decoder_ids") + and self.generation_config.forced_decoder_ids is not None + ): + forced_decoder_ids = self.generation_config.forced_decoder_ids + else: + forced_decoder_ids = kwargs.get("forced_decoder_ids", None) + + if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): + forced_decoder_ids = [] + if hasattr(generation_config, "language"): + if generation_config.language in generation_config.lang_to_id.keys(): + language_token = generation_config.language + elif generation_config.language in TO_LANGUAGE_CODE.keys(): + language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" + elif generation_config.language in TO_LANGUAGE_CODE.values(): + language_token = f"<|{generation_config.language}|>" + else: + is_language_code = len(generation_config.language) == 2 + raise ValueError( + f"Unsupported language: {generation_config.language}. Language should be one of:" + f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." + ) + forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) + else: + forced_decoder_ids.append((1, None)) # automatically detect the language + + if hasattr(generation_config, "task"): + if generation_config.task in TASK_IDS: + forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) + else: + raise ValueError( + f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" + ) + elif hasattr(generation_config, "task_to_id"): + forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe + if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: + idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 + forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) + + if forced_decoder_ids is not None: + generation_config.forced_decoder_ids = forced_decoder_ids + + if prompt_ids is not None: + if kwargs.get("decoder_start_token_id") is not None: + raise ValueError( + "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." + ) + prompt_ids = prompt_ids.tolist() + decoder_start_token_id, *text_prompt_ids = prompt_ids + # Slicing the text prompt ids in a manner consistent with the OpenAI implementation + # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) + text_prompt_ids = text_prompt_ids[-self.config.max_target_positions // 2 - 1 :] + # Set the decoder_start_token_id to <|startofprev|> + kwargs.update({"decoder_start_token_id": decoder_start_token_id}) + + # If the user passes `max_new_tokens`, increase its number to account for the prompt + if kwargs.get("max_new_tokens", None) is not None: + kwargs["max_new_tokens"] += len(text_prompt_ids) + if kwargs["max_new_tokens"] >= self.config.max_target_positions: + raise ValueError( + f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` " + f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced " + f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the " + f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " + "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " + f"so that their combined length is less that {self.config.max_target_positions}." + ) + + # Reformat the forced_decoder_ids to incorporate the prompt + non_prompt_forced_decoder_ids = ( + kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids + ) + forced_decoder_ids = [ + *text_prompt_ids, + generation_config.decoder_start_token_id, + *[token for _rank, token in non_prompt_forced_decoder_ids], + ] + forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] + generation_config.forced_decoder_ids = forced_decoder_ids + + if return_token_timestamps: + kwargs["output_attentions"] = True + return_dict_in_generate = True + + if getattr(generation_config, "task", None) == "translate": + logger.warning("Token-level timestamps may not be reliable for task 'translate'.") + if not hasattr(generation_config, "alignment_heads"): + raise ValueError( + "Model generation config has no `alignment_heads`, token-level timestamps not available. " + "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." + ) + + if kwargs.get("num_frames") is not None: + generation_config.num_frames = kwargs.pop("num_frames") + + if generation_config.return_timestamps is True: + last_forced_decoder_ids = ( + generation_config.forced_decoder_ids[-1][-1] + if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids + else None + ) + if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id: + # remove no_timestamp to be forcefully generated if we want to return timestamps + # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly + forced_decoder_ids = generation_config.forced_decoder_ids[:-1] + # Make sure that if list is empty we set it to None + generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids + + timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)] + logits_processor = ( + timestamp_processor if logits_processor is None else timestamp_processor + logits_processor + ) + + # 5. If we're in shortform mode, simple generate the whole input at once and return the output + if is_shortform: + outputs = super().generate( + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + num_frames = getattr(generation_config, "num_frames", None) + outputs["token_timestamps"] = self._extract_token_timestamps( + outputs, generation_config.alignment_heads, num_frames=num_frames + ) + + return outputs + + # 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated + # timestamp tokens + # 6.1 Set running parameters for while loop + if not return_segments and return_dict_in_generate: + raise ValueError( + "Make sure to set `return_segments=True` to return generation outputs as part of the `'segments' key.`" + ) + + # if input is longer than 30 seconds we default to long-form generation + timestamp_begin = self.generation_config.no_timestamps_token_id + 1 + # input stride is mel frames per encoder output vector which is the product of all conv strides + batch_size = input_features.shape[0] + + if batch_size > 1 and attention_mask is None: + raise ValueError( + "When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " + ) + elif batch_size > 1: + max_frames = attention_mask.sum(-1).cpu().to(torch.long) + seek = torch.zeros((batch_size,), dtype=torch.long) + else: + max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames + seek = torch.zeros((1,), dtype=torch.long) + + current_segments = [[] for _ in range(batch_size)] + cur_to_prev_index_map = list(range(batch_size)) + + # batch size can decrease during the run + cur_bsz = prev_bsz = batch_size + + # 6.2 Transcribe audio until we reach the end of all input audios + while (seek < max_frames).any(): + prev_bsz = cur_bsz + + # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop + # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order + # to know which original audio is being decoded + new_cur_to_prev_index_map = [] + for i in range(prev_bsz): + prev_i = cur_to_prev_index_map[i] + if seek[prev_i] >= max_frames[prev_i]: + cut_index = i + (cur_bsz - prev_bsz) + cur_bsz -= 1 + input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0) + else: + # cut out index that goes away + new_cur_to_prev_index_map.append(prev_i) + + # 6.4 Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk + cur_to_prev_index_map = new_cur_to_prev_index_map + time_offset = seek * time_precision / input_stride + seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) + + # 6.5 Make sure that all inputs are padded to the same input length + segment_input = [] + for i in range(cur_bsz): + prev_i = cur_to_prev_index_map[i] + segment_input_slice = input_features[ + i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i] + ] + + if segment_input_slice.shape[-1] < num_segment_frames: + # pad to 3000 if necessary + segment_input_slice = torch.nn.functional.pad( + segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1]) + ) + + segment_input.append(segment_input_slice) + + segment_input = torch.cat(segment_input, dim=0) + + # 6.6 Batch generate current chunk + seek_outputs = super().generate( + segment_input, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + num_frames = getattr(generation_config, "num_frames", None) + seek_outputs["token_timestamps"] = self._extract_token_timestamps( + seek_outputs, generation_config.alignment_heads, num_frames=num_frames + ) + + if return_dict_in_generate: + seek_sequences = seek_outputs["sequences"] + seek_outputs = [ + {k: v[i] for k, v in seek_outputs.items()} + for i in range(next(iter(seek_outputs.values())).size(0)) + ] + else: + seek_sequences = seek_outputs + + # 6.7 Loop over each decoded audio individually as each decoding can be of a different length + for i, seek_sequence in enumerate(seek_sequences): + prev_i = cur_to_prev_index_map[i] + + # make sure we cut a predicted EOS token if we are not finished with the generation yet + is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] + if is_not_final and seek_sequence[-1] == self.generation_config.eos_token_id: + seek_sequence = seek_sequence[:-1] + + # remove all padding tokens + if seek_sequence[-1] == self.generation_config.pad_token_id: + num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum() + seek_sequence = seek_sequence[:-num_paddings] + + segments, segment_offset = self._retrieve_segment( + seek_sequence=seek_sequence, + seek_outputs=seek_outputs, + time_offset=time_offset, + timestamp_begin=timestamp_begin, + seek_num_frames=seek_num_frames, + cur_bsz=cur_bsz, + time_precision=time_precision, + input_stride=input_stride, + prev_idx=prev_i, + idx=i, + ) + + current_segments[prev_i] += segments + seek[prev_i] += segment_offset + + # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted + # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output + sequences = [] + max_total_length = 0 + for current_segment_list in current_segments: + sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1)) + max_total_length = max(max_total_length, len(sequences[-1])) + + for i in range(batch_size): + sequences[i] = torch.nn.functional.pad( + sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id + ) + + sequences = torch.stack(sequences, dim=0) + + # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. + if return_segments: + return {"sequences": sequences, "segments": current_segments} + + return sequences diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 618d3807b8..f5badac7b6 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -637,6 +637,10 @@ def _inner_training_loop( if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping + if getattr(self, "do_grad_scaling", False): + # AMP: gradients need unscaling + self.scaler.unscale_(self.optimizer) + if is_sagemaker_mp_enabled() and args.fp16: self.optimizer.clip_master_grads(args.max_grad_norm) elif self.use_apex: @@ -652,12 +656,20 @@ def _inner_training_loop( ) # Optimizer step - self.optimizer.step() - optimizer_was_run = not self.accelerator.optimizer_step_was_skipped - if optimizer_was_run: - # Delay optimizer scheduling until metrics are generated - if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - self.lr_scheduler.step() + optimizer_was_run = True + if self.deepspeed: + pass # called outside the loop + elif getattr(self, "do_grad_scaling", False): + scale_before = self.scaler.get_scale() + self.scaler.step(self.optimizer) + self.scaler.update() + scale_after = self.scaler.get_scale() + optimizer_was_run = scale_before <= scale_after + else: + self.optimizer.step() + + if optimizer_was_run and not self.deepspeed: + self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 diff --git a/optimum/intel/utils/dummy_openvino_objects.py b/optimum/intel/utils/dummy_openvino_objects.py index a6d62652d5..9e17035d70 100644 --- a/optimum/intel/utils/dummy_openvino_objects.py +++ b/optimum/intel/utils/dummy_openvino_objects.py @@ -136,6 +136,17 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["openvino"]) +class OVModelForSpeechSeq2Seq(metaclass=DummyObject): + _backends = ["openvino"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["openvino"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["openvino"]) + + class OVModelForSequenceClassification(metaclass=DummyObject): _backends = ["openvino"] diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index c29e8c2eef..dc33b39f2a 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -40,6 +40,7 @@ AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, + AutoModelForSpeechSeq2Seq, AutoModelForTokenClassification, AutoTokenizer, GenerationConfig, @@ -65,6 +66,7 @@ OVModelForQuestionAnswering, OVModelForSeq2SeqLM, OVModelForSequenceClassification, + OVModelForSpeechSeq2Seq, OVModelForTokenClassification, OVStableDiffusionPipeline, ) @@ -1205,3 +1207,66 @@ def test_compare_with_and_without_past_key_values(self): del model_with_pkv del model_without_pkv gc.collect() + + +class OVModelForSpeechSeq2SeqIntegrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ("whisper",) + + def _generate_random_audio_data(self): + np.random.seed(10) + t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False) + # generate pure sine wave at 220 Hz + audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) + return audio_data + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True) + self.assertIsInstance(ov_model.config, PretrainedConfig) + transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id) + processor = get_preprocessor(model_id) + data = self._generate_random_audio_data() + features = processor.feature_extractor(data, return_tensors="pt") + + decoder_start_token_id = transformers_model.config.decoder_start_token_id + decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} + + with torch.no_grad(): + transformers_outputs = transformers_model(**features, **decoder_inputs) + + for input_type in ["pt", "np"]: + features = processor.feature_extractor(data, return_tensors=input_type) + + if input_type == "np": + decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id} + + ov_outputs = ov_model(**features, **decoder_inputs) + self.assertIn("logits", ov_outputs) + # Compare tensor outputs + self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-3)) + + del transformers_model + del ov_model + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True) + processor = get_preprocessor(model_id) + GenerationConfig.from_pretrained(model_id) + pipe = pipeline( + "automatic-speech-recognition", + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + ) + data = self._generate_random_audio_data() + outputs = pipe(data) + self.assertIsInstance(outputs["text"], str) + + del pipe + del model + gc.collect() diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index f8abf6bc6a..044d32f012 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -68,6 +68,7 @@ "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", "segformer": "hf-internal-testing/tiny-random-SegformerModel", + "speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel", "squeezebert": "hf-internal-testing/tiny-random-squeezebert", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", @@ -84,6 +85,7 @@ "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", "wav2vec2-hf": "hf-internal-testing/tiny-random-Wav2Vec2Model", "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", + "whisper": "openai/whisper-tiny.en", "xlm": "hf-internal-testing/tiny-random-xlm", "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", }