diff --git a/rvc/train/train.py b/rvc/train/train.py index cf1c3485..5b6a9b4a 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -125,7 +125,9 @@ def verify_checkpoint_shapes(checkpoint_path, model): else: model_state_dict = model.load_state_dict(checkpoint_state_dict) except RuntimeError: - print("The parameters of the pretrain model such as the sample rate or architecture do not match the selected model.") + print( + "The parameters of the pretrain model such as the sample rate or architecture do not match the selected model." + ) sys.exit(1) else: del checkpoint @@ -324,8 +326,8 @@ def run( DistributedBucketSampler, TextAudioCollateMultiNSFsid, TextAudioLoaderMultiNSFsid, - ) - + ) + train_dataset = TextAudioLoaderMultiNSFsid(config.data) collate_fn = TextAudioCollateMultiNSFsid() train_sampler = DistributedBucketSampler( @@ -351,8 +353,8 @@ def run( # Initialize models and optimizers from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminator from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminatorV2 - from rvc.lib.algorithm.synthesizers import Synthesizer - + from rvc.lib.algorithm.synthesizers import Synthesizer + net_g = Synthesizer( config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, @@ -468,6 +470,7 @@ def run( scheduler_g.step() scheduler_d.step() + def train_and_evaluate( rank, epoch, @@ -537,17 +540,31 @@ def train_and_evaluate( info = [tensor.cuda(rank, non_blocking=True) for tensor in info] elif device.type != "cuda": info = [tensor.to(device) for tensor in info] - # else iterator is going thru a cached list with a device already assigned - - phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid = info + # else iterator is going thru a cached list with a device already assigned + + ( + phone, + phone_lengths, + pitch, + pitchf, + spec, + spec_lengths, + wave, + wave_lengths, + sid, + ) = info pitch = pitch if pitch_guidance else None pitchf = pitchf if pitch_guidance else None # Forward pass use_amp = config.train.fp16_run and device.type == "cuda" with autocast(enabled=use_amp): - model_output = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) - y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = model_output + model_output = net_g( + phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid + ) + y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = ( + model_output + ) # used for tensorboard chart - all/mel mel = spec_to_mel_torch( spec,