diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index d71ffbc460..1780871d9c 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -439,6 +439,7 @@ def forward( if ( step == 0 or not self.flash2 + or self.self_attn_type != "scaled-dot-flash" or self.max_relative_positions not in [0, -1] or query.size(0) > 128 or query.dtype != torch.float16 @@ -685,6 +686,8 @@ def forward( scores = self.alibi(scores) scores = scores.float() + if key_pad_mask is not None and mask is None: + mask = key_pad_mask.unsqueeze(1) if mask is not None: # not 100% necessary but expand to nb of heads