Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Whisper for the task "automatic-speech-recognition" w/o. KV cache #789

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@
is_transformers_neuronx_available,
map_torch_dtype,
)
from ...neuron.utils.misc import maybe_save_preprocessors
from ...neuron.utils.version_utils import (
check_compiler_compatibility_for_stable_diffusion,
)
from ...utils import is_diffusers_available, logging
from ...utils.save_utils import maybe_save_preprocessors
from ..error_utils import AtolError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from .base import NeuronExportConfig
Expand Down
3 changes: 1 addition & 2 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
is_sentence_transformers_available,
logging,
)
from .config import TextSeq2SeqNeuronConfig


if TYPE_CHECKING:
Expand Down Expand Up @@ -533,7 +532,7 @@ def export_neuronx(
# Prepare the model / function(tp) to trace
aliases = {}
tensor_parallel_size = config.tensor_parallel_size
if isinstance(config, TextSeq2SeqNeuronConfig):
if hasattr(config, "is_encoder_decoder") and config.is_encoder_decoder:
checked_model = config.patch_model_for_export(model_or_path, **input_shapes)
if tensor_parallel_size == 1 and hasattr(config, "generate_io_aliases"):
aliases = config.generate_io_aliases(checked_model)
Expand Down
75 changes: 75 additions & 0 deletions optimum/exporters/neuron/model_configs/traced_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
DummyControNetInputGenerator,
DummyIPAdapterInputGenerator,
DummyMaskedPosGenerator,
WhisperDummyTextInputGenerator,
is_neuronx_distributed_available,
)
from ..config import (
Expand All @@ -63,6 +64,8 @@
T5EncoderForSeq2SeqLMWrapper,
T5EncoderWrapper,
UnetNeuronWrapper,
WhisperDecoderWrapper,
WhisperEncoderWrapper,
)


Expand Down Expand Up @@ -858,6 +861,10 @@ class T5EncoderForDiffusersNeuronConfig(T5EncoderBaseNeuronConfig):
def outputs(self) -> List[str]:
return ["last_hidden_state"]

@property
def is_encoder_decoder(self) -> bool:
return True

def patch_model_for_export(self, model_or_path, **input_shapes):
return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes)

Expand Down Expand Up @@ -989,6 +996,10 @@ def outputs(self) -> List[str]:

return common_outputs

@property
def is_encoder_decoder(self) -> bool:
return True

def generate_dummy_inputs(self, **kwargs):
batch_size = kwargs.pop("batch_size") * kwargs.get("num_beams")
dummy_inputs = super().generate_dummy_inputs(batch_size=batch_size, **kwargs)
Expand Down Expand Up @@ -1087,3 +1098,67 @@ def generate_io_aliases(self, decoder):
aliases[decoder.past_key_values_ca[i]] = len(decoder.past_key_values_sa) + i + num_outputs_from_trace

return aliases


@register_in_tasks_manager("whisper-encoder", *["automatic-speech-recognition"])
class WhisperEncoderNeuronConfig(AudioNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
MODEL_TYPE = "whisper-encoder"
CUSTOM_MODEL_WRAPPER = WhisperEncoderWrapper
INPUT_ARGS = AudioNeuronConfig.INPUT_ARGS + ("sequence_length",)
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
feature_size="num_mel_bins",
allow_new=True,
)
@property
def inputs(self) -> List[str]:
return ["input_features", "attention_mask"]

@property
def outputs(self) -> List[str]:
return ["last_hidden_state"]

@property
def is_encoder_decoder(self) -> bool:
return True

def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
if "audio_sequence_length" in kwargs:
kwargs["sequence_length"] = kwargs["audio_sequence_length"]
self._axes["sequence_length"] = self._axes["audio_sequence_length"]
return super().generate_dummy_inputs(return_tuple=return_tuple, **kwargs)

def patch_model_for_export(self, model_or_path, **input_shapes):
return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes)


@register_in_tasks_manager("whisper-decoder", *["automatic-speech-recognition"])
class WhisperDecoderNeuronConfig(AudioNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
MODEL_TYPE = "whisper-decoder"
DUMMY_INPUT_GENERATOR_CLASSES = (WhisperDummyTextInputGenerator, )
INPUT_ARGS = AudioNeuronConfig.INPUT_ARGS + ("sequence_length",)
CUSTOM_MODEL_WRAPPER = WhisperDecoderWrapper
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
feature_size="num_mel_bins",
hidden_size="d_model",
allow_new=True,
)
@property
def inputs(self) -> List[str]:
return ["decoder_input_ids", "encoder_hidden_states"]

@property
def outputs(self) -> List[str]:
return ["lm_logits", "encoder_last_hidden_state"]

@property
def is_encoder_decoder(self) -> bool:
return True

def patch_model_for_export(self, model_or_path, **input_shapes):
return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes)
74 changes: 74 additions & 0 deletions optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,80 @@ def forward(self, input_ids, pixel_values, attention_mask):
return (text_embeds, image_embeds)


class WhisperEncoderWrapper(torch.nn.Module):
"""Wrapper to trace the encoder of Whisper."""

def __init__(
self,
model: "PreTrainedModel",
batch_size: int,
audio_sequence_length: int,
output_hidden_states: bool = False,
output_attentions: bool = False,
**kwargs,
):
super().__init__()
self.model = model
self.config = model.config
self.batch_size = batch_size
self.sequence_length = audio_sequence_length
self.output_hidden_states = output_hidden_states
self.output_attentions = output_attentions

def forward(
self,
input_features,
attention_mask=None,
**kwargs,
):
outputs = self.model.model.encoder(
input_features=input_features,
attention_mask=attention_mask,
output_attentions=self.output_attentions,
output_hidden_states=self.output_hidden_states,
return_dict=True,
)
return outputs.last_hidden_state


class WhisperDecoderWrapper(torch.nn.Module):
"""Wrapper to trace the decoder and projection output layer of Whisper."""

def __init__(
self,
model: "PreTrainedModel",
batch_size: int,
sequence_length: int,
output_hidden_states: bool = False,
output_attentions: bool = False,
**kwargs,
):
super().__init__()
self.model = model
self.config = model.config
self.batch_size = batch_size
self.sequence_length = sequence_length
self.output_hidden_states = output_hidden_states
self.output_attentions = output_attentions

def forward(
self,
input_ids,
encoder_hidden_states,
**kwargs,
):
outputs = self.model.model.decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
lm_logits = self.model.proj_out(outputs[0])
return (lm_logits, outputs.last_hidden_state)



class NoCacheModelWrapper(torch.nn.Module):
def __init__(self, model: "PreTrainedModel", input_names: List[str]):
super().__init__()
Expand Down
7 changes: 5 additions & 2 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@
"NeuronPixArtSigmaPipeline",
],
"modeling_decoder": ["NeuronDecoderModel"],
"modeling_seq2seq": ["NeuronModelForSeq2SeqLM"],
"modeling_seq2seq": [
"NeuronModelForSeq2SeqLM",
"NeuronWhisperForConditionalGeneration",
],
"accelerate": [
"NeuronAccelerator",
"NeuronAcceleratorState",
Expand Down Expand Up @@ -110,7 +113,7 @@
NeuronStableDiffusionXLInpaintPipeline,
NeuronStableDiffusionXLPipeline,
)
from .modeling_seq2seq import NeuronModelForSeq2SeqLM
from .modeling_seq2seq import NeuronModelForSeq2SeqLM, NeuronWhisperForConditionalGeneration
from .modeling_traced import NeuronTracedModel
from .pipelines import pipeline
from .trainers import NeuronORPOTrainer, NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer
Expand Down
Loading
Loading