diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index a6eb0bf1..ead52dd5 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -37,7 +37,7 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): def __post_init__(self) -> None: super().__post_init__() assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate" - assert self.f_min > 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive" + assert self.f_min >= 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive" assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive" assert self.num_filters > 0, "number of filters needs to be positive" assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense" @@ -57,23 +57,25 @@ class LogMelFeatureExtractionV1(nn.Module): def __init__(self, cfg: LogMelFeatureExtractionV1Config): super().__init__() - self.register_buffer("n_fft", torch.tensor(cfg.n_fft)) - self.register_buffer("window", torch.hann_window(int(cfg.win_size * cfg.sample_rate))) - self.register_buffer("hop_length", torch.tensor(int(cfg.hop_size * cfg.sample_rate))) - self.register_buffer("min_amp", torch.tensor(cfg.min_amp)) self.center = cfg.center + self.hop_length = int(cfg.hop_size * cfg.sample_rate) + self.min_amp = cfg.min_amp + self.n_fft = cfg.n_fft + self.win_length = int(cfg.win_size * cfg.sample_rate) + self.register_buffer( "mel_basis", torch.tensor( filters.mel( sr=cfg.sample_rate, - n_fft=int(cfg.sample_rate * cfg.win_size), + n_fft=cfg.n_fft, n_mels=cfg.num_filters, fmin=cfg.f_min, fmax=cfg.f_max, ) ), ) + self.register_buffer("window", torch.hann_window(self.win_length)) def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -87,6 +89,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: raw_audio, n_fft=self.n_fft, hop_length=self.hop_length, + win_length=self.win_length, window=self.window, center=self.center, pad_mode="constant", @@ -99,7 +102,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again power_spectrum = torch.unsqueeze(power_spectrum, 0) melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) - log_melspec = torch.log10(torch.max(self.min_amp, melspec)) + log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) feature_data = torch.transpose(log_melspec, 1, 2) if self.center: