-
Notifications
You must be signed in to change notification settings - Fork 0
/
stft_loss.py
95 lines (72 loc) · 2.83 KB
/
stft_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""STFT-based Loss modules."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import TorchSTFT
class STFTLoss(TorchSTFT):
"""STFT loss module."""
def __init__(self, fft_size, hop_size, win_size):
"""Initialize STFT loss module."""
super().__init__(fft_size, hop_size, win_size)
def spec2mag(self, real, imag):
return torch.sqrt(real**2 + imag**2 + 1e-7)
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, t).
y (Tensor): Groundtruth signal (B, t).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
Tensor: Magnitude of `x` (B, F, T).
Tensor: Magnitude of `y` (B, F, T).
"""
x_mag = self.spec2mag(*self.stft(x))
y_mag = self.spec2mag(*self.stft(y))
sc_loss = torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
mag_loss = F.l1_loss(torch.log(x_mag), torch.log(y_mag))
return sc_loss, mag_loss, x_mag, y_mag
class MultiResolutionSTFTLoss(nn.Module):
"""Multi resolution STFT loss module."""
def __init__(
self,
fft_sizes=[128, 256, 512, 1024, 2048],
hop_sizes=[32, 64, 128, 256, 512],
win_sizes=[128, 256, 512, 1024, 2048],
):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_sizes (list): List of window lengths.
"""
super().__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_sizes)
self.stft_losses = nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_sizes):
self.stft_losses += [STFTLoss(fs, ss, wl)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Groundtruth signal (B, t).
y (Tensor): Predicted signal (B, t).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
Tensor: List of magnitude of `x` [(B, F, T), ...]
Tensor: List of magnitude of `y` [(B, F, T), ...]
"""
sc_loss, mag_loss = 0.0, 0.0
xs_mag, ys_mag = [], []
for f in self.stft_losses:
sc_l, mag_l, x_mag, y_mag = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
xs_mag.append(x_mag)
ys_mag.append(y_mag)
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss, xs_mag, ys_mag