From 6c9c6c5504ac2082112f49d5f026227e847536d0 Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Wed, 6 Sep 2023 10:22:39 +0200 Subject: [PATCH] Add mask_tensor util Utility to create a boolean tensor mask on the tensors device based on sequence lengths. --- i6_models/util/mask.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 i6_models/util/mask.py diff --git a/i6_models/util/mask.py b/i6_models/util/mask.py new file mode 100644 index 00000000..4248e95d --- /dev/null +++ b/i6_models/util/mask.py @@ -0,0 +1,16 @@ +import torch + + +def mask_tensor(tensor: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor: + """ + Create a "positive" mask for a tensor (boolean true means position is used) + on the same device as the tensor. + + :param tensor: [B,T,....] + :param seq_len: [B] + :return: Mask of [B,T] + """ + seq_len = seq_len.to(device=tensor.device) + r = torch.arange(tensor.shape[1], device=tensor.device) # [T] + seq_mask = torch.less(r[None, :], seq_len[:, None]) # broadcast to [B,T] + return seq_mask