Skip to content

Commit

Permalink
add LogMelFeatureExtraction (#26)
Browse files Browse the repository at this point in the history
Adds PyTorch-based log-mel feature extraction that is compatible to the librosa-based feature extraction in RETURNN.

Co-authored-by: vieting <[email protected]>
  • Loading branch information
JackTemaki and vieting authored Jul 25, 2023
1 parent 18d3f1a commit d2c8a24
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 1 deletion.
110 changes: 110 additions & 0 deletions i6_models/primitives/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]

from dataclasses import dataclass
from typing import Optional, Tuple

from librosa import filters
import torch
from torch import nn

from i6_models.config import ModelConfiguration


@dataclass
class LogMelFeatureExtractionV1Config(ModelConfiguration):
"""
Attributes:
sample_rate: audio sample rate in Hz
win_size: window size in seconds
hop_size: window shift in seconds
f_min: minimum filter frequency in Hz
f_max: maximum filter frequency in Hz
min_amp: minimum amplitude for safe log
num_filters: number of mel windows
center: centered STFT with automatic padding
"""

sample_rate: int
win_size: float
hop_size: float
f_min: int
f_max: int
min_amp: float
num_filters: int
center: bool
n_fft: Optional[int] = None

def __post_init__(self) -> None:
super().__post_init__()
assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate"
assert self.f_min > 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive"
assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive"
assert self.num_filters > 0, "number of filters needs to be positive"
assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense"
if self.n_fft is None:
# if n_fft is not given, set n_fft to the window size (in samples)
self.n_fft = int(self.win_size * self.sample_rate)
else:
assert self.n_fft >= self.win_size * self.sample_rate, "n_fft cannot to be smaller than the window size"


class LogMelFeatureExtractionV1(nn.Module):
"""
Librosa-compatible log-mel feature extraction using log10. Does not use torchaudio.
Using it wrapped with torch.no_grad() is recommended if no gradient is needed
"""

def __init__(self, cfg: LogMelFeatureExtractionV1Config):
super().__init__()
self.register_buffer("n_fft", torch.tensor(cfg.n_fft))
self.register_buffer("window", torch.hann_window(int(cfg.win_size * cfg.sample_rate)))
self.register_buffer("hop_length", torch.tensor(int(cfg.hop_size * cfg.sample_rate)))
self.register_buffer("min_amp", torch.tensor(cfg.min_amp))
self.center = cfg.center
self.register_buffer(
"mel_basis",
torch.tensor(
filters.mel(
sr=cfg.sample_rate,
n_fft=int(cfg.sample_rate * cfg.win_size),
n_mels=cfg.num_filters,
fmin=cfg.f_min,
fmax=cfg.f_max,
)
),
)

def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param raw_audio: [B, T]
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=self.window,
center=self.center,
pad_mode="constant",
return_complex=True,
)
)
** 2
)
if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
power_spectrum = torch.unsqueeze(power_spectrum, 0)
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
log_melspec = torch.log10(torch.max(self.min_amp, melspec))
feature_data = torch.transpose(log_melspec, 1, 2)

if self.center:
length = (length // self.hop_length) + 1
else:
length = ((length - self.n_fft) // self.hop_length) + 1

return feature_data, length.int()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
typeguard
torch
torch
librosa
67 changes: 67 additions & 0 deletions tests/test_feature_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import copy
import numpy
import torch

from librosa.feature import melspectrogram

from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1, LogMelFeatureExtractionV1Config


def test_logmel_librosa_compatibility():

audio = numpy.asarray(numpy.random.random((50000)), dtype=numpy.float32)
librosa_mel = melspectrogram(
y=audio,
sr=16000,
n_fft=int(0.05 * 16000),
hop_length=int(0.0125 * 16000),
win_length=int(0.05 * 16000),
fmin=60,
fmax=7600,
n_mels=80,
)
librosa_log_mel = numpy.log10(numpy.maximum(librosa_mel, 1e-10))

fe_cfg = LogMelFeatureExtractionV1Config(
sample_rate=16000,
win_size=0.05,
hop_size=0.0125,
f_min=60,
f_max=7600,
min_amp=1e-10,
num_filters=80,
center=True,
)
fe = LogMelFeatureExtractionV1(cfg=fe_cfg)
audio_tensor = torch.unsqueeze(torch.Tensor(audio), 0) # [B, T]
audio_length = torch.tensor([50000]) # [B]
pytorch_log_mel, frame_length = fe(audio_tensor, audio_length)
librosa_log_mel = torch.tensor(librosa_log_mel).transpose(0, 1)
assert torch.allclose(librosa_log_mel, pytorch_log_mel, atol=1e-06)


def test_logmel_length():
fe_center_cfg = LogMelFeatureExtractionV1Config(
sample_rate=16000,
win_size=0.05,
hop_size=0.0125,
f_min=60,
f_max=7600,
min_amp=1e-10,
num_filters=80,
center=True,
)
fe_center = LogMelFeatureExtractionV1(cfg=fe_center_cfg)
fe_no_center_cfg = copy.deepcopy(fe_center_cfg)
fe_no_center_cfg.center = False
fe_no_center = LogMelFeatureExtractionV1(cfg=fe_no_center_cfg)
for i in range(10):
audio_length = int(numpy.random.randint(10000, 50000))
audio = numpy.asarray(numpy.random.random(audio_length), dtype=numpy.float32)
audio_length = torch.tensor(int(audio_length))
audio_length = torch.unsqueeze(audio_length, 0)
audio = torch.unsqueeze(torch.tensor(audio), 0)
mel_center, length_center = fe_center(audio, audio_length)
assert torch.all(mel_center.size()[1] == length_center)
mel_no_center, length_no_center = fe_no_center(audio, audio_length)
assert torch.all(mel_no_center.size()[1] == length_no_center)

0 comments on commit d2c8a24

Please sign in to comment.