From 73bf6befa04e40ad3e800ee70c1b35dc39e7a410 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 6 Nov 2024 03:10:17 +0200 Subject: [PATCH] add arrow dataset --- src/f5_tts/train/finetune_gradio.py | 81 +++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index bea62d368..77c637721 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -22,6 +22,7 @@ import numpy as np import torch import torchaudio +from datasets import load_dataset from datasets import Dataset as Dataset_ from datasets.arrow_writer import ArrowWriter from safetensors.torch import save_file @@ -31,6 +32,7 @@ from f5_tts.api import F5TTS from f5_tts.model.utils import convert_char_to_pinyin from importlib.resources import files +import soundfile as sf training_process = None system = platform.system() @@ -44,6 +46,7 @@ path_data = str(files("f5_tts").joinpath("../../data")) path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts")) + file_train = "src/f5_tts/train/finetune_cli.py" device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" @@ -786,6 +789,65 @@ def has_supported_extension(file_name): return file_audio +def get_nested_value(data, format): + keys = format.split("/") + + item = data + for key in keys: + item = item.get(key) + if item is None: + return None + + return item + + +def create_metadata_from_arrow( + name_project, arrow_type, arrow_path, arrow_name, arrow_text, arrow_audio, arrow_split, progress=gr.Progress() +): + path_project = os.path.join(path_data, name_project) + path_project_wavs = os.path.join(path_project, "wavs") + file_metadata = os.path.join(path_project, "metadata.csv") + file_custom_dataset_dir = os.path.join(path_project, "custom_dataset_dir") + os.makedirs(file_custom_dataset_dir, exist_ok=True) + os.makedirs(path_project_wavs, exist_ok=True) + data = "" + num = 0 + if arrow_type == "Local" or arrow_type == "Online": + if arrow_type == "locals": + dataset = Dataset_.from_file(arrow_path) + + if arrow_type == "Online": + if arrow_split == "": + arrow_split = None + if arrow_name == "": + arrow_name = None + dataset = load_dataset(arrow_path, arrow_name, split=arrow_split, cache_dir=file_custom_dataset_dir) + + is_audio_path = None + for item in progress.tqdm(dataset): + text = get_nested_value(item, arrow_text) + audio = get_nested_value(item, arrow_audio) + + if is_audio_path is None: + if isinstance(audio, str): + is_audio_path = True + else: + is_audio_path = False + + if not is_audio_path: + namefile = "segment_{num}" + filename = os.path.join(path_project_wavs, namefile + ".wav") + sf.write(filename, audio, 24000) + num += 1 + else: + filename = audio + + data += f"{filename}|{text}\n" + + with open(file_metadata, "w", encoding="utf-sig") as f: + f.write(data) + + def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): path_project = os.path.join(path_data, name_project) path_project_wavs = os.path.join(path_project, "wavs") @@ -1505,6 +1567,18 @@ def get_audio_select(file_sample): Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt ```""") + with gr.Accordion("Dataset Arrow", open=False): + arrow_type = gr.Radio(label="Type", choices=["Local", "Online"], value="Local") + with gr.Row(): + arrow_path = gr.Textbox(label="Path", value="") + arrow_name = gr.Textbox(label="Name", value="") + arrow_split = gr.Textbox(label="Split", value="") + + with gr.Row(): + arrow_text = gr.Textbox(label="Text", value="audio/array") + arrow_audio = gr.Textbox(label="Audio", value="transcript") + bt_covert_metadata = bt_create = gr.Button("Create Metadata") + gr.Markdown( """```plaintext Place all your "wavs" folder and your "metadata.csv" file in your project name directory. @@ -1530,6 +1604,7 @@ def get_audio_select(file_sample): ```""" ) ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False) + bt_prepare = bt_create = gr.Button("Prepare") txt_info_prepare = gr.Text(label="Info", value="") txt_vocab_prepare = gr.Text(label="Vocab", value="") @@ -1538,6 +1613,12 @@ def get_audio_select(file_sample): fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare] ) + bt_covert_metadata.click( + fn=create_metadata_from_arrow, + inputs=[cm_project, arrow_type, arrow_path, arrow_name, arrow_text, arrow_audio, arrow_split], + outputs=[], + ) + random_sample_prepare = gr.Button("Random Sample") with gr.Row():