From e7d150976b2936c09c2d5759b69517c69b5afe1f Mon Sep 17 00:00:00 2001 From: Nikita Davidchuk Date: Mon, 15 May 2023 01:35:59 +0300 Subject: [PATCH] 1.2.1 - Update Modeling * Add `custom/modeling_{arch}.py` for every custom architecture instead of `custom/models.py` * Add `gradient_checkpointing` for multimodal archs * Edit `tests/aniemore/recognizers/test_multimodal.py::GENERAL_WAVLM_BERT_MODEL` to WavLMBertFusion * Update `poetry.lock` up-to-date * Other minors updates in classes.py, models.py (improve imports and edit out-dated imports) --- aniemore/custom/modeling_classificators.py | 312 ++++++++ aniemore/custom/modeling_hubert.py | 24 + aniemore/custom/modeling_unispeech_sat.py | 24 + aniemore/custom/modeling_wav2vec2.py | 85 +++ aniemore/custom/modeling_wavlm.py | 305 ++++++++ aniemore/custom/models.py | 676 ------------------ aniemore/models.py | 7 +- aniemore/utils/classes.py | 4 +- poetry.lock | 3 +- pyproject.toml | 2 +- tests/aniemore/recognizers/test_multimodal.py | 10 +- 11 files changed, 764 insertions(+), 688 deletions(-) create mode 100644 aniemore/custom/modeling_classificators.py create mode 100644 aniemore/custom/modeling_hubert.py create mode 100644 aniemore/custom/modeling_unispeech_sat.py create mode 100644 aniemore/custom/modeling_wav2vec2.py create mode 100644 aniemore/custom/modeling_wavlm.py delete mode 100644 aniemore/custom/models.py diff --git a/aniemore/custom/modeling_classificators.py b/aniemore/custom/modeling_classificators.py new file mode 100644 index 0000000..9b1c7f6 --- /dev/null +++ b/aniemore/custom/modeling_classificators.py @@ -0,0 +1,312 @@ +"""Base model classes +""" +from dataclasses import dataclass +from typing import Union, Type, Tuple + +import torch +from huggingface_hub import PyTorchModelHubMixin +from transformers.utils import ModelOutput +from transformers import ( + PreTrainedModel, + PretrainedConfig +) + + +@dataclass +class SpeechModelOutput(ModelOutput): + """Base class for model's outputs, with potential hidden states and attentions. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of + each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attention's weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + + Examples:: + >>> from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Tokenizer + >>> import torch + >>> + >>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h") + >>> input_values = tokenizer("Hello, my dog is cute", return_tensors="pt").input_values # Batch size 1 + >>> logits = model(input_values).logits + >>> assert logits.shape == (1, 2) + """ + loss: torch.FloatTensor + logits: torch.FloatTensor = None + hidden_states: torch.FloatTensor = None + attentions: torch.FloatTensor = None + + +class MultiModalConfig(PretrainedConfig): + """Base class for multimodal configs""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class BaseClassificationModel(PreTrainedModel, PyTorchModelHubMixin): + config: Type[Union[PretrainedConfig, None]] = None + + def compute_loss(self, logits, labels): + """Compute loss + + Args: + logits (torch.FloatTensor): logits + labels (torch.LongTensor): labels + + Returns: + torch.FloatTensor: loss + + Raises: + ValueError: Invalid number of labels + """ + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1: + self.config.problem_type = "single_label_classification" + else: + raise ValueError("Invalid number of labels: {}".format(self.num_labels)) + + if self.config.problem_type == "single_label_classification": + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + elif self.config.problem_type == "multi_label_classification": + loss_fct = torch.nn.BCEWithLogitsLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) + + elif self.config.problem_type == "regression": + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + raise ValueError("Problem_type {} not supported".format(self.config.problem_type)) + + return loss + + @staticmethod + def merged_strategy( + hidden_states, + mode="mean" + ): + """Merged strategy for pooling + + Args: + hidden_states (torch.FloatTensor): hidden states + mode (str, optional): pooling mode. Defaults to "mean". + + Returns: + torch.FloatTensor: pooled hidden states + """ + if mode == "mean": + outputs = torch.mean(hidden_states, dim=1) + elif mode == "sum": + outputs = torch.sum(hidden_states, dim=1) + elif mode == "max": + outputs = torch.max(hidden_states, dim=1)[0] + else: + raise Exception( + "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") + + return outputs + + def resize_position_embeddings(self, new_num_position_embeddings: int): + pass + + def get_position_embeddings(self) -> Union[torch.nn.Embedding, Tuple[torch.nn.Embedding]]: + pass + + def prepare_inputs_for_generation(self, *args, **kwargs): + pass + + def _reorder_cache(self, past_key_values, beam_idx): + pass + + +class BaseModelForVoiceBaseClassification(BaseClassificationModel): + def __init__(self, config, num_labels): + """Base model for voice classification + + Args: + config (PretrainedConfig): config + num_labels (int): number of labels + """ + super().__init__(config=config) + self.num_labels = num_labels + self.pooling_mode = config.pooling_mode + self.projector = torch.nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = torch.nn.Linear(config.classifier_proj_size, config.num_labels) + + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + """Forward + + Args: + input_values (torch.FloatTensor): input values + attention_mask (torch.LongTensor, optional): attention mask. Defaults to None. + output_attentions (bool, optional): output attentions. Defaults to None. + output_hidden_states (bool, optional): output hidden states. Defaults to None. + return_dict (bool, optional): return dict. Defaults to None. + labels (torch.LongTensor, optional): labels. Defaults to None. + + Returns: + torch.FloatTensor: logits + + Raises: + ValueError: Invalid number of labels + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.wavlm( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = self.projector(outputs.last_hidden_state) + hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss = self.compute_loss(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SpeechModelOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BaseMultiModalForSequenceBaseClassification(BaseClassificationModel): + config_class = MultiModalConfig + + def __init__(self, config): + """ + Args: + config (MultiModalConfig): config + + Attributes: + config (MultiModalConfig): config + num_labels (int): number of labels + audio_config (Union[PretrainedConfig, None]): audio config + text_config (Union[PretrainedConfig, None]): text config + audio_model (Union[PreTrainedModel, None]): audio model + text_model (Union[PreTrainedModel, None]): text model + classifier (Union[torch.nn.Linear, None]): classifier + """ + super().__init__(config) + self.config = config + self.num_labels = self.config.num_labels + self.audio_config: Union[PretrainedConfig, None] = None + self.text_config: Union[PretrainedConfig, None] = None + self.audio_model: Union[PreTrainedModel, None] = None + self.text_model: Union[PreTrainedModel, None] = None + self.classifier: Union[torch.nn.Linear, None] = None + + def forward( + self, + input_ids=None, + input_values=None, + text_attention_mask=None, + audio_attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=True, + ): + """Forward method for multimodal model for sequence classification task (e.g. text + audio) + + Args: + input_ids (torch.LongTensor, optional): input ids. Defaults to None. + input_values (torch.FloatTensor, optional): input values. Defaults to None. + text_attention_mask (torch.LongTensor, optional): text attention mask. Defaults to None. + audio_attention_mask (torch.LongTensor, optional): audio attention mask. Defaults to None. + token_type_ids (torch.LongTensor, optional): token type ids. Defaults to None. + position_ids (torch.LongTensor, optional): position ids. Defaults to None. + head_mask (torch.FloatTensor, optional): head mask. Defaults to None. + inputs_embeds (torch.FloatTensor, optional): inputs embeds. Defaults to None. + labels (torch.LongTensor, optional): labels. Defaults to None. + output_attentions (bool, optional): output attentions. Defaults to None. + output_hidden_states (bool, optional): output hidden states. Defaults to None. + return_dict (bool, optional): return dict. Defaults to True. + + Returns: + torch.FloatTensor: logits + """ + audio_output = self.audio_model( + input_values=input_values, + attention_mask=audio_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + text_output = self.text_model( + input_ids=input_ids, + attention_mask=text_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + audio_mean = self.merged_strategy(audio_output.last_hidden_state, mode=self.config.pooling_mode) + + pooled_output = torch.cat( + (audio_mean, text_output.pooler_output), dim=1 + ) + logits = self.classifier(pooled_output) + loss = None + + if labels is not None: + loss = self.compute_loss(logits, labels) + + return SpeechModelOutput( + loss=loss, + logits=logits + ) + + +class AudioTextFusionModelForSequenceClassificaion(BaseMultiModalForSequenceBaseClassification): + def __init__(self, config): + """ + Args: + config (MultiModalConfig): config + Attributes: + audio_projector (Union[torch.nn.Linear, None]): Projection layer for audio embeds + text_projector (Union[torch.nn.Linear, None]): Projection layer for text embeds + audio_avg_pool (Union[torch.nn.AvgPool1d, None]): Audio average pool (out from fusion block) + text_avg_pool (Union[torch.nn.AvgPool1d, None]): Text average pool (out from fusion block) + """ + super().__init__(config) + self.audio_projector: Union[torch.nn.Linear, None] = None + self.text_projector: Union[torch.nn.Linear, None] = None + self.audio_avg_pool: Union[torch.nn.AvgPool1d, None] = None + self.text_avg_pool: Union[torch.nn.AvgPool1d, None] = None diff --git a/aniemore/custom/modeling_hubert.py b/aniemore/custom/modeling_hubert.py new file mode 100644 index 0000000..5a661c9 --- /dev/null +++ b/aniemore/custom/modeling_hubert.py @@ -0,0 +1,24 @@ +"""Base model classes +""" +from aniemore.custom.modeling_classificators import BaseModelForVoiceBaseClassification +from transformers import HubertForSequenceClassification + + +class HubertForVoiceClassification(BaseModelForVoiceBaseClassification): + """HubertForVoiceClassification is a model for voice classification task + (e.g. speech command, voice activity detection, etc.) + + Args: + config (HubertConfig): config + num_labels (int): number of labels + + Attributes: + config (HubertConfig): config + num_labels (int): number of labels + hubert (HubertForSequenceClassification): hubert model + """ + + def __init__(self, config, num_labels): + super().__init__(config, num_labels) + self.hubert = HubertForSequenceClassification(config) + self.init_weights() diff --git a/aniemore/custom/modeling_unispeech_sat.py b/aniemore/custom/modeling_unispeech_sat.py new file mode 100644 index 0000000..9580c8f --- /dev/null +++ b/aniemore/custom/modeling_unispeech_sat.py @@ -0,0 +1,24 @@ +"""Base model classes +""" +from aniemore.custom.modeling_classificators import BaseModelForVoiceBaseClassification +from transformers import UniSpeechSatForSequenceClassification + + +class UniSpeechSatForVoiceClassification(BaseModelForVoiceBaseClassification): + """UniSpeechSatForVoiceClassification is a model for voice classification task + (e.g. speech command, voice activity detection, etc.) + + Args: + config (UniSpeechSatConfig): config + num_labels (int): number of labels + + Attributes: + config (UniSpeechSatConfig): config + num_labels (int): number of labels + unispeech_sat (UniSpeechSatForSequenceClassification): unispeech_sat model + """ + + def __init__(self, config, num_labels): + super().__init__(config, num_labels) + self.unispeech_sat = UniSpeechSatForSequenceClassification(config) + self.init_weights() diff --git a/aniemore/custom/modeling_wav2vec2.py b/aniemore/custom/modeling_wav2vec2.py new file mode 100644 index 0000000..d56c2e0 --- /dev/null +++ b/aniemore/custom/modeling_wav2vec2.py @@ -0,0 +1,85 @@ +"""Base model classes +""" +from aniemore.custom.modeling_classificators import ( + BaseModelForVoiceBaseClassification, + BaseMultiModalForSequenceBaseClassification +) + +import torch +from transformers import ( + Wav2Vec2ForSequenceClassification, + BertConfig, + BertModel, + Wav2Vec2Config, + Wav2Vec2Model +) + +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Encoder, + Wav2Vec2EncoderStableLayerNorm, + Wav2Vec2FeatureEncoder +) + +from transformers.models.bert.modeling_bert import BertEncoder + + +class Wav2Vec2ForVoiceClassification(BaseModelForVoiceBaseClassification): + """Wav2Vec2ForVoiceClassification is a model for voice classification task + (e.g. speech command, voice activity detection, etc.) + + Args: + config (Wav2Vec2Config): config + num_labels (int): number of labels + + Attributes: + config (Wav2Vec2Config): config + num_labels (int): number of labels + wav2vec2 (Wav2Vec2ForSequenceClassification): wav2vec2 model + """ + + def __init__(self, config, num_labels): + super().__init__(config, num_labels) + self.wav2vec2 = Wav2Vec2ForSequenceClassification(config) + self.init_weights() + + +class Wav2Vec2BertForSequenceClassification(BaseMultiModalForSequenceBaseClassification): + """Wav2Vec2BertForSequenceClassification is a model for sequence classification task + (e.g. sentiment analysis, text classification, etc.) + + Args: + config (Wav2Vec2BertConfig): config + + Attributes: + config (Wav2Vec2BertConfig): config + audio_config (Wav2Vec2Config): wav2vec2 config + text_config (BertConfig): bert config + audio_model (Wav2Vec2Model): wav2vec2 model + text_model (BertModel): bert model + classifier (torch.nn.Linear): classifier + """ + + def __init__(self, config, finetune=False): + super().__init__(config) + self.supports_gradient_checkpointing = getattr(config, "gradient_checkpointing", True) + + self.audio_config = Wav2Vec2Config.from_dict(self.config.Wav2Vec2Model) + self.text_config = BertConfig.from_dict(self.config.BertModel) + + if not finetune: + self.audio_model = Wav2Vec2Model(self.audio_config) + self.text_model = BertModel(self.text_config) + + else: + self.audio_model = Wav2Vec2Model.from_pretrained(self.audio_config._name_or_path, config=self.audio_config) + self.text_model = BertModel.from_pretrained(self.text_config._name_or_path, config=self.text_config) + + self.classifier = torch.nn.Linear( + self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels + ) + self.init_weights() + + @staticmethod + def _set_gradient_checkpointing(module, value=False): + if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder, BertEncoder)): + module.gradient_checkpointing = value diff --git a/aniemore/custom/modeling_wavlm.py b/aniemore/custom/modeling_wavlm.py new file mode 100644 index 0000000..e77df01 --- /dev/null +++ b/aniemore/custom/modeling_wavlm.py @@ -0,0 +1,305 @@ +"""Base model classes +""" +from typing import Union, Dict + +from aniemore.custom.modeling_classificators import ( + BaseModelForVoiceBaseClassification, + BaseMultiModalForSequenceBaseClassification, + AudioTextFusionModelForSequenceClassificaion, + SpeechModelOutput, + MultiModalConfig +) + +import torch +from transformers import ( + WavLMForSequenceClassification, + WavLMConfig, + BertConfig, + WavLMModel, + BertModel, + PretrainedConfig, +) + +from transformers.models.wavlm.modeling_wavlm import ( + WavLMEncoder, + WavLMEncoderStableLayerNorm, + WavLMFeatureEncoder +) + +from transformers.models.bert.modeling_bert import BertEncoder + + +class FusionConfig(MultiModalConfig): + """Base class for fusion configs + Just for fine-tuning models, no more + + Args: + audio_config (PretrainedConfig): audio config + text_config (PretrainedConfig): text config + id2label (Dict[int, str]): id2label + label2id (Dict[str, int]): label2id + num_heads (int, optional): number of heads. Defaults to 8. + kernel_size (int, optional): kernel size. Defaults to 1. + pooling_mode (str, optional): pooling mode. Defaults to "mean". + problem_type (str, optional): problem type. Defaults to "single_label_classification". + gradient_checkpointing (bool, optional): gradient checkpointing. Defaults to True. + """ + + def __init__( + self, + audio_config: PretrainedConfig, + text_config: PretrainedConfig, + id2label: Dict[int, str], + label2id: Dict[str, int], + num_heads: int = 8, + kernel_size: int = 1, + pooling_mode: str = "mean", + problem_type: str = "single_label_classification", + gradient_checkpointing: bool = True, + **kwargs): + super().__init__(**kwargs) + + self.update({audio_config.architectures[0]: audio_config.to_dict()}) + self.update({text_config.architectures[0]: text_config.to_dict()}) + + self.id2label = id2label + self.label2id = label2id + self.num_labels = len(id2label) + self.num_heads = num_heads + self.kernel_size = kernel_size + self.pooling_mode = pooling_mode + self.problem_type = problem_type + self.gradient_checkpointing = gradient_checkpointing + + +class FusionModuleQ(torch.nn.Module): + """FusionModuleQ is a fusion module for the query + https://arxiv.org/abs/2302.13661 + https://arxiv.org/abs/2207.04697 + + Args: + audio_dim (int): audio dimension + text_dim (int): text dimension + num_heads (int): number of heads + """ + + def __init__(self, audio_dim, text_dim, num_heads): + super().__init__() + + # pick the lowest dimension of the two modalities + self.dimension = min(audio_dim, text_dim) + + # attention modules + self.a_self_attention = torch.nn.MultiheadAttention(self.dimension, num_heads=num_heads) + self.t_self_attention = torch.nn.MultiheadAttention(self.dimension, num_heads=num_heads) + + # layer norm + self.audio_norm = torch.nn.LayerNorm(self.dimension) + self.text_norm = torch.nn.LayerNorm(self.dimension) + + +class WavLMForVoiceClassification(BaseModelForVoiceBaseClassification): + """WavLMForVoiceClassification is a model for voice classification task + (e.g. speech command, voice activity detection, etc.) + + Args: + config (WavLMConfig): config + num_labels (int): number of labels + + Attributes: + config (WavLMConfig): config + num_labels (int): number of labels + wavlm (WavLMForSequenceClassification): wavlm model + """ + + def __init__(self, config, num_labels): + super().__init__(config, num_labels) + self.wavlm = WavLMForSequenceClassification(config) + self.init_weights() + + +class WavLMBertForSequenceClassification(BaseMultiModalForSequenceBaseClassification): + """WavLMBertForSequenceClassification is a model for sequence classification task + (e.g. sentiment analysis, text classification, etc.) + + Args: + config (WavLMBertConfig): config + + Attributes: + config (WavLMBertConfig): config + audio_config (WavLMConfig): wavlm config + text_config (BertConfig): bert config + audio_model (WavLMModel): wavlm model + text_model (BertModel): bert model + classifier (torch.nn.Linear): classifier + """ + + def __init__(self, config, finetune=False): + super().__init__(config) + self.supports_gradient_checkpointing = getattr(config, "gradient_checkpointing", True) + + self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel) + self.text_config = BertConfig.from_dict(self.config.BertModel) + + if not finetune: + self.audio_model = WavLMModel(self.audio_config) + self.text_model = BertModel(self.text_config) + + else: + self.audio_model = WavLMModel.from_pretrained(self.audio_config._name_or_path, config=self.audio_config) + self.text_model = BertModel.from_pretrained(self.text_config._name_or_path, config=self.text_config) + + self.classifier = torch.nn.Linear( + self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels + ) + self.init_weights() + + @staticmethod + def _set_gradient_checkpointing(module, value=False): + if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder, BertEncoder)): + module.gradient_checkpointing = value + + +class WavLMBertFusionForSequenceClassification(AudioTextFusionModelForSequenceClassificaion): + """WavLMBertForSequenceClassification is a model for sequence classification task + (e.g. sentiment analysis, text classification, etc.) for fine-tuning + Args: + config (WavLMBertConfig): config + Attributes: + config (WavLMBertConfig): config + audio_config (WavLMConfig): wavlm config + text_config (BertConfig): bert config + audio_model (WavLMModel): wavlm model + text_model (BertModel): bert model + fusion_module_{i} (FusionModuleQ): Fusion Module Q + audio_projector (Union[torch.nn.Linear, None]): Projection layer for audio embeds + text_projector (Union[torch.nn.Linear, None]): Projection layer for text embeds + audio_avg_pool (Union[torch.nn.AvgPool1d, None]): Audio average pool (out from fusion block) + text_avg_pool (Union[torch.nn.AvgPool1d, None]): Text average pool (out from fusion block) + classifier (torch.nn.Linear): classifier + """ + + def __init__(self, config, finetune=False): + super().__init__(config) + self.supports_gradient_checkpointing = getattr(config, "gradient_checkpointing", True) + + self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel) + self.text_config = BertConfig.from_dict(self.config.BertModel) + + if not finetune: + self.audio_model = WavLMModel(self.audio_config) + self.text_model = BertModel(self.text_config) + + else: + self.audio_model = WavLMModel.from_pretrained(self.audio_config._name_or_path, config=self.audio_config) + self.text_model = BertModel.from_pretrained(self.text_config._name_or_path, config=self.text_config) + + # fusion module with V3 strategy (one projection on entry, no projection in continuous) + for i in range(self.config.num_fusion_layers): + setattr(self, f"fusion_module_{i + 1}", FusionModuleQ( + self.audio_config.hidden_size, self.text_config.hidden_size, self.config.num_heads + )) + + self.audio_projector = torch.nn.Linear(self.audio_config.hidden_size, self.text_config.hidden_size) + self.text_projector = torch.nn.Linear(self.text_config.hidden_size, self.text_config.hidden_size) + + # Avg Pool + self.audio_avg_pool = torch.nn.AvgPool1d(self.config.kernel_size) + self.text_avg_pool = torch.nn.AvgPool1d(self.config.kernel_size) + + # output dimensions of wav2vec2 and bert are 768 and 1024 respectively + cls_dim = min(self.audio_config.hidden_size, self.text_config.hidden_size) + self.classifier = torch.nn.Linear( + (cls_dim * 2) // self.config.kernel_size, self.config.num_labels + ) + + self.init_weights() + + @staticmethod + def _set_gradient_checkpointing(module, value=False): + if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder, BertEncoder)): + module.gradient_checkpointing = value + + def forward( + self, + input_ids=None, + input_values=None, + text_attention_mask=None, + audio_attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=True, + ): + """Forward method for multimodal model for sequence classification task (e.g. text + audio) + Args: + input_ids (torch.LongTensor, optional): input ids. Defaults to None. + input_values (torch.FloatTensor, optional): input values. Defaults to None. + text_attention_mask (torch.LongTensor, optional): text attention mask. Defaults to None. + audio_attention_mask (torch.LongTensor, optional): audio attention mask. Defaults to None. + token_type_ids (torch.LongTensor, optional): token type ids. Defaults to None. + position_ids (torch.LongTensor, optional): position ids. Defaults to None. + head_mask (torch.FloatTensor, optional): head mask. Defaults to None. + inputs_embeds (torch.FloatTensor, optional): inputs embeds. Defaults to None. + labels (torch.LongTensor, optional): labels. Defaults to None. + output_attentions (bool, optional): output attentions. Defaults to None. + output_hidden_states (bool, optional): output hidden states. Defaults to None. + return_dict (bool, optional): return dict. Defaults to True. + Returns: + torch.FloatTensor: logits + """ + audio_output = self.audio_model( + input_values=input_values, + attention_mask=audio_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + text_output = self.text_model( + input_ids=input_ids, + attention_mask=text_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Mean pooling + audio_avg = self.merged_strategy(audio_output.last_hidden_state, mode=self.config.pooling_mode) + + # Projection + audio_proj = self.audio_projector(audio_avg) + text_proj = self.text_projector(text_output.pooler_output) + + audio_mha, text_mha = None, None + + for i in range(self.config.num_fusion_layers): + fusion_module = getattr(self, f"fusion_module_{i + 1}") + + if i == 0: + audio_mha, text_mha = fusion_module(audio_proj, text_proj) + else: + audio_mha, text_mha = fusion_module(audio_mha, text_mha) + + audio_avg = self.audio_avg_pool(audio_mha) + text_avg = self.text_avg_pool(text_mha) + + fusion_output = torch.concat((audio_avg, text_avg), dim=1) + + logits = self.classifier(fusion_output) + loss = None + + if labels is not None: + loss = self.compute_loss(logits, labels) + + return SpeechModelOutput( + loss=loss, + logits=logits + ) diff --git a/aniemore/custom/models.py b/aniemore/custom/models.py deleted file mode 100644 index 89403c7..0000000 --- a/aniemore/custom/models.py +++ /dev/null @@ -1,676 +0,0 @@ -"""Base model classes -""" -from dataclasses import dataclass -from typing import Union, Type - -import torch -from transformers.utils import ModelOutput -from transformers import ( - Wav2Vec2ForSequenceClassification, - WavLMForSequenceClassification, - UniSpeechSatForSequenceClassification, - HubertForSequenceClassification, - PreTrainedModel, - PretrainedConfig, - WavLMConfig, - BertConfig, - WavLMModel, - BertModel, - Wav2Vec2Config, - Wav2Vec2Model -) - - -@dataclass -class SpeechModelOutput(ModelOutput): - """Base class for model's outputs, with potential hidden states and attentions. - - Args: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of - each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the - weighted average in the self-attention heads. - - Examples:: - >>> from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Tokenizer - >>> import torch - >>> - >>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") - >>> model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h") - >>> input_values = tokenizer("Hello, my dog is cute", return_tensors="pt").input_values # Batch size 1 - >>> logits = model(input_values).logits - >>> assert logits.shape == (1, 2) - """ - loss: torch.FloatTensor - logits: torch.FloatTensor = None - hidden_states: torch.FloatTensor = None - attentions: torch.FloatTensor = None - - -class MultiModalConfig(PretrainedConfig): - """Base class for multimodal configs""" - def __init__(self, **kwargs): - super().__init__(**kwargs) - - -class BaseClassificationModel(PreTrainedModel): # noqa - config: Type[Union[PretrainedConfig, None]] = None - - def compute_loss(self, logits, labels): - """Compute loss - - Args: - logits (torch.FloatTensor): logits - labels (torch.LongTensor): labels - - Returns: - torch.FloatTensor: loss - - Raises: - ValueError: Invalid number of labels - """ - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1: - self.config.problem_type = "single_label_classification" - else: - raise ValueError("Invalid number of labels: {}".format(self.num_labels)) - - if self.config.problem_type == "single_label_classification": - loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - elif self.config.problem_type == "multi_label_classification": - loss_fct = torch.nn.BCEWithLogitsLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) - - elif self.config.problem_type == "regression": - loss_fct = torch.nn.MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: - raise ValueError("Problem_type {} not supported".format(self.config.problem_type)) - - return loss - - @staticmethod - def merged_strategy( - hidden_states, - mode="mean" - ): - """Merged strategy for pooling - - Args: - hidden_states (torch.FloatTensor): hidden states - mode (str, optional): pooling mode. Defaults to "mean". - - Returns: - torch.FloatTensor: pooled hidden states - """ - if mode == "mean": - outputs = torch.mean(hidden_states, dim=1) - elif mode == "sum": - outputs = torch.sum(hidden_states, dim=1) - elif mode == "max": - outputs = torch.max(hidden_states, dim=1)[0] - else: - raise Exception( - "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") - - return outputs - - -class BaseModelForVoiceBaseClassification(BaseClassificationModel): # noqa - def __init__(self, config, num_labels): - """Base model for voice classification - - Args: - config (PretrainedConfig): config - num_labels (int): number of labels - """ - super().__init__(config=config) - self.num_labels = num_labels - self.pooling_mode = config.pooling_mode - self.projector = torch.nn.Linear(config.hidden_size, config.classifier_proj_size) - self.classifier = torch.nn.Linear(config.classifier_proj_size, config.num_labels) - - def forward( - self, - input_values, - attention_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - labels=None, - ): - """Forward - - Args: - input_values (torch.FloatTensor): input values - attention_mask (torch.LongTensor, optional): attention mask. Defaults to None. - output_attentions (bool, optional): output attentions. Defaults to None. - output_hidden_states (bool, optional): output hidden states. Defaults to None. - return_dict (bool, optional): return dict. Defaults to None. - labels (torch.LongTensor, optional): labels. Defaults to None. - - Returns: - torch.FloatTensor: logits - - Raises: - ValueError: Invalid number of labels - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.wavlm( - input_values, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = self.projector(outputs.last_hidden_state) - hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) - logits = self.classifier(hidden_states) - - loss = None - if labels is not None: - loss = self.compute_loss(logits, labels) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return SpeechModelOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class BaseMultiModalForSequenceBaseClassification(BaseClassificationModel): # noqa - config_class = MultiModalConfig - - def __init__(self, config): - """ - Args: - config (MultiModalConfig): config - - Attributes: - config (MultiModalConfig): config - num_labels (int): number of labels - audio_config (Union[PretrainedConfig, None]): audio config - text_config (Union[PretrainedConfig, None]): text config - audio_model (Union[PreTrainedModel, None]): audio model - text_model (Union[PreTrainedModel, None]): text model - classifier (Union[torch.nn.Linear, None]): classifier - """ - super().__init__(config) - self.config = config - self.num_labels = self.config.num_labels - self.audio_config: Union[PretrainedConfig, None] = None - self.text_config: Union[PretrainedConfig, None] = None - self.audio_model: Union[PreTrainedModel, None] = None - self.text_model: Union[PreTrainedModel, None] = None - self.classifier: Union[torch.nn.Linear, None] = None - - def forward( - self, - input_ids=None, - input_values=None, - text_attention_mask=None, - audio_attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=True, - ): - """Forward method for multimodal model for sequence classification task (e.g. text + audio) - - Args: - input_ids (torch.LongTensor, optional): input ids. Defaults to None. - input_values (torch.FloatTensor, optional): input values. Defaults to None. - text_attention_mask (torch.LongTensor, optional): text attention mask. Defaults to None. - audio_attention_mask (torch.LongTensor, optional): audio attention mask. Defaults to None. - token_type_ids (torch.LongTensor, optional): token type ids. Defaults to None. - position_ids (torch.LongTensor, optional): position ids. Defaults to None. - head_mask (torch.FloatTensor, optional): head mask. Defaults to None. - inputs_embeds (torch.FloatTensor, optional): inputs embeds. Defaults to None. - labels (torch.LongTensor, optional): labels. Defaults to None. - output_attentions (bool, optional): output attentions. Defaults to None. - output_hidden_states (bool, optional): output hidden states. Defaults to None. - return_dict (bool, optional): return dict. Defaults to True. - - Returns: - torch.FloatTensor: logits - """ - audio_output = self.audio_model( - input_values=input_values, - attention_mask=audio_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict - ) - text_output = self.text_model( - input_ids=input_ids, - attention_mask=text_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - audio_mean = self.merged_strategy(audio_output.last_hidden_state, mode="mean") - - pooled_output = torch.cat( - (audio_mean, text_output.pooler_output), dim=1 - ) - logits = self.classifier(pooled_output) - loss = None - - if labels is not None: - loss = self.compute_loss(logits, labels) - - return SpeechModelOutput( - loss=loss, - logits=logits - ) - - -class Wav2Vec2ForVoiceClassification(BaseModelForVoiceBaseClassification): # noqa - """Wav2Vec2ForVoiceClassification is a model for voice classification task - (e.g. speech command, voice activity detection, etc.) - - Args: - config (Wav2Vec2Config): config - num_labels (int): number of labels - - Attributes: - config (Wav2Vec2Config): config - num_labels (int): number of labels - wav2vec2 (Wav2Vec2ForSequenceClassification): wav2vec2 model - """ - def __init__(self, config, num_labels): - super().__init__(config, num_labels) - self.wav2vec2 = Wav2Vec2ForSequenceClassification(config) - self.init_weights() - - -class WavLMForVoiceClassification(BaseModelForVoiceBaseClassification): # noqa - """WavLMForVoiceClassification is a model for voice classification task - (e.g. speech command, voice activity detection, etc.) - - Args: - config (WavLMConfig): config - num_labels (int): number of labels - - Attributes: - config (WavLMConfig): config - num_labels (int): number of labels - wavlm (WavLMForSequenceClassification): wavlm model - """ - def __init__(self, config, num_labels): - super().__init__(config, num_labels) - self.wavlm = WavLMForSequenceClassification(config) - self.init_weights() - - -class UniSpeechSatForVoiceClassification(BaseModelForVoiceBaseClassification): # noqa - """UniSpeechSatForVoiceClassification is a model for voice classification task - (e.g. speech command, voice activity detection, etc.) - - Args: - config (UniSpeechSatConfig): config - num_labels (int): number of labels - - Attributes: - config (UniSpeechSatConfig): config - num_labels (int): number of labels - unispeech_sat (UniSpeechSatForSequenceClassification): unispeech_sat model - """ - def __init__(self, config, num_labels): - super().__init__(config, num_labels) - self.unispeech_sat = UniSpeechSatForSequenceClassification(config) - self.init_weights() - - -class HubertForVoiceClassification(BaseModelForVoiceBaseClassification): # noqa - """HubertForVoiceClassification is a model for voice classification task - (e.g. speech command, voice activity detection, etc.) - - Args: - config (HubertConfig): config - num_labels (int): number of labels - - Attributes: - config (HubertConfig): config - num_labels (int): number of labels - hubert (HubertForSequenceClassification): hubert model - """ - def __init__(self, config, num_labels): - super().__init__(config, num_labels) - self.hubert = HubertForSequenceClassification(config) - self.init_weights() - - -class Wav2Vec2BertForSequenceClassification(BaseMultiModalForSequenceBaseClassification): # noqa - """Wav2Vec2BertForSequenceClassification is a model for sequence classification task - (e.g. sentiment analysis, text classification, etc.) - - Args: - config (Wav2Vec2BertConfig): config - - Attributes: - config (Wav2Vec2BertConfig): config - audio_config (Wav2Vec2Config): wav2vec2 config - text_config (BertConfig): bert config - audio_model (Wav2Vec2Model): wav2vec2 model - text_model (BertModel): bert model - classifier (torch.nn.Linear): classifier - """ - def __init__(self, config): - super().__init__(config) - self.audio_config = Wav2Vec2Config.from_dict(self.config.Wav2Vec2Model) - self.text_config = BertConfig.from_dict(self.config.BertModel) - self.audio_model = Wav2Vec2Model(self.audio_config) - self.text_model = BertModel(self.text_config) - self.classifier = torch.nn.Linear( - self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels - ) - self.init_weights() - - -class WavLMBertForSequenceClassification(BaseMultiModalForSequenceBaseClassification): # noqa - """WavLMBertForSequenceClassification is a model for sequence classification task - (e.g. sentiment analysis, text classification, etc.) - - Args: - config (WavLMBertConfig): config - - Attributes: - config (WavLMBertConfig): config - audio_config (WavLMConfig): wavlm config - text_config (BertConfig): bert config - audio_model (WavLMModel): wavlm model - text_model (BertModel): bert model - classifier (torch.nn.Linear): classifier - """ - def __init__(self, config): - super().__init__(config) - self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel) - self.text_config = BertConfig.from_dict(self.config.BertModel) - self.audio_model = WavLMModel(self.audio_config) - self.text_model = BertModel(self.text_config) - self.classifier = torch.nn.Linear( - self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels - ) - self.init_weights() - - -class FineTuneWav2Vec2BertForSequenceClassification(BaseMultiModalForSequenceBaseClassification): # noqa - """FineTuneWav2Vec2BertForSequenceClassification is a model for sequence classification task - (e.g. sentiment analysis, text classification, etc.) for fine-tuning - - Args: - config (Wav2Vec2BertConfig): config - - Attributes: - config (Wav2Vec2BertConfig): config - audio_config (Wav2Vec2Config): wav2vec2 config - text_config (BertConfig): bert config - audio_model (Wav2Vec2Model): wav2vec2 model - text_model (BertModel): bert model - classifier (torch.nn.Linear): classifier - """ - def __init__(self, config): - super().__init__(config) - self.audio_config = Wav2Vec2Config.from_dict(self.config.Wav2Vec2Model) - self.text_config = BertConfig.from_dict(self.config.BertModel) - self.audio_model = Wav2Vec2Model.from_pretrained(self.audio_config._name_or_path, config=self.audio_config) - self.text_model = BertModel.from_pretrained(self.text_config._name_or_path, config=self.text_config) - self.classifier = torch.nn.Linear( - self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels - ) - self.init_weights() - - -class FineTuneWavLMBertForSequenceClassification(BaseMultiModalForSequenceBaseClassification): # noqa - """FineTuneWavLMBertForSequenceClassification is a model for sequence classification task - (e.g. sentiment analysis, text classification, etc.) for fine-tuning - - Args: - config (WavLMBertConfig): config - - Attributes: - config (WavLMBertConfig): config - audio_config (WavLMConfig): wavlm config - text_config (BertConfig): bert config - audio_model (WavLMModel): wavlm model - text_model (BertModel): bert model - classifier (torch.nn.Linear): classifier - """ - def __init__(self, config): - super().__init__(config) - self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel) - self.text_config = BertConfig.from_dict(self.config.BertModel) - self.audio_model = WavLMModel.from_pretrained(self.audio_config._name_or_path, config=self.audio_config) - self.text_model = BertModel.from_pretrained(self.text_config._name_or_path, config=self.text_config) - self.classifier = torch.nn.Linear( - self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels - ) - self.init_weights() - - -class FusionModuleQ(torch.nn.Module): - """FusionModuleQ is a fusion module for the query - https://arxiv.org/abs/2302.13661 - https://arxiv.org/abs/2207.04697 - - Args: - audio_dim (int): audio dimension - text_dim (int): text dimension - num_heads (int): number of heads - """ - def __init__(self, audio_dim, text_dim, num_heads): - super().__init__() - - # pick the lowest dimension of the two modalities - self.dimension = min(audio_dim, text_dim) - - # attention modules - self.a_self_attention = torch.nn.MultiheadAttention(self.dimension, num_heads=num_heads) - self.t_self_attention = torch.nn.MultiheadAttention(self.dimension, num_heads=num_heads) - - # layer norm - self.audio_norm = torch.nn.LayerNorm(self.dimension) - self.text_norm = torch.nn.LayerNorm(self.dimension) - - def forward(self, audio_output, text_output): - """Forward pass - - Args: - audio_output (torch.Tensor): audio output - text_output (torch.Tensor): text output - - Returns: - torch.Tensor: audio output of the fusion module - """ - # Multihead cross attention (dims ARE switched) - audio_attn, _ = self.a_self_attention(audio_output, text_output, text_output) - text_attn, _ = self.t_self_attention(text_output, audio_output, audio_output) - - # Add & Norm with dropout - audio_add = self.audio_norm(audio_output + audio_attn) - text_add = self.text_norm(text_output + text_attn) - - return audio_add, text_add - - -class AudioTextFusionModelForSequenceClassificaion(BaseMultiModalForSequenceBaseClassification): # noqa - def __init__(self, config): - """ - Args: - config (MultiModalConfig): config - Attributes: - audio_projector (Union[torch.nn.Linear, None]): Projection layer for audio embeds - text_projector (Union[torch.nn.Linear, None]): Projection layer for text embeds - audio_avg_pool (Union[torch.nn.AvgPool1d, None]): Audio average pool (out from fusion block) - text_avg_pool (Union[torch.nn.AvgPool1d, None]): Text average pool (out from fusion block) - """ - super().__init__(config) - self.audio_projector: Union[torch.nn.Linear, None] = None - self.text_projector: Union[torch.nn.Linear, None] = None - self.audio_avg_pool: Union[torch.nn.AvgPool1d, None] = None - self.text_avg_pool: Union[torch.nn.AvgPool1d, None] = None - - -class WavLMBertFusionForSequenceClassification(AudioTextFusionModelForSequenceClassificaion): # noqa - """ - WavLMBertForSequenceClassification is a model for sequence classification task - (e.g. sentiment analysis, text classification, etc.) for fine-tuning - Args: - config (WavLMBertConfig): config - Attributes: - config (WavLMBertConfig): config - audio_config (WavLMConfig): wavlm config - text_config (BertConfig): bert config - audio_model (WavLMModel): wavlm model - text_model (BertModel): bert model - fusion_module_{i} (FusionModuleQ): Fusion Module Q - audio_projector (Union[torch.nn.Linear, None]): Projection layer for audio embeds - text_projector (Union[torch.nn.Linear, None]): Projection layer for text embeds - audio_avg_pool (Union[torch.nn.AvgPool1d, None]): Audio average pool (out from fusion block) - text_avg_pool (Union[torch.nn.AvgPool1d, None]): Text average pool (out from fusion block) - classifier (torch.nn.Linear): classifier - """ - - def __init__(self, config, finetune=False): - super().__init__(config) - self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel) - self.text_config = BertConfig.from_dict(self.config.BertModel) - - if not finetune: - self.audio_model = WavLMModel(self.audio_config) - self.text_model = BertModel(self.text_config) - - else: - self.audio_model = WavLMModel.from_pretrained(self.audio_config._name_or_path, config=self.audio_config) - self.text_model = BertModel.from_pretrained(self.text_config._name_or_path, config=self.text_config) - - # fusion module with V3 strategy (one projection on entry, no projection in continuous) - for i in range(self.config.num_fusion_layers): - setattr(self, f"fusion_module_{i + 1}", FusionModuleQ( - self.audio_config.hidden_size, self.text_config.hidden_size, self.config.num_heads - )) - - self.audio_projector = torch.nn.Linear(self.audio_config.hidden_size, self.text_config.hidden_size) - self.text_projector = torch.nn.Linear(self.text_config.hidden_size, self.text_config.hidden_size) - - # Avg Pool - self.audio_avg_pool = torch.nn.AvgPool1d(self.config.kernel_size) - self.text_avg_pool = torch.nn.AvgPool1d(self.config.kernel_size) - - # output dimensions of wav2vec2 and bert are 768 and 1024 respectively - cls_dim = min(self.audio_config.hidden_size, self.text_config.hidden_size) - self.classifier = torch.nn.Linear( - (cls_dim * 2) // self.config.kernel_size, self.config.num_labels - ) - - self.init_weights() - - def forward( - self, - input_ids=None, - input_values=None, - text_attention_mask=None, - audio_attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=True, - ): - """Forward method for multimodal model for sequence classification task (e.g. text + audio) - Args: - input_ids (torch.LongTensor, optional): input ids. Defaults to None. - input_values (torch.FloatTensor, optional): input values. Defaults to None. - text_attention_mask (torch.LongTensor, optional): text attention mask. Defaults to None. - audio_attention_mask (torch.LongTensor, optional): audio attention mask. Defaults to None. - token_type_ids (torch.LongTensor, optional): token type ids. Defaults to None. - position_ids (torch.LongTensor, optional): position ids. Defaults to None. - head_mask (torch.FloatTensor, optional): head mask. Defaults to None. - inputs_embeds (torch.FloatTensor, optional): inputs embeds. Defaults to None. - labels (torch.LongTensor, optional): labels. Defaults to None. - output_attentions (bool, optional): output attentions. Defaults to None. - output_hidden_states (bool, optional): output hidden states. Defaults to None. - return_dict (bool, optional): return dict. Defaults to True. - Returns: - torch.FloatTensor: logits - """ - audio_output = self.audio_model( - input_values=input_values, - attention_mask=audio_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict - ) - text_output = self.text_model( - input_ids=input_ids, - attention_mask=text_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # Mean pooling - audio_avg = self.merged_strategy(audio_output.last_hidden_state, mode=self.config.pooling_mode) - - # Projection - audio_proj = self.audio_projector(audio_avg) - text_proj = self.text_projector(text_output.pooler_output) - - audio_mha, text_mha = None, None - - for i in range(self.config.num_fusion_layers): - fusion_module = getattr(self, f"fusion_module_{i + 1}") - - if i == 0: - audio_mha, text_mha = fusion_module(audio_proj, text_proj) - else: - audio_mha, text_mha = fusion_module(audio_mha, text_mha) - - audio_avg = self.audio_avg_pool(audio_mha) - text_avg = self.text_avg_pool(text_mha) - - fusion_output = torch.concat((audio_avg, text_avg), dim=1) - - logits = self.classifier(fusion_output) - loss = None - - if labels is not None: - loss = self.compute_loss(logits, labels) - - return SpeechModelOutput( - loss=loss, - logits=logits - ) diff --git a/aniemore/models.py b/aniemore/models.py index 12b698f..c01eb2e 100644 --- a/aniemore/models.py +++ b/aniemore/models.py @@ -13,9 +13,10 @@ PreTrainedModel ) -from aniemore.custom.models import ( - Wav2Vec2BertForSequenceClassification, - WavLMBertForSequenceClassification, WavLMBertFusionForSequenceClassification +from aniemore.custom.modeling_wav2vec2 import Wav2Vec2BertForSequenceClassification +from aniemore.custom.modeling_wavlm import ( + WavLMBertForSequenceClassification, + WavLMBertFusionForSequenceClassification ) diff --git a/aniemore/utils/classes.py b/aniemore/utils/classes.py index ce10221..8de7f95 100644 --- a/aniemore/utils/classes.py +++ b/aniemore/utils/classes.py @@ -16,7 +16,7 @@ PreTrainedModel ) -from aniemore.custom.models import BaseMultiModalForSequenceBaseClassification +from aniemore.custom.modeling_classificators import BaseMultiModalForSequenceBaseClassification from aniemore.models import Model RecognizerOutputOne: Type[Dict[str, float]] = dict @@ -78,7 +78,7 @@ def _setup_variables(self) -> None: try: self.config = AutoConfig.from_pretrained(self.model_url) self._model = self.model_cls.from_pretrained(self.model_url, config=self.config) - except Exception as exc: # TODO: needs more precise exception work + except (RuntimeError, ValueError) as exc: self.config = AutoConfig.from_pretrained(self.model_url, trust_remote_code=True) self._model = self.model_cls.from_pretrained( self.model_url, trust_remote_code=True, config=self.config diff --git a/poetry.lock b/poetry.lock index 8837a2b..6ad8ba6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1367,6 +1367,7 @@ files = [ {file = "soundfile-0.12.1-py2.py3-none-any.whl", hash = "sha256:828a79c2e75abab5359f780c81dccd4953c45a2c4cd4f05ba3e233ddf984b882"}, {file = "soundfile-0.12.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d922be1563ce17a69582a352a86f28ed8c9f6a8bc951df63476ffc310c064bfa"}, {file = "soundfile-0.12.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:bceaab5c4febb11ea0554566784bcf4bc2e3977b53946dda2b12804b4fe524a8"}, + {file = "soundfile-0.12.1-py2.py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:2dc3685bed7187c072a46ab4ffddd38cef7de9ae5eb05c03df2ad569cf4dacbc"}, {file = "soundfile-0.12.1-py2.py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:074247b771a181859d2bc1f98b5ebf6d5153d2c397b86ee9e29ba602a8dfe2a6"}, {file = "soundfile-0.12.1-py2.py3-none-win32.whl", hash = "sha256:59dfd88c79b48f441bbf6994142a19ab1de3b9bb7c12863402c2bc621e49091a"}, {file = "soundfile-0.12.1-py2.py3-none-win_amd64.whl", hash = "sha256:0d86924c00b62552b650ddd28af426e3ff2d4dc2e9047dae5b3d8452e0a49a77"}, diff --git a/pyproject.toml b/pyproject.toml index 5b9978c..6ca6ef7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "aniemore" -version = "1.2.0" +version = "1.2.1" authors = [ "Ilya Lubenets ", "Nikita Davidchuk ", diff --git a/tests/aniemore/recognizers/test_multimodal.py b/tests/aniemore/recognizers/test_multimodal.py index f35c83a..e0c9efc 100644 --- a/tests/aniemore/recognizers/test_multimodal.py +++ b/tests/aniemore/recognizers/test_multimodal.py @@ -13,7 +13,7 @@ TESTS_DIR = Path(__file__).parent TEST_VOICE_DATA_PATH = str(TESTS_DIR / 'src' / 'my_voice.ogg') -GENERAL_WAVLM_BERT_MODEL = HuggingFaceModel.MultiModal.WavLMBertBase +GENERAL_WAVLM_BERT_MODEL = HuggingFaceModel.MultiModal.WavLMBertFusion @pytest.fixture(autouse=True) @@ -35,9 +35,9 @@ def test_create_empty(): def test_create_dummy_voice_text(): vtr = VoiceTextRecognizer(model=GENERAL_WAVLM_BERT_MODEL) - assert vtr.model_url == HuggingFaceModel.MultiModal.WavLMBertBase.model_url + assert vtr.model_url == GENERAL_WAVLM_BERT_MODEL.model_url - assert vtr.model_cls == HuggingFaceModel.MultiModal.WavLMBertBase.model_cls + assert vtr.model_cls == GENERAL_WAVLM_BERT_MODEL.model_cls del vtr @@ -45,9 +45,9 @@ def test_create_dummy_voice_text(): def test_create_dummy_multimodal(): mr = MultiModalRecognizer(model=GENERAL_WAVLM_BERT_MODEL, s2t_model=SmallSpeech2Text()) - assert mr.model_url == HuggingFaceModel.MultiModal.WavLMBertBase.model_url + assert mr.model_url == GENERAL_WAVLM_BERT_MODEL.model_url - assert mr.model_cls == HuggingFaceModel.MultiModal.WavLMBertBase.model_cls + assert mr.model_cls == GENERAL_WAVLM_BERT_MODEL.model_cls assert isinstance(mr.s2t_model, Speech2Text)