Skip to content

Commit

Permalink
Merge pull request SWivid#359 from lpscr/main
Browse files Browse the repository at this point in the history
small update gradio finetune
  • Loading branch information
SWivid authored Nov 1, 2024
2 parents 2a3deaa + b664bc7 commit 11d2886
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
3 changes: 1 addition & 2 deletions src/f5_tts/train/datasets/prepare_csv_wavs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def prepare_csv_wavs_dir(input_dir):

def get_audio_duration(audio_path):
audio, sample_rate = torchaudio.load(audio_path)
num_channels = audio.shape[0]
return audio.shape[1] / (sample_rate * num_channels)
return audio.shape[1] / sample_rate


def read_audio_text_pairs(csv_file_path):
Expand Down
17 changes: 10 additions & 7 deletions src/f5_tts/train/finetune_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,9 @@ def load_settings(project_name):

# Load metadata
def get_audio_duration(audio_path):
"""Calculate the duration of an audio file."""
"""Calculate the duration mono of an audio file."""
audio, sample_rate = torchaudio.load(audio_path)
num_channels = audio.shape[0]
return audio.shape[1] / (sample_rate * num_channels)
return audio.shape[1] / sample_rate


def clear_text(text):
Expand Down Expand Up @@ -383,13 +382,17 @@ def start_training(
stream=False,
logger="wandb",
):
global training_process, tts_api, stop_signal
global training_process, tts_api, stop_signal, pipe

if tts_api is not None:
del tts_api
if tts_api is not None or pipe is not None:
if tts_api is not None:
del tts_api
if pipe is not None:
del pipe
gc.collect()
torch.cuda.empty_cache()
tts_api = None
pipe = None

path_project = os.path.join(path_data, dataset_name)

Expand Down Expand Up @@ -1557,7 +1560,7 @@ def get_audio_select(file_sample):
last_per_steps = gr.Number(label="Last per Steps", value=100)

with gr.Row():
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False)
Expand Down

0 comments on commit 11d2886

Please sign in to comment.