Skip to content

Commit

Permalink
Update i6_models/parts/conformer/mhsa_rel_pos.py
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Zeyer <[email protected]>
  • Loading branch information
kuacakuaca and albertz authored Sep 5, 2024
1 parent a4929dc commit 338ff2c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion i6_models/parts/conformer/mhsa_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
mask = (
torch.zeros_like(inv_sequence_mask, dtype=input_tensor.dtype)
.masked_fill(inv_sequence_mask, float("-inf"))
.view(batch_dim_size, 1, 1, time_dim_size)
.reshape(batch_dim_size, 1, 1, time_dim_size)
) # [B, 1, 1, T']

# query, key and value sequences
Expand Down

0 comments on commit 338ff2c

Please sign in to comment.