Skip to content

Commit

Permalink
Merge pull request #896 from IAHispano/formatter/main
Browse files Browse the repository at this point in the history
chore(format): run black on main
  • Loading branch information
blaisewf authored Dec 2, 2024
2 parents 3dbcf51 + 964a65d commit 2b630e7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 29 deletions.
48 changes: 29 additions & 19 deletions rvc/train/mel_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,38 +144,39 @@ def mel_spectrogram_torch(
melspec = spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax)

return melspec



def compute_window_length(n_mels: int, sample_rate: int):
f_min = 0
f_max = sample_rate / 2
window_length_seconds = 8 * n_mels / (f_max - f_min)
window_length = int(window_length_seconds * sample_rate)
return 2**(window_length.bit_length()-1)
return 2 ** (window_length.bit_length() - 1)


class MultiScaleMelSpectrogramLoss(torch.nn.Module):

def __init__(
self,
sample_rate: int = 24000,
n_mels = [5, 10, 20, 40, 80, 160, 320, 480],

loss_fn = torch.nn.L1Loss(),
n_mels=[5, 10, 20, 40, 80, 160, 320, 480],
loss_fn=torch.nn.L1Loss(),
):
super().__init__()
self.sample_rate = sample_rate
self.loss_fn = loss_fn
self.log_base = torch.log(torch.tensor(10.0))
self.stft_params = {}
self.mel_banks = {}

window_lengths = [compute_window_length(mel, sample_rate) for mel in n_mels]
#print(window_lengths)
# print(window_lengths)

for n_mels, window_length in zip(n_mels, window_lengths):
self.stft_params[n_mels] = {
"n_mels": n_mels,
"window_length": window_length,
"hop_length": self.sample_rate // 100,
"n_mels": n_mels,
"window_length": window_length,
"hop_length": self.sample_rate // 100,
}
self.mel_banks[n_mels] = torch.from_numpy(
librosa_mel_fn(
Expand All @@ -187,22 +188,32 @@ def __init__(
)
)

def mel_spectrogram(self, wav, n_mels, window_length, hop_length,):
wav = wav.squeeze(1) # -> torch(B, T)
def mel_spectrogram(
self,
wav,
n_mels,
window_length,
hop_length,
):
wav = wav.squeeze(1) # -> torch(B, T)
window = torch.hann_window(window_length).to(wav.device).to(wav.dtype)
stft = torch.stft(
wav.float(),
n_fft=window_length,
hop_length=hop_length,
window=window,
return_complex=True,
) # -> torch (B, window_length // 2 + 1, (T - window_length)/hop_length + 1)
magnitude = torch.sqrt(stft.real.pow(2) + stft.imag.pow(2) + 1e-6)
mel_basis = self.mel_banks[n_mels].to(wav.device) # torch(n_mels, window_length // 2 + 1)
mel_spectrogram = torch.matmul(mel_basis, magnitude) # torch(B, n_mels, stft.frames)
) # -> torch (B, window_length // 2 + 1, (T - window_length)/hop_length + 1)
magnitude = torch.sqrt(stft.real.pow(2) + stft.imag.pow(2) + 1e-6)
mel_basis = self.mel_banks[n_mels].to(
wav.device
) # torch(n_mels, window_length // 2 + 1)
mel_spectrogram = torch.matmul(
mel_basis, magnitude
) # torch(B, n_mels, stft.frames)
return mel_spectrogram

def forward(self, real, fake): # real: torch(B, 1, T) , fake: torch(B, 1, T)
def forward(self, real, fake): # real: torch(B, 1, T) , fake: torch(B, 1, T)
loss = 0.0
for p in self.stft_params.values():
real_mels = self.mel_spectrogram(real, **p)
Expand All @@ -211,4 +222,3 @@ def forward(self, real, fake): # real: torch(B, 1, T) , fake: torch(B, 1, T)
fake_logmels = torch.log(fake_mels.clamp(min=1e-5).pow(1)) / self.log_base
loss += self.loss_fn(real_logmels, fake_logmels)
return loss

30 changes: 20 additions & 10 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
generator_loss,
kl_loss,
)
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch, MultiScaleMelSpectrogramLoss
from mel_processing import (
mel_spectrogram_torch,
spec_to_mel_torch,
MultiScaleMelSpectrogramLoss,
)

from rvc.train.process.extract_model import extract_model

Expand Down Expand Up @@ -377,7 +381,7 @@ def run(
betas=config.train.betas,
eps=config.train.eps,
)

fn_mel_loss = MultiScaleMelSpectrogramLoss(sample_rate=sample_rate)

# Wrap models with DDP for multi-gpu processing
Expand Down Expand Up @@ -491,7 +495,7 @@ def run(
custom_total_epoch,
device,
reference,
fn_mel_loss
fn_mel_loss,
)

scheduler_g.step()
Expand All @@ -512,7 +516,7 @@ def train_and_evaluate(
custom_total_epoch,
device,
reference,
fn_mel_loss
fn_mel_loss,
):
"""
Trains and evaluates the model for one epoch.
Expand Down Expand Up @@ -617,13 +621,19 @@ def train_and_evaluate(
_, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
with autocast(enabled=False):
loss_mel = fn_mel_loss(wave, y_hat) * config.train.c_mel / 3.0
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
loss_kl = (
kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
)
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl

if loss_gen_all < lowest_value["value"]:
lowest_value = {"step": global_step, "value": loss_gen_all, "epoch": epoch}
lowest_value = {
"step": global_step,
"value": loss_gen_all,
"epoch": epoch,
}

optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
Expand All @@ -645,7 +655,7 @@ def train_and_evaluate(
config.data.sample_rate,
config.data.mel_fmin,
config.data.mel_fmax,
)
)
# used for tensorboard chart - slice/mel_org
y_mel = commons.slice_segments(
mel,
Expand All @@ -664,10 +674,10 @@ def train_and_evaluate(
config.data.win_length,
config.data.mel_fmin,
config.data.mel_fmax,
)
)
if use_amp:
y_hat_mel = y_hat_mel.half()
y_hat_mel = y_hat_mel.half()

lr = optim_g.param_groups[0]["lr"]
if loss_mel > 75:
loss_mel = 75
Expand Down

0 comments on commit 2b630e7

Please sign in to comment.