|
12 | 12 |
|
13 | 13 | import numpy as np
|
14 | 14 | import torch
|
| 15 | +import torch.nn.functional as F |
15 | 16 |
|
16 | 17 | try:
|
17 | 18 | import kaldi_io
|
@@ -90,6 +91,74 @@ def sequence_mask(sequence_length, max_len=None):
|
90 | 91 | return seq_range_expand < seq_length_expand
|
91 | 92 |
|
92 | 93 |
|
| 94 | +def chunk_streaming_mask( |
| 95 | + sequence_length: torch.Tensor, |
| 96 | + chunk_size: int, |
| 97 | + left_window: int = 0, |
| 98 | + right_window: int = 0, |
| 99 | + always_partial_in_last: bool = False, |
| 100 | +): |
| 101 | + """Returns a mask for chunk streaming Transformer models. |
| 102 | +
|
| 103 | + Args: |
| 104 | + sequence_length (LongTensor): sequence_length of shape `(batch)` |
| 105 | + chunk_size (int): chunk size |
| 106 | + left_window (int): how many left chunks can be seen (default: 0) |
| 107 | + right_window (int): how many right chunks can be seen (default: 0) |
| 108 | + always_partial_in_last (bool): if True always makes the last chunk partial; |
| 109 | + otherwise makes either the first or last chunk have partial size randomly, |
| 110 | + which is to avoid learning to emit EOS just based on partial chunk size |
| 111 | + (default: False) |
| 112 | +
|
| 113 | + Returns: |
| 114 | + mask: (BoolTensor): a mask tensor of shape `(tgt_len, src_len)`, where |
| 115 | + `tgt_len` is the length of output and `src_len` is the length of input. |
| 116 | + `attn_mask[tgt_i, src_j] = True` means that when calculating the embedding |
| 117 | + for `tgt_i`, we need the embedding of `src_j`. |
| 118 | + """ |
| 119 | + |
| 120 | + max_len = sequence_length.data.max() |
| 121 | + chunk_start_idx = torch.arange( |
| 122 | + 0, |
| 123 | + max_len, |
| 124 | + chunk_size, |
| 125 | + dtype=sequence_length.dtype, |
| 126 | + device=sequence_length.device, |
| 127 | + ) # e.g. [0,18,36,54] |
| 128 | + if not always_partial_in_last and np.random.rand() > 0.5: |
| 129 | + # either first or last chunk is partial. If only the last one is not complete, EOS is not effective |
| 130 | + chunk_start_idx = max_len - chunk_start_idx |
| 131 | + chunk_start_idx = chunk_start_idx.flip([0]) |
| 132 | + chunk_start_idx = chunk_start_idx[:-1] |
| 133 | + chunk_start_idx = F.pad(chunk_start_idx, (1, 0)) |
| 134 | + |
| 135 | + start_pad = torch.nn.functional.pad(chunk_start_idx, (1, 0)) # [0,0,18,36,54] |
| 136 | + end_pad = torch.nn.functional.pad( |
| 137 | + chunk_start_idx, (0, 1), value=max_len |
| 138 | + ) # [0,18,36,54,max_len] |
| 139 | + seq_range = torch.arange( |
| 140 | + 0, max_len, dtype=sequence_length.dtype, device=sequence_length.device |
| 141 | + ) |
| 142 | + idx = ( |
| 143 | + (seq_range.unsqueeze(-1) >= start_pad) & (seq_range.unsqueeze(-1) < end_pad) |
| 144 | + ).nonzero()[ |
| 145 | + :, 1 |
| 146 | + ] # max_len |
| 147 | + seq_range_expand = seq_range.unsqueeze(0).expand(max_len, -1) # max_len x max_len |
| 148 | + |
| 149 | + idx_left = idx - left_window |
| 150 | + idx_left[idx_left < 0] = 0 |
| 151 | + boundary_left = start_pad[idx_left] # max_len |
| 152 | + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) |
| 153 | + |
| 154 | + idx_right = idx + right_window |
| 155 | + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) |
| 156 | + boundary_right = end_pad[idx_right] # max_len |
| 157 | + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) |
| 158 | + |
| 159 | + return mask_left & mask_right |
| 160 | + |
| 161 | + |
93 | 162 | def convert_padding_direction(
|
94 | 163 | src_frames,
|
95 | 164 | src_lengths,
|
|
0 commit comments