diff --git a/mmf/models/vilbert.py b/mmf/models/vilbert.py index 1b9b3c44d..07c76c628 100644 --- a/mmf/models/vilbert.py +++ b/mmf/models/vilbert.py @@ -10,6 +10,7 @@ import torch.nn.functional as F from mmf.common.registry import registry from mmf.models import BaseModel +from mmf.models.transformers.heads.itm import ITM from mmf.modules.hf_layers import replace_with_jit from mmf.utils.configuration import get_mmf_cache_dir from mmf.utils.modeling import get_optimizer_parameters_for_bert @@ -1061,6 +1062,9 @@ def __init__(self, config): self.visual_target = config.visual_target self.num_negative = config.num_negative self.loss_fct = CrossEntropyLoss(ignore_index=-1) + + if itm_loss is not False: + itm_head = ITM({"type": "itm", "hidden_size": self.vocab_size}) if self.visual_target == 0: self.vis_criterion = nn.KLDivLoss(reduction="none") @@ -1099,6 +1103,8 @@ def forward( image_label: Optional[Tensor] = None, image_target: Optional[Tensor] = None, output_all_attention_masks: bool = False, + itm_loss: bool = False, + next_sentence_label: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, ) -> Dict[str, Tensor]: masked_img_loss: Optional[Tensor] = None ( @@ -1226,6 +1232,13 @@ def forward( prediction_scores_t.view(-1, self.vocab_size), masked_lm_labels.view(-1) ) output["masked_lm_loss"] = masked_lm_loss.unsqueeze(0) + + if itm_loss is not False: + seq_output = torch.cat(sequence_output_t, sequence_output_v) + multimodal_alignment_loss = itm_head(seq_output, processed_sample_list = next_sentence_label) + if multimodal_alignment_loss is not None: + output["itm_loss"] = multimodal_alignment_loss["losses"]["itm_loss"] + # next_sentence_loss = self.loss_fct( # seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) # )