forked from xcmyz/FastSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
29 lines (21 loc) · 1016 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
class FastSpeechLoss(nn.Module):
""" FastSPeech Loss """
def __init__(self):
super(FastSpeechLoss, self).__init__()
self.mse_loss = nn.MSELoss()
self.l1_loss = nn.L1Loss()
def forward(self, mel, mel_postnet, duration_predicted, mel_target, duration_predictor_target):
mel_target.requires_grad = False
mel_loss = self.mse_loss(mel, mel_target)
mel_postnet_loss = self.mse_loss(mel_postnet, mel_target)
duration_predictor_target.requires_grad = False
# duration_predictor_target = duration_predictor_target + 1
# duration_predictor_target = torch.log(
# duration_predictor_target.float())
# print(duration_predictor_target)
# print(duration_predicted)
duration_predictor_loss = self.l1_loss(
duration_predicted, duration_predictor_target.float())
return mel_loss, mel_postnet_loss, duration_predictor_loss