From a01f7fa7d00d6c5bbcbbbf00671a68d75bbabfc3 Mon Sep 17 00:00:00 2001 From: Wilfried Michel Date: Thu, 24 Aug 2023 15:26:58 +0200 Subject: [PATCH 1/4] fix n_fft, win_length, f_min --- i6_models/primitives/feature_extraction.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index a6eb0bf1..78fedefb 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -37,7 +37,7 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): def __post_init__(self) -> None: super().__post_init__() assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate" - assert self.f_min > 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive" + assert self.f_min >= 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive" assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive" assert self.num_filters > 0, "number of filters needs to be positive" assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense" @@ -58,6 +58,7 @@ class LogMelFeatureExtractionV1(nn.Module): def __init__(self, cfg: LogMelFeatureExtractionV1Config): super().__init__() self.register_buffer("n_fft", torch.tensor(cfg.n_fft)) + self.register_buffer("win_length", torch.tensor(int(cfg.win_size * cfg.sample_rate))) self.register_buffer("window", torch.hann_window(int(cfg.win_size * cfg.sample_rate))) self.register_buffer("hop_length", torch.tensor(int(cfg.hop_size * cfg.sample_rate))) self.register_buffer("min_amp", torch.tensor(cfg.min_amp)) @@ -67,7 +68,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): torch.tensor( filters.mel( sr=cfg.sample_rate, - n_fft=int(cfg.sample_rate * cfg.win_size), + n_fft=cfg.n_fft, n_mels=cfg.num_filters, fmin=cfg.f_min, fmax=cfg.f_max, @@ -87,6 +88,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: raw_audio, n_fft=self.n_fft, hop_length=self.hop_length, + win_length=self.win_length, window=self.window, center=self.center, pad_mode="constant", From 2c750441a227f82b8d453d70f3b3b53746a2b9ca Mon Sep 17 00:00:00 2001 From: Wilfried Michel Date: Thu, 24 Aug 2023 18:33:06 +0200 Subject: [PATCH 2/4] change buffers to class member --- i6_models/primitives/feature_extraction.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 78fedefb..77ff4fec 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -57,12 +57,12 @@ class LogMelFeatureExtractionV1(nn.Module): def __init__(self, cfg: LogMelFeatureExtractionV1Config): super().__init__() - self.register_buffer("n_fft", torch.tensor(cfg.n_fft)) - self.register_buffer("win_length", torch.tensor(int(cfg.win_size * cfg.sample_rate))) - self.register_buffer("window", torch.hann_window(int(cfg.win_size * cfg.sample_rate))) - self.register_buffer("hop_length", torch.tensor(int(cfg.hop_size * cfg.sample_rate))) - self.register_buffer("min_amp", torch.tensor(cfg.min_amp)) self.center = cfg.center + self.hop_length = int(cfg.hop_size * cfg.sample_rate) + self.min_amp = cfg.min_amp + self.n_fft = cfg.n_fft + self.win_length = int(cfg.win_size * cfg.sample_rate) + self.register_buffer( "mel_basis", torch.tensor( @@ -75,6 +75,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): ) ), ) + self.register_buffer("window", torch.hann_window(self.win_length)) def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: """ From 4928c02d203d3bc1040fa45b079383f8fe16c24d Mon Sep 17 00:00:00 2001 From: Wilfried Michel Date: Thu, 24 Aug 2023 18:43:29 +0200 Subject: [PATCH 3/4] change max to clamp --- i6_models/primitives/feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 77ff4fec..db5d3954 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -102,7 +102,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again power_spectrum = torch.unsqueeze(power_spectrum, 0) melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) - log_melspec = torch.log10(torch.max(self.min_amp, melspec)) + log_melspec = torch.log10(torch.clamp(melspec, max=self.min_amp)) feature_data = torch.transpose(log_melspec, 1, 2) if self.center: From 97f04408ea8e0dd68f163297ffb7a4e8112d0ebe Mon Sep 17 00:00:00 2001 From: michelwi Date: Thu, 24 Aug 2023 18:45:16 +0200 Subject: [PATCH 4/4] Update i6_models/primitives/feature_extraction.py Co-authored-by: Albert Zeyer --- i6_models/primitives/feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index db5d3954..ead52dd5 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -102,7 +102,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again power_spectrum = torch.unsqueeze(power_spectrum, 0) melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) - log_melspec = torch.log10(torch.clamp(melspec, max=self.min_amp)) + log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) feature_data = torch.transpose(log_melspec, 1, 2) if self.center: