-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Legacy-like zero-Specaugment implementation (#34)
* Legacy-like zero-Specaugment implementation Adds Specaugment compatible to old Returnn/TF setups. Code follows the Tensorflow implementation Co-Authored-By: Mohammad Zeineldeen <[email protected]> * Fixes for specaugment * Apply suggestions from code review Co-authored-by: SimBe195 <[email protected]> Co-authored-by: Albert Zeyer <[email protected]> * changes from review * Update i6_models/primitives/specaugment.py Co-authored-by: Albert Zeyer <[email protected]> * correct num_mask shape and device placement * Apply suggestions from code review Co-authored-by: Albert Zeyer <[email protected]> * Update i6_models/primitives/specaugment.py Co-authored-by: vieting <[email protected]> * correct number of masks * better naming --------- Co-authored-by: Mohammad Zeineldeen <[email protected]> Co-authored-by: SimBe195 <[email protected]> Co-authored-by: Albert Zeyer <[email protected]> Co-authored-by: vieting <[email protected]>
- Loading branch information
1 parent
95f5521
commit f6fdaf7
Showing
1 changed file
with
137 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torch | ||
|
||
|
||
def _mask(tensor: torch.Tensor, batch_axis: int, axis: int, pos: torch.Tensor, max_len: int) -> torch.Tensor: | ||
""" | ||
:param tensor: e.g. [B, ..., A, ...] but arbitrary axis order | ||
:param batch_axis: index of the batch axis | ||
:param axis: which axis A to mask | ||
:param pos: at which positions along axis to start the mask (size [B]) | ||
:param max_len: mask length drawn uniformly from [0, max_len] | ||
""" | ||
batch_dim_size = tensor.shape[batch_axis] | ||
mask_dim_size = tensor.shape[axis] | ||
mask_len = torch.randint(low=1, high=max_len + 1, size=(batch_dim_size,), dtype=torch.int32, device=tensor.device) | ||
end_pos = torch.min(pos + mask_len, torch.tensor([mask_dim_size] * batch_dim_size, device=tensor.device)) | ||
idxs = torch.arange(0, mask_dim_size, device=tensor.device).unsqueeze(0) # [1,dim] | ||
pos_bc = pos.unsqueeze(1) # [B,1] | ||
end_pos_bc = end_pos.unsqueeze(1) # [B,1] | ||
mask = torch.logical_and(torch.greater_equal(idxs, pos_bc), torch.less(idxs, end_pos_bc)) # [B,dim] | ||
if batch_axis > axis: | ||
mask = mask.transpose(0, 1) # [dim,B] | ||
mask = torch.reshape(mask, shape=[tensor.shape[i] if i in (batch_axis, axis) else 1 for i in range(tensor.ndim)]) | ||
tensor = torch.where(mask, 0.0, tensor) | ||
return tensor | ||
|
||
|
||
def _random_mask(tensor: torch.Tensor, batch_axis: int, axis: int, min_num: int, max_num: int, max_len: int): | ||
""" | ||
Mask tensor along axis using N in [min_num, max_num] masks of length [0, max_len] | ||
:param tensor: e.g. [B, ..., A, ...] but arbitrary axis order | ||
:param batch_axis: index of the batch axis | ||
:param axis: which axis to mask | ||
:param min_num: minimum number of masks | ||
:param max_num: maximum number of masks | ||
:param max_amount: mask length drawn uniformly from [0, max_amount] | ||
""" | ||
|
||
batch_dim_size = tensor.shape[batch_axis] | ||
if max_num < min_num: | ||
max_num = min_num | ||
num_masks = torch.randint(min_num, max_num + 1, size=(batch_dim_size,), device="cpu") # [B] | ||
|
||
max_num_masks = num_masks.max().item() | ||
|
||
z = torch.rand((batch_dim_size, tensor.shape[axis]), device=tensor.device) # [B,dim] | ||
_, indices = torch.topk(z, max_num_masks, dim=1) | ||
|
||
# Make num_masks broadcastable to shape of tensor for torch.where. | ||
num_masks = torch.reshape(num_masks, [1] * batch_axis + [batch_dim_size] + [1] * (tensor.dim() - batch_axis - 1)) | ||
|
||
num_masks = num_masks.to(device=tensor.device) | ||
|
||
for i in range(max_num_masks): | ||
tensor = torch.where(i < num_masks, _mask(tensor, batch_axis, axis, indices[:, i], max_len), tensor) | ||
|
||
return tensor | ||
|
||
|
||
def specaugment_v1( | ||
audio_features: torch.Tensor, | ||
*, | ||
time_min_num_masks: int, | ||
time_max_num_masks: int, | ||
time_mask_max_size: int, | ||
freq_min_num_masks: int, | ||
freq_max_num_masks: int, | ||
freq_mask_max_size: int, | ||
): | ||
""" | ||
Specaugment from legacy rossenbach/zeineldeen/zeyer attention setups e.g., | ||
https://github.com/rwth-i6/i6_experiments/blob/main/users/zeineldeen/data_aug/specaugment/specaug_tf2.py | ||
but without any step-based scheduling and without dependence on length. | ||
See `specaugment_v1_by_length` for a variant which is more close to the original. | ||
Fills masks with zeros. | ||
Basically just a convenience wrapper around _random_mask. | ||
See also: https://arxiv.org/abs/1904.08779 | ||
:param audio_features: e.g. log-mel features as [B, T, F] | ||
:param time_min_num_masks: minimum number of masks along T | ||
:param time_max_num_masks: maximum number of masks along T | ||
:param time_mask_max_size: maximum size of masks along T | ||
:param freq_min_num_masks: minimum number of masks along F | ||
:param freq_max_num_masks: maximum number of masks along F | ||
:param freq_mask_max_size: maximum size of masks along F | ||
:return: masked audio features | ||
""" | ||
assert len(audio_features.shape) == 3 | ||
assert time_min_num_masks <= time_max_num_masks | ||
assert freq_min_num_masks <= freq_max_num_masks | ||
masked_audio_features = _random_mask( | ||
audio_features, 0, 1, time_min_num_masks, time_max_num_masks, time_mask_max_size | ||
) # time masking | ||
masked_audio_features = _random_mask( | ||
masked_audio_features, 0, 2, freq_min_num_masks, freq_max_num_masks, freq_mask_max_size | ||
) # freq masking | ||
return masked_audio_features | ||
|
||
|
||
def specaugment_v1_by_length( | ||
audio_features: torch.Tensor, | ||
*, | ||
time_min_num_masks: int, | ||
time_max_mask_per_n_frames: int, | ||
time_mask_max_size: int, | ||
freq_min_num_masks: int, | ||
freq_max_num_masks: int, | ||
freq_mask_max_size: int, | ||
): | ||
""" | ||
Convenience wrapper around specaugment_v1 with time-length adaptive number of masks. | ||
:param audio_features: e.g. log-mel features as [B, T, F] | ||
:param time_max_mask_per_n_frames: used for the maximum number time masks, | ||
max_num_masks = T / max_mask_per_n_frames for each batch. | ||
They are still drawn depending on the full batch length, so shorter sequences | ||
might get more masks than that by chance, or none at all when all masks | ||
fall into the padding space. | ||
:param time_min_num_masks: minimum number of masks along T | ||
:param time_mask_max_size: maximum size of masks along T | ||
:param freq_min_num_masks: minimum number of masks along F | ||
:param freq_max_num_masks: maximum number of masks along F | ||
:param freq_mask_max_size: maximum size of masks along F | ||
:return: masked audio features | ||
""" | ||
return specaugment_v1( | ||
audio_features, | ||
time_min_num_masks=time_min_num_masks, | ||
time_max_num_masks=audio_features.size(1) // time_max_mask_per_n_frames, | ||
time_mask_max_size=time_mask_max_size, | ||
freq_min_num_masks=freq_min_num_masks, | ||
freq_max_num_masks=freq_max_num_masks, | ||
freq_mask_max_size=freq_mask_max_size, | ||
) |