Skip to content

Commit

Permalink
update docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
kuacakuaca committed Aug 28, 2024
1 parent 600752e commit 7eb61a1
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions i6_models/parts/conformer/mhsa_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,16 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
Apply layer norm and multi-head self attention and dropout
:param input_tensor: Input to the self attention of shape (B, T, F)
:param sequence_mask: bool mask of shape (B, T), True signals within sequence, False outside, will be inverted to match the torch.nn.MultiheadAttention module
which will be applied/added to dot product, used to mask padded key positions out
:param sequence_mask: bool mask of shape (B, T), True signals within sequence, False outside
"""
output_tensor = self.layernorm(input_tensor) # [B,T,F]
output_tensor = self.layernorm(input_tensor) # [B, T, F]

time_dim_size = output_tensor.shape[1]
batch_dim_size = output_tensor.shape[0]

# attention mask
inv_sequence_mask = compat.logical_not(sequence_mask) # [B, T]
# T: query seq. length, T' key/value seg length; T = T' if same input tensor
inv_sequence_mask = compat.logical_not(sequence_mask) # [B, T']
mask = (
torch.zeros_like(inv_sequence_mask, dtype=input_tensor.dtype)
.masked_fill(inv_sequence_mask, float("-inf"))
Expand Down

0 comments on commit 7eb61a1

Please sign in to comment.