diff --git a/src/f5_tts/train/datasets/prepare_csv_wavs.py b/src/f5_tts/train/datasets/prepare_csv_wavs.py index f39001a6c..dd51ef098 100644 --- a/src/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/src/f5_tts/train/datasets/prepare_csv_wavs.py @@ -54,8 +54,7 @@ def prepare_csv_wavs_dir(input_dir): def get_audio_duration(audio_path): audio, sample_rate = torchaudio.load(audio_path) - num_channels = audio.shape[0] - return audio.shape[1] / (sample_rate * num_channels) + return audio.shape[1] / sample_rate def read_audio_text_pairs(csv_file_path): diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 007dad828..9b46b0102 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -172,10 +172,9 @@ def load_settings(project_name): # Load metadata def get_audio_duration(audio_path): - """Calculate the duration of an audio file.""" + """Calculate the duration mono of an audio file.""" audio, sample_rate = torchaudio.load(audio_path) - num_channels = audio.shape[0] - return audio.shape[1] / (sample_rate * num_channels) + return audio.shape[1] / sample_rate def clear_text(text): @@ -383,13 +382,17 @@ def start_training( stream=False, logger="wandb", ): - global training_process, tts_api, stop_signal + global training_process, tts_api, stop_signal, pipe - if tts_api is not None: - del tts_api + if tts_api is not None or pipe is not None: + if tts_api is not None: + del tts_api + if pipe is not None: + del pipe gc.collect() torch.cuda.empty_cache() tts_api = None + pipe = None path_project = os.path.join(path_data, dataset_name) @@ -1557,7 +1560,7 @@ def get_audio_select(file_sample): last_per_steps = gr.Number(label="Last per Steps", value=100) with gr.Row(): - mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none") + mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none") cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb") start_button = gr.Button("Start Training") stop_button = gr.Button("Stop Training", interactive=False)