Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper is ExecuTorch compatible #33842

Open
Tracked by #32253
guangy10 opened this issue Sep 30, 2024 · 2 comments
Open
Tracked by #32253

Whisper is ExecuTorch compatible #33842

guangy10 opened this issue Sep 30, 2024 · 2 comments
Labels
ExecuTorch Feature request Request for a new feature

Comments

@guangy10
Copy link
Contributor

Feature request

Enable Whisper to "Export to ExecuTorch" workflow

Motivation

See details in #32253

Your contribution

TBD

@guangy10 guangy10 added the Feature request Request for a new feature label Sep 30, 2024
@guangy10 guangy10 mentioned this issue Sep 30, 2024
26 tasks
@v-prgmr
Copy link

v-prgmr commented Oct 22, 2024

Is there anything I can do to contribute to this? :) @guangy10

@thomasloux
Copy link

Following some works at Consumer AI Edge Hackathon with @gsuriano and a huge thanks to @xenova for his experience in exporting models in ONNX, here are some first hints on how to torch.export.export Whisper model.

A direct export of WhisperForConditionalGeneration does not work.

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperConfig, WhisperModel

import numpy as np
from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_sample = ds[0]["audio"]
# Step 1: Load Whisper model and processor from Hugging Face

# config = WhisperConfig()
model_name = "openai/whisper-small"
model = WhisperForConditionalGeneration.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)

# Ensure the model is in evaluation mode
model.eval()
torch.export.export(
    model,
    args=(input_features,),
    strict=False,
)
ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds

Error that is brought by the decoder.

Let's divide the model in two parts: the encoder and the decoder.

Exporting the encoder

Actually the encoder exports without any issue.

# config = WhisperConfig()
model_name = "openai/whisper-small"
model = WhisperModel.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)

encoder = model.get_encoder()
encoder.eval()
encoder_output = encoder(input_features)

And it is working directly in a strict mode.

exported_encoder = torch.export.export(
    encoder,
    args=(input_features,),
    strict=True,
)
from executorch.exir import to_edge
edge_module = to_edge(exported_encoder)

executorch_program = edge_module.to_executorch()
exported_output = exported_encoder.module().forward(input_features)

torch.allclose(exported_output.last_hidden_state, encoder_output.last_hidden_state)
# return True !

Exporting the decoder

We need to provide : input_ids(torch.LongTensor), attention_mask(torch.Tensor), encoder_hidden_states(torch.FloatTensor).

decoder = model.get_decoder()
help(decoder.forward) # For more information on the arguments

Actually the decoder works directly if no encoder_hidden_states are provided, which is not what we want.

exported_decoder = torch.export.export(
    decoder, 
    args=(torch.randint(0, 100, (1, 16), dtype=torch.int64), torch.ones((1, 32), dtype=torch.int64)),
    strict=True)

But the export works in strict=False mode. But breaking for strict=True.

exported_decoder = torch.export.export(
    decoder, 
    args=(torch.randint(0, 100,(1, 16), dtype=torch.int64), torch.ones((1, 32), dtype=torch.int64), encoder_output.last_hidden_state),
    strict=False)

Actually the graph break happens in the [if-condition line 301 in modeling_whisper.py] (https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/whisper/modeling_whisper.py#L301)

Following strict mode strategy, it will assume that the if condition is always True. In our case the if-statement is

if is_cross_attention and past_key_value and is_updated:

and we expect is_cross_attention to be True, past_key_value to be provided. At the moment we are not sure what is the workflow behind the is_updated variable.

But simply comparing the exported_decoder with the original decoder supports that the graph is correct.

size = 16
exported_decoder = torch.export.export(
    decoder, 
    args=(torch.randint(0, 100,(1, size), dtype=torch.int64), torch.ones((1, size), dtype=torch.int64), encoder_output.last_hidden_state),
    strict=False)

args = (torch.randint(0, 10, (1, size), dtype=torch.int64), torch.randint(0, 2, (1, size), dtype=torch.int64), torch.rand((1, 1500, 768)))

original_answer = decoder.forward(*args)
exported_answer = exported_decoder.module().forward(*args)

torch.allclose(original_answer.last_hidden_state, exported_answer.last_hidden_state)

The same is also possible with KV-cache, by providing the past_key_values as an argument.

size = 16
a = decoder(torch.randint(0, 100,(1, size), dtype=torch.int64), torch.ones((1, size), dtype=torch.int64), torch.rand((1, 1500, 768)))
## Note: attention_mask need to be of size input_ids + size of past_key_values
exported_decoder = torch.export.export(
    decoder, 
    args=(torch.randint(0, 100,(1, size), dtype=torch.int64), torch.ones((1, size+size), dtype=torch.int64), encoder_output.last_hidden_state),
    kwargs={"past_key_values": a.past_key_values},
    strict=False)


args = (torch.randint(0, 100, (1, size), dtype=torch.int64), torch.randint(0, 2, (1, size+size), dtype=torch.int64), torch.rand((1, 1500, 768)))
kwargs = {"past_key_values": a.past_key_values}

original_answer = decoder.forward(*args, **kwargs)
exported_answer = exported_decoder.module().forward(*args, **kwargs)

torch.allclose(original_answer.last_hidden_state, exported_answer.last_hidden_state)

What to do next ?

  • Check that there is no problem with Dynamic Shapes with all decoder's arguments
  • Try to replace part of decoder.forward() to prevent graph breaks and allow to export with strict=True
  • Rewrite a generate() pipeline in python using these exported models.

Side note: this procedure can probably be applied to other models, especially encoder/decoder.

And shout out to my teammates @Shaamallow @fracapuano @gsuriano

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ExecuTorch Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

4 participants