Skip to content

Commit

Permalink
Add mask_tensor util
Browse files Browse the repository at this point in the history
Utility to create a boolean tensor mask on the tensors device
based on sequence lengths.
  • Loading branch information
JackTemaki committed Sep 6, 2023
1 parent d2c8a24 commit 6c9c6c5
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions i6_models/util/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


def mask_tensor(tensor: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
"""
Create a "positive" mask for a tensor (boolean true means position is used)
on the same device as the tensor.
:param tensor: [B,T,....]
:param seq_len: [B]
:return: Mask of [B,T]
"""
seq_len = seq_len.to(device=tensor.device)
r = torch.arange(tensor.shape[1], device=tensor.device) # [T]
seq_mask = torch.less(r[None, :], seq_len[:, None]) # broadcast to [B,T]
return seq_mask

0 comments on commit 6c9c6c5

Please sign in to comment.