Skip to content

Commit

Permalink
finish with accuracy issue
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Mar 6, 2025
1 parent 1e522ee commit 6466ed0
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 52 deletions.
7 changes: 5 additions & 2 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@
"NeuronPixArtSigmaPipeline",
],
"modeling_decoder": ["NeuronDecoderModel"],
"modeling_seq2seq": ["NeuronModelForSeq2SeqLM"],
"modeling_seq2seq": [
"NeuronModelForSeq2SeqLM",
"NeuronWhisperForConditionalGeneration",
],
"accelerate": [
"NeuronAccelerator",
"NeuronAcceleratorState",
Expand Down Expand Up @@ -110,7 +113,7 @@
NeuronStableDiffusionXLInpaintPipeline,
NeuronStableDiffusionXLPipeline,
)
from .modeling_seq2seq import NeuronModelForSeq2SeqLM
from .modeling_seq2seq import NeuronModelForSeq2SeqLM, NeuronWhisperForConditionalGeneration
from .modeling_traced import NeuronTracedModel
from .pipelines import pipeline
from .trainers import NeuronORPOTrainer, NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer
Expand Down
238 changes: 189 additions & 49 deletions optimum/neuron/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

import torch
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModelForSeq2SeqLM, GenerationConfig
from transformers import AutoConfig, AutoModelForSeq2SeqLM, WhisperForConditionalGeneration, GenerationConfig
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
Expand Down Expand Up @@ -89,9 +90,100 @@
"""


class _NeuronSeq2SeqModelPart:
"""
For Seq2Seq architecture, we usually compile it to multiple neuron models. Each represents a part of the model.
"""

def __init__(
self,
model: torch.jit._script.ScriptModule,
parent_model: NeuronTracedModel,
config: Optional["PretrainedConfig"] = None,
neuron_config: Optional["NeuronDefaultConfig"] = None,
model_type: str = "encoder",
device: Optional[int] = None,
):
self.model = model
self.parent_model = parent_model
self.config = config
self.neuron_config = neuron_config
self.model_type = model_type
self.device = device

@abstractmethod
def forward(self, *args, **kwargs):
pass

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)


class NeuronEncoder(_NeuronSeq2SeqModelPart):
"""
Encoder part of the encoder-decoder model for Neuron inference. (Actually it's a monolith of encoder + decoder without past_key_values to workaround the control flow in the decoder).
"""

main_input_name = "input_ids"

def __init__(
self,
model: torch.jit._script.ScriptModule,
parent_model: NeuronTracedModel,
config: Optional["PretrainedConfig"] = None,
neuron_config: Optional[Dict[str, str]] = None,
):
super().__init__(model, parent_model, config, neuron_config, "encoder")

def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor):
inputs = (
input_ids,
attention_mask,
)
outputs = self.model(*inputs)
return outputs


class NeuronDecoder(_NeuronSeq2SeqModelPart):
"""
Decoder part of the encoder-decoder model for Neuron inference. (Actually it's decoder with past_key_values).
"""

def __init__(
self,
model: torch.jit._script.ScriptModule,
parent_model: NeuronTracedModel,
config: Optional["PretrainedConfig"] = None,
neuron_config: Optional[Dict[str, str]] = None,
):
super().__init__(model, parent_model, config, neuron_config, "decoder")

def forward(
self,
input_ids: torch.LongTensor,
decoder_attention_mask: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
encoder_attention_mask: torch.FloatTensor,
beam_idx: torch.LongTensor,
beam_scores: torch.FloatTensor,
):
inputs = (
input_ids,
decoder_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
beam_idx,
beam_scores,
)
outputs = self.model(*inputs)
return outputs


class NeuronModelForConditionalGeneration(NeuronTracedModel, ABC):
base_model_prefix = "neuron_model"
config_name = "config.json"
encoder_class = NeuronEncoder
decoder_class = NeuronDecoder

def __init__(
self,
Expand All @@ -114,13 +206,13 @@ def __init__(
self.neuron_configs[ENCODER_NAME]
) # only for the encoder
self._attributes_init(model_save_dir, preprocessors, **kwargs)
self.encoder = NeuronEncoder(
self.encoder = self.encoder_class(
encoder,
self,
self.configs[ENCODER_NAME],
self.neuron_configs[ENCODER_NAME],
)
self.decoder = NeuronDecoder(
self.decoder = self.decoder_class(
decoder,
self,
self.configs[DECODER_NAME],
Expand Down Expand Up @@ -609,41 +701,20 @@ def can_generate(self):
return True


class _NeuronSeq2SeqModelPart:
"""
For Seq2Seq architecture, we usually compile it to multiple neuron models. Each represents a part of the model.
"""

def __init__(
self,
model: torch.jit._script.ScriptModule,
parent_model: NeuronTracedModel,
config: Optional["PretrainedConfig"] = None,
neuron_config: Optional["NeuronDefaultConfig"] = None,
model_type: str = "encoder",
device: Optional[int] = None,
):
self.model = model
self.parent_model = parent_model
self.config = config
self.neuron_config = neuron_config
self.model_type = model_type
self.device = device

@abstractmethod
def forward(self, *args, **kwargs):
pass

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

class DummyLayer:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

def __call__(self, x):
return x


class NeuronEncoder(_NeuronSeq2SeqModelPart):
class NeuronWhisperEncoder(_NeuronSeq2SeqModelPart):
"""
Encoder part of the encoder-decoder model for Neuron inference. (Actually it's a monolith of encoder + decoder without past_key_values to workaround the control flow in the decoder).
Decoder with output embedding of the whisper model for Neuron inference.
"""

main_input_name = "input_ids"
main_input_name = "input_features"

def __init__(
self,
Expand All @@ -653,19 +724,26 @@ def __init__(
neuron_config: Optional[Dict[str, str]] = None,
):
super().__init__(model, parent_model, config, neuron_config, "encoder")
self.conv1 = DummyLayer(stride=[1])
self.conv2 = DummyLayer(stride=[2])

def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor):
def forward(
self,
input_features: torch.LongTensor,
attention_mask: torch.FloatTensor,
**kwargs,
):
inputs = (
input_ids,
input_features,
attention_mask,
)
outputs = self.model(*inputs)
return outputs


class NeuronDecoder(_NeuronSeq2SeqModelPart):
class NeuronWhisperDecoder(_NeuronSeq2SeqModelPart):
"""
Decoder part of the encoder-decoder model for Neuron inference. (Actually it's decoder with past_key_values).
Decoder with output embedding of the whisper model for Neuron inference.
"""

def __init__(
Expand All @@ -679,20 +757,82 @@ def __init__(

def forward(
self,
input_ids: torch.LongTensor,
decoder_attention_mask: torch.FloatTensor,
decoder_input_ids: torch.LongTensor,
encoder_hidden_states: torch.FloatTensor,
encoder_attention_mask: torch.FloatTensor,
beam_idx: torch.LongTensor,
beam_scores: torch.FloatTensor,
**kwargs,
):
inputs = (
input_ids,
decoder_attention_mask,
decoder_input_ids,
encoder_hidden_states,
encoder_attention_mask,
beam_idx,
beam_scores,
)
outputs = self.model(*inputs)
return outputs

class NeuronWhisperModel:
def __init__(self, encoder: NeuronEncoder, decoder: NeuronWhisperDecoder):
self.encoder = encoder
self.decoder = decoder


class NeuronWhisperForConditionalGeneration(NeuronModelForConditionalGeneration, WhisperForConditionalGeneration):
auto_model_class = WhisperForConditionalGeneration
main_input_name = "input_features"
encoder_class = NeuronWhisperEncoder
decoder_class = NeuronWhisperDecoder

def __init__(
self,
encoder: torch.jit._script.ScriptModule,
decoder: torch.jit._script.ScriptModule,
config: "PretrainedConfig",
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
encoder_file_name: Optional[str] = NEURON_FILE_NAME,
decoder_file_name: Optional[str] = NEURON_FILE_NAME,
preprocessors: Optional[List] = None,
neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None,
configs: Optional[Dict[str, "PretrainedConfig"]] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
super().__init__(
encoder,
decoder,
config,
model_save_dir,
encoder_file_name,
decoder_file_name,
preprocessors,
neuron_configs,
configs,
generation_config,
**kwargs
)
self.model = NeuronWhisperModel(self.encoder, self.decoder)

@property
def device(self):
return torch.device("cpu")

def get_encoder(self) -> "NeuronWhisperEncoder":
return self.encoder

def forward(
self,
decoder_input_ids: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
# pad `decoder_input_ids` to the sequence length of the compilation
decoder_input_ids_length = decoder_input_ids.shape[1]
pad_size = torch.as_tensor(self.neuron_configs["decoder"].sequence_length - decoder_input_ids_length)
decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, pad_size), "constant", self.preprocessors[0].pad_token_id)

outputs = self.decoder(
decoder_input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs,
)

return Seq2SeqLMOutput(
logits=outputs[0][:, :decoder_input_ids_length, :],
encoder_last_hidden_state=outputs[1],
)
2 changes: 1 addition & 1 deletion optimum/neuron/modeling_traced.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
from .utils.hub_cache_utils import ModelCacheEntry, build_cache_config, create_hub_compile_cache_proxy
from .utils.import_utils import is_neuronx_available
from .utils.misc import maybe_load_preprocessors
from ..utils.save_utils import maybe_load_preprocessors
from .utils.version_utils import check_compiler_compatibility, get_neuroncc_version, get_neuronxcc_version


Expand Down

0 comments on commit 6466ed0

Please sign in to comment.