diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 85c1cb2a4..e7a223e40 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -190,7 +190,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef vocoder = load_vocoder(vocoder_name=self.vocoder_name) - target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate + target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate log_samples_path = f"{self.checkpoint_path}/samples" os.makedirs(log_samples_path, exist_ok=True)