diff --git a/rvc/train/mel_processing.py b/rvc/train/mel_processing.py index e4a8110d..b3bb74f8 100644 --- a/rvc/train/mel_processing.py +++ b/rvc/train/mel_processing.py @@ -144,22 +144,23 @@ def mel_spectrogram_torch( melspec = spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax) return melspec - + + def compute_window_length(n_mels: int, sample_rate: int): f_min = 0 f_max = sample_rate / 2 window_length_seconds = 8 * n_mels / (f_max - f_min) window_length = int(window_length_seconds * sample_rate) - return 2**(window_length.bit_length()-1) + return 2 ** (window_length.bit_length() - 1) + class MultiScaleMelSpectrogramLoss(torch.nn.Module): def __init__( self, sample_rate: int = 24000, - n_mels = [5, 10, 20, 40, 80, 160, 320, 480], - - loss_fn = torch.nn.L1Loss(), + n_mels=[5, 10, 20, 40, 80, 160, 320, 480], + loss_fn=torch.nn.L1Loss(), ): super().__init__() self.sample_rate = sample_rate @@ -167,15 +168,15 @@ def __init__( self.log_base = torch.log(torch.tensor(10.0)) self.stft_params = {} self.mel_banks = {} - + window_lengths = [compute_window_length(mel, sample_rate) for mel in n_mels] - #print(window_lengths) - + # print(window_lengths) + for n_mels, window_length in zip(n_mels, window_lengths): self.stft_params[n_mels] = { - "n_mels": n_mels, - "window_length": window_length, - "hop_length": self.sample_rate // 100, + "n_mels": n_mels, + "window_length": window_length, + "hop_length": self.sample_rate // 100, } self.mel_banks[n_mels] = torch.from_numpy( librosa_mel_fn( @@ -187,8 +188,14 @@ def __init__( ) ) - def mel_spectrogram(self, wav, n_mels, window_length, hop_length,): - wav = wav.squeeze(1) # -> torch(B, T) + def mel_spectrogram( + self, + wav, + n_mels, + window_length, + hop_length, + ): + wav = wav.squeeze(1) # -> torch(B, T) window = torch.hann_window(window_length).to(wav.device).to(wav.dtype) stft = torch.stft( wav.float(), @@ -196,13 +203,17 @@ def mel_spectrogram(self, wav, n_mels, window_length, hop_length,): hop_length=hop_length, window=window, return_complex=True, - ) # -> torch (B, window_length // 2 + 1, (T - window_length)/hop_length + 1) - magnitude = torch.sqrt(stft.real.pow(2) + stft.imag.pow(2) + 1e-6) - mel_basis = self.mel_banks[n_mels].to(wav.device) # torch(n_mels, window_length // 2 + 1) - mel_spectrogram = torch.matmul(mel_basis, magnitude) # torch(B, n_mels, stft.frames) + ) # -> torch (B, window_length // 2 + 1, (T - window_length)/hop_length + 1) + magnitude = torch.sqrt(stft.real.pow(2) + stft.imag.pow(2) + 1e-6) + mel_basis = self.mel_banks[n_mels].to( + wav.device + ) # torch(n_mels, window_length // 2 + 1) + mel_spectrogram = torch.matmul( + mel_basis, magnitude + ) # torch(B, n_mels, stft.frames) return mel_spectrogram - def forward(self, real, fake): # real: torch(B, 1, T) , fake: torch(B, 1, T) + def forward(self, real, fake): # real: torch(B, 1, T) , fake: torch(B, 1, T) loss = 0.0 for p in self.stft_params.values(): real_mels = self.mel_spectrogram(real, **p) @@ -211,4 +222,3 @@ def forward(self, real, fake): # real: torch(B, 1, T) , fake: torch(B, 1, T) fake_logmels = torch.log(fake_mels.clamp(min=1e-5).pow(1)) / self.log_base loss += self.loss_fn(real_logmels, fake_logmels) return loss - diff --git a/rvc/train/train.py b/rvc/train/train.py index 6f615038..c1f64fbe 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -43,7 +43,11 @@ generator_loss, kl_loss, ) -from mel_processing import mel_spectrogram_torch, spec_to_mel_torch, MultiScaleMelSpectrogramLoss +from mel_processing import ( + mel_spectrogram_torch, + spec_to_mel_torch, + MultiScaleMelSpectrogramLoss, +) from rvc.train.process.extract_model import extract_model @@ -377,7 +381,7 @@ def run( betas=config.train.betas, eps=config.train.eps, ) - + fn_mel_loss = MultiScaleMelSpectrogramLoss(sample_rate=sample_rate) # Wrap models with DDP for multi-gpu processing @@ -491,7 +495,7 @@ def run( custom_total_epoch, device, reference, - fn_mel_loss + fn_mel_loss, ) scheduler_g.step() @@ -512,7 +516,7 @@ def train_and_evaluate( custom_total_epoch, device, reference, - fn_mel_loss + fn_mel_loss, ): """ Trains and evaluates the model for one epoch. @@ -617,13 +621,19 @@ def train_and_evaluate( _, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) with autocast(enabled=False): loss_mel = fn_mel_loss(wave, y_hat) * config.train.c_mel / 3.0 - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl + loss_kl = ( + kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl + ) loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, losses_gen = generator_loss(y_d_hat_g) loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl if loss_gen_all < lowest_value["value"]: - lowest_value = {"step": global_step, "value": loss_gen_all, "epoch": epoch} + lowest_value = { + "step": global_step, + "value": loss_gen_all, + "epoch": epoch, + } optim_g.zero_grad() scaler.scale(loss_gen_all).backward() @@ -645,7 +655,7 @@ def train_and_evaluate( config.data.sample_rate, config.data.mel_fmin, config.data.mel_fmax, - ) + ) # used for tensorboard chart - slice/mel_org y_mel = commons.slice_segments( mel, @@ -664,10 +674,10 @@ def train_and_evaluate( config.data.win_length, config.data.mel_fmin, config.data.mel_fmax, - ) + ) if use_amp: - y_hat_mel = y_hat_mel.half() - + y_hat_mel = y_hat_mel.half() + lr = optim_g.param_groups[0]["lr"] if loss_mel > 75: loss_mel = 75