Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 8th No.9】在 PaddleSpeech 中复现 DAC 训练需要用到的 loss #3954

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 245 additions & 2 deletions paddlespeech/t2s/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable
from typing import List
from typing import Tuple
from typing import Union

import librosa
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddleaudio.audiotools import AudioSignal
from paddleaudio.audiotools import STFTParams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

miss SISDRLoss?

from scipy import signal
from scipy.stats import betabinom
from typeguard import check_argument_types
Expand Down Expand Up @@ -456,7 +461,8 @@ def stft(x,
win_length=None,
window='hann',
center=True,
pad_mode='reflect'):
pad_mode='reflect',
clamp_eps=1e-7):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x(Tensor):
Expand Down Expand Up @@ -496,7 +502,7 @@ def stft(x,
real = x_stft.real()
imag = x_stft.imag()

return paddle.sqrt(paddle.clip(real**2 + imag**2, min=1e-7)).transpose(
return paddle.sqrt(paddle.clip(real**2 + imag**2, min=clamp_eps)).transpose(
[0, 2, 1])


Expand Down Expand Up @@ -984,6 +990,118 @@ def forward(self, y_hat, y):
return mel_loss


class MultiMelSpectrogramLoss(nn.Layer):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.

Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [150, 80],
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False

Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
"""

def __init__(
self,
n_mels: List[int]=[150, 80],
window_lengths: List[int]=[2048, 512],
loss_fn: Callable=nn.L1Loss(),
clamp_eps: float=1e-5,
mag_weight: float=1.0,
log_weight: float=1.0,
pow: float=2.0,
weight: float=1.0,
match_stride: bool=False,
mel_fmin: List[float]=[0.0, 0.0],
mel_fmax: List[float]=[None, None],
window_type: str=None,
fs: int=44100, ):
super().__init__()

self.mel_loss_fns = [
MelSpectrogramLoss(
fs=fs,
fft_size=w,
hop_size=w // 4,
num_mels=n_mel,
fmin=fmin,
fmax=fmax,
eps=clamp_eps, )
for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin,
mel_fmax)
]

self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow

def forward(self,
x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi mel loss between an estimate and a reference
signal.

Parameters
----------
x : AudioSignal or Tensor
Estimate signal
y : AudioSignal or Tensor
Reference signal

Returns
-------
paddle.Tensor
Mel loss.
"""
loss = 0.0
for i, mel_loss_fn in enumerate(self.mel_loss_fns):
if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor):
loss += self.log_weight * mel_loss_fn(x, y)
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
s = mel_loss_fn.mel_spectrogram.stft_params
x_mels = x.mel_spectrogram(
self.n_mels[i],
mel_fmin=self.mel_fmin[i],
mel_fmax=self.mel_fmax[i],
window_length=s['n_fft'],
hop_length=s['hop_length'])
y_mels = y.mel_spectrogram(
self.n_mels[i],
mel_fmin=self.mel_fmin[i],
mel_fmax=self.mel_fmax[i],
window_length=s['n_fft'],
hop_length=s['hop_length'])
loss += self.log_weight * self.loss_fn(
paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10())
else:
raise ValueError('\'x\' amd \'y\' should be the same type')
return loss


class FeatureMatchLoss(nn.Layer):
"""Feature matching loss module."""

Expand Down Expand Up @@ -1326,3 +1444,128 @@ def _generate_prior(self, text_lengths, feats_lengths,
bb_prior[bidx, :T, :N] = prob

return bb_prior


class MultiScaleSTFTLoss(nn.Layer):
"""Multi resolution STFT loss module."""

def __init__(
self,
window_lengths: List[int]=[2048, 512],
loss_fn: Callable=nn.L1Loss(),
clamp_eps: float=1e-5,
mag_weight: float=1.0,
log_weight: float=1.0,
pow: float=2.0,
weight: float=1.0,
match_stride: bool=False,
window_type: str='hann', ):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type, ) for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
self.mag_weight = mag_weight
self.clamp_eps = clamp_eps
self.weight = weight
self.pow = pow

def forward(self,
x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi-scale STFT between an estimate and a reference
signal.

Parameters
----------
x : AudioSignal or Tensor
Estimate signal
y : AudioSignal or Tensor
Reference signal

Returns
-------
paddle.Tensor
Multi-scale STFT loss.
"""
loss = 0.0

for s in self.stft_params:
if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor):
x_mag = stft(
x.reshape([-1, x.shape[-1]]),
fft_size=s.window_length,
hop_length=s.hop_length,
win_length=s.window_length,
window=s.window_type,
clamp_eps=1e-5)
y_mag = stft(
y.reshape([-1, y.shape[-1]]),
fft_size=s.window_length,
hop_length=s.hop_length,
win_length=s.window_length,
window=s.window_type,
clamp_eps=1e-5)
x_mag = x_mag.transpose([0, 2, 1])
y_mag = y_mag.transpose([0, 2, 1])
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)
x_mag = x.magnitude
y_mag = y.magnitude
else:
raise ValueError('\'x\' amd \'y\' should be the same type')

loss += self.log_weight * self.loss_fn(
paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), )

loss += self.mag_weight * self.loss_fn(x_mag, y_mag)
return loss


class GANLoss(nn.Layer):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""

def __init__(self, discriminator):
super().__init__()
self.discriminator = discriminator

def forward(self, fake, real):
d_fake = self.discriminator(fake.audio_data)
d_real = self.discriminator(real.audio_data)
return d_fake, d_real

def discriminator_loss(self, fake, real):
d_fake, d_real = self.forward(fake.clone().detach(), real)

loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += paddle.mean(x_fake[-1]**2)
loss_d += paddle.mean((1 - x_real[-1])**2)
return loss_d

def generator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)

loss_g = 0
for x_fake in d_fake:
loss_g += paddle.mean((1 - x_fake[-1])**2)

loss_feature = 0

for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += paddle.nn.functional.l1_loss(
d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature
87 changes: 87 additions & 0 deletions tests/unit/tts/test_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
from paddleaudio.audiotools.core.audio_signal import AudioSignal

from paddlespeech.t2s.modules.losses import MultiMelSpectrogramLoss
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss


def test_dac_losses():
for i in range(10):
loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt')
recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav')
signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav')

recons.audio_data.stop_gradient = False
signal.audio_data.stop_gradient = False

loss_fn_1 = MultiScaleSTFTLoss()
loss_fn_2 = MultiMelSpectrogramLoss(
n_mels=[5, 10, 20, 40, 80, 160, 320],
window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
mag_weight=0.0,
pow=1.0,
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
mel_fmax=[None, None, None, None, None, None, None])

#
# Test AudioSignal
#

loss_1 = loss_fn_1(recons, signal)
loss_1.backward()
loss_1_grad = signal.audio_data.grad.sum()

assert abs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest use np.testing.assert_allclose, these losses can pass 1e-6?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently not. After debugging, I find out that the loss is generated by 'paddle.signal.stft' (without cuda), so I have to compare the implement with '_VF' and paddle. I'm sure that the loss can decrease to 0 if fixing this

(loss_1.item() - loss_origin['stft/loss'].item()) /
loss_1.item()) < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'"
assert abs(
(loss_1_grad.item() - loss_origin['stft/grad'].sum().item()
) / loss_1_grad.
item()) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'"

signal.audio_data.clear_grad()
recons.audio_data.clear_grad()

loss_2 = loss_fn_2(recons, signal)
loss_2.backward()
loss_2_grad = signal.audio_data.grad.sum()

assert abs(
(loss_2.item() - loss_origin['mel/loss'].item()) / loss_2.
item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'"
assert abs(
(signal.audio_data.grad.sum().item() -
loss_origin['mel/grad'].sum().item()) / loss_2_grad.
item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"

signal.audio_data.clear_grad()
recons.audio_data.clear_grad()

#
# Test Tensor
#

loss_1 = loss_fn_1(recons.audio_data, signal.audio_data)
loss_1.backward()
loss_1_grad = signal.audio_data.grad.sum()

assert abs(loss_1.item() - loss_origin['stft/loss'].item(
)) / loss_1.item() < 1e-5, r"value incorrect for 'MultiScaleSTFTLoss'"
assert abs(loss_1_grad.item() - loss_origin['stft/grad'].sum()
.item()) / loss_1_grad.item(
) < 1e-5, r"gradient incorrect for 'MultiScaleSTFTLoss'"

signal.audio_data.clear_grad()
recons.audio_data.clear_grad()

loss_2 = loss_fn_2(recons.audio_data, signal.audio_data)
loss_2.backward()
loss_2_grad = signal.audio_data.grad.sum()

assert abs(
(loss_2.item() - loss_origin['mel/loss'].item()) / loss_2.
item()) < 1e-5, r"value incorrect for 'MultiMelSpectrogramLoss'"
assert abs(
(loss_2_grad.item() - loss_origin['mel/grad'].sum().item()
) / loss_2_grad.
item()) < 1e-5, r"gradient incorrect for 'MultiMelSpectrogramLoss'"