From f2f939646f22fb6302f5c78434a72e98d096cb81 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 24 Oct 2024 02:37:51 +0300 Subject: [PATCH] update gradio more easy --- finetune-cli.py | 35 +++++-- finetune_gradio.py | 225 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 205 insertions(+), 55 deletions(-) diff --git a/finetune-cli.py b/finetune-cli.py index bc11ee2cf..ba4924bfc 100644 --- a/finetune-cli.py +++ b/finetune-cli.py @@ -14,26 +14,35 @@ # -------------------------- Argument Parsing --------------------------- # def parse_args(): + # batch_size_per_gpu = 1000 settting for gpu 8GB + # batch_size_per_gpu = 1600 settting for gpu 12GB + # batch_size_per_gpu = 2000 settting for gpu 16GB + # batch_size_per_gpu = 3200 settting for gpu 24GB + + # num_warmup_updates 10000 sample = 500 + + # change save_per_updates , last_per_steps what you need + parser = argparse.ArgumentParser(description="Train CFM Model") parser.add_argument( "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name" ) parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") - parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training") - parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU") + parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training") + parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU") parser.add_argument( "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type" ) - parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch") + parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch") parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs") - parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps") - parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps") - parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps") + parser.add_argument("--num_warmup_updates", type=int, default=500, help="Warmup steps") + parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps") + parser.add_argument("--last_per_steps", type=int, default=20000, help="Save last checkpoint every X steps") parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune") - + parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune") parser.add_argument( "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type" ) @@ -59,14 +68,19 @@ def main(): model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) if args.finetune: - ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + if args.pretrain == "": + ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + else: + ckpt_path = args.pretrain elif args.exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if args.finetune: - ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) - + if args.pretrain == "": + ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + else: + ckpt_path = args.pretrain if args.finetune: path_ckpt = os.path.join("ckpts", args.dataset_name) if not os.path.isdir(path_ckpt): @@ -117,6 +131,7 @@ def main(): ) train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) + trainer.train( train_dataset, resumable_with_seed=666, # seed for shuffling dataset diff --git a/finetune_gradio.py b/finetune_gradio.py index a61d95ac2..724cf032f 100644 --- a/finetune_gradio.py +++ b/finetune_gradio.py @@ -25,7 +25,7 @@ from datasets.arrow_writer import ArrowWriter from datasets import Dataset as Dataset_ from api import F5TTS - +from safetensors.torch import save_file training_process = None system = platform.system() @@ -247,6 +247,9 @@ def start_training( save_per_updates=400, last_per_steps=800, finetune=True, + file_checkpoint_train="", + tokenizer_type="pinyin", + tokenizer_file="", ): global training_process, tts_api @@ -256,7 +259,7 @@ def start_training( torch.cuda.empty_cache() tts_api = None - path_project = os.path.join(path_data, dataset_name + "_pinyin") + path_project = os.path.join(path_data, dataset_name) if not os.path.isdir(path_project): yield ( @@ -295,6 +298,13 @@ def start_training( if finetune: cmd += f" --finetune {finetune}" + if file_checkpoint_train != "": + cmd += f" --file_checkpoint_train {file_checkpoint_train}" + + if tokenizer_file != "": + cmd += f" --tokenizer_path {tokenizer_file}" + cmd += f" --tokenizer {tokenizer_type} " + print(cmd) try: @@ -331,10 +341,28 @@ def stop_training(): return "train stop", gr.update(interactive=True), gr.update(interactive=False) -def create_data_project(name): - name += "_pinyin" +def get_list_projects(): + project_list = [] + for folder in os.listdir("data"): + path_folder = os.path.join("data", folder) + if not os.path.isdir(path_folder): + continue + folder = folder.lower() + if folder == "emilia_zh_en_pinyin": + continue + project_list.append(folder) + + projects_selelect = None if not project_list else project_list[-1] + + return project_list, projects_selelect + + +def create_data_project(name, tokenizer_type): + name += "_" + tokenizer_type os.makedirs(os.path.join(path_data, name), exist_ok=True) os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True) + project_list, projects_selelect = get_list_projects() + return gr.update(choices=project_list, value=name) def transcribe(file_audio, language="english"): @@ -359,14 +387,14 @@ def transcribe(file_audio, language="english"): def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()): - name_project += "_pinyin" path_project = os.path.join(path_data, name_project) path_dataset = os.path.join(path_project, "dataset") path_project_wavs = os.path.join(path_project, "wavs") file_metadata = os.path.join(path_project, "metadata.csv") - if audio_files is None: - return "You need to load an audio file." + if user == False: + if audio_files is None: + return "You need to load an audio file." if os.path.isdir(path_project_wavs): shutil.rmtree(path_project_wavs) @@ -418,7 +446,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr. except: # noqa: E722 error_num += 1 - with open(file_metadata, "w", encoding="utf-8") as f: + with open(file_metadata, "w", encoding="utf-8-sig") as f: f.write(data) if error_num != []: @@ -437,7 +465,6 @@ def format_seconds_to_hms(seconds): def create_metadata(name_project, progress=gr.Progress()): - name_project += "_pinyin" path_project = os.path.join(path_data, name_project) path_project_wavs = os.path.join(path_project, "wavs") file_metadata = os.path.join(path_project, "metadata.csv") @@ -448,7 +475,7 @@ def create_metadata(name_project, progress=gr.Progress()): if not os.path.isfile(file_metadata): return "The file was not found in " + file_metadata - with open(file_metadata, "r", encoding="utf-8") as f: + with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() audio_path_list = [] @@ -499,7 +526,7 @@ def create_metadata(name_project, progress=gr.Progress()): for line in progress.tqdm(result, total=len(result), desc="prepare data"): writer.write(line) - with open(file_duration, "w", encoding="utf-8") as f: + with open(file_duration, "w") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt" @@ -529,7 +556,6 @@ def calculate_train( last_per_steps, finetune, ): - name_project += "_pinyin" path_project = os.path.join(path_data, name_project) file_duraction = os.path.join(path_project, "duration.json") @@ -548,8 +574,8 @@ def calculate_train( data = json.load(file) duration_list = data["duration"] - samples = len(duration_list) + hours = sum(duration_list) / 3600 if torch.cuda.is_available(): gpu_properties = torch.cuda.get_device_properties(0) @@ -583,34 +609,67 @@ def calculate_train( save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates) last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps) + total_hours = hours + mel_hop_length = 256 + mel_sampling_rate = 24000 + + # target + wanted_max_updates = 1000000 + + # train params + gpus = 1 + frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200 + grad_accum = 1 + + # intermediate + mini_batch_frames = frames_per_gpu * grad_accum * gpus + mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600 + updates_per_epoch = total_hours / mini_batch_hours + steps_per_epoch = updates_per_epoch * grad_accum + epochs = wanted_max_updates / updates_per_epoch + if finetune: learning_rate = 1e-5 else: learning_rate = 7.5e-5 - return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate + return ( + batch_size_per_gpu, + max_samples, + num_warmup_updates, + save_per_updates, + last_per_steps, + samples, + learning_rate, + int(epochs), + ) -def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None: +def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str: try: checkpoint = torch.load(checkpoint_path) print("Original Checkpoint Keys:", checkpoint.keys()) ema_model_state_dict = checkpoint.get("ema_model_state_dict", None) + if ema_model_state_dict is None: + return "No 'ema_model_state_dict' found in the checkpoint." - if ema_model_state_dict is not None: + if safetensors: + new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors") + save_file(ema_model_state_dict, new_checkpoint_path) + else: + new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt") new_checkpoint = {"ema_model_state_dict": ema_model_state_dict} torch.save(new_checkpoint, new_checkpoint_path) - return f"New checkpoint saved at: {new_checkpoint_path}" - else: - return "No 'ema_model_state_dict' found in the checkpoint." + + return f"New checkpoint saved at: {new_checkpoint_path}" except Exception as e: return f"An error occurred: {e}" def vocab_check(project_name): - name_project = project_name + "_pinyin" + name_project = project_name path_project = os.path.join(path_data, name_project) file_metadata = os.path.join(path_project, "metadata.csv") @@ -619,15 +678,15 @@ def vocab_check(project_name): if not os.path.isfile(file_vocab): return f"the file {file_vocab} not found !" - with open(file_vocab, "r", encoding="utf-8") as f: + with open(file_vocab, "r", encoding="utf-8-sig") as f: data = f.read() - - vocab = data.split("\n") + vocab = data.split("\n") + vocab = set(vocab) if not os.path.isfile(file_metadata): return f"the file {file_metadata} not found !" - with open(file_metadata, "r", encoding="utf-8") as f: + with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() miss_symbols = [] @@ -652,7 +711,7 @@ def vocab_check(project_name): def get_random_sample_prepare(project_name): - name_project = project_name + "_pinyin" + name_project = project_name path_project = os.path.join(path_data, name_project) file_arrow = os.path.join(path_project, "raw.arrow") if not os.path.isfile(file_arrow): @@ -665,14 +724,14 @@ def get_random_sample_prepare(project_name): def get_random_sample_transcribe(project_name): - name_project = project_name + "_pinyin" + name_project = project_name path_project = os.path.join(path_data, name_project) file_metadata = os.path.join(path_project, "metadata.csv") if not os.path.isfile(file_metadata): return "", None data = "" - with open(file_metadata, "r", encoding="utf-8") as f: + with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() list_data = [] @@ -703,13 +762,14 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): global last_checkpoint, last_device, tts_api if not os.path.isfile(file_checkpoint): - return None + return None, "checkpoint not found!" if training_process is not None: device_test = "cpu" else: device_test = None + device_test = "cpu" if last_checkpoint != file_checkpoint or last_device != device_test: if last_checkpoint != file_checkpoint: last_checkpoint = file_checkpoint @@ -722,19 +782,66 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name) - return f.name + return f.name, tts_api.device + + +def check_finetune(finetune): + return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune) + + +def get_checkpoints_project(project_name, is_gradio=True): + if project_name is None: + return [], "" + path_project_ckpts = os.path.join("ckpts", project_name) + + if os.path.isdir(path_project_ckpts): + files_checkpoints = glob(os.path.join(path_project_ckpts, "*.pt")) + files_checkpoints = sorted( + files_checkpoints, + key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]) + if os.path.basename(x) != "model_last.pt" + else float("inf"), + ) + else: + files_checkpoints = [] + + selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0] + + if is_gradio: + return gr.update(choices=files_checkpoints, value=selelect_checkpoint) + + return files_checkpoints, selelect_checkpoint with gr.Blocks() as app: + gr.Markdown( + """ +# E2/F5 TTS AUTOMATIC FINETUNE + +This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models: + +* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching) +* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS) + +The checkpoints support English and Chinese. + +for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143) +""" + ) + with gr.Row(): + projects, projects_selelect = get_list_projects() + tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char"], value="pinyin") project_name = gr.Textbox(label="project name", value="my_speak") bt_create = gr.Button("create new project") - bt_create.click(fn=create_data_project, inputs=[project_name]) + cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True) + + bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project]) with gr.Tabs(): with gr.TabItem("transcribe Data"): - ch_manual = gr.Checkbox(label="user", value=False) + ch_manual = gr.Checkbox(label="audio from path", value=False) mark_info_transcribe = gr.Markdown( """```plaintext @@ -756,7 +863,7 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): txt_info_transcribe = gr.Text(label="info", value="") bt_transcribe.click( fn=transcribe_all, - inputs=[project_name, audio_speaker, txt_lang, ch_manual], + inputs=[cm_project, audio_speaker, txt_lang, ch_manual], outputs=[txt_info_transcribe], ) ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe]) @@ -769,7 +876,7 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): random_sample_transcribe.click( fn=get_random_sample_transcribe, - inputs=[project_name], + inputs=[cm_project], outputs=[random_text_transcribe, random_audio_transcribe], ) @@ -797,7 +904,7 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): bt_prepare = bt_create = gr.Button("prepare") txt_info_prepare = gr.Text(label="info", value="") - bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare]) + bt_prepare.click(fn=create_metadata, inputs=[cm_project], outputs=[txt_info_prepare]) random_sample_prepare = gr.Button("random sample") @@ -806,16 +913,20 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): random_audio_prepare = gr.Audio(label="Audio", type="filepath") random_sample_prepare.click( - fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare] + fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare] ) with gr.TabItem("train Data"): with gr.Row(): bt_calculate = bt_create = gr.Button("Auto Settings") - ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True) lb_samples = gr.Label(label="samples") batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame") + with gr.Row(): + ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True) + tokenizer_file = gr.Textbox(label="Tokenizer File", value="") + file_checkpoint_train = gr.Textbox(label="Checkpoint", value="") + with gr.Row(): exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base") learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5) @@ -844,7 +955,7 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): start_button.click( fn=start_training, inputs=[ - project_name, + cm_project, exp_name, learning_rate, batch_size_per_gpu, @@ -857,14 +968,18 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): save_per_updates, last_per_steps, ch_finetune, + file_checkpoint_train, + tokenizer_type, + tokenizer_file, ], outputs=[txt_info_train, start_button, stop_button], ) stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button]) + bt_calculate.click( fn=calculate_train, inputs=[ - project_name, + cm_project, batch_size_type, max_samples, learning_rate, @@ -881,29 +996,42 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): last_per_steps, lb_samples, learning_rate, + epochs, ], ) + ch_finetune.change( + check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type] + ) + with gr.TabItem("reduse checkpoint"): txt_path_checkpoint = gr.Text(label="path checkpoint :") txt_path_checkpoint_small = gr.Text(label="path output :") + ch_safetensors = gr.Checkbox(label="safetensors", value="") txt_info_reduse = gr.Text(label="info", value="") reduse_button = gr.Button("reduse") reduse_button.click( fn=extract_and_save_ema_model, - inputs=[txt_path_checkpoint, txt_path_checkpoint_small], + inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors], outputs=[txt_info_reduse], ) with gr.TabItem("vocab check experiment"): check_button = gr.Button("check vocab") txt_info_check = gr.Text(label="info", value="") - check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check]) + check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check]) with gr.TabItem("test model"): exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") + list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False) + nfe_step = gr.Number(label="n_step", value=32) - file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="") + + with gr.Row(): + cm_checkpoint = gr.Dropdown( + choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True + ) + bt_checkpoint_refresh = gr.Button("refresh") random_sample_infer = gr.Button("random sample") @@ -911,17 +1039,24 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): ref_audio = gr.Audio(label="audio ref", type="filepath") gen_text = gr.Textbox(label="gen text") random_sample_infer.click( - fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio] + fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio] ) - check_button_infer = gr.Button("infer") + + with gr.Row(): + txt_info_gpu = gr.Textbox("", label="device") + check_button_infer = gr.Button("infer") + gen_audio = gr.Audio(label="audio gen", type="filepath") check_button_infer.click( fn=infer, - inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step], - outputs=[gen_audio], + inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step], + outputs=[gen_audio, txt_info_gpu], ) + bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) + cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) + @click.command() @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")