Skip to content

Commit

Permalink
restored masked scaled dot attention
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Dec 29, 2023
1 parent 0436cdd commit 5579e4b
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def forward(
if (
step == 0
or not self.flash2
or self.self_attn_type != "scaled-dot-flash"
or self.max_relative_positions not in [0, -1]
or query.size(0) > 128
or query.dtype != torch.float16
Expand Down Expand Up @@ -685,6 +686,8 @@ def forward(
scores = self.alibi(scores)

scores = scores.float()
if key_pad_mask is not None and mask is None:
mask = key_pad_mask.unsqueeze(1)

if mask is not None:
# not 100% necessary but expand to nb of heads
Expand Down

0 comments on commit 5579e4b

Please sign in to comment.