From 2c04e0ac9219ec84c197c4a76b8972c23f205a60 Mon Sep 17 00:00:00 2001 From: suzakuwcx Date: Tue, 17 Dec 2024 08:50:06 +0800 Subject: [PATCH 1/5] =?UTF-8?q?[Hackathon=207th=20No.56]=20=E5=9C=A8=20Pad?= =?UTF-8?q?dleSpeech=20=E4=B8=AD=E5=A4=8D=E7=8E=B0=20DAC=20=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E9=9C=80=E8=A6=81=E7=94=A8=E5=88=B0=E7=9A=84=20loss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddlespeech/t2s/modules/losses.py | 221 ++++++++++++++++++++++++++++- tests/unit/tts/test_losses.py | 22 +++ 2 files changed, 242 insertions(+), 1 deletion(-) create mode 100644 tests/unit/tts/test_losses.py diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index e675dcab76d..581e17adf13 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple +from typing import Tuple, Callable, List, Union import librosa import numpy as np @@ -28,6 +28,8 @@ DurationPredictorLoss, # noqa: H301 ) +from paddleaudio.audiotools import AudioSignal, STFTParams + # Losses for WaveRNN def log_sum_exp(x): @@ -984,6 +986,108 @@ def forward(self, y_hat, y): return mel_loss +class MultiMelSpectrogramLoss(nn.Layer): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + fs: int = 44100, + ): + super().__init__() + + self.mel_loss_fns = [ + MelSpectrogramLoss( + fs=fs, + fft_size=w, + hop_size=w // 4, + num_mels=n_mel, + fmin=fmin, + fmax=fmax, + eps=clamp_eps, + ) + for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, mel_fmax) + ] + + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): + """Computes multi mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal or Tensor + Estimate signal + y : AudioSignal or Tensor + Reference signal + + Returns + ------- + paddle.Tensor + Mel loss. + """ + loss = 0.0 + for i, mel_loss_fn in enumerate(self.mel_loss_fns): + if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): + loss += self.log_weight * mel_loss_fn(x, y) + elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): + s = mel_loss_fn.mel_spectrogram.stft_params + x_mels = x.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) + y_mels = y.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) + loss += self.log_weight * self.loss_fn( + paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(), + paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10() + ) + else: + raise ValueError('\'x\' amd \'y\' should be the same type') + return loss + + class FeatureMatchLoss(nn.Layer): """Feature matching loss module.""" @@ -1326,3 +1430,118 @@ def _generate_prior(self, text_lengths, feats_lengths, bb_prior[bidx, :T, :N] = prob return bb_prior + + +class MultiScaleSTFTLoss(nn.Layer): + """Multi resolution STFT loss module.""" + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = 'hann', + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + + def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal or Tensor + Estimate signal + y : AudioSignal or Tensor + Reference signal + + Returns + ------- + paddle.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + + for s in self.stft_params: + if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): + x_mag = stft(x.reshape([-1, x.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) + y_mag = stft(y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) + x_mag = x_mag.transpose([0, 2, 1]) + y_mag = y_mag.transpose([0, 2, 1]) + elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + x_mag = x.magnitude + y_mag = y.magnitude + else: + raise ValueError('\'x\' amd \'y\' should be the same type') + + loss += self.log_weight * self.loss_fn( + paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(), + paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), + ) + + loss += self.mag_weight * self.loss_fn(x_mag, y_mag) + return loss + + +class GANLoss(nn.Layer): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += paddle.mean(x_fake[-1] ** 2) + loss_d += paddle.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += paddle.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += paddle.nn.functional.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/tests/unit/tts/test_losses.py b/tests/unit/tts/test_losses.py new file mode 100644 index 00000000000..8b379d7472f --- /dev/null +++ b/tests/unit/tts/test_losses.py @@ -0,0 +1,22 @@ +import torch +from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss, MultiMelSpectrogramLoss +from paddleaudio.audiotools.core.audio_signal import AudioSignal + +def test_dac_losses(): + for i in range(10): + loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt') + recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav') + signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav') + loss_fn_1 = MultiScaleSTFTLoss() + loss_fn_2 = MultiMelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mag_weight=0.0, pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None]) + # + # Test AudioSignal + # + assert abs(loss_fn_1(recons, signal).item() - loss_origin['stft/loss'].item()) < 1e-5 + assert abs(loss_fn_2(recons, signal).item() - loss_origin['mel/loss'].item()) < 1e-5 + + # + # Test Tensor + # + assert abs(loss_fn_1(recons.audio_data, signal.audio_data).item() - loss_origin['stft/loss'].item()) < 1e-3 + assert abs(loss_fn_2(recons.audio_data, signal.audio_data).item() - loss_origin['mel/loss'].item()) < 1e-3 From b741545f5e15ebb10c273a88e50c225f2527c474 Mon Sep 17 00:00:00 2001 From: suzakuwcx Date: Tue, 17 Dec 2024 10:08:08 +0800 Subject: [PATCH 2/5] Using pre-commit to format code --- paddlespeech/t2s/modules/losses.py | 125 +++++++++++++++++------------ tests/unit/tts/test_losses.py | 29 +++++-- 2 files changed, 96 insertions(+), 58 deletions(-) diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 581e17adf13..029ad1be10b 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple, Callable, List, Union +from typing import Callable +from typing import List +from typing import Tuple +from typing import Union import librosa import numpy as np import paddle from paddle import nn from paddle.nn import functional as F +from paddleaudio.audiotools import AudioSignal +from paddleaudio.audiotools import STFTParams from scipy import signal from scipy.stats import betabinom from typeguard import check_argument_types @@ -28,8 +33,6 @@ DurationPredictorLoss, # noqa: H301 ) -from paddleaudio.audiotools import AudioSignal, STFTParams - # Losses for WaveRNN def log_sum_exp(x): @@ -1015,21 +1018,20 @@ class MultiMelSpectrogramLoss(nn.Layer): """ def __init__( - self, - n_mels: List[int] = [150, 80], - window_lengths: List[int] = [2048, 512], - loss_fn: Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 1.0, - log_weight: float = 1.0, - pow: float = 2.0, - weight: float = 1.0, - match_stride: bool = False, - mel_fmin: List[float] = [0.0, 0.0], - mel_fmax: List[float] = [None, None], - window_type: str = None, - fs: int = 44100, - ): + self, + n_mels: List[int]=[150, 80], + window_lengths: List[int]=[2048, 512], + loss_fn: Callable=nn.L1Loss(), + clamp_eps: float=1e-5, + mag_weight: float=1.0, + log_weight: float=1.0, + pow: float=2.0, + weight: float=1.0, + match_stride: bool=False, + mel_fmin: List[float]=[0.0, 0.0], + mel_fmax: List[float]=[None, None], + window_type: str=None, + fs: int=44100, ): super().__init__() self.mel_loss_fns = [ @@ -1040,11 +1042,11 @@ def __init__( num_mels=n_mel, fmin=fmin, fmax=fmax, - eps=clamp_eps, - ) - for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, mel_fmax) + eps=clamp_eps, ) + for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, + mel_fmax) ] - + self.n_mels = n_mels self.loss_fn = loss_fn self.clamp_eps = clamp_eps @@ -1055,7 +1057,9 @@ def __init__( self.mel_fmax = mel_fmax self.pow = pow - def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): + def forward(self, + x: Union[AudioSignal, paddle.Tensor], + y: Union[AudioSignal, paddle.Tensor]): """Computes multi mel loss between an estimate and a reference signal. @@ -1077,12 +1081,21 @@ def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, pa loss += self.log_weight * mel_loss_fn(x, y) elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): s = mel_loss_fn.mel_spectrogram.stft_params - x_mels = x.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) - y_mels = y.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) + x_mels = x.mel_spectrogram( + self.n_mels[i], + mel_fmin=self.mel_fmin[i], + mel_fmax=self.mel_fmax[i], + window_length=s['n_fft'], + hop_length=s['hop_length']) + y_mels = y.mel_spectrogram( + self.n_mels[i], + mel_fmin=self.mel_fmin[i], + mel_fmax=self.mel_fmax[i], + window_length=s['n_fft'], + hop_length=s['hop_length']) loss += self.log_weight * self.loss_fn( paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(), - paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10() - ) + paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10()) else: raise ValueError('\'x\' amd \'y\' should be the same type') return loss @@ -1436,26 +1449,23 @@ class MultiScaleSTFTLoss(nn.Layer): """Multi resolution STFT loss module.""" def __init__( - self, - window_lengths: List[int] = [2048, 512], - loss_fn: Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 1.0, - log_weight: float = 1.0, - pow: float = 2.0, - weight: float = 1.0, - match_stride: bool = False, - window_type: str = 'hann', - ): + self, + window_lengths: List[int]=[2048, 512], + loss_fn: Callable=nn.L1Loss(), + clamp_eps: float=1e-5, + mag_weight: float=1.0, + log_weight: float=1.0, + pow: float=2.0, + weight: float=1.0, + match_stride: bool=False, + window_type: str='hann', ): super().__init__() self.stft_params = [ STFTParams( window_length=w, hop_length=w // 4, match_stride=match_stride, - window_type=window_type, - ) - for w in window_lengths + window_type=window_type, ) for w in window_lengths ] self.loss_fn = loss_fn self.log_weight = log_weight @@ -1464,8 +1474,9 @@ def __init__( self.weight = weight self.pow = pow - - def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): + def forward(self, + x: Union[AudioSignal, paddle.Tensor], + y: Union[AudioSignal, paddle.Tensor]): """Computes multi-scale STFT between an estimate and a reference signal. @@ -1482,11 +1493,21 @@ def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, pa Multi-scale STFT loss. """ loss = 0.0 - + for s in self.stft_params: if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): - x_mag = stft(x.reshape([-1, x.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) - y_mag = stft(y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) + x_mag = stft( + x.reshape([-1, x.shape[-1]]), + fft_size=s.window_length, + hop_length=s.hop_length, + win_length=s.window_length, + window=s.window_type) + y_mag = stft( + y.reshape([-1, y.shape[-1]]), + fft_size=s.window_length, + hop_length=s.hop_length, + win_length=s.window_length, + window=s.window_type) x_mag = x_mag.transpose([0, 2, 1]) y_mag = y_mag.transpose([0, 2, 1]) elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): @@ -1499,8 +1520,7 @@ def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, pa loss += self.log_weight * self.loss_fn( paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(), - paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), - ) + paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), ) loss += self.mag_weight * self.loss_fn(x_mag, y_mag) return loss @@ -1528,8 +1548,8 @@ def discriminator_loss(self, fake, real): loss_d = 0 for x_fake, x_real in zip(d_fake, d_real): - loss_d += paddle.mean(x_fake[-1] ** 2) - loss_d += paddle.mean((1 - x_real[-1]) ** 2) + loss_d += paddle.mean(x_fake[-1]**2) + loss_d += paddle.mean((1 - x_real[-1])**2) return loss_d def generator_loss(self, fake, real): @@ -1537,11 +1557,12 @@ def generator_loss(self, fake, real): loss_g = 0 for x_fake in d_fake: - loss_g += paddle.mean((1 - x_fake[-1]) ** 2) + loss_g += paddle.mean((1 - x_fake[-1])**2) loss_feature = 0 for i in range(len(d_fake)): for j in range(len(d_fake[i]) - 1): - loss_feature += paddle.nn.functional.l1_loss(d_fake[i][j], d_real[i][j].detach()) + loss_feature += paddle.nn.functional.l1_loss( + d_fake[i][j], d_real[i][j].detach()) return loss_g, loss_feature diff --git a/tests/unit/tts/test_losses.py b/tests/unit/tts/test_losses.py index 8b379d7472f..9480c0069cf 100644 --- a/tests/unit/tts/test_losses.py +++ b/tests/unit/tts/test_losses.py @@ -1,22 +1,39 @@ import torch -from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss, MultiMelSpectrogramLoss from paddleaudio.audiotools.core.audio_signal import AudioSignal +from paddlespeech.t2s.modules.losses import MultiMelSpectrogramLoss +from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss + + def test_dac_losses(): for i in range(10): loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt') recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav') signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav') loss_fn_1 = MultiScaleSTFTLoss() - loss_fn_2 = MultiMelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mag_weight=0.0, pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None]) + loss_fn_2 = MultiMelSpectrogramLoss( + n_mels=[5, 10, 20, 40, 80, 160, 320], + window_lengths=[32, 64, 128, 256, 512, 1024, 2048], + mag_weight=0.0, + pow=1.0, + mel_fmin=[0, 0, 0, 0, 0, 0, 0], + mel_fmax=[None, None, None, None, None, None, None]) # # Test AudioSignal # - assert abs(loss_fn_1(recons, signal).item() - loss_origin['stft/loss'].item()) < 1e-5 - assert abs(loss_fn_2(recons, signal).item() - loss_origin['mel/loss'].item()) < 1e-5 + assert abs( + loss_fn_1(recons, signal).item() - loss_origin['stft/loss'] + .item()) < 1e-5 + assert abs( + loss_fn_2(recons, signal).item() - loss_origin['mel/loss'] + .item()) < 1e-5 # # Test Tensor # - assert abs(loss_fn_1(recons.audio_data, signal.audio_data).item() - loss_origin['stft/loss'].item()) < 1e-3 - assert abs(loss_fn_2(recons.audio_data, signal.audio_data).item() - loss_origin['mel/loss'].item()) < 1e-3 + assert abs( + loss_fn_1(recons.audio_data, signal.audio_data).item() - + loss_origin['stft/loss'].item()) < 1e-3 + assert abs( + loss_fn_2(recons.audio_data, signal.audio_data).item() - + loss_origin['mel/loss'].item()) < 1e-3 From 37f60d6c2a28d19bf9d52c210e3dc3287dca18d0 Mon Sep 17 00:00:00 2001 From: suzakuwcx Date: Sun, 29 Dec 2024 00:09:09 +0800 Subject: [PATCH 3/5] t2s/modules/losses.py: Add a 'clamp_eps' parameter to dynamically adjust the clipping threshold --- paddlespeech/t2s/modules/losses.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 029ad1be10b..61dfbaf031f 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -461,7 +461,8 @@ def stft(x, win_length=None, window='hann', center=True, - pad_mode='reflect'): + pad_mode='reflect', + clamp_eps=1e-7): """Perform STFT and convert to magnitude spectrogram. Args: x(Tensor): @@ -501,7 +502,7 @@ def stft(x, real = x_stft.real() imag = x_stft.imag() - return paddle.sqrt(paddle.clip(real**2 + imag**2, min=1e-7)).transpose( + return paddle.sqrt(paddle.clip(real**2 + imag**2, min=clamp_eps)).transpose( [0, 2, 1]) @@ -1501,13 +1502,15 @@ def forward(self, fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, - window=s.window_type) + window=s.window_type, + clamp_eps=1e-5) y_mag = stft( y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, - window=s.window_type) + window=s.window_type, + clamp_eps=1e-5) x_mag = x_mag.transpose([0, 2, 1]) y_mag = y_mag.transpose([0, 2, 1]) elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): From c1a8f996f36173e4b662dc590b3909df71fb3781 Mon Sep 17 00:00:00 2001 From: suzakuwcx Date: Sun, 29 Dec 2024 13:25:01 +0800 Subject: [PATCH 4/5] tests/unit/tts/test_losses.py: Add gradient tests and update precision calculation methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Change precision threshold to ’1e-5‘ - Use relative error instead of absolute error --- tests/unit/tts/test_losses.py | 61 ++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/tests/unit/tts/test_losses.py b/tests/unit/tts/test_losses.py index 9480c0069cf..4948b20655d 100644 --- a/tests/unit/tts/test_losses.py +++ b/tests/unit/tts/test_losses.py @@ -10,6 +10,10 @@ def test_dac_losses(): loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt') recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav') signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav') + + recons.audio_data.stop_gradient = False + signal.audio_data.stop_gradient = False + loss_fn_1 = MultiScaleSTFTLoss() loss_fn_2 = MultiMelSpectrogramLoss( n_mels=[5, 10, 20, 40, 80, 160, 320], @@ -18,22 +22,57 @@ def test_dac_losses(): pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None]) + # # Test AudioSignal # + + loss_1 = loss_fn_1(recons, signal) + loss_1.backward() + loss_1_grad = signal.audio_data.grad.sum() + + assert abs((loss_1.item() - loss_origin['stft/loss'].item()) / + loss_1.item()) < 1e-5 + assert abs((loss_1_grad.item() - loss_origin['stft/grad'].sum().item()) + / loss_1_grad.item()) < 1e-5 + + signal.audio_data.clear_grad() + recons.audio_data.clear_grad() + + loss_2 = loss_fn_2(recons, signal) + loss_2.backward() + loss_2_grad = signal.audio_data.grad.sum() + + assert abs((loss_2.item() - loss_origin['mel/loss'].item()) / + loss_2.item()) < 1e-5 assert abs( - loss_fn_1(recons, signal).item() - loss_origin['stft/loss'] - .item()) < 1e-5 - assert abs( - loss_fn_2(recons, signal).item() - loss_origin['mel/loss'] - .item()) < 1e-5 + (signal.audio_data.grad.sum().item() - + loss_origin['mel/grad'].sum().item()) / loss_2_grad.item()) < 1e-5 + + signal.audio_data.clear_grad() + recons.audio_data.clear_grad() # # Test Tensor # - assert abs( - loss_fn_1(recons.audio_data, signal.audio_data).item() - - loss_origin['stft/loss'].item()) < 1e-3 - assert abs( - loss_fn_2(recons.audio_data, signal.audio_data).item() - - loss_origin['mel/loss'].item()) < 1e-3 + + loss_1 = loss_fn_1(recons.audio_data, signal.audio_data) + loss_1.backward() + loss_1_grad = signal.audio_data.grad.sum() + + assert abs(loss_1.item() - loss_origin['stft/loss'] + .item()) / loss_1.item() < 1e-5 + assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum() + .item()) / loss_1_grad.item() < 1e-5 + + signal.audio_data.clear_grad() + recons.audio_data.clear_grad() + + loss_2 = loss_fn_2(recons.audio_data, signal.audio_data) + loss_2.backward() + loss_2_grad = signal.audio_data.grad.sum() + + assert abs((loss_2.item() - loss_origin['mel/loss'].item()) / + loss_2.item()) < 1e-5 + assert abs((loss_2_grad.item() - loss_origin['mel/grad'].sum().item()) / + loss_2_grad.item()) < 1e-5 From fd5365c5b6e42f2feab0c15fb54c89c2519e18cf Mon Sep 17 00:00:00 2001 From: suzakuwcx Date: Wed, 1 Jan 2025 23:52:28 +0800 Subject: [PATCH 5/5] tests/unit/tts/test_losses.py: Add error message on assert failed --- tests/unit/tts/test_losses.py | 37 ++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/tests/unit/tts/test_losses.py b/tests/unit/tts/test_losses.py index 4948b20655d..26b1b20c213 100644 --- a/tests/unit/tts/test_losses.py +++ b/tests/unit/tts/test_losses.py @@ -31,10 +31,13 @@ def test_dac_losses(): loss_1.backward() loss_1_grad = signal.audio_data.grad.sum() - assert abs((loss_1.item() - loss_origin['stft/loss'].item()) / - loss_1.item()) < 1e-5 - assert abs((loss_1_grad.item() - loss_origin['stft/grad'].sum().item()) - / loss_1_grad.item()) < 1e-5 + assert abs( + (loss_1.item() - loss_origin['stft/loss'].item()) / + loss_1.item()) < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'" + assert abs( + (loss_1_grad.item() - loss_origin['stft/grad'].sum().item() + ) / loss_1_grad. + item()) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'" signal.audio_data.clear_grad() recons.audio_data.clear_grad() @@ -43,11 +46,13 @@ def test_dac_losses(): loss_2.backward() loss_2_grad = signal.audio_data.grad.sum() - assert abs((loss_2.item() - loss_origin['mel/loss'].item()) / - loss_2.item()) < 1e-5 + assert abs( + (loss_2.item() - loss_origin['mel/loss'].item()) / loss_2. + item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'" assert abs( (signal.audio_data.grad.sum().item() - - loss_origin['mel/grad'].sum().item()) / loss_2_grad.item()) < 1e-5 + loss_origin['mel/grad'].sum().item()) / loss_2_grad. + item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'" signal.audio_data.clear_grad() recons.audio_data.clear_grad() @@ -60,10 +65,11 @@ def test_dac_losses(): loss_1.backward() loss_1_grad = signal.audio_data.grad.sum() - assert abs(loss_1.item() - loss_origin['stft/loss'] - .item()) / loss_1.item() < 1e-5 + assert abs(loss_1.item() - loss_origin['stft/loss'].item( + )) / loss_1.item() < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'" assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum() - .item()) / loss_1_grad.item() < 1e-5 + .item()) / loss_1_grad.item( + ) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'" signal.audio_data.clear_grad() recons.audio_data.clear_grad() @@ -72,7 +78,10 @@ def test_dac_losses(): loss_2.backward() loss_2_grad = signal.audio_data.grad.sum() - assert abs((loss_2.item() - loss_origin['mel/loss'].item()) / - loss_2.item()) < 1e-5 - assert abs((loss_2_grad.item() - loss_origin['mel/grad'].sum().item()) / - loss_2_grad.item()) < 1e-5 + assert abs( + (loss_2.item() - loss_origin['mel/loss'].item()) / loss_2. + item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'" + assert abs( + (loss_2_grad.item() - loss_origin['mel/grad'].sum().item() + ) / loss_2_grad. + item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"