-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper is ExecuTorch compatible #33842
Comments
Is there anything I can do to contribute to this? :) @guangy10 |
Following some works at Consumer AI Edge Hackathon with @gsuriano and a huge thanks to @xenova for his experience in exporting models in ONNX, here are some first hints on how to torch.export.export Whisper model. A direct export of WhisperForConditionalGeneration does not work. import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperConfig, WhisperModel
import numpy as np
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_sample = ds[0]["audio"]
# Step 1: Load Whisper model and processor from Hugging Face
# config = WhisperConfig()
model_name = "openai/whisper-small"
model = WhisperForConditionalGeneration.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)
# Ensure the model is in evaluation mode
model.eval()
torch.export.export(
model,
args=(input_features,),
strict=False,
)
Error that is brought by the decoder. Let's divide the model in two parts: the encoder and the decoder. Exporting the encoderActually the encoder exports without any issue.
And it is working directly in a strict mode.
Exporting the decoderWe need to provide : input_ids(torch.LongTensor), attention_mask(torch.Tensor), encoder_hidden_states(torch.FloatTensor).
Actually the decoder works directly if no encoder_hidden_states are provided, which is not what we want.
But the export works in strict=False mode. But breaking for strict=True.
Actually the graph break happens in the [if-condition line 301 in modeling_whisper.py] (https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/whisper/modeling_whisper.py#L301) Following strict mode strategy, it will assume that the if condition is always True. In our case the if-statement is
and we expect is_cross_attention to be True, past_key_value to be provided. At the moment we are not sure what is the workflow behind the is_updated variable. But simply comparing the exported_decoder with the original decoder supports that the graph is correct.
The same is also possible with KV-cache, by providing the past_key_values as an argument.
What to do next ?
Side note: this procedure can probably be applied to other models, especially encoder/decoder. And shout out to my teammates @Shaamallow @fracapuano @gsuriano |
Feature request
Enable Whisper to "Export to ExecuTorch" workflow
Motivation
See details in #32253
Your contribution
TBD
The text was updated successfully, but these errors were encountered: