-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds PyTorch-based log-mel feature extraction that is compatible to the librosa-based feature extraction in RETURNN. Co-authored-by: vieting <[email protected]>
- Loading branch information
1 parent
18d3f1a
commit d2c8a24
Showing
3 changed files
with
179 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"] | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Tuple | ||
|
||
from librosa import filters | ||
import torch | ||
from torch import nn | ||
|
||
from i6_models.config import ModelConfiguration | ||
|
||
|
||
@dataclass | ||
class LogMelFeatureExtractionV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
sample_rate: audio sample rate in Hz | ||
win_size: window size in seconds | ||
hop_size: window shift in seconds | ||
f_min: minimum filter frequency in Hz | ||
f_max: maximum filter frequency in Hz | ||
min_amp: minimum amplitude for safe log | ||
num_filters: number of mel windows | ||
center: centered STFT with automatic padding | ||
""" | ||
|
||
sample_rate: int | ||
win_size: float | ||
hop_size: float | ||
f_min: int | ||
f_max: int | ||
min_amp: float | ||
num_filters: int | ||
center: bool | ||
n_fft: Optional[int] = None | ||
|
||
def __post_init__(self) -> None: | ||
super().__post_init__() | ||
assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate" | ||
assert self.f_min > 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive" | ||
assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive" | ||
assert self.num_filters > 0, "number of filters needs to be positive" | ||
assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense" | ||
if self.n_fft is None: | ||
# if n_fft is not given, set n_fft to the window size (in samples) | ||
self.n_fft = int(self.win_size * self.sample_rate) | ||
else: | ||
assert self.n_fft >= self.win_size * self.sample_rate, "n_fft cannot to be smaller than the window size" | ||
|
||
|
||
class LogMelFeatureExtractionV1(nn.Module): | ||
""" | ||
Librosa-compatible log-mel feature extraction using log10. Does not use torchaudio. | ||
Using it wrapped with torch.no_grad() is recommended if no gradient is needed | ||
""" | ||
|
||
def __init__(self, cfg: LogMelFeatureExtractionV1Config): | ||
super().__init__() | ||
self.register_buffer("n_fft", torch.tensor(cfg.n_fft)) | ||
self.register_buffer("window", torch.hann_window(int(cfg.win_size * cfg.sample_rate))) | ||
self.register_buffer("hop_length", torch.tensor(int(cfg.hop_size * cfg.sample_rate))) | ||
self.register_buffer("min_amp", torch.tensor(cfg.min_amp)) | ||
self.center = cfg.center | ||
self.register_buffer( | ||
"mel_basis", | ||
torch.tensor( | ||
filters.mel( | ||
sr=cfg.sample_rate, | ||
n_fft=int(cfg.sample_rate * cfg.win_size), | ||
n_mels=cfg.num_filters, | ||
fmin=cfg.f_min, | ||
fmax=cfg.f_max, | ||
) | ||
), | ||
) | ||
|
||
def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
:param raw_audio: [B, T] | ||
:param length in samples: [B] | ||
:return features as [B,T,F] and length in frames [B] | ||
""" | ||
power_spectrum = ( | ||
torch.abs( | ||
torch.stft( | ||
raw_audio, | ||
n_fft=self.n_fft, | ||
hop_length=self.hop_length, | ||
window=self.window, | ||
center=self.center, | ||
pad_mode="constant", | ||
return_complex=True, | ||
) | ||
) | ||
** 2 | ||
) | ||
if len(power_spectrum.size()) == 2: | ||
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again | ||
power_spectrum = torch.unsqueeze(power_spectrum, 0) | ||
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) | ||
log_melspec = torch.log10(torch.max(self.min_amp, melspec)) | ||
feature_data = torch.transpose(log_melspec, 1, 2) | ||
|
||
if self.center: | ||
length = (length // self.hop_length) + 1 | ||
else: | ||
length = ((length - self.n_fft) // self.hop_length) + 1 | ||
|
||
return feature_data, length.int() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
typeguard | ||
torch | ||
torch | ||
librosa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import copy | ||
import numpy | ||
import torch | ||
|
||
from librosa.feature import melspectrogram | ||
|
||
from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1, LogMelFeatureExtractionV1Config | ||
|
||
|
||
def test_logmel_librosa_compatibility(): | ||
|
||
audio = numpy.asarray(numpy.random.random((50000)), dtype=numpy.float32) | ||
librosa_mel = melspectrogram( | ||
y=audio, | ||
sr=16000, | ||
n_fft=int(0.05 * 16000), | ||
hop_length=int(0.0125 * 16000), | ||
win_length=int(0.05 * 16000), | ||
fmin=60, | ||
fmax=7600, | ||
n_mels=80, | ||
) | ||
librosa_log_mel = numpy.log10(numpy.maximum(librosa_mel, 1e-10)) | ||
|
||
fe_cfg = LogMelFeatureExtractionV1Config( | ||
sample_rate=16000, | ||
win_size=0.05, | ||
hop_size=0.0125, | ||
f_min=60, | ||
f_max=7600, | ||
min_amp=1e-10, | ||
num_filters=80, | ||
center=True, | ||
) | ||
fe = LogMelFeatureExtractionV1(cfg=fe_cfg) | ||
audio_tensor = torch.unsqueeze(torch.Tensor(audio), 0) # [B, T] | ||
audio_length = torch.tensor([50000]) # [B] | ||
pytorch_log_mel, frame_length = fe(audio_tensor, audio_length) | ||
librosa_log_mel = torch.tensor(librosa_log_mel).transpose(0, 1) | ||
assert torch.allclose(librosa_log_mel, pytorch_log_mel, atol=1e-06) | ||
|
||
|
||
def test_logmel_length(): | ||
fe_center_cfg = LogMelFeatureExtractionV1Config( | ||
sample_rate=16000, | ||
win_size=0.05, | ||
hop_size=0.0125, | ||
f_min=60, | ||
f_max=7600, | ||
min_amp=1e-10, | ||
num_filters=80, | ||
center=True, | ||
) | ||
fe_center = LogMelFeatureExtractionV1(cfg=fe_center_cfg) | ||
fe_no_center_cfg = copy.deepcopy(fe_center_cfg) | ||
fe_no_center_cfg.center = False | ||
fe_no_center = LogMelFeatureExtractionV1(cfg=fe_no_center_cfg) | ||
for i in range(10): | ||
audio_length = int(numpy.random.randint(10000, 50000)) | ||
audio = numpy.asarray(numpy.random.random(audio_length), dtype=numpy.float32) | ||
audio_length = torch.tensor(int(audio_length)) | ||
audio_length = torch.unsqueeze(audio_length, 0) | ||
audio = torch.unsqueeze(torch.tensor(audio), 0) | ||
mel_center, length_center = fe_center(audio, audio_length) | ||
assert torch.all(mel_center.size()[1] == length_center) | ||
mel_no_center, length_no_center = fe_no_center(audio, audio_length) | ||
assert torch.all(mel_no_center.size()[1] == length_no_center) |