diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 000000000..524f04feb --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,14 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..9ac5ee15d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,14 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.7.0 + hooks: + # Run the linter. + - id: ruff + args: [--fix] + # Run the formatter. + - id: ruff-format + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml diff --git a/README.md b/README.md index 5e825b6d2..9b7b3c6f5 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,26 @@ pip install -r requirements.txt docker build -t f5tts:v1 . ``` +### Development + +When making a pull request, please use pre-commit to ensure code quality: + +```bash +pip install pre-commit +pre-commit install +``` + +This will run linters and formatters automatically before each commit. + +Manually run using: + +```bash +pre-commit run --all-files +``` + +Note: Some model components have linting exceptions for E722 to accommodate tensor notation + + ## Prepare Dataset Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`. diff --git a/finetune-cli.py b/finetune-cli.py index 79ce9bb04..bc11ee2cf 100644 --- a/finetune-cli.py +++ b/finetune-cli.py @@ -1,42 +1,57 @@ import argparse -from model import CFM, UNetT, DiT, MMDiT, Trainer +from model import CFM, UNetT, DiT, Trainer from model.utils import get_tokenizer from model.dataset import load_dataset from cached_path import cached_path -import shutil,os +import shutil +import os + # -------------------------- Dataset Settings --------------------------- # target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 -tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' -tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) # -------------------------- Argument Parsing --------------------------- # def parse_args(): - parser = argparse.ArgumentParser(description='Train CFM Model') - - parser.add_argument('--exp_name', type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"],help='Experiment name') - parser.add_argument('--dataset_name', type=str, default="Emilia_ZH_EN", help='Name of the dataset to use') - parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for training') - parser.add_argument('--batch_size_per_gpu', type=int, default=256, help='Batch size per GPU') - parser.add_argument('--batch_size_type', type=str, default="frame", choices=["frame", "sample"],help='Batch size type') - parser.add_argument('--max_samples', type=int, default=16, help='Max sequences per batch') - parser.add_argument('--grad_accumulation_steps', type=int, default=1,help='Gradient accumulation steps') - parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping') - parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs') - parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps') - parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps') - parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps') - parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune') - + parser = argparse.ArgumentParser(description="Train CFM Model") + + parser.add_argument( + "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name" + ) + parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training") + parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU") + parser.add_argument( + "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type" + ) + parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch") + parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") + parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") + parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs") + parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps") + parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps") + parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps") + parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune") + + parser.add_argument( + "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type" + ) + parser.add_argument( + "--tokenizer_path", + type=str, + default=None, + help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')", + ) + return parser.parse_args() + # -------------------------- Training Settings -------------------------- # + def main(): args = parse_args() - # Model parameters based on experiment name if args.exp_name == "F5TTS_Base": @@ -44,24 +59,31 @@ def main(): model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) if args.finetune: - ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) elif args.exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if args.finetune: - ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) - + ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + if args.finetune: - path_ckpt = os.path.join("ckpts",args.dataset_name) - if os.path.isdir(path_ckpt)==False: - os.makedirs(path_ckpt,exist_ok=True) - shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path))) - - checkpoint_path=os.path.join("ckpts",args.dataset_name) - - # Use the dataset_name provided in the command line - tokenizer_path = args.dataset_name if tokenizer != "custom" else tokenizer_path + path_ckpt = os.path.join("ckpts", args.dataset_name) + if not os.path.isdir(path_ckpt): + os.makedirs(path_ckpt, exist_ok=True) + shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path))) + + checkpoint_path = os.path.join("ckpts", args.dataset_name) + + # Use the tokenizer and tokenizer_path provided in the command line arguments + tokenizer = args.tokenizer + if tokenizer == "custom": + if not args.tokenizer_path: + raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.") + tokenizer_path = args.tokenizer_path + else: + tokenizer_path = args.dataset_name + vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) mel_spec_kwargs = dict( @@ -71,11 +93,7 @@ def main(): ) e2tts = CFM( - transformer=model_cls( - **model_cfg, - text_num_embeds=vocab_size, - mel_dim=n_mel_channels - ), + transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=mel_spec_kwargs, vocab_char_map=vocab_char_map, ) @@ -99,10 +117,11 @@ def main(): ) train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) - trainer.train(train_dataset, - resumable_with_seed=666 # seed for shuffling dataset - ) + trainer.train( + train_dataset, + resumable_with_seed=666, # seed for shuffling dataset + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/finetune_gradio.py b/finetune_gradio.py index 3f442dc0f..e61fc53cc 100644 --- a/finetune_gradio.py +++ b/finetune_gradio.py @@ -1,4 +1,5 @@ -import os,sys +import os +import sys from transformers import pipeline import gradio as gr @@ -20,34 +21,37 @@ import subprocess from datasets.arrow_writer import ArrowWriter -import json -training_process = None +training_process = None system = platform.system() python_executable = sys.executable or "python" -path_data="data" +path_data = "data" -device = ( -"cuda" - if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() else "cpu" -) +device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" pipe = None + # Load metadata def get_audio_duration(audio_path): """Calculate the duration of an audio file.""" audio, sample_rate = torchaudio.load(audio_path) - num_channels = audio.shape[0] + num_channels = audio.shape[0] return audio.shape[1] / (sample_rate * num_channels) + def clear_text(text): """Clean and prepare text by lowering the case and stripping whitespace.""" return text.lower().strip() -def get_rms(y,frame_length=2048,hop_length=512,pad_mode="constant",): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py + +def get_rms( + y, + frame_length=2048, + hop_length=512, + pad_mode="constant", +): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py padding = (int(frame_length // 2), int(frame_length // 2)) y = np.pad(y, padding, mode=pad_mode) @@ -74,7 +78,8 @@ def get_rms(y,frame_length=2048,hop_length=512,pad_mode="constant",): # https:// return np.sqrt(power) -class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py + +class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py def __init__( self, sr: int, @@ -85,13 +90,9 @@ def __init__( max_sil_kept: int = 2000, ): if not min_length >= min_interval >= hop_size: - raise ValueError( - "The following condition must be satisfied: min_length >= min_interval >= hop_size" - ) + raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size") if not max_sil_kept >= hop_size: - raise ValueError( - "The following condition must be satisfied: max_sil_kept >= hop_size" - ) + raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size") min_interval = sr * min_interval / 1000 self.threshold = 10 ** (threshold / 20.0) self.hop_size = round(sr * hop_size / 1000) @@ -102,13 +103,9 @@ def __init__( def _apply_slice(self, waveform, begin, end): if len(waveform.shape) > 1: - return waveform[ - :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size) - ] + return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)] else: - return waveform[ - begin * self.hop_size : min(waveform.shape[0], end * self.hop_size) - ] + return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)] # @timeit def slice(self, waveform): @@ -118,9 +115,7 @@ def slice(self, waveform): samples = waveform if samples.shape[0] <= self.min_length: return [waveform] - rms_list = get_rms( - y=samples, frame_length=self.win_size, hop_length=self.hop_size - ).squeeze(0) + rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) sil_tags = [] silence_start = None clip_start = 0 @@ -136,10 +131,7 @@ def slice(self, waveform): continue # Clear recorded silence start if interval is not enough or clip is too short is_leading_silence = silence_start == 0 and i > self.max_sil_kept - need_slice_middle = ( - i - silence_start >= self.min_interval - and i - clip_start >= self.min_length - ) + need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length if not is_leading_silence and not need_slice_middle: silence_start = None continue @@ -152,21 +144,10 @@ def slice(self, waveform): sil_tags.append((pos, pos)) clip_start = pos elif i - silence_start <= self.max_sil_kept * 2: - pos = rms_list[ - i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 - ].argmin() + pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin() pos += i - self.max_sil_kept - pos_l = ( - rms_list[ - silence_start : silence_start + self.max_sil_kept + 1 - ].argmin() - + silence_start - ) - pos_r = ( - rms_list[i - self.max_sil_kept : i + 1].argmin() - + i - - self.max_sil_kept - ) + pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start + pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept if silence_start == 0: sil_tags.append((0, pos_r)) clip_start = pos_r @@ -174,17 +155,8 @@ def slice(self, waveform): sil_tags.append((min(pos_l, pos), max(pos_r, pos))) clip_start = max(pos_r, pos) else: - pos_l = ( - rms_list[ - silence_start : silence_start + self.max_sil_kept + 1 - ].argmin() - + silence_start - ) - pos_r = ( - rms_list[i - self.max_sil_kept : i + 1].argmin() - + i - - self.max_sil_kept - ) + pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start + pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept if silence_start == 0: sil_tags.append((0, pos_r)) else: @@ -193,33 +165,39 @@ def slice(self, waveform): silence_start = None # Deal with trailing silence. total_frames = rms_list.shape[0] - if ( - silence_start is not None - and total_frames - silence_start >= self.min_interval - ): + if silence_start is not None and total_frames - silence_start >= self.min_interval: silence_end = min(total_frames, silence_start + self.max_sil_kept) pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start sil_tags.append((pos, total_frames + 1)) # Apply and return slices. ####音频+起始时间+终止时间 if len(sil_tags) == 0: - return [[waveform,0,int(total_frames*self.hop_size)]] + return [[waveform, 0, int(total_frames * self.hop_size)]] else: chunks = [] if sil_tags[0][0] > 0: - chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)]) + chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)]) for i in range(len(sil_tags) - 1): chunks.append( - [self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),int(sil_tags[i][1]*self.hop_size),int(sil_tags[i + 1][0]*self.hop_size)] + [ + self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]), + int(sil_tags[i][1] * self.hop_size), + int(sil_tags[i + 1][0] * self.hop_size), + ] ) if sil_tags[-1][1] < total_frames: chunks.append( - [self._apply_slice(waveform, sil_tags[-1][1], total_frames),int(sil_tags[-1][1]*self.hop_size),int(total_frames*self.hop_size)] + [ + self._apply_slice(waveform, sil_tags[-1][1], total_frames), + int(sil_tags[-1][1] * self.hop_size), + int(total_frames * self.hop_size), + ] ) return chunks -#terminal -def terminate_process_tree(pid, including_parent=True): + +# terminal +def terminate_process_tree(pid, including_parent=True): try: parent = psutil.Process(pid) except psutil.NoSuchProcess: @@ -238,6 +216,7 @@ def terminate_process_tree(pid, including_parent=True): except OSError: pass + def terminate_process(pid): if system == "Windows": cmd = f"taskkill /t /f /pid {pid}" @@ -245,132 +224,154 @@ def terminate_process(pid): else: terminate_process_tree(pid) -def start_training(dataset_name="", - exp_name="F5TTS_Base", - learning_rate=1e-4, - batch_size_per_gpu=400, - batch_size_type="frame", - max_samples=64, - grad_accumulation_steps=1, - max_grad_norm=1.0, - epochs=11, - num_warmup_updates=200, - save_per_updates=400, - last_per_steps=800, - finetune=True, - ): - +def start_training( + dataset_name="", + exp_name="F5TTS_Base", + learning_rate=1e-4, + batch_size_per_gpu=400, + batch_size_type="frame", + max_samples=64, + grad_accumulation_steps=1, + max_grad_norm=1.0, + epochs=11, + num_warmup_updates=200, + save_per_updates=400, + last_per_steps=800, + finetune=True, +): global training_process path_project = os.path.join(path_data, dataset_name + "_pinyin") - if os.path.isdir(path_project)==False: - yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False) + if not os.path.isdir(path_project): + yield ( + f"There is not project with name {dataset_name}", + gr.update(interactive=True), + gr.update(interactive=False), + ) return - file_raw = os.path.join(path_project,"raw.arrow") - if os.path.isfile(file_raw)==False: - yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False) - return + file_raw = os.path.join(path_project, "raw.arrow") + if not os.path.isfile(file_raw): + yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False) + return # Check if a training process is already running if training_process is not None: - return "Train run already!",gr.update(interactive=False),gr.update(interactive=True) + return "Train run already!", gr.update(interactive=False), gr.update(interactive=True) - yield "start train",gr.update(interactive=False),gr.update(interactive=False) + yield "start train", gr.update(interactive=False), gr.update(interactive=False) # Command to run the training script with the specified arguments - cmd = f"accelerate launch finetune-cli.py --exp_name {exp_name} " \ - f"--learning_rate {learning_rate} " \ - f"--batch_size_per_gpu {batch_size_per_gpu} " \ - f"--batch_size_type {batch_size_type} " \ - f"--max_samples {max_samples} " \ - f"--grad_accumulation_steps {grad_accumulation_steps} " \ - f"--max_grad_norm {max_grad_norm} " \ - f"--epochs {epochs} " \ - f"--num_warmup_updates {num_warmup_updates} " \ - f"--save_per_updates {save_per_updates} " \ - f"--last_per_steps {last_per_steps} " \ - f"--dataset_name {dataset_name}" - if finetune:cmd += f" --finetune {finetune}" + cmd = ( + f"accelerate launch finetune-cli.py --exp_name {exp_name} " + f"--learning_rate {learning_rate} " + f"--batch_size_per_gpu {batch_size_per_gpu} " + f"--batch_size_type {batch_size_type} " + f"--max_samples {max_samples} " + f"--grad_accumulation_steps {grad_accumulation_steps} " + f"--max_grad_norm {max_grad_norm} " + f"--epochs {epochs} " + f"--num_warmup_updates {num_warmup_updates} " + f"--save_per_updates {save_per_updates} " + f"--last_per_steps {last_per_steps} " + f"--dataset_name {dataset_name}" + ) + if finetune: + cmd += f" --finetune {finetune}" print(cmd) - + try: - # Start the training process - training_process = subprocess.Popen(cmd, shell=True) + # Start the training process + training_process = subprocess.Popen(cmd, shell=True) - time.sleep(5) - yield "check terminal for wandb",gr.update(interactive=False),gr.update(interactive=True) - - # Wait for the training process to finish - training_process.wait() - time.sleep(1) - - if training_process is None: - text_info = 'train stop' - else: - text_info = "train complete !" + time.sleep(5) + yield "check terminal for wandb", gr.update(interactive=False), gr.update(interactive=True) + + # Wait for the training process to finish + training_process.wait() + time.sleep(1) + + if training_process is None: + text_info = "train stop" + else: + text_info = "train complete !" except Exception as e: # Catch all exceptions # Ensure that we reset the training process variable in case of an error - text_info=f"An error occurred: {str(e)}" - - training_process=None + text_info = f"An error occurred: {str(e)}" + + training_process = None + + yield text_info, gr.update(interactive=True), gr.update(interactive=False) - yield text_info,gr.update(interactive=True),gr.update(interactive=False) def stop_training(): global training_process - if training_process is None:return f"Train not run !",gr.update(interactive=True),gr.update(interactive=False) + if training_process is None: + return "Train not run !", gr.update(interactive=True), gr.update(interactive=False) terminate_process_tree(training_process.pid) training_process = None - return 'train stop',gr.update(interactive=True),gr.update(interactive=False) + return "train stop", gr.update(interactive=True), gr.update(interactive=False) + def create_data_project(name): - name+="_pinyin" - os.makedirs(os.path.join(path_data,name),exist_ok=True) - os.makedirs(os.path.join(path_data,name,"dataset"),exist_ok=True) - -def transcribe(file_audio,language="english"): + name += "_pinyin" + os.makedirs(os.path.join(path_data, name), exist_ok=True) + os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True) + + +def transcribe(file_audio, language="english"): global pipe if pipe is None: - pipe = pipeline("automatic-speech-recognition",model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16,device=device) + pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3-turbo", + torch_dtype=torch.float16, + device=device, + ) text_transcribe = pipe( file_audio, chunk_length_s=30, batch_size=128, - generate_kwargs={"task": "transcribe","language": language}, + generate_kwargs={"task": "transcribe", "language": language}, return_timestamps=False, )["text"].strip() return text_transcribe -def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Progress()): - name_project+="_pinyin" - path_project= os.path.join(path_data,name_project) - path_dataset = os.path.join(path_project,"dataset") - path_project_wavs = os.path.join(path_project,"wavs") - file_metadata = os.path.join(path_project,"metadata.csv") - if audio_files is None:return "You need to load an audio file." +def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()): + name_project += "_pinyin" + path_project = os.path.join(path_data, name_project) + path_dataset = os.path.join(path_project, "dataset") + path_project_wavs = os.path.join(path_project, "wavs") + file_metadata = os.path.join(path_project, "metadata.csv") + + if audio_files is None: + return "You need to load an audio file." if os.path.isdir(path_project_wavs): - shutil.rmtree(path_project_wavs) + shutil.rmtree(path_project_wavs) if os.path.isfile(file_metadata): - os.remove(file_metadata) + os.remove(file_metadata) + + os.makedirs(path_project_wavs, exist_ok=True) - os.makedirs(path_project_wavs,exist_ok=True) - if user: - file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))] - if file_audios==[]:return "No audio file was found in the dataset." + file_audios = [ + file + for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac") + for file in glob(os.path.join(path_dataset, format)) + ] + if file_audios == []: + return "No audio file was found in the dataset." else: - file_audios = audio_files - + file_audios = audio_files alpha = 0.5 _max = 1.0 @@ -378,181 +379,202 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog num = 0 error_num = 0 - data="" - for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))): - - audio, _ = librosa.load(file_audio, sr=24000, mono=True) - - list_slicer=slicer.slice(audio) - for chunk, start, end in progress.tqdm(list_slicer,total=len(list_slicer), desc="slicer files"): - + data = "" + for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))): + audio, _ = librosa.load(file_audio, sr=24000, mono=True) + + list_slicer = slicer.slice(audio) + for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"): name_segment = os.path.join(f"segment_{num}") - file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav") - + file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav") + tmp_max = np.abs(chunk).max() - if(tmp_max>1):chunk/=tmp_max + if tmp_max > 1: + chunk /= tmp_max chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk - wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16)) - + wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16)) + try: - text=transcribe(file_segment,language) - text = text.lower().strip().replace('"',"") + text = transcribe(file_segment, language) + text = text.lower().strip().replace('"', "") - data+= f"{name_segment}|{text}\n" + data += f"{name_segment}|{text}\n" - num+=1 - except: - error_num +=1 + num += 1 + except: # noqa: E722 + error_num += 1 - with open(file_metadata,"w",encoding="utf-8") as f: + with open(file_metadata, "w", encoding="utf-8") as f: f.write(data) - - if error_num!=[]: - error_text=f"\nerror files : {error_num}" + + if error_num != []: + error_text = f"\nerror files : {error_num}" else: - error_text="" - + error_text = "" + return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}" + def format_seconds_to_hms(seconds): hours = int(seconds / 3600) minutes = int((seconds % 3600) / 60) seconds = seconds % 60 return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds)) -def create_metadata(name_project,progress=gr.Progress()): - name_project+="_pinyin" - path_project= os.path.join(path_data,name_project) - path_project_wavs = os.path.join(path_project,"wavs") - file_metadata = os.path.join(path_project,"metadata.csv") - file_raw = os.path.join(path_project,"raw.arrow") - file_duration = os.path.join(path_project,"duration.json") - file_vocab = os.path.join(path_project,"vocab.txt") - - if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata - - with open(file_metadata,"r",encoding="utf-8") as f: - data=f.read() - - audio_path_list=[] - text_list=[] - duration_list=[] - - count=data.split("\n") - lenght=0 - result=[] - error_files=[] - for line in progress.tqdm(data.split("\n"),total=count): - sp_line=line.split("|") - if len(sp_line)!=2:continue - name_audio,text = sp_line[:2] + +def create_metadata(name_project, progress=gr.Progress()): + name_project += "_pinyin" + path_project = os.path.join(path_data, name_project) + path_project_wavs = os.path.join(path_project, "wavs") + file_metadata = os.path.join(path_project, "metadata.csv") + file_raw = os.path.join(path_project, "raw.arrow") + file_duration = os.path.join(path_project, "duration.json") + file_vocab = os.path.join(path_project, "vocab.txt") + + if not os.path.isfile(file_metadata): + return "The file was not found in " + file_metadata + + with open(file_metadata, "r", encoding="utf-8") as f: + data = f.read() + + audio_path_list = [] + text_list = [] + duration_list = [] + + count = data.split("\n") + lenght = 0 + result = [] + error_files = [] + for line in progress.tqdm(data.split("\n"), total=count): + sp_line = line.split("|") + if len(sp_line) != 2: + continue + name_audio, text = sp_line[:2] file_audio = os.path.join(path_project_wavs, name_audio + ".wav") - if os.path.isfile(file_audio)==False: + if not os.path.isfile(file_audio): error_files.append(file_audio) continue duraction = get_audio_duration(file_audio) - if duraction<2 and duraction>15:continue - if len(text)<4:continue + if duraction < 2 and duraction > 15: + continue + if len(text) < 4: + continue text = clear_text(text) - text = convert_char_to_pinyin([text], polyphone = True)[0] + text = convert_char_to_pinyin([text], polyphone=True)[0] audio_path_list.append(file_audio) duration_list.append(duraction) text_list.append(text) - + result.append({"audio_path": file_audio, "text": text, "duration": duraction}) - lenght+=duraction + lenght += duraction - if duration_list==[]: - error_files_text="\n".join(error_files) + if duration_list == []: + error_files_text = "\n".join(error_files) return f"Error: No audio files found in the specified path : \n{error_files_text}" - - min_second = round(min(duration_list),2) - max_second = round(max(duration_list),2) + + min_second = round(min(duration_list), 2) + max_second = round(max(duration_list), 2) with ArrowWriter(path=file_raw, writer_batch_size=1) as writer: - for line in progress.tqdm(result,total=len(result), desc=f"prepare data"): + for line in progress.tqdm(result, total=len(result), desc="prepare data"): writer.write(line) - with open(file_duration, 'w', encoding='utf-8') as f: + with open(file_duration, "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) - - file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt" - if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!" + + file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt" + if not os.path.isfile(file_vocab_finetune): + return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!" shutil.copy2(file_vocab_finetune, file_vocab) - - if error_files!=[]: - error_text="error files\n" + "\n".join(error_files) + + if error_files != []: + error_text = "error files\n" + "\n".join(error_files) else: - error_text="" - + error_text = "" + return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}" + def check_user(value): - return gr.update(visible=not value),gr.update(visible=value) + return gr.update(visible=not value), gr.update(visible=value) + + +def calculate_train( + name_project, + batch_size_type, + max_samples, + learning_rate, + num_warmup_updates, + save_per_updates, + last_per_steps, + finetune, +): + name_project += "_pinyin" + path_project = os.path.join(path_data, name_project) + file_duraction = os.path.join(path_project, "duration.json") -def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,finetune): - name_project+="_pinyin" - path_project= os.path.join(path_data,name_project) - file_duraction = os.path.join(path_project,"duration.json") + with open(file_duraction, "r") as file: + data = json.load(file) - with open(file_duraction, 'r') as file: - data = json.load(file) - - duration_list = data['duration'] + duration_list = data["duration"] samples = len(duration_list) if torch.cuda.is_available(): gpu_properties = torch.cuda.get_device_properties(0) - total_memory = gpu_properties.total_memory / (1024 ** 3) + total_memory = gpu_properties.total_memory / (1024**3) elif torch.backends.mps.is_available(): - total_memory = psutil.virtual_memory().available / (1024 ** 3) - - if batch_size_type=="frame": - batch = int(total_memory * 0.5) - batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch) - batch_size_per_gpu = int(38400 / batch ) + total_memory = psutil.virtual_memory().available / (1024**3) + + if batch_size_type == "frame": + batch = int(total_memory * 0.5) + batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch) + batch_size_per_gpu = int(38400 / batch) else: - batch_size_per_gpu = int(total_memory / 8) - batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu) - batch = batch_size_per_gpu + batch_size_per_gpu = int(total_memory / 8) + batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu) + batch = batch_size_per_gpu - if batch_size_per_gpu<=0:batch_size_per_gpu=1 + if batch_size_per_gpu <= 0: + batch_size_per_gpu = 1 - if samples<64: - max_samples = int(samples * 0.25) + if samples < 64: + max_samples = int(samples * 0.25) else: - max_samples = 64 - - num_warmup_updates = int(samples * 0.10) - save_per_updates = int(samples * 0.25) - last_per_steps =int(save_per_updates * 5) - + max_samples = 64 + + num_warmup_updates = int(samples * 0.10) + save_per_updates = int(samples * 0.25) + last_per_steps = int(save_per_updates * 5) + max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples) num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates) save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates) last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps) - if finetune:learning_rate=1e-4 - else:learning_rate=7.5e-5 + if finetune: + learning_rate = 1e-4 + else: + learning_rate = 7.5e-5 + + return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate - return batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,samples,learning_rate def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None: try: checkpoint = torch.load(checkpoint_path) print("Original Checkpoint Keys:", checkpoint.keys()) - - ema_model_state_dict = checkpoint.get('ema_model_state_dict', None) + + ema_model_state_dict = checkpoint.get("ema_model_state_dict", None) if ema_model_state_dict is not None: - new_checkpoint = {'ema_model_state_dict': ema_model_state_dict} + new_checkpoint = {"ema_model_state_dict": ema_model_state_dict} torch.save(new_checkpoint, new_checkpoint_path) return f"New checkpoint saved at: {new_checkpoint_path}" else: @@ -561,65 +583,61 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) - except Exception as e: return f"An error occurred: {e}" + def vocab_check(project_name): name_project = project_name + "_pinyin" path_project = os.path.join(path_data, name_project) file_metadata = os.path.join(path_project, "metadata.csv") - - file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt" - if os.path.isfile(file_vocab)==False: + + file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt" + if not os.path.isfile(file_vocab): return f"the file {file_vocab} not found !" - - with open(file_vocab,"r",encoding="utf-8") as f: - data=f.read() + + with open(file_vocab, "r", encoding="utf-8") as f: + data = f.read() vocab = data.split("\n") - if os.path.isfile(file_metadata)==False: + if not os.path.isfile(file_metadata): return f"the file {file_metadata} not found !" - with open(file_metadata,"r",encoding="utf-8") as f: - data=f.read() + with open(file_metadata, "r", encoding="utf-8") as f: + data = f.read() - miss_symbols=[] - miss_symbols_keep={} + miss_symbols = [] + miss_symbols_keep = {} for item in data.split("\n"): - sp=item.split("|") - if len(sp)!=2:continue - text=sp[1].lower().strip() - - for t in text: - if (t in vocab)==False and (t in miss_symbols_keep)==False: - miss_symbols.append(t) - miss_symbols_keep[t]=t - - - if miss_symbols==[]:info ="You can train using your language !" - else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols) + sp = item.split("|") + if len(sp) != 2: + continue + text = sp[1].lower().strip() + + for t in text: + if t not in vocab and t not in miss_symbols_keep: + miss_symbols.append(t) + miss_symbols_keep[t] = t + if miss_symbols == []: + info = "You can train using your language !" + else: + info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols) return info - with gr.Blocks() as app: - with gr.Row(): - project_name=gr.Textbox(label="project name",value="my_speak") - bt_create=gr.Button("create new project") - - bt_create.click(fn=create_data_project,inputs=[project_name]) - - with gr.Tabs(): - + project_name = gr.Textbox(label="project name", value="my_speak") + bt_create = gr.Button("create new project") - with gr.TabItem("transcribe Data"): + bt_create.click(fn=create_data_project, inputs=[project_name]) + with gr.Tabs(): + with gr.TabItem("transcribe Data"): + ch_manual = gr.Checkbox(label="user", value=False) - ch_manual = gr.Checkbox(label="user",value=False) - - mark_info_transcribe=gr.Markdown( - """```plaintext + mark_info_transcribe = gr.Markdown( + """```plaintext Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory. my_speak/ @@ -628,18 +646,24 @@ def vocab_check(project_name): ├── audio1.wav └── audio2.wav ... - ```""",visible=False) - - audio_speaker = gr.File(label="voice",type="filepath",file_count="multiple") - txt_lang = gr.Text(label="Language",value="english") - bt_transcribe=bt_create=gr.Button("transcribe") - txt_info_transcribe=gr.Text(label="info",value="") - bt_transcribe.click(fn=transcribe_all,inputs=[project_name,audio_speaker,txt_lang,ch_manual],outputs=[txt_info_transcribe]) - ch_manual.change(fn=check_user,inputs=[ch_manual],outputs=[audio_speaker,mark_info_transcribe]) - - with gr.TabItem("prepare Data"): - gr.Markdown( - """```plaintext + ```""", + visible=False, + ) + + audio_speaker = gr.File(label="voice", type="filepath", file_count="multiple") + txt_lang = gr.Text(label="Language", value="english") + bt_transcribe = bt_create = gr.Button("transcribe") + txt_info_transcribe = gr.Text(label="info", value="") + bt_transcribe.click( + fn=transcribe_all, + inputs=[project_name, audio_speaker, txt_lang, ch_manual], + outputs=[txt_info_transcribe], + ) + ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe]) + + with gr.TabItem("prepare Data"): + gr.Markdown( + """```plaintext place all your wavs folder and your metadata.csv file in {your name project} my_speak/ │ @@ -656,61 +680,104 @@ def vocab_check(project_name): audio2|text1 ... - ```""") - - bt_prepare=bt_create=gr.Button("prepare") - txt_info_prepare=gr.Text(label="info",value="") - bt_prepare.click(fn=create_metadata,inputs=[project_name],outputs=[txt_info_prepare]) - - with gr.TabItem("train Data"): - - with gr.Row(): - bt_calculate=bt_create=gr.Button("Auto Settings") - ch_finetune=bt_create=gr.Checkbox(label="finetune",value=True) - lb_samples = gr.Label(label="samples") - batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame") - - with gr.Row(): - exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base") - learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4) - - with gr.Row(): - batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000) - max_samples = gr.Number(label="Max Samples", value=16) - - with gr.Row(): - grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1) - max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0) - - with gr.Row(): - epochs = gr.Number(label="Epochs", value=10) - num_warmup_updates = gr.Number(label="Warmup Updates", value=5) - - with gr.Row(): - save_per_updates = gr.Number(label="Save per Updates", value=10) - last_per_steps = gr.Number(label="Last per Steps", value=50) - - with gr.Row(): - start_button = gr.Button("Start Training") - stop_button = gr.Button("Stop Training",interactive=False) - - txt_info_train=gr.Text(label="info",value="") - start_button.click(fn=start_training,inputs=[project_name,exp_name,learning_rate,batch_size_per_gpu,batch_size_type,max_samples,grad_accumulation_steps,max_grad_norm,epochs,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[txt_info_train,start_button,stop_button]) - stop_button.click(fn=stop_training,outputs=[txt_info_train,start_button,stop_button]) - bt_calculate.click(fn=calculate_train,inputs=[project_name,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,lb_samples,learning_rate]) - - with gr.TabItem("reduse checkpoint"): - txt_path_checkpoint = gr.Text(label="path checkpoint :") - txt_path_checkpoint_small = gr.Text(label="path output :") - txt_info_reduse = gr.Text(label="info",value="") - reduse_button = gr.Button("reduse") - reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse]) - - with gr.TabItem("vocab check experiment"): - check_button = gr.Button("check vocab") - txt_info_check=gr.Text(label="info",value="") - check_button.click(fn=vocab_check,inputs=[project_name],outputs=[txt_info_check]) - + ```""" + ) + + bt_prepare = bt_create = gr.Button("prepare") + txt_info_prepare = gr.Text(label="info", value="") + bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare]) + + with gr.TabItem("train Data"): + with gr.Row(): + bt_calculate = bt_create = gr.Button("Auto Settings") + ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True) + lb_samples = gr.Label(label="samples") + batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame") + + with gr.Row(): + exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base") + learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4) + + with gr.Row(): + batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000) + max_samples = gr.Number(label="Max Samples", value=16) + + with gr.Row(): + grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1) + max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0) + + with gr.Row(): + epochs = gr.Number(label="Epochs", value=10) + num_warmup_updates = gr.Number(label="Warmup Updates", value=5) + + with gr.Row(): + save_per_updates = gr.Number(label="Save per Updates", value=10) + last_per_steps = gr.Number(label="Last per Steps", value=50) + + with gr.Row(): + start_button = gr.Button("Start Training") + stop_button = gr.Button("Stop Training", interactive=False) + + txt_info_train = gr.Text(label="info", value="") + start_button.click( + fn=start_training, + inputs=[ + project_name, + exp_name, + learning_rate, + batch_size_per_gpu, + batch_size_type, + max_samples, + grad_accumulation_steps, + max_grad_norm, + epochs, + num_warmup_updates, + save_per_updates, + last_per_steps, + ch_finetune, + ], + outputs=[txt_info_train, start_button, stop_button], + ) + stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button]) + bt_calculate.click( + fn=calculate_train, + inputs=[ + project_name, + batch_size_type, + max_samples, + learning_rate, + num_warmup_updates, + save_per_updates, + last_per_steps, + ch_finetune, + ], + outputs=[ + batch_size_per_gpu, + max_samples, + num_warmup_updates, + save_per_updates, + last_per_steps, + lb_samples, + learning_rate, + ], + ) + + with gr.TabItem("reduse checkpoint"): + txt_path_checkpoint = gr.Text(label="path checkpoint :") + txt_path_checkpoint_small = gr.Text(label="path output :") + txt_info_reduse = gr.Text(label="info", value="") + reduse_button = gr.Button("reduse") + reduse_button.click( + fn=extract_and_save_ema_model, + inputs=[txt_path_checkpoint, txt_path_checkpoint_small], + outputs=[txt_info_reduse], + ) + + with gr.TabItem("vocab check experiment"): + check_button = gr.Button("check vocab") + txt_info_check = gr.Text(label="info", value="") + check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check]) + @click.command() @click.option("--port", "-p", default=None, type=int, help="Port to run the app on") @@ -725,10 +792,9 @@ def vocab_check(project_name): @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access") def main(port, host, share, api): global app - print(f"Starting app...") - app.queue(api_open=api).launch( - server_name=host, server_port=port, share=share, show_api=api - ) + print("Starting app...") + app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api) + if __name__ == "__main__": main() diff --git a/gradio_app.py b/gradio_app.py index f40c60b97..112aaaca4 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -1,3 +1,6 @@ +# ruff: noqa: E402 +# Above allows ruff to ignore E402: module level import not at top of file + import re import tempfile @@ -11,16 +14,19 @@ try: import spaces + USING_SPACES = True except ImportError: USING_SPACES = False + def gpu_decorator(func): if USING_SPACES: return spaces.GPU(func) else: return func + from model import DiT, UNetT from model.utils import ( save_spectrogram, @@ -38,15 +44,18 @@ def gpu_decorator(func): # load models F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) -F5TTS_ema_model = load_model(DiT, F5TTS_model_cfg, str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))) +F5TTS_ema_model = load_model( + DiT, F5TTS_model_cfg, str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) +) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) -E2TTS_ema_model = load_model(UNetT, E2TTS_model_cfg, str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))) +E2TTS_ema_model = load_model( + UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) +) @gpu_decorator def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1): - ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info) if model == "F5-TTS": @@ -54,7 +63,16 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_ elif model == "E2-TTS": ema_model = E2TTS_ema_model - final_wave, final_sample_rate, combined_spectrogram = infer_process(ref_audio, ref_text, gen_text, ema_model, cross_fade_duration=cross_fade_duration, speed=speed, show_info=gr.Info, progress=gr.Progress()) + final_wave, final_sample_rate, combined_spectrogram = infer_process( + ref_audio, + ref_text, + gen_text, + ema_model, + cross_fade_duration=cross_fade_duration, + speed=speed, + show_info=gr.Info, + progress=gr.Progress(), + ) # Remove silence if remove_silence: @@ -73,17 +91,19 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_ @gpu_decorator -def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence): +def generate_podcast( + script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence +): # Split the script into speaker blocks speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE) speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element - + generated_audio_segments = [] - + for i in range(0, len(speaker_blocks), 2): speaker = speaker_blocks[i] - text = speaker_blocks[i+1].strip() - + text = speaker_blocks[i + 1].strip() + # Determine which speaker is talking if speaker == speaker1_name: ref_audio = ref_audio1 @@ -93,51 +113,52 @@ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name ref_text = ref_text2 else: continue # Skip if the speaker is neither speaker1 nor speaker2 - + # Generate audio for this block audio, _ = infer(ref_audio, ref_text, text, model, remove_silence) - + # Convert the generated audio to a numpy array sr, audio_data = audio - + # Save the audio data as a WAV file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: sf.write(temp_file.name, audio_data, sr) audio_segment = AudioSegment.from_wav(temp_file.name) - + generated_audio_segments.append(audio_segment) - + # Add a short pause between speakers pause = AudioSegment.silent(duration=500) # 500ms pause generated_audio_segments.append(pause) - + # Concatenate all audio segments final_podcast = sum(generated_audio_segments) - + # Export the final podcast with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: podcast_path = temp_file.name final_podcast.export(podcast_path, format="wav") - + return podcast_path + def parse_speechtypes_text(gen_text): # Pattern to find (Emotion) - pattern = r'\((.*?)\)' + pattern = r"\((.*?)\)" # Split the text by the pattern tokens = re.split(pattern, gen_text) segments = [] - current_emotion = 'Regular' + current_emotion = "Regular" for i in range(len(tokens)): if i % 2 == 0: # This is text text = tokens[i].strip() if text: - segments.append({'emotion': current_emotion, 'text': text}) + segments.append({"emotion": current_emotion, "text": text}) else: # This is emotion emotion = tokens[i].strip() @@ -158,9 +179,7 @@ def parse_speechtypes_text(gen_text): gr.Markdown("# Batched TTS") ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") gen_text_input = gr.Textbox(label="Text to Generate", lines=10) - model_choice = gr.Radio( - choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS" - ) + model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS") generate_btn = gr.Button("Synthesize", variant="primary") with gr.Accordion("Advanced Settings", open=False): ref_text_input = gr.Textbox( @@ -206,23 +225,24 @@ def parse_speechtypes_text(gen_text): ], outputs=[audio_output, spectrogram_output], ) - + with gr.Blocks() as app_podcast: gr.Markdown("# Podcast Generation") speaker1_name = gr.Textbox(label="Speaker 1 Name") ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath") ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2) - + speaker2_name = gr.Textbox(label="Speaker 2 Name") ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath") ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2) - - script_input = gr.Textbox(label="Podcast Script", lines=10, - placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...") - - podcast_model_choice = gr.Radio( - choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS" + + script_input = gr.Textbox( + label="Podcast Script", + lines=10, + placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...", ) + + podcast_model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS") podcast_remove_silence = gr.Checkbox( label="Remove Silences", value=True, @@ -230,8 +250,12 @@ def parse_speechtypes_text(gen_text): generate_podcast_btn = gr.Button("Generate Podcast", variant="primary") podcast_output = gr.Audio(label="Generated Podcast") - def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence): - return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence) + def podcast_generation( + script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence + ): + return generate_podcast( + script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence + ) generate_podcast_btn.click( podcast_generation, @@ -249,23 +273,24 @@ def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_au outputs=podcast_output, ) + def parse_emotional_text(gen_text): # Pattern to find (Emotion) - pattern = r'\((.*?)\)' + pattern = r"\((.*?)\)" # Split the text by the pattern tokens = re.split(pattern, gen_text) segments = [] - current_emotion = 'Regular' + current_emotion = "Regular" for i in range(len(tokens)): if i % 2 == 0: # This is text text = tokens[i].strip() if text: - segments.append({'emotion': current_emotion, 'text': text}) + segments.append({"emotion": current_emotion, "text": text}) else: # This is emotion emotion = tokens[i].strip() @@ -273,6 +298,7 @@ def parse_emotional_text(gen_text): return segments + with gr.Blocks() as app_emotional: # New section for emotional generation gr.Markdown( @@ -287,13 +313,15 @@ def parse_emotional_text(gen_text): """ ) - gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.") + gr.Markdown( + "Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button." + ) # Regular speech type (mandatory) with gr.Row(): - regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False) - regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath') - regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2) + regular_name = gr.Textbox(value="Regular", label="Speech Type Name", interactive=False) + regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath") + regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2) # Additional speech types (up to 99 more) max_speech_types = 100 @@ -304,9 +332,9 @@ def parse_emotional_text(gen_text): for i in range(max_speech_types - 1): with gr.Row(): - name_input = gr.Textbox(label='Speech Type Name', visible=False) - audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False) - ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False) + name_input = gr.Textbox(label="Speech Type Name", visible=False) + audio_input = gr.Audio(label="Reference Audio", type="filepath", visible=False) + ref_text_input = gr.Textbox(label="Reference Text", lines=2, visible=False) delete_btn = gr.Button("Delete", variant="secondary", visible=False) speech_type_names.append(name_input) speech_type_audios.append(audio_input) @@ -351,7 +379,11 @@ def add_speech_type_fn(speech_type_count): add_speech_type_btn.click( add_speech_type_fn, inputs=speech_type_count, - outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns + outputs=[speech_type_count] + + speech_type_names + + speech_type_audios + + speech_type_ref_texts + + speech_type_delete_btns, ) # Function to delete a speech type @@ -365,9 +397,9 @@ def delete_speech_type_fn(speech_type_count): for i in range(max_speech_types - 1): if i == index: - name_updates.append(gr.update(visible=False, value='')) + name_updates.append(gr.update(visible=False, value="")) audio_updates.append(gr.update(visible=False, value=None)) - ref_text_updates.append(gr.update(visible=False, value='')) + ref_text_updates.append(gr.update(visible=False, value="")) delete_btn_updates.append(gr.update(visible=False)) else: name_updates.append(gr.update()) @@ -386,16 +418,18 @@ def delete_speech_type_fn(speech_type_count): delete_btn.click( delete_fn, inputs=speech_type_count, - outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns + outputs=[speech_type_count] + + speech_type_names + + speech_type_audios + + speech_type_ref_texts + + speech_type_delete_btns, ) # Text input for the prompt gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10) # Model choice - model_choice_emotional = gr.Radio( - choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS" - ) + model_choice_emotional = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS") with gr.Accordion("Advanced Settings", open=False): remove_silence_emotional = gr.Checkbox( @@ -408,6 +442,7 @@ def delete_speech_type_fn(speech_type_count): # Output audio audio_output_emotional = gr.Audio(label="Synthesized Audio") + @gpu_decorator def generate_emotional_speech( regular_audio, @@ -417,37 +452,39 @@ def generate_emotional_speech( ): num_additional_speech_types = max_speech_types - 1 speech_type_names_list = args[:num_additional_speech_types] - speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types] - speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types] + speech_type_audios_list = args[num_additional_speech_types : 2 * num_additional_speech_types] + speech_type_ref_texts_list = args[2 * num_additional_speech_types : 3 * num_additional_speech_types] model_choice = args[3 * num_additional_speech_types] remove_silence = args[3 * num_additional_speech_types + 1] # Collect the speech types and their audios into a dict - speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}} + speech_types = {"Regular": {"audio": regular_audio, "ref_text": regular_ref_text}} - for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list): + for name_input, audio_input, ref_text_input in zip( + speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list + ): if name_input and audio_input: - speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input} + speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input} # Parse the gen_text into segments segments = parse_speechtypes_text(gen_text) # For each segment, generate speech generated_audio_segments = [] - current_emotion = 'Regular' + current_emotion = "Regular" for segment in segments: - emotion = segment['emotion'] - text = segment['text'] + emotion = segment["emotion"] + text = segment["text"] if emotion in speech_types: current_emotion = emotion else: # If emotion not available, default to Regular - current_emotion = 'Regular' + current_emotion = "Regular" - ref_audio = speech_types[current_emotion]['audio'] - ref_text = speech_types[current_emotion].get('ref_text', '') + ref_audio = speech_types[current_emotion]["audio"] + ref_text = speech_types[current_emotion].get("ref_text", "") # Generate speech for this segment audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0) @@ -469,7 +506,11 @@ def generate_emotional_speech( regular_audio, regular_ref_text, gen_text_input_emotional, - ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [ + ] + + speech_type_names + + speech_type_audios + + speech_type_ref_texts + + [ model_choice_emotional, remove_silence_emotional, ], @@ -477,11 +518,7 @@ def generate_emotional_speech( ) # Validation function to disable Generate button if speech types are missing - def validate_speech_types( - gen_text, - regular_name, - *args - ): + def validate_speech_types(gen_text, regular_name, *args): num_additional_speech_types = max_speech_types - 1 speech_type_names_list = args[:num_additional_speech_types] @@ -495,7 +532,7 @@ def validate_speech_types( # Parse the gen_text to get the speech types used segments = parse_emotional_text(gen_text) - speech_types_in_text = set(segment['emotion'] for segment in segments) + speech_types_in_text = set(segment["emotion"] for segment in segments) # Check if all speech types in text are available missing_speech_types = speech_types_in_text - speech_types_available @@ -510,7 +547,7 @@ def validate_speech_types( gen_text_input_emotional.change( validate_speech_types, inputs=[gen_text_input_emotional, regular_name] + speech_type_names, - outputs=generate_emotional_btn + outputs=generate_emotional_btn, ) with gr.Blocks() as app: gr.Markdown( @@ -531,6 +568,7 @@ def validate_speech_types( ) gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"]) + @click.command() @click.option("--port", "-p", default=None, type=int, help="Port to run the app on") @click.option("--host", "-H", default=None, help="Host to run the app on") @@ -544,10 +582,8 @@ def validate_speech_types( @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access") def main(port, host, share, api): global app - print(f"Starting app...") - app.queue(api_open=api).launch( - server_name=host, server_port=port, share=share, show_api=api - ) + print("Starting app...") + app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api) if __name__ == "__main__": diff --git a/inference-cli.py b/inference-cli.py index ae004c798..3d4bd1539 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -44,19 +44,8 @@ "--vocab_file", help="The vocab .txt", ) -parser.add_argument( - "-r", - "--ref_audio", - type=str, - help="Reference audio file < 15 seconds." -) -parser.add_argument( - "-s", - "--ref_text", - type=str, - default="666", - help="Subtitle for the reference audio." -) +parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.") +parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.") parser.add_argument( "-t", "--gen_text", @@ -99,8 +88,8 @@ ckpt_file = args.ckpt_file if args.ckpt_file else "" vocab_file = args.vocab_file if args.vocab_file else "" remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"] -wave_path = Path(output_dir)/"out.wav" -spectrogram_path = Path(output_dir)/"out.png" +wave_path = Path(output_dir) / "out.wav" +spectrogram_path = Path(output_dir) / "out.png" vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path) @@ -110,44 +99,46 @@ if model == "F5-TTS": model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) - if ckpt_file == "": - repo_name= "F5-TTS" + if ckpt_file == "": + repo_name = "F5-TTS" exp_name = "F5TTS_Base" - ckpt_step= 1200000 + ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path elif model == "E2-TTS": model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - if ckpt_file == "": - repo_name= "E2-TTS" + if ckpt_file == "": + repo_name = "E2-TTS" exp_name = "E2TTS_Base" - ckpt_step= 1200000 + ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path print(f"Using {model}...") ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file) - + def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence): - main_voice = {"ref_audio":ref_audio, "ref_text":ref_text} + main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} if "voices" not in config: voices = {"main": main_voice} else: voices = config["voices"] voices["main"] = main_voice for voice in voices: - voices[voice]['ref_audio'], voices[voice]['ref_text'] = preprocess_ref_audio_text(voices[voice]['ref_audio'], voices[voice]['ref_text']) + voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( + voices[voice]["ref_audio"], voices[voice]["ref_text"] + ) print("Voice:", voice) - print("Ref_audio:", voices[voice]['ref_audio']) - print("Ref_text:", voices[voice]['ref_text']) + print("Ref_audio:", voices[voice]["ref_audio"]) + print("Ref_text:", voices[voice]["ref_text"]) generated_audio_segments = [] - reg1 = r'(?=\[\w+\])' + reg1 = r"(?=\[\w+\])" chunks = re.split(reg1, text_gen) - reg2 = r'\[(\w+)\]' + reg2 = r"\[(\w+)\]" for text in chunks: match = re.match(reg2, text) if match: @@ -160,8 +151,8 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence): voice = "main" text = re.sub(reg2, "", text) gen_text = text.strip() - ref_audio = voices[voice]['ref_audio'] - ref_text = voices[voice]['ref_text'] + ref_audio = voices[voice]["ref_audio"] + ref_text = voices[voice]["ref_text"] print(f"Voice: {voice}") audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj) generated_audio_segments.append(audio) diff --git a/model/__init__.py b/model/__init__.py index d505b15dc..4b4f03139 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -5,3 +5,6 @@ from model.backbones.mmdit import MMDiT from model.trainer import Trainer + + +__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] diff --git a/model/backbones/dit.py b/model/backbones/dit.py index c91040c60..9ff535134 100644 --- a/model/backbones/dit.py +++ b/model/backbones/dit.py @@ -21,14 +21,16 @@ ConvPositionEmbedding, DiTBlock, AdaLayerNormZero_Final, - precompute_freqs_cis, get_pos_embed_indices, + precompute_freqs_cis, + get_pos_embed_indices, ) # Text embedding + class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token @@ -36,20 +38,22 @@ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2): self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) - self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) + self.text_blocks = nn.Sequential( + *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + ) else: self.extra_modeling = False - def forward(self, text: int['b nt'], seq_len, drop_text = False): + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 batch, text_len = text.shape[0], text.shape[1] text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens - text = F.pad(text, (0, seq_len - text_len), value = 0) + text = F.pad(text, (0, seq_len - text_len), value=0) if drop_text: # cfg for text text = torch.zeros_like(text) - text = self.text_embed(text) # b n -> b n d + text = self.text_embed(text) # b n -> b n d # possible extra modeling if self.extra_modeling: @@ -67,88 +71,91 @@ def forward(self, text: int['b nt'], seq_len, drop_text = False): # noised input audio and context mixing embedding + class InputEmbedding(nn.Module): def __init__(self, mel_dim, text_dim, out_dim): super().__init__() self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) - self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False): + def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) - x = self.proj(torch.cat((x, cond, text_embed), dim = -1)) + x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) x = self.conv_pos_embed(x) + x return x - + # Transformer backbone using DiT blocks + class DiT(nn.Module): - def __init__(self, *, - dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4, - mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0, - long_skip_connection = False, + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_num_embeds=256, + text_dim=None, + conv_layers=0, + long_skip_connection=False, ): super().__init__() self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim - self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers) + self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) self.dim = dim self.depth = depth - + self.transformer_blocks = nn.ModuleList( - [ - DiTBlock( - dim = dim, - heads = heads, - dim_head = dim_head, - ff_mult = ff_mult, - dropout = dropout - ) - for _ in range(depth) - ] + [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)] ) - self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None - + self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) def forward( self, - x: float['b n d'], # nosied input audio - cond: float['b n d'], # masked cond audio - text: int['b nt'], # text - time: float['b'] | float[''], # time step + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 drop_audio_cond, # cfg for cond audio - drop_text, # cfg for text - mask: bool['b n'] | None = None, + drop_text, # cfg for text + mask: bool["b n"] | None = None, # noqa: F722 ): batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: time = time.repeat(batch) - + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) - text_embed = self.text_embed(text, seq_len, drop_text = drop_text) - x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond) - + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + rope = self.rotary_embed.forward_from_seq_len(seq_len) if self.long_skip_connection is not None: residual = x for block in self.transformer_blocks: - x = block(x, t, mask = mask, rope = rope) + x = block(x, t, mask=mask, rope=rope) if self.long_skip_connection is not None: - x = self.long_skip_connection(torch.cat((x, residual), dim = -1)) + x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) x = self.norm_out(x, t) output = self.proj_out(x) diff --git a/model/backbones/mmdit.py b/model/backbones/mmdit.py index a0ff3b0ab..86313b136 100644 --- a/model/backbones/mmdit.py +++ b/model/backbones/mmdit.py @@ -19,12 +19,14 @@ ConvPositionEmbedding, MMDiTBlock, AdaLayerNormZero_Final, - precompute_freqs_cis, get_pos_embed_indices, + precompute_freqs_cis, + get_pos_embed_indices, ) # text embedding + class TextEmbedding(nn.Module): def __init__(self, out_dim, text_num_embeds): super().__init__() @@ -33,7 +35,7 @@ def __init__(self, out_dim, text_num_embeds): self.precompute_max_pos = 1024 self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) - def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']: + def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722 text = text + 1 if drop_text: text = torch.zeros_like(text) @@ -52,27 +54,37 @@ def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']: # noised input & masked cond audio embedding + class AudioEmbedding(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.linear = nn.Linear(2 * in_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(out_dim) - def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False): + def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722 if drop_audio_cond: cond = torch.zeros_like(cond) - x = torch.cat((x, cond), dim = -1) + x = torch.cat((x, cond), dim=-1) x = self.linear(x) x = self.conv_pos_embed(x) + x return x - + # Transformer backbone using MM-DiT blocks + class MMDiT(nn.Module): - def __init__(self, *, - dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4, - text_num_embeds = 256, mel_dim = 100, + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + text_num_embeds=256, + mel_dim=100, ): super().__init__() @@ -84,16 +96,16 @@ def __init__(self, *, self.dim = dim self.depth = depth - + self.transformer_blocks = nn.ModuleList( [ MMDiTBlock( - dim = dim, - heads = heads, - dim_head = dim_head, - dropout = dropout, - ff_mult = ff_mult, - context_pre_only = i == depth - 1, + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ff_mult=ff_mult, + context_pre_only=i == depth - 1, ) for i in range(depth) ] @@ -103,13 +115,13 @@ def __init__(self, *, def forward( self, - x: float['b n d'], # nosied input audio - cond: float['b n d'], # masked cond audio - text: int['b nt'], # text - time: float['b'] | float[''], # time step + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 drop_audio_cond, # cfg for cond audio - drop_text, # cfg for text - mask: bool['b n'] | None = None, + drop_text, # cfg for text + mask: bool["b n"] | None = None, # noqa: F722 ): batch = x.shape[0] if time.ndim == 0: @@ -117,16 +129,16 @@ def forward( # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) - c = self.text_embed(text, drop_text = drop_text) - x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond) + c = self.text_embed(text, drop_text=drop_text) + x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond) seq_len = x.shape[1] text_len = text.shape[1] rope_audio = self.rotary_embed.forward_from_seq_len(seq_len) rope_text = self.rotary_embed.forward_from_seq_len(text_len) - + for block in self.transformer_blocks: - c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text) + c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text) x = self.norm_out(x, t) output = self.proj_out(x) diff --git a/model/backbones/unett.py b/model/backbones/unett.py index 89718632c..c4ce2c64d 100644 --- a/model/backbones/unett.py +++ b/model/backbones/unett.py @@ -24,14 +24,16 @@ Attention, AttnProcessor, FeedForward, - precompute_freqs_cis, get_pos_embed_indices, + precompute_freqs_cis, + get_pos_embed_indices, ) # Text embedding + class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token @@ -39,20 +41,22 @@ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2): self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) - self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) + self.text_blocks = nn.Sequential( + *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + ) else: self.extra_modeling = False - def forward(self, text: int['b nt'], seq_len, drop_text = False): + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 batch, text_len = text.shape[0], text.shape[1] text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens - text = F.pad(text, (0, seq_len - text_len), value = 0) + text = F.pad(text, (0, seq_len - text_len), value=0) if drop_text: # cfg for text text = torch.zeros_like(text) - text = self.text_embed(text) # b n -> b n d + text = self.text_embed(text) # b n -> b n d # possible extra modeling if self.extra_modeling: @@ -70,28 +74,40 @@ def forward(self, text: int['b nt'], seq_len, drop_text = False): # noised input audio and context mixing embedding + class InputEmbedding(nn.Module): def __init__(self, mel_dim, text_dim, out_dim): super().__init__() self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) - self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False): + def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) - x = self.proj(torch.cat((x, cond, text_embed), dim = -1)) + x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) x = self.conv_pos_embed(x) + x return x # Flat UNet Transformer backbone + class UNetT(nn.Module): - def __init__(self, *, - dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4, - mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0, - skip_connect_type: Literal['add', 'concat', 'none'] = 'concat', + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_num_embeds=256, + text_dim=None, + conv_layers=0, + skip_connect_type: Literal["add", "concat", "none"] = "concat", ): super().__init__() assert depth % 2 == 0, "UNet-Transformer's depth should be even." @@ -99,7 +115,7 @@ def __init__(self, *, self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim - self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers) + self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -108,7 +124,7 @@ def __init__(self, *, self.dim = dim self.skip_connect_type = skip_connect_type - needs_skip_proj = skip_connect_type == 'concat' + needs_skip_proj = skip_connect_type == "concat" self.depth = depth self.layers = nn.ModuleList([]) @@ -118,53 +134,57 @@ def __init__(self, *, attn_norm = RMSNorm(dim) attn = Attention( - processor = AttnProcessor(), - dim = dim, - heads = heads, - dim_head = dim_head, - dropout = dropout, - ) + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) ff_norm = RMSNorm(dim) - ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh") - - skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None - - self.layers.append(nn.ModuleList([ - skip_proj, - attn_norm, - attn, - ff_norm, - ff, - ])) + ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None + + self.layers.append( + nn.ModuleList( + [ + skip_proj, + attn_norm, + attn, + ff_norm, + ff, + ] + ) + ) self.norm_out = RMSNorm(dim) self.proj_out = nn.Linear(dim, mel_dim) def forward( self, - x: float['b n d'], # nosied input audio - cond: float['b n d'], # masked cond audio - text: int['b nt'], # text - time: float['b'] | float[''], # time step + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 drop_audio_cond, # cfg for cond audio - drop_text, # cfg for text - mask: bool['b n'] | None = None, + drop_text, # cfg for text + mask: bool["b n"] | None = None, # noqa: F722 ): batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: time = time.repeat(batch) - + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) - text_embed = self.text_embed(text, seq_len, drop_text = drop_text) - x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond) + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) # postfix time t to input x, [b n d] -> [b n+1 d] x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x if mask is not None: mask = F.pad(mask, (1, 0), value=1) - + rope = self.rotary_embed.forward_from_seq_len(seq_len + 1) # flat unet transformer @@ -182,14 +202,14 @@ def forward( if is_later_half: skip = skips.pop() - if skip_connect_type == 'concat': - x = torch.cat((x, skip), dim = -1) + if skip_connect_type == "concat": + x = torch.cat((x, skip), dim=-1) x = maybe_skip_proj(x) - elif skip_connect_type == 'add': + elif skip_connect_type == "add": x = x + skip # attention and feedforward blocks - x = attn(attn_norm(x), rope = rope, mask = mask) + x + x = attn(attn_norm(x), rope=rope, mask=mask) + x x = ff(ff_norm(x)) + x assert len(skips) == 0 diff --git a/model/cfm.py b/model/cfm.py index aa6d1c854..e2d7f726f 100644 --- a/model/cfm.py +++ b/model/cfm.py @@ -20,29 +20,32 @@ from model.modules import MelSpec from model.utils import ( - default, exists, - list_str_to_idx, list_str_to_tensor, - lens_to_mask, mask_from_frac_lengths, -) + default, + exists, + list_str_to_idx, + list_str_to_tensor, + lens_to_mask, + mask_from_frac_lengths, +) class CFM(nn.Module): def __init__( self, transformer: nn.Module, - sigma = 0., + sigma=0.0, odeint_kwargs: dict = dict( # atol = 1e-5, # rtol = 1e-5, - method = 'euler' # 'midpoint' + method="euler" # 'midpoint' ), - audio_drop_prob = 0.3, - cond_drop_prob = 0.2, - num_channels = None, + audio_drop_prob=0.3, + cond_drop_prob=0.2, + num_channels=None, mel_spec_module: nn.Module | None = None, mel_spec_kwargs: dict = dict(), - frac_lengths_mask: tuple[float, float] = (0.7, 1.), - vocab_char_map: dict[str: int] | None = None + frac_lengths_mask: tuple[float, float] = (0.7, 1.0), + vocab_char_map: dict[str:int] | None = None, ): super().__init__() @@ -78,21 +81,21 @@ def device(self): @torch.no_grad() def sample( self, - cond: float['b n d'] | float['b nw'], - text: int['b nt'] | list[str], - duration: int | int['b'], + cond: float["b n d"] | float["b nw"], # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + duration: int | int["b"], # noqa: F821 *, - lens: int['b'] | None = None, - steps = 32, - cfg_strength = 1., - sway_sampling_coef = None, + lens: int["b"] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, seed: int | None = None, - max_duration = 4096, - vocoder: Callable[[float['b d n']], float['b nw']] | None = None, - no_ref_audio = False, - duplicate_test = False, - t_inter = 0.1, - edit_mask = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + no_ref_audio=False, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, ): self.eval() @@ -108,7 +111,7 @@ def sample( batch, cond_seq_len, device = *cond.shape[:2], cond.device if not exists(lens): - lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long) + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) # text @@ -120,8 +123,8 @@ def sample( assert text.shape[0] == batch if exists(text): - text_lens = (text != -1).sum(dim = -1) - lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters + text_lens = (text != -1).sum(dim=-1) + lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters # duration @@ -130,20 +133,22 @@ def sample( cond_mask = cond_mask & edit_mask if isinstance(duration, int): - duration = torch.full((batch,), duration, device = device, dtype = torch.long) + duration = torch.full((batch,), duration, device=device, dtype=torch.long) - duration = torch.maximum(lens + 1, duration) # just add one token so something is generated - duration = duration.clamp(max = max_duration) + duration = torch.maximum(lens + 1, duration) # just add one token so something is generated + duration = duration.clamp(max=max_duration) max_duration = duration.amax() - + # duplicate test corner for inner time step oberservation if duplicate_test: - test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.) - - cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.) - cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False) + test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) + + cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) cond_mask = cond_mask.unsqueeze(-1) - step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in + step_cond = torch.where( + cond_mask, cond, torch.zeros_like(cond) + ) # allow direct control (cut cond audio) with lens passed in if batch > 1: mask = lens_to_mask(duration) @@ -161,11 +166,15 @@ def fn(t, x): # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # predict flow - pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False) + pred = self.transformer( + x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False + ) if cfg_strength < 1e-5: return pred - - null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True) + + null_pred = self.transformer( + x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True + ) return pred + (pred - null_pred) * cfg_strength # noise input @@ -175,8 +184,8 @@ def fn(t, x): for dur in duration: if exists(seed): torch.manual_seed(seed) - y0.append(torch.randn(dur, self.num_channels, device = self.device, dtype=step_cond.dtype)) - y0 = pad_sequence(y0, padding_value = 0, batch_first = True) + y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) + y0 = pad_sequence(y0, padding_value=0, batch_first=True) t_start = 0 @@ -186,12 +195,12 @@ def fn(t, x): y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) - t = torch.linspace(t_start, 1, steps, device = self.device, dtype=step_cond.dtype) + t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) trajectory = odeint(fn, y0, t, **self.odeint_kwargs) - + sampled = trajectory[-1] out = sampled out = torch.where(cond_mask, cond, out) @@ -204,10 +213,10 @@ def fn(t, x): def forward( self, - inp: float['b n d'] | float['b nw'], # mel or raw wave - text: int['b nt'] | list[str], + inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 *, - lens: int['b'] | None = None, + lens: int["b"] | None = None, # noqa: F821 noise_scheduler: str | None = None, ): # handle raw wave @@ -216,7 +225,7 @@ def forward( inp = inp.permute(0, 2, 1) assert inp.shape[-1] == self.num_channels - batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma + batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma # handle text as string if isinstance(text, list): @@ -228,12 +237,12 @@ def forward( # lens and mask if not exists(lens): - lens = torch.full((batch,), seq_len, device = device) - - mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch + lens = torch.full((batch,), seq_len, device=device) + + mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch # get a random span to mask out for training conditionally - frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) + frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) if exists(mask): @@ -246,7 +255,7 @@ def forward( x0 = torch.randn_like(x1) # time step - time = torch.rand((batch,), dtype = dtype, device = self.device) + time = torch.rand((batch,), dtype=dtype, device=self.device) # TODO. noise_scheduler # sample xt (φ_t(x) in the paper) @@ -255,10 +264,7 @@ def forward( flow = x1 - x0 # only predict what is within the random mask span for infilling - cond = torch.where( - rand_span_mask[..., None], - torch.zeros_like(x1), x1 - ) + cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) # transformer and cfg training with a drop rate drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper @@ -267,13 +273,15 @@ def forward( drop_text = True else: drop_text = False - + # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences - pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text) + pred = self.transformer( + x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text + ) # flow matching loss - loss = F.mse_loss(pred, flow, reduction = 'none') + loss = F.mse_loss(pred, flow, reduction="none") loss = loss[rand_span_mask] return loss.mean(), cond, pred diff --git a/model/dataset.py b/model/dataset.py index 00e91006d..03ed473fa 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torch.utils.data import Dataset, Sampler import torchaudio -from datasets import load_dataset, load_from_disk +from datasets import load_from_disk from datasets import Dataset as Dataset_ from model.modules import MelSpec @@ -16,53 +16,55 @@ class HFDataset(Dataset): def __init__( self, hf_dataset: Dataset, - target_sample_rate = 24_000, - n_mel_channels = 100, - hop_length = 256, + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, ): self.data = hf_dataset self.target_sample_rate = target_sample_rate self.hop_length = hop_length - self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length) - + self.mel_spectrogram = MelSpec( + target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length + ) + def get_frame_len(self, index): row = self.data[index] - audio = row['audio']['array'] - sample_rate = row['audio']['sampling_rate'] + audio = row["audio"]["array"] + sample_rate = row["audio"]["sampling_rate"] return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length def __len__(self): return len(self.data) - + def __getitem__(self, index): row = self.data[index] - audio = row['audio']['array'] + audio = row["audio"]["array"] # logger.info(f"Audio shape: {audio.shape}") - sample_rate = row['audio']['sampling_rate'] + sample_rate = row["audio"]["sampling_rate"] duration = audio.shape[-1] / sample_rate if duration > 30 or duration < 0.3: return self.__getitem__((index + 1) % len(self.data)) - + audio_tensor = torch.from_numpy(audio).float() - + if sample_rate != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) audio_tensor = resampler(audio_tensor) - + audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t') - + mel_spec = self.mel_spectrogram(audio_tensor) - + mel_spec = mel_spec.squeeze(0) # '1 d t -> d t' - - text = row['text'] - + + text = row["text"] + return dict( - mel_spec = mel_spec, - text = text, + mel_spec=mel_spec, + text=text, ) @@ -70,11 +72,11 @@ class CustomDataset(Dataset): def __init__( self, custom_dataset: Dataset, - durations = None, - target_sample_rate = 24_000, - hop_length = 256, - n_mel_channels = 100, - preprocessed_mel = False, + durations=None, + target_sample_rate=24_000, + hop_length=256, + n_mel_channels=100, + preprocessed_mel=False, ): self.data = custom_dataset self.durations = durations @@ -82,16 +84,20 @@ def __init__( self.hop_length = hop_length self.preprocessed_mel = preprocessed_mel if not preprocessed_mel: - self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels) + self.mel_spectrogram = MelSpec( + target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels + ) def get_frame_len(self, index): - if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM + if ( + self.durations is not None + ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM return self.durations[index] * self.target_sample_rate / self.hop_length return self.data[index]["duration"] * self.target_sample_rate / self.hop_length - + def __len__(self): return len(self.data) - + def __getitem__(self, index): row = self.data[index] audio_path = row["audio_path"] @@ -108,45 +114,52 @@ def __getitem__(self, index): if duration > 30 or duration < 0.3: return self.__getitem__((index + 1) % len(self.data)) - + if source_sample_rate != self.target_sample_rate: resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) audio = resampler(audio) - + mel_spec = self.mel_spectrogram(audio) mel_spec = mel_spec.squeeze(0) # '1 d t -> d t') - + return dict( - mel_spec = mel_spec, - text = text, + mel_spec=mel_spec, + text=text, ) - + # Dynamic Batch Sampler + class DynamicBatchSampler(Sampler[list[int]]): - """ Extension of Sampler that will do the following: - 1. Change the batch size (essentially number of sequences) - in a batch to ensure that the total number of frames are less - than a certain threshold. - 2. Make sure the padding efficiency in the batch is high. + """Extension of Sampler that will do the following: + 1. Change the batch size (essentially number of sequences) + in a batch to ensure that the total number of frames are less + than a certain threshold. + 2. Make sure the padding efficiency in the batch is high. """ - def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False): + def __init__( + self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False + ): self.sampler = sampler self.frames_threshold = frames_threshold self.max_samples = max_samples indices, batches = [], [] data_source = self.sampler.data_source - - for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"): + + for idx in tqdm( + self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration" + ): indices.append((idx, data_source.get_frame_len(idx))) - indices.sort(key=lambda elem : elem[1]) + indices.sort(key=lambda elem: elem[1]) batch = [] batch_frames = 0 - for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"): + for idx, frame_len in tqdm( + indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu" + ): if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples): batch.append(idx) batch_frames += frame_len @@ -182,76 +195,86 @@ def __len__(self): # Load dataset + def load_dataset( - dataset_name: str, - tokenizer: str = "pinyin", - dataset_type: str = "CustomDataset", - audio_type: str = "raw", - mel_spec_kwargs: dict = dict() - ) -> CustomDataset | HFDataset: - ''' + dataset_name: str, + tokenizer: str = "pinyin", + dataset_type: str = "CustomDataset", + audio_type: str = "raw", + mel_spec_kwargs: dict = dict(), +) -> CustomDataset | HFDataset: + """ dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer - ''' - + """ + print("Loading dataset ...") if dataset_type == "CustomDataset": if audio_type == "raw": try: train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw") - except: + except: # noqa: E722 train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow") preprocessed_mel = False elif audio_type == "mel": train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow") preprocessed_mel = True - with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f: + with open(f"data/{dataset_name}_{tokenizer}/duration.json", "r", encoding="utf-8") as f: data_dict = json.load(f) durations = data_dict["duration"] - train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs) - + train_dataset = CustomDataset( + train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs + ) + elif dataset_type == "CustomDatasetPath": try: train_dataset = load_from_disk(f"{dataset_name}/raw") - except: + except: # noqa: E722 train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow") - - with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f: + + with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f: data_dict = json.load(f) durations = data_dict["duration"] - train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs) - + train_dataset = CustomDataset( + train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs + ) + elif dataset_type == "HFDataset": - print("Should manually modify the path of huggingface dataset to your need.\n" + - "May also the corresponding script cuz different dataset may have different format.") + print( + "Should manually modify the path of huggingface dataset to your need.\n" + + "May also the corresponding script cuz different dataset may have different format." + ) pre, post = dataset_name.split("_") - train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),) + train_dataset = HFDataset( + load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"), + ) return train_dataset # collation + def collate_fn(batch): - mel_specs = [item['mel_spec'].squeeze(0) for item in batch] + mel_specs = [item["mel_spec"].squeeze(0) for item in batch] mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) max_mel_length = mel_lengths.amax() padded_mel_specs = [] for spec in mel_specs: # TODO. maybe records mask for attention here padding = (0, max_mel_length - spec.size(-1)) - padded_spec = F.pad(spec, padding, value = 0) + padded_spec = F.pad(spec, padding, value=0) padded_mel_specs.append(padded_spec) - + mel_specs = torch.stack(padded_mel_specs) - text = [item['text'] for item in batch] + text = [item["text"] for item in batch] text_lengths = torch.LongTensor([len(item) for item in text]) return dict( - mel = mel_specs, - mel_lengths = mel_lengths, - text = text, - text_lengths = text_lengths, + mel=mel_specs, + mel_lengths=mel_lengths, + text=text, + text_lengths=text_lengths, ) diff --git a/model/ecapa_tdnn.py b/model/ecapa_tdnn.py index 30b611eda..6bc431eb9 100644 --- a/model/ecapa_tdnn.py +++ b/model/ecapa_tdnn.py @@ -9,13 +9,14 @@ import torch.nn.functional as F -''' Res2Conv1d + BatchNorm1d + ReLU -''' +""" Res2Conv1d + BatchNorm1d + ReLU +""" + class Res2Conv1dReluBn(nn.Module): - ''' + """ in_channels == out_channels == channels - ''' + """ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4): super().__init__() @@ -51,8 +52,9 @@ def forward(self, x): return out -''' Conv1d + BatchNorm1d + ReLU -''' +""" Conv1d + BatchNorm1d + ReLU +""" + class Conv1dReluBn(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True): @@ -64,8 +66,9 @@ def forward(self, x): return self.bn(F.relu(self.conv(x))) -''' The SE connection of 1D case. -''' +""" The SE connection of 1D case. +""" + class SE_Connect(nn.Module): def __init__(self, channels, se_bottleneck_dim=128): @@ -82,8 +85,8 @@ def forward(self, x): return out -''' SE-Res2Block of the ECAPA-TDNN architecture. -''' +""" SE-Res2Block of the ECAPA-TDNN architecture. +""" # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): # return nn.Sequential( @@ -93,6 +96,7 @@ def forward(self, x): # SE_Connect(channels) # ) + class SE_Res2Block(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim): super().__init__() @@ -122,8 +126,9 @@ def forward(self, x): return x + residual -''' Attentive weighted mean and standard deviation pooling. -''' +""" Attentive weighted mean and standard deviation pooling. +""" + class AttentiveStatsPool(nn.Module): def __init__(self, in_dim, attention_channels=128, global_context_att=False): @@ -138,7 +143,6 @@ def __init__(self, in_dim, attention_channels=128, global_context_att=False): self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper def forward(self, x): - if self.global_context_att: context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) @@ -151,38 +155,52 @@ def forward(self, x): # alpha = F.relu(self.linear1(x_in)) alpha = torch.softmax(self.linear2(alpha), dim=2) mean = torch.sum(alpha * x, dim=2) - residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2 + residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 std = torch.sqrt(residuals.clamp(min=1e-9)) return torch.cat([mean, std], dim=1) class ECAPA_TDNN(nn.Module): - def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False, - feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): + def __init__( + self, + feat_dim=80, + channels=512, + emb_dim=192, + global_context_att=False, + feat_type="wavlm_large", + sr=16000, + feature_selection="hidden_states", + update_extract=False, + config_path=None, + ): super().__init__() self.feat_type = feat_type self.feature_selection = feature_selection self.update_extract = update_extract self.sr = sr - - torch.hub._validate_not_a_forked_repo=lambda a,b,c: True + + torch.hub._validate_not_a_forked_repo = lambda a, b, c: True try: local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main") - self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path) - except: - self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type) + self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path) + except: # noqa: E722 + self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) - if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"): + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( + self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention" + ): self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False - if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"): + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( + self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention" + ): self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False self.feat_num = self.get_feat_num() self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) - if feat_type != 'fbank' and feat_type != 'mfcc': - freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer'] + if feat_type != "fbank" and feat_type != "mfcc": + freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"] for name, param in self.feature_extract.named_parameters(): for freeze_val in freeze_list: if freeze_val in name: @@ -198,18 +216,46 @@ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=Fa self.channels = [channels] * 4 + [1536] self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) - self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128) - self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128) - self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128) + self.layer2 = SE_Res2Block( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=1, + padding=2, + dilation=2, + scale=8, + se_bottleneck_dim=128, + ) + self.layer3 = SE_Res2Block( + self.channels[1], + self.channels[2], + kernel_size=3, + stride=1, + padding=3, + dilation=3, + scale=8, + se_bottleneck_dim=128, + ) + self.layer4 = SE_Res2Block( + self.channels[2], + self.channels[3], + kernel_size=3, + stride=1, + padding=4, + dilation=4, + scale=8, + se_bottleneck_dim=128, + ) # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) cat_channels = channels * 3 self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) - self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att) + self.pooling = AttentiveStatsPool( + self.channels[-1], attention_channels=128, global_context_att=global_context_att + ) self.bn = nn.BatchNorm1d(self.channels[-1] * 2) self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) - def get_feat_num(self): self.feature_extract.eval() wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)] @@ -226,12 +272,12 @@ def get_feat(self, x): x = self.feature_extract([sample for sample in x]) else: with torch.no_grad(): - if self.feat_type == 'fbank' or self.feat_type == 'mfcc': + if self.feat_type == "fbank" or self.feat_type == "mfcc": x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len else: x = self.feature_extract([sample for sample in x]) - if self.feat_type == 'fbank': + if self.feat_type == "fbank": x = x.log() if self.feat_type != "fbank" and self.feat_type != "mfcc": @@ -263,6 +309,22 @@ def forward(self, x): return out -def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): - return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim, - feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path) +def ECAPA_TDNN_SMALL( + feat_dim, + emb_dim=256, + feat_type="wavlm_large", + sr=16000, + feature_selection="hidden_states", + update_extract=False, + config_path=None, +): + return ECAPA_TDNN( + feat_dim=feat_dim, + channels=512, + emb_dim=emb_dim, + feat_type=feat_type, + sr=sr, + feature_selection=feature_selection, + update_extract=update_extract, + config_path=config_path, + ) diff --git a/model/modules.py b/model/modules.py index 7fe3a0110..c026eff10 100644 --- a/model/modules.py +++ b/model/modules.py @@ -21,39 +21,40 @@ # raw wav to mel spec + class MelSpec(nn.Module): def __init__( self, - filter_length = 1024, - hop_length = 256, - win_length = 1024, - n_mel_channels = 100, - target_sample_rate = 24_000, - normalize = False, - power = 1, - norm = None, - center = True, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + normalize=False, + power=1, + norm=None, + center=True, ): super().__init__() self.n_mel_channels = n_mel_channels self.mel_stft = torchaudio.transforms.MelSpectrogram( - sample_rate = target_sample_rate, - n_fft = filter_length, - win_length = win_length, - hop_length = hop_length, - n_mels = n_mel_channels, - power = power, - center = center, - normalized = normalize, - norm = norm, + sample_rate=target_sample_rate, + n_fft=filter_length, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=power, + center=center, + normalized=normalize, + norm=norm, ) - self.register_buffer('dummy', torch.tensor(0), persistent = False) + self.register_buffer("dummy", torch.tensor(0), persistent=False) def forward(self, inp): if len(inp.shape) == 3: - inp = inp.squeeze(1) # 'b 1 nw -> b nw' + inp = inp.squeeze(1) # 'b 1 nw -> b nw' assert len(inp.shape) == 2 @@ -61,12 +62,13 @@ def forward(self, inp): self.to(inp.device) mel = self.mel_stft(inp) - mel = mel.clamp(min = 1e-5).log() + mel = mel.clamp(min=1e-5).log() return mel - + # sinusoidal position embedding + class SinusPositionEmbedding(nn.Module): def __init__(self, dim): super().__init__() @@ -84,35 +86,37 @@ def forward(self, x, scale=1000): # convolutional position embedding + class ConvPositionEmbedding(nn.Module): - def __init__(self, dim, kernel_size = 31, groups = 16): + def __init__(self, dim, kernel_size=31, groups=16): super().__init__() assert kernel_size % 2 != 0 self.conv1d = nn.Sequential( - nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), nn.Mish(), - nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), nn.Mish(), ) - def forward(self, x: float['b n d'], mask: bool['b n'] | None = None): + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 if mask is not None: mask = mask[..., None] - x = x.masked_fill(~mask, 0.) + x = x.masked_fill(~mask, 0.0) x = x.permute(0, 2, 1) x = self.conv1d(x) out = x.permute(0, 2, 1) if mask is not None: - out = out.masked_fill(~mask, 0.) + out = out.masked_fill(~mask, 0.0) return out # rotary positional embedding related -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.): + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ @@ -125,12 +129,14 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca freqs_sin = torch.sin(freqs) # imaginary part return torch.cat([freqs_cos, freqs_sin], dim=-1) -def get_pos_embed_indices(start, length, max_pos, scale=1.): + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): # length = length if isinstance(length, int) else length.max() scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar - pos = start.unsqueeze(1) + ( - torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * - scale.unsqueeze(1)).long() + pos = ( + start.unsqueeze(1) + + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() + ) # avoid extra long error. pos = torch.where(pos < max_pos, pos, max_pos - 1) return pos @@ -138,6 +144,7 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.): # Global Response Normalization layer (Instance Normalization ?) + class GRN(nn.Module): def __init__(self, dim): super().__init__() @@ -153,6 +160,7 @@ def forward(self, x): # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 + class ConvNeXtV2Block(nn.Module): def __init__( self, @@ -162,7 +170,9 @@ def __init__( ): super().__init__() padding = (dilation * (7 - 1)) // 2 - self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() @@ -185,6 +195,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # AdaLayerNormZero # return with modulated x for attn input, and params for later mlp modulation + class AdaLayerNormZero(nn.Module): def __init__(self, dim): super().__init__() @@ -194,7 +205,7 @@ def __init__(self, dim): self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - def forward(self, x, emb = None): + def forward(self, x, emb=None): emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) @@ -205,6 +216,7 @@ def forward(self, x, emb = None): # AdaLayerNormZero for final layer # return only with modulated x for attn input, cuz no more mlp modulation + class AdaLayerNormZero_Final(nn.Module): def __init__(self, dim): super().__init__() @@ -224,22 +236,16 @@ def forward(self, x, emb): # FeedForward + class FeedForward(nn.Module): - def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim activation = nn.GELU(approximate=approximate) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - activation - ) - self.ff = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.ff(x) @@ -248,6 +254,7 @@ def forward(self, x): # Attention with possible joint part # modified from diffusers/src/diffusers/models/attention_processor.py + class Attention(nn.Module): def __init__( self, @@ -256,8 +263,8 @@ def __init__( heads: int = 8, dim_head: int = 64, dropout: float = 0.0, - context_dim: Optional[int] = None, # if not None -> joint attention - context_pre_only = None, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, ): super().__init__() @@ -293,20 +300,21 @@ def __init__( def forward( self, - x: float['b n d'], # noised input x - c: float['b n d'] = None, # context c - mask: bool['b n'] | None = None, - rope = None, # rotary position embedding for x - c_rope = None, # rotary position embedding for c + x: float["b n d"], # noised input x # noqa: F722 + c: float["b n d"] = None, # context c # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c ) -> torch.Tensor: if c is not None: - return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope) + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) else: - return self.processor(self, x, mask = mask, rope = rope) + return self.processor(self, x, mask=mask, rope=rope) # Attention processor + class AttnProcessor: def __init__(self): pass @@ -314,11 +322,10 @@ def __init__(self): def __call__( self, attn: Attention, - x: float['b n d'], # noised input x - mask: bool['b n'] | None = None, - rope = None, # rotary position embedding + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding ) -> torch.FloatTensor: - batch_size = x.shape[0] # `sample` projections. @@ -329,7 +336,7 @@ def __call__( # apply rotary position embedding if rope is not None: freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.) + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) @@ -360,14 +367,15 @@ def __call__( if mask is not None: mask = mask.unsqueeze(-1) - x = x.masked_fill(~mask, 0.) + x = x.masked_fill(~mask, 0.0) return x - + # Joint Attention processor for MM-DiT # modified from diffusers/src/diffusers/models/attention_processor.py + class JointAttnProcessor: def __init__(self): pass @@ -375,11 +383,11 @@ def __init__(self): def __call__( self, attn: Attention, - x: float['b n d'], # noised input x - c: float['b nt d'] = None, # context c, here text - mask: bool['b n'] | None = None, - rope = None, # rotary position embedding for x - c_rope = None, # rotary position embedding for c + x: float["b n d"], # noised input x # noqa: F722 + c: float["b nt d"] = None, # context c, here text # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c ) -> torch.FloatTensor: residual = x @@ -398,12 +406,12 @@ def __call__( # apply rope for context and noised input independently if rope is not None: freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.) + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) if c_rope is not None: freqs, xpos_scale = c_rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.) + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) @@ -420,7 +428,7 @@ def __call__( # mask. e.g. inference got a batch with different target durations, mask out the padding if mask is not None: - attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text) + attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) else: @@ -432,8 +440,8 @@ def __call__( # Split the attention outputs. x, c = ( - x[:, :residual.shape[1]], - x[:, residual.shape[1]:], + x[:, : residual.shape[1]], + x[:, residual.shape[1] :], ) # linear proj @@ -445,7 +453,7 @@ def __call__( if mask is not None: mask = mask.unsqueeze(-1) - x = x.masked_fill(~mask, 0.) + x = x.masked_fill(~mask, 0.0) # c = c.masked_fill(~mask, 0.) # no mask for c (text) return x, c @@ -453,24 +461,24 @@ def __call__( # DiT Block -class DiTBlock(nn.Module): - def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1): +class DiTBlock(nn.Module): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): super().__init__() - + self.attn_norm = AdaLayerNormZero(dim) self.attn = Attention( - processor = AttnProcessor(), - dim = dim, - heads = heads, - dim_head = dim_head, - dropout = dropout, - ) - + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh") + self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") - def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding # pre-norm & modulation for attention input norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) @@ -479,7 +487,7 @@ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time em # process attention output for input x x = x + gate_msa.unsqueeze(1) * attn_output - + norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm) x = x + gate_mlp.unsqueeze(1) * ff_output @@ -489,8 +497,9 @@ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time em # MMDiT Block https://arxiv.org/abs/2403.03206 + class MMDiTBlock(nn.Module): - r""" + r""" modified from diffusers/src/diffusers/models/attention.py notes. @@ -499,33 +508,33 @@ class MMDiTBlock(nn.Module): context_pre_only: last layer only do prenorm + modulation cuz no more ffn """ - def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False): super().__init__() self.context_pre_only = context_pre_only - + self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) self.attn_norm_x = AdaLayerNormZero(dim) self.attn = Attention( - processor = JointAttnProcessor(), - dim = dim, - heads = heads, - dim_head = dim_head, - dropout = dropout, - context_dim = dim, - context_pre_only = context_pre_only, - ) + processor=JointAttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + context_dim=dim, + context_pre_only=context_pre_only, + ) if not context_pre_only: self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh") + self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") else: self.ff_norm_c = None self.ff_c = None self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh") + self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") - def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding + def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding # pre-norm & modulation for attention input if self.context_pre_only: norm_c = self.attn_norm_c(c, t) @@ -539,7 +548,7 @@ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised # process attention output for context c if self.context_pre_only: c = None - else: # if not last layer + else: # if not last layer c = c + c_gate_msa.unsqueeze(1) * c_attn_output norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] @@ -548,7 +557,7 @@ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised # process attention output for input x x = x + x_gate_msa.unsqueeze(1) * x_attn_output - + norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] x_ff_output = self.ff_x(norm_x) x = x + x_gate_mlp.unsqueeze(1) * x_ff_output @@ -558,17 +567,14 @@ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised # time step conditioning embedding + class TimestepEmbedding(nn.Module): def __init__(self, dim, freq_embed_dim=256): super().__init__() self.time_embed = SinusPositionEmbedding(freq_embed_dim) - self.time_mlp = nn.Sequential( - nn.Linear(freq_embed_dim, dim), - nn.SiLU(), - nn.Linear(dim, dim) - ) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) - def forward(self, timestep: float['b']): + def forward(self, timestep: float["b"]): # noqa: F821 time_hidden = self.time_embed(timestep) time_hidden = time_hidden.to(timestep.dtype) time = self.time_mlp(time_hidden) # b d diff --git a/model/trainer.py b/model/trainer.py index 35a5f642e..470c8b376 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -22,71 +22,69 @@ # trainer + class Trainer: def __init__( self, model: CFM, epochs, learning_rate, - num_warmup_updates = 20000, - save_per_updates = 1000, - checkpoint_path = None, - batch_size = 32, + num_warmup_updates=20000, + save_per_updates=1000, + checkpoint_path=None, + batch_size=32, batch_size_type: str = "sample", - max_samples = 32, - grad_accumulation_steps = 1, - max_grad_norm = 1.0, + max_samples=32, + grad_accumulation_steps=1, + max_grad_norm=1.0, noise_scheduler: str | None = None, duration_predictor: torch.nn.Module | None = None, - wandb_project = "test_e2-tts", - wandb_run_name = "test_run", + wandb_project="test_e2-tts", + wandb_run_name="test_run", wandb_resume_id: str = None, - last_per_steps = None, + last_per_steps=None, accelerate_kwargs: dict = dict(), ema_kwargs: dict = dict(), bnb_optimizer: bool = False, ): - - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) logger = "wandb" if wandb.api.api_key else None print(f"Using logger: {logger}") self.accelerator = Accelerator( - log_with = logger, - kwargs_handlers = [ddp_kwargs], - gradient_accumulation_steps = grad_accumulation_steps, - **accelerate_kwargs + log_with=logger, + kwargs_handlers=[ddp_kwargs], + gradient_accumulation_steps=grad_accumulation_steps, + **accelerate_kwargs, ) if logger == "wandb": if exists(wandb_resume_id): - init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}} + init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} else: - init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}} + init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} self.accelerator.init_trackers( - project_name = wandb_project, + project_name=wandb_project, init_kwargs=init_kwargs, - config={"epochs": epochs, - "learning_rate": learning_rate, - "num_warmup_updates": num_warmup_updates, - "batch_size": batch_size, - "batch_size_type": batch_size_type, - "max_samples": max_samples, - "grad_accumulation_steps": grad_accumulation_steps, - "max_grad_norm": max_grad_norm, - "gpus": self.accelerator.num_processes, - "noise_scheduler": noise_scheduler} - ) + config={ + "epochs": epochs, + "learning_rate": learning_rate, + "num_warmup_updates": num_warmup_updates, + "batch_size": batch_size, + "batch_size_type": batch_size_type, + "max_samples": max_samples, + "grad_accumulation_steps": grad_accumulation_steps, + "max_grad_norm": max_grad_norm, + "gpus": self.accelerator.num_processes, + "noise_scheduler": noise_scheduler, + }, + ) self.model = model if self.is_main: - self.ema_model = EMA( - model, - include_online_model = False, - **ema_kwargs - ) + self.ema_model = EMA(model, include_online_model=False, **ema_kwargs) self.ema_model.to(self.accelerator.device) @@ -94,7 +92,7 @@ def __init__( self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps) - self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts') + self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") self.batch_size = batch_size self.batch_size_type = batch_size_type @@ -108,12 +106,11 @@ def __init__( if bnb_optimizer: import bitsandbytes as bnb + self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) else: self.optimizer = AdamW(model.parameters(), lr=learning_rate) - self.model, self.optimizer = self.accelerator.prepare( - self.model, self.optimizer - ) + self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) @property def is_main(self): @@ -123,81 +120,112 @@ def save_checkpoint(self, step, last=False): self.accelerator.wait_for_everyone() if self.is_main: checkpoint = dict( - model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(), - optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(), - ema_model_state_dict = self.ema_model.state_dict(), - scheduler_state_dict = self.scheduler.state_dict(), - step = step + model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), + optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(), + ema_model_state_dict=self.ema_model.state_dict(), + scheduler_state_dict=self.scheduler.state_dict(), + step=step, ) if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) - if last == True: + if last: self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") print(f"Saved last checkpoint at step {step}") else: self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") def load_checkpoint(self): - if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path): + if ( + not exists(self.checkpoint_path) + or not os.path.exists(self.checkpoint_path) + or not os.listdir(self.checkpoint_path) + ): return 0 - + self.accelerator.wait_for_everyone() if "model_last.pt" in os.listdir(self.checkpoint_path): latest_checkpoint = "model_last.pt" else: - latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1] + latest_checkpoint = sorted( + [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")], + key=lambda x: int("".join(filter(str.isdigit, x))), + )[-1] # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") if self.is_main: - self.ema_model.load_state_dict(checkpoint['ema_model_state_dict']) + self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"]) - if 'step' in checkpoint: - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict']) - self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict']) + if "step" in checkpoint: + self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) + self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"]) if self.scheduler: - self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) - step = checkpoint['step'] + self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + step = checkpoint["step"] else: - checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]} - self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict']) + checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "step"] + } + self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) step = 0 - del checkpoint; gc.collect() + del checkpoint + gc.collect() return step def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): - if exists(resumable_with_seed): generator = torch.Generator() generator.manual_seed(resumable_with_seed) - else: + else: generator = None if self.batch_size_type == "sample": - train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True, - batch_size=self.batch_size, shuffle=True, generator=generator) + train_dataloader = DataLoader( + train_dataset, + collate_fn=collate_fn, + num_workers=num_workers, + pin_memory=True, + persistent_workers=True, + batch_size=self.batch_size, + shuffle=True, + generator=generator, + ) elif self.batch_size_type == "frame": self.accelerator.even_batches = False sampler = SequentialSampler(train_dataset) - batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False) - train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True, - batch_sampler=batch_sampler) + batch_sampler = DynamicBatchSampler( + sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False + ) + train_dataloader = DataLoader( + train_dataset, + collate_fn=collate_fn, + num_workers=num_workers, + pin_memory=True, + persistent_workers=True, + batch_sampler=batch_sampler, + ) else: raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") - + # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices - warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp - # otherwise by default with split_batches=False, warmup steps change with num_processes + warmup_steps = ( + self.num_warmup_updates * self.accelerator.num_processes + ) # consider a fixed warmup steps while using accelerate multi-gpu ddp + # otherwise by default with split_batches=False, warmup steps change with num_processes total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps decay_steps = total_steps - warmup_steps warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) - self.scheduler = SequentialLR(self.optimizer, - schedulers=[warmup_scheduler, decay_scheduler], - milestones=[warmup_steps]) - train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus + self.scheduler = SequentialLR( + self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps] + ) + train_dataloader, self.scheduler = self.accelerator.prepare( + train_dataloader, self.scheduler + ) # actual steps = 1 gpu steps / gpus start_step = self.load_checkpoint() global_step = start_step @@ -212,23 +240,36 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int for epoch in range(skipped_epoch, self.epochs): self.model.train() if exists(resumable_with_seed) and epoch == skipped_epoch: - progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, - initial=skipped_batch, total=orig_epoch_step) + progress_bar = tqdm( + skipped_dataloader, + desc=f"Epoch {epoch+1}/{self.epochs}", + unit="step", + disable=not self.accelerator.is_local_main_process, + initial=skipped_batch, + total=orig_epoch_step, + ) else: - progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process) + progress_bar = tqdm( + train_dataloader, + desc=f"Epoch {epoch+1}/{self.epochs}", + unit="step", + disable=not self.accelerator.is_local_main_process, + ) for batch in progress_bar: with self.accelerator.accumulate(self.model): - text_inputs = batch['text'] - mel_spec = batch['mel'].permute(0, 2, 1) + text_inputs = batch["text"] + mel_spec = batch["mel"].permute(0, 2, 1) mel_lengths = batch["mel_lengths"] # TODO. add duration predictor training if self.duration_predictor is not None and self.accelerator.is_local_main_process: - dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations')) + dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations")) self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step) - loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler) + loss, cond, pred = self.model( + mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler + ) self.accelerator.backward(loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: @@ -245,13 +286,13 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int if self.accelerator.is_local_main_process: self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step) - + progress_bar.set_postfix(step=str(global_step), loss=loss.item()) - + if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0: self.save_checkpoint(global_step) - + if global_step % self.last_per_steps == 0: self.save_checkpoint(global_step, last=True) - + self.accelerator.end_training() diff --git a/model/utils.py b/model/utils.py index c898d9161..2253cb812 100644 --- a/model/utils.py +++ b/model/utils.py @@ -8,6 +8,7 @@ from collections import defaultdict import matplotlib + matplotlib.use("Agg") import matplotlib.pylab as plt @@ -25,109 +26,102 @@ # seed everything -def seed_everything(seed = 0): + +def seed_everything(seed=0): random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False + # helpers + def exists(v): return v is not None + def default(v, d): return v if exists(v) else d + # tensor helpers -def lens_to_mask( - t: int['b'], - length: int | None = None -) -> bool['b n']: +def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 if not exists(length): length = t.amax() - seq = torch.arange(length, device = t.device) + seq = torch.arange(length, device=t.device) return seq[None, :] < t[:, None] -def mask_from_start_end_indices( - seq_len: int['b'], - start: int['b'], - end: int['b'] -): - max_seq_len = seq_len.max().item() - seq = torch.arange(max_seq_len, device = start.device).long() + +def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821 + max_seq_len = seq_len.max().item() + seq = torch.arange(max_seq_len, device=start.device).long() start_mask = seq[None, :] >= start[:, None] end_mask = seq[None, :] < end[:, None] return start_mask & end_mask -def mask_from_frac_lengths( - seq_len: int['b'], - frac_lengths: float['b'] -): + +def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths rand = torch.rand_like(frac_lengths) - start = (max_start * rand).long().clamp(min = 0) + start = (max_start * rand).long().clamp(min=0) end = start + lengths return mask_from_start_end_indices(seq_len, start, end) -def maybe_masked_mean( - t: float['b n d'], - mask: bool['b n'] = None -) -> float['b d']: +def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722 if not exists(mask): - return t.mean(dim = 1) + return t.mean(dim=1) - t = torch.where(mask[:, :, None], t, torch.tensor(0., device=t.device)) + t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) num = t.sum(dim=1) den = mask.float().sum(dim=1) - return num / den.clamp(min=1.) + return num / den.clamp(min=1.0) # simple utf-8 tokenizer, since paper went character based -def list_str_to_tensor( - text: list[str], - padding_value = -1 -) -> int['b nt']: - list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style - text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True) +def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 + list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style + text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) return text + # char tokenizer, based on custom dataset's extracted .txt file def list_str_to_idx( text: list[str] | list[list[str]], vocab_char_map: dict[str, int], # {char: idx} - padding_value = -1 -) -> int['b nt']: + padding_value=-1, +) -> int["b nt"]: # noqa: F722 list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style - text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True) + text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) return text # Get tokenizer + def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): - ''' + """ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file - "char" for char-wise tokenizer, need .txt vocab_file - "byte" for utf-8 tokenizer - "custom" if you're directly passing in a path to the vocab.txt you want to use vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols - if use "char", derived from unfiltered character & symbol counts of custom dataset - - if use "byte", set to 256 (unicode byte range) - ''' + - if use "byte", set to 256 (unicode byte range) + """ if tokenizer in ["pinyin", "char"]: - with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f: + with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i @@ -138,7 +132,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): vocab_char_map = None vocab_size = 256 elif tokenizer == "custom": - with open (dataset_name, "r", encoding="utf-8") as f: + with open(dataset_name, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i @@ -149,16 +143,19 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): # convert char to pinyin -def convert_char_to_pinyin(text_list, polyphone = True): + +def convert_char_to_pinyin(text_list, polyphone=True): final_text_list = [] - god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean - custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov + god_knows_why_en_testset_contains_zh_quote = str.maketrans( + {"“": '"', "”": '"', "‘": "'", "’": "'"} + ) # in case librispeech (orig no-pc) test-clean + custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov for text in text_list: char_list = [] text = text.translate(god_knows_why_en_testset_contains_zh_quote) text = text.translate(custom_trans) for seg in jieba.cut(text): - seg_byte_len = len(bytes(seg, 'UTF-8')) + seg_byte_len = len(bytes(seg, "UTF-8")) if seg_byte_len == len(seg): # if pure alphabets and symbols if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") @@ -187,7 +184,7 @@ def convert_char_to_pinyin(text_list, polyphone = True): # save spectrogram def save_spectrogram(spectrogram, path): plt.figure(figsize=(12, 4)) - plt.imshow(spectrogram, origin='lower', aspect='auto') + plt.imshow(spectrogram, origin="lower", aspect="auto") plt.colorbar() plt.savefig(path) plt.close() @@ -195,13 +192,15 @@ def save_spectrogram(spectrogram, path): # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav def get_seedtts_testset_metainfo(metalst): - f = open(metalst); lines = f.readlines(); f.close() + f = open(metalst) + lines = f.readlines() + f.close() metainfo = [] for line in lines: - if len(line.strip().split('|')) == 5: - utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|') - elif len(line.strip().split('|')) == 4: - utt, prompt_text, prompt_wav, gt_text = line.strip().split('|') + if len(line.strip().split("|")) == 5: + utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") + elif len(line.strip().split("|")) == 4: + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") if not os.path.isabs(prompt_wav): prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) @@ -211,18 +210,20 @@ def get_seedtts_testset_metainfo(metalst): # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path): - f = open(metalst); lines = f.readlines(); f.close() + f = open(metalst) + lines = f.readlines() + f.close() metainfo = [] for line in lines: - ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t') + ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t") # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) - ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-') - ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac') + ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") + ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) - gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-') - gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac') + gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") + gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav)) @@ -234,7 +235,7 @@ def padded_mel_batch(ref_mels): max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() padded_ref_mels = [] for mel in ref_mels: - padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0) + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) padded_ref_mels.append(padded_ref_mel) padded_ref_mels = torch.stack(padded_ref_mels) padded_ref_mels = padded_ref_mels.permute(0, 2, 1) @@ -243,12 +244,21 @@ def padded_mel_batch(ref_mels): # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav + def get_inference_prompt( - metainfo, - speed = 1., tokenizer = "pinyin", polyphone = True, - target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1, - use_truth_duration = False, - infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40, + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_mel_channels=100, + hop_length=256, + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, ): prompts_all = [] @@ -256,13 +266,15 @@ def get_inference_prompt( max_tokens = max_secs * target_sample_rate // hop_length batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \ - ([[] for _ in range(num_buckets)] for _ in range(6)) + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) - mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length) + mel_spectrogram = MelSpec( + target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length + ) for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."): - # Audio ref_audio, ref_sr = torchaudio.load(prompt_wav) ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) @@ -274,11 +286,11 @@ def get_inference_prompt( ref_audio = resampler(ref_audio) # Text - if len(prompt_text[-1].encode('utf-8')) == 1: + if len(prompt_text[-1].encode("utf-8")) == 1: prompt_text = prompt_text + " " text = [prompt_text + gt_text] if tokenizer == "pinyin": - text_list = convert_char_to_pinyin(text, polyphone = polyphone) + text_list = convert_char_to_pinyin(text, polyphone=polyphone) else: text_list = text @@ -294,8 +306,8 @@ def get_inference_prompt( # # test vocoder resynthesis # ref_audio = gt_audio else: - ref_text_len = len(prompt_text.encode('utf-8')) - gen_text_len = len(gt_text.encode('utf-8')) + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed) # to mel spectrogram @@ -304,8 +316,9 @@ def get_inference_prompt( # deal with batch assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - assert min_tokens <= total_mel_len <= max_tokens, \ - f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets) utts[bucket_i].append(utt) @@ -319,28 +332,39 @@ def get_inference_prompt( if batch_accum[bucket_i] >= infer_batch_size: # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}") - prompts_all.append(( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i] - )) + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) batch_accum[bucket_i] = 0 - utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], [] + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = [], [], [], [], [], [] # add residual for bucket_i, bucket_frames in enumerate(batch_accum): if bucket_frames > 0: - prompts_all.append(( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i] - )) + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) # not only leave easy work for last workers random.seed(666) random.shuffle(prompts_all) @@ -351,6 +375,7 @@ def get_inference_prompt( # get wav_res_ref_text of seed-tts test metalst # https://github.com/BytedanceSpeech/seed-tts-eval + def get_seed_tts_test(metalst, gen_wav_dir, gpus): f = open(metalst) lines = f.readlines() @@ -358,14 +383,14 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus): test_set_ = [] for line in tqdm(lines): - if len(line.strip().split('|')) == 5: - utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|') - elif len(line.strip().split('|')) == 4: - utt, prompt_text, prompt_wav, gt_text = line.strip().split('|') + if len(line.strip().split("|")) == 5: + utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") + elif len(line.strip().split("|")) == 4: + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") - if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')): + if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")): continue - gen_wav = os.path.join(gen_wav_dir, utt + '.wav') + gen_wav = os.path.join(gen_wav_dir, utt + ".wav") if not os.path.isabs(prompt_wav): prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) @@ -374,65 +399,69 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus): num_jobs = len(gpus) if num_jobs == 1: return [(gpus[0], test_set_)] - + wav_per_job = len(test_set_) // num_jobs + 1 test_set = [] for i in range(num_jobs): - test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job])) + test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job])) return test_set # get librispeech test-clean cross sentence test -def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False): + +def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False): f = open(metalst) lines = f.readlines() f.close() test_set_ = [] for line in tqdm(lines): - ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t') + ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t") if eval_ground_truth: - gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-') - gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac') + gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") + gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") else: - if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')): + if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")): raise FileNotFoundError(f"Generated wav not found: {gen_utt}") - gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav') + gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav") - ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-') - ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac') + ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") + ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") test_set_.append((gen_wav, ref_wav, gen_txt)) num_jobs = len(gpus) if num_jobs == 1: return [(gpus[0], test_set_)] - + wav_per_job = len(test_set_) // num_jobs + 1 test_set = [] for i in range(num_jobs): - test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job])) + test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job])) return test_set # load asr model -def load_asr_model(lang, ckpt_dir = ""): + +def load_asr_model(lang, ckpt_dir=""): if lang == "zh": from funasr import AutoModel + model = AutoModel( - model = os.path.join(ckpt_dir, "paraformer-zh"), - # vad_model = os.path.join(ckpt_dir, "fsmn-vad"), + model=os.path.join(ckpt_dir, "paraformer-zh"), + # vad_model = os.path.join(ckpt_dir, "fsmn-vad"), # punc_model = os.path.join(ckpt_dir, "ct-punc"), - # spk_model = os.path.join(ckpt_dir, "cam++"), + # spk_model = os.path.join(ckpt_dir, "cam++"), disable_update=True, - ) # following seed-tts setting + ) # following seed-tts setting elif lang == "en": from faster_whisper import WhisperModel + model_size = "large-v3" if ckpt_dir == "" else ckpt_dir model = WhisperModel(model_size, device="cuda", compute_type="float16") return model @@ -440,44 +469,50 @@ def load_asr_model(lang, ckpt_dir = ""): # WER Evaluation, the way Seed-TTS does + def run_asr_wer(args): rank, lang, test_set, ckpt_dir = args if lang == "zh": import zhconv + torch.cuda.set_device(rank) elif lang == "en": os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) else: - raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.") + raise NotImplementedError( + "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now." + ) + + asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir) - asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir) - from zhon.hanzi import punctuation + punctuation_all = punctuation + string.punctuation wers = [] from jiwer import compute_measures + for gen_wav, prompt_wav, truth in tqdm(test_set): if lang == "zh": res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True) hypo = res[0]["text"] - hypo = zhconv.convert(hypo, 'zh-cn') + hypo = zhconv.convert(hypo, "zh-cn") elif lang == "en": segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en") - hypo = '' + hypo = "" for segment in segments: - hypo = hypo + ' ' + segment.text + hypo = hypo + " " + segment.text # raw_truth = truth # raw_hypo = hypo for x in punctuation_all: - truth = truth.replace(x, '') - hypo = hypo.replace(x, '') + truth = truth.replace(x, "") + hypo = hypo.replace(x, "") - truth = truth.replace(' ', ' ') - hypo = hypo.replace(' ', ' ') + truth = truth.replace(" ", " ") + hypo = hypo.replace(" ", " ") if lang == "zh": truth = " ".join([x for x in truth]) @@ -501,22 +536,22 @@ def run_asr_wer(args): # SIM Evaluation + def run_sim(args): rank, test_set, ckpt_dir = args device = f"cuda:{rank}" - model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None) + model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None) state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage) - model.load_state_dict(state_dict['model'], strict=False) + model.load_state_dict(state_dict["model"], strict=False) - use_gpu=True if torch.cuda.is_available() else False + use_gpu = True if torch.cuda.is_available() else False if use_gpu: model = model.cuda(device) model.eval() sim_list = [] for wav1, wav2, truth in tqdm(test_set): - wav1, sr1 = torchaudio.load(wav1) wav2, sr2 = torchaudio.load(wav2) @@ -531,20 +566,21 @@ def run_sim(args): with torch.no_grad(): emb1 = model(wav1) emb2 = model(wav2) - + sim = F.cosine_similarity(emb1, emb2)[0].item() # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).") sim_list.append(sim) - + return sim_list # filter func for dirty data with many repetitions -def repetition_found(text, length = 2, tolerance = 10): + +def repetition_found(text, length=2, tolerance=10): pattern_count = defaultdict(int) for i in range(len(text) - length + 1): - pattern = text[i:i + length] + pattern = text[i : i + length] pattern_count[pattern] += 1 for pattern, count in pattern_count.items(): if count > tolerance: @@ -554,25 +590,31 @@ def repetition_found(text, length = 2, tolerance = 10): # load model checkpoint for inference -def load_checkpoint(model, ckpt_path, device, use_ema = True): + +def load_checkpoint(model, ckpt_path, device, use_ema=True): if device == "cuda": model = model.half() ckpt_type = ckpt_path.split(".")[-1] if ckpt_type == "safetensors": from safetensors.torch import load_file + checkpoint = load_file(ckpt_path) else: checkpoint = torch.load(ckpt_path, weights_only=True) if use_ema: if ckpt_type == "safetensors": - checkpoint = {'ema_model_state_dict': checkpoint} - checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]} - model.load_state_dict(checkpoint['model_state_dict']) + checkpoint = {"ema_model_state_dict": checkpoint} + checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "step"] + } + model.load_state_dict(checkpoint["model_state_dict"]) else: if ckpt_type == "safetensors": - checkpoint = {'model_state_dict': checkpoint} - model.load_state_dict(checkpoint['model_state_dict']) + checkpoint = {"model_state_dict": checkpoint} + model.load_state_dict(checkpoint["model_state_dict"]) return model.to(device) diff --git a/model/utils_infer.py b/model/utils_infer.py index 80560c817..faea34cdf 100644 --- a/model/utils_infer.py +++ b/model/utils_infer.py @@ -19,11 +19,7 @@ convert_char_to_pinyin, ) -device = ( - "cuda" - if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() else "cpu" -) +device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device") asr_pipe = pipeline( @@ -54,6 +50,7 @@ # chunk text into smaller pieces + def chunk_text(text, max_chars=135): """ Splits the input text into chunks, each with a maximum number of characters. @@ -68,15 +65,15 @@ def chunk_text(text, max_chars=135): chunks = [] current_chunk = "" # Split the text into sentences based on punctuation followed by whitespace - sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text) + sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text) for sentence in sentences: - if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars: - current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence + if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: + current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence else: if current_chunk: chunks.append(current_chunk.strip()) - current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence + current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence if current_chunk: chunks.append(current_chunk.strip()) @@ -86,6 +83,7 @@ def chunk_text(text, max_chars=135): # load vocoder + def load_vocoder(is_local=False, local_path=""): if is_local: print(f"Load vocos from local path {local_path}") @@ -101,23 +99,21 @@ def load_vocoder(is_local=False, local_path=""): # load model for inference + def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""): - if vocab_file == "": vocab_file = "Emilia_ZH_EN" tokenizer = "pinyin" else: tokenizer = "custom" - print("\nvocab : ", vocab_file, tokenizer) - print("tokenizer : ", tokenizer) - print("model : ", ckpt_path,"\n") + print("\nvocab : ", vocab_file, tokenizer) + print("tokenizer : ", tokenizer) + print("model : ", ckpt_path, "\n") vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer) model = CFM( - transformer=model_cls( - **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels - ), + transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, @@ -129,21 +125,20 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""): vocab_char_map=vocab_char_map, ).to(device) - model = load_checkpoint(model, ckpt_path, device, use_ema = True) + model = load_checkpoint(model, ckpt_path, device, use_ema=True) return model # preprocess reference audio and text + def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print): show_info("Converting audio...") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig) - non_silent_segs = silence.split_on_silence( - aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000 - ) + non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: non_silent_wave += non_silent_seg @@ -181,22 +176,27 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print): # infer process: chunk text -> infer batches [i.e. infer_batch_process()] -def infer_process(ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm): +def infer_process( + ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm +): # Split the input text into batches audio, sr = torchaudio.load(ref_audio) - max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr)) + max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr)) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) for i, gen_text in enumerate(gen_text_batches): - print(f'gen_text {i}', gen_text) - + print(f"gen_text {i}", gen_text) + show_info(f"Generating audio in {len(gen_text_batches)} batches...") return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress) # infer batches -def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm): + +def infer_batch_process( + ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm +): audio, sr = ref_audio if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) @@ -212,7 +212,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_ generated_waves = [] spectrograms = [] - if len(ref_text[-1].encode('utf-8')) == 1: + if len(ref_text[-1].encode("utf-8")) == 1: ref_text = ref_text + " " for i, gen_text in enumerate(progress.tqdm(gen_text_batches)): # Prepare the text @@ -221,8 +221,8 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_ # Calculate duration ref_audio_len = audio.shape[-1] // hop_length - ref_text_len = len(ref_text.encode('utf-8')) - gen_text_len = len(gen_text.encode('utf-8')) + ref_text_len = len(ref_text.encode("utf-8")) + gen_text_len = len(gen_text.encode("utf-8")) duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) # inference @@ -245,7 +245,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_ # wav -> numpy generated_wave = generated_wave.squeeze().cpu().numpy() - + generated_waves.append(generated_wave) spectrograms.append(generated_mel_spec[0].cpu().numpy()) @@ -280,11 +280,9 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in # Combine - new_wave = np.concatenate([ - prev_wave[:-cross_fade_samples], - cross_faded_overlap, - next_wave[cross_fade_samples:] - ]) + new_wave = np.concatenate( + [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] + ) final_wave = new_wave @@ -296,6 +294,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_ # remove silence from generated wav + def remove_silence_for_generated_wav(filename): aseg = AudioSegment.from_file(filename) non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500) diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..4c3887643 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,10 @@ +line-length = 120 +target-version = "py310" + +[lint] +# Only ignore variables with names starting with "_". +dummy-variable-rgx = "^_.*$" + +[lint.isort] +force-single-line = true +lines-after-imports = 2 diff --git a/scripts/count_max_epoch.py b/scripts/count_max_epoch.py index 2a4f3e7cf..7cd7332df 100644 --- a/scripts/count_max_epoch.py +++ b/scripts/count_max_epoch.py @@ -1,6 +1,7 @@ -'''ADAPTIVE BATCH SIZE''' -print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in') -print(' -> least padding, gather wavs with accumulated frames in a batch\n') +"""ADAPTIVE BATCH SIZE""" + +print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in") +print(" -> least padding, gather wavs with accumulated frames in a batch\n") # data total_hours = 95282 diff --git a/scripts/count_params_gflops.py b/scripts/count_params_gflops.py index 737c6dcef..7fc493a8d 100644 --- a/scripts/count_params_gflops.py +++ b/scripts/count_params_gflops.py @@ -1,13 +1,15 @@ -import sys, os +import sys +import os + sys.path.append(os.getcwd()) -from model import M2_TTS, UNetT, DiT, MMDiT +from model import M2_TTS, DiT import torch import thop -''' ~155M ''' +""" ~155M """ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4) # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4) # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2) @@ -15,11 +17,11 @@ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True) # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2) -''' ~335M ''' +""" ~335M """ # FLOPs: 622.1 G, Params: 333.2 M # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4) # FLOPs: 363.4 G, Params: 335.8 M -transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) +transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) model = M2_TTS(transformer=transformer) @@ -30,6 +32,8 @@ frame_length = int(duration * target_sample_rate / hop_length) text_length = 150 -flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))) +flops, params = thop.profile( + model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)) +) print(f"FLOPs: {flops / 1e9} G") print(f"Params: {params / 1e6} M") diff --git a/scripts/eval_infer_batch.py b/scripts/eval_infer_batch.py index 2f051c642..3ca4a2809 100644 --- a/scripts/eval_infer_batch.py +++ b/scripts/eval_infer_batch.py @@ -1,4 +1,6 @@ -import sys, os +import sys +import os + sys.path.append(os.getcwd()) import time @@ -14,9 +16,9 @@ from model import CFM, UNetT, DiT from model.utils import ( load_checkpoint, - get_tokenizer, - get_seedtts_testset_metainfo, - get_librispeech_test_clean_metainfo, + get_tokenizer, + get_seedtts_testset_metainfo, + get_librispeech_test_clean_metainfo, get_inference_prompt, ) @@ -38,16 +40,16 @@ parser = argparse.ArgumentParser(description="batch inference") -parser.add_argument('-s', '--seed', default=None, type=int) -parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN") -parser.add_argument('-n', '--expname', required=True) -parser.add_argument('-c', '--ckptstep', default=1200000, type=int) +parser.add_argument("-s", "--seed", default=None, type=int) +parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN") +parser.add_argument("-n", "--expname", required=True) +parser.add_argument("-c", "--ckptstep", default=1200000, type=int) -parser.add_argument('-nfe', '--nfestep', default=32, type=int) -parser.add_argument('-o', '--odemethod', default="euler") -parser.add_argument('-ss', '--swaysampling', default=-1, type=float) +parser.add_argument("-nfe", "--nfestep", default=32, type=int) +parser.add_argument("-o", "--odemethod", default="euler") +parser.add_argument("-ss", "--swaysampling", default=-1, type=float) -parser.add_argument('-t', '--testset', required=True) +parser.add_argument("-t", "--testset", required=True) args = parser.parse_args() @@ -66,26 +68,26 @@ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended) -cfg_strength = 2. -speed = 1. +cfg_strength = 2.0 +speed = 1.0 use_truth_duration = False no_ref_audio = False if exp_name == "F5TTS_Base": model_cls = DiT - model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) elif exp_name == "E2TTS_Base": model_cls = UNetT - model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if testset == "ls_pc_test_clean": metalst = "data/librispeech_pc_test_clean_cross_sentence.lst" librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path) - + elif testset == "seedtts_test_zh": metalst = "data/seedtts_testset/zh/meta.lst" metainfo = get_seedtts_testset_metainfo(metalst) @@ -96,13 +98,16 @@ # path to save genereted wavs -if seed is None: seed = random.randint(-10000, 10000) -output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \ - f"seed{seed}_{ode_method}_nfe{nfe_step}" \ - f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \ - f"_cfg{cfg_strength}_speed{speed}" \ - f"{'_gt-dur' if use_truth_duration else ''}" \ +if seed is None: + seed = random.randint(-10000, 10000) +output_dir = ( + f"results/{exp_name}_{ckpt_step}/{testset}/" + f"seed{seed}_{ode_method}_nfe{nfe_step}" + f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" + f"_cfg{cfg_strength}_speed{speed}" + f"{'_gt-dur' if use_truth_duration else ''}" f"{'_no-ref-audio' if no_ref_audio else ''}" +) # -------------------------------------------------# @@ -110,15 +115,15 @@ use_ema = True prompts_all = get_inference_prompt( - metainfo, - speed = speed, - tokenizer = tokenizer, - target_sample_rate = target_sample_rate, - n_mel_channels = n_mel_channels, - hop_length = hop_length, - target_rms = target_rms, - use_truth_duration = use_truth_duration, - infer_batch_size = infer_batch_size, + metainfo, + speed=speed, + tokenizer=tokenizer, + target_sample_rate=target_sample_rate, + n_mel_channels=n_mel_channels, + hop_length=hop_length, + target_rms=target_rms, + use_truth_duration=use_truth_duration, + infer_batch_size=infer_batch_size, ) # Vocoder model @@ -137,23 +142,19 @@ # Model model = CFM( - transformer = model_cls( - **model_cfg, - text_num_embeds = vocab_size, - mel_dim = n_mel_channels + transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + mel_spec_kwargs=dict( + target_sample_rate=target_sample_rate, + n_mel_channels=n_mel_channels, + hop_length=hop_length, ), - mel_spec_kwargs = dict( - target_sample_rate = target_sample_rate, - n_mel_channels = n_mel_channels, - hop_length = hop_length, + odeint_kwargs=dict( + method=ode_method, ), - odeint_kwargs = dict( - method = ode_method, - ), - vocab_char_map = vocab_char_map, + vocab_char_map=vocab_char_map, ).to(device) -model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema) +model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema) if not os.path.exists(output_dir) and accelerator.is_main_process: os.makedirs(output_dir) @@ -163,29 +164,28 @@ start = time.time() with accelerator.split_between_processes(prompts_all) as prompts: - for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt ref_mels = ref_mels.to(device) - ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device) - total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device) - + ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) + total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + # Inference with torch.inference_mode(): generated, _ = model.sample( - cond = ref_mels, - text = final_text_list, - duration = total_mel_lens, - lens = ref_mel_lens, - steps = nfe_step, - cfg_strength = cfg_strength, - sway_sampling_coef = sway_sampling_coef, - no_ref_audio = no_ref_audio, - seed = seed, + cond=ref_mels, + text=final_text_list, + duration=total_mel_lens, + lens=ref_mel_lens, + steps=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + no_ref_audio=no_ref_audio, + seed=seed, ) # Final result for i, gen in enumerate(generated): - gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0) + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) gen_mel_spec = gen.permute(0, 2, 1) generated_wave = vocos.decode(gen_mel_spec.cpu()) if ref_rms_list[i] < target_rms: diff --git a/scripts/eval_librispeech_test_clean.py b/scripts/eval_librispeech_test_clean.py index 2f5820f7a..a1ce8b7b3 100644 --- a/scripts/eval_librispeech_test_clean.py +++ b/scripts/eval_librispeech_test_clean.py @@ -1,6 +1,8 @@ # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation) -import sys, os +import sys +import os + sys.path.append(os.getcwd()) import multiprocessing as mp @@ -19,7 +21,7 @@ librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path gen_wav_dir = "PATH_TO_GENERATED" # generated wavs -gpus = [0,1,2,3,4,5,6,7] +gpus = [0, 1, 2, 3, 4, 5, 6, 7] test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path) ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book, @@ -46,7 +48,7 @@ for wers_ in results: wers.extend(wers_) - wer = round(np.mean(wers)*100, 3) + wer = round(np.mean(wers) * 100, 3) print(f"\nTotal {len(wers)} samples") print(f"WER : {wer}%") @@ -62,6 +64,6 @@ for sim_ in results: sim_list.extend(sim_) - sim = round(sum(sim_list)/len(sim_list), 3) + sim = round(sum(sim_list) / len(sim_list), 3) print(f"\nTotal {len(sim_list)} samples") print(f"SIM : {sim}") diff --git a/scripts/eval_seedtts_testset.py b/scripts/eval_seedtts_testset.py index c50bd501d..e70534e11 100644 --- a/scripts/eval_seedtts_testset.py +++ b/scripts/eval_seedtts_testset.py @@ -1,6 +1,8 @@ # Evaluate with Seed-TTS testset -import sys, os +import sys +import os + sys.path.append(os.getcwd()) import multiprocessing as mp @@ -14,21 +16,21 @@ eval_task = "wer" # sim | wer -lang = "zh" # zh | en +lang = "zh" # zh | en metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs -gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs +gen_wav_dir = "PATH_TO_GENERATED" # generated wavs # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different -# zh 1.254 seems a result of 4 workers wer_seed_tts -gpus = [0,1,2,3,4,5,6,7] +# zh 1.254 seems a result of 4 workers wer_seed_tts +gpus = [0, 1, 2, 3, 4, 5, 6, 7] test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus) local = False if local: # use local custom checkpoint dir if lang == "zh": - asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr + asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr elif lang == "en": asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" else: @@ -48,7 +50,7 @@ for wers_ in results: wers.extend(wers_) - wer = round(np.mean(wers)*100, 3) + wer = round(np.mean(wers) * 100, 3) print(f"\nTotal {len(wers)} samples") print(f"WER : {wer}%") @@ -64,6 +66,6 @@ for sim_ in results: sim_list.extend(sim_) - sim = round(sum(sim_list)/len(sim_list), 3) + sim = round(sum(sim_list) / len(sim_list), 3) print(f"\nTotal {len(sim_list)} samples") print(f"SIM : {sim}") diff --git a/scripts/prepare_csv_wavs.py b/scripts/prepare_csv_wavs.py index 59dbaf211..6e56774dc 100644 --- a/scripts/prepare_csv_wavs.py +++ b/scripts/prepare_csv_wavs.py @@ -1,4 +1,6 @@ -import sys, os +import sys +import os + sys.path.append(os.getcwd()) from pathlib import Path @@ -17,10 +19,11 @@ PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt" + def is_csv_wavs_format(input_dataset_dir): fpath = Path(input_dataset_dir) metadata = fpath / "metadata.csv" - wavs = fpath / 'wavs' + wavs = fpath / "wavs" return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir() @@ -46,22 +49,24 @@ def prepare_csv_wavs_dir(input_dir): return sub_result, durations, vocab_set + def get_audio_duration(audio_path): audio, sample_rate = torchaudio.load(audio_path) num_channels = audio.shape[0] return audio.shape[1] / (sample_rate * num_channels) + def read_audio_text_pairs(csv_file_path): audio_text_pairs = [] parent = Path(csv_file_path).parent - with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile: - reader = csv.reader(csvfile, delimiter='|') + with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile: + reader = csv.reader(csvfile, delimiter="|") next(reader) # Skip the header row for row in reader: if len(row) >= 2: audio_file = row[0].strip() # First column: audio file path - text = row[1].strip() # Second column: text + text = row[1].strip() # Second column: text audio_file_path = parent / audio_file audio_text_pairs.append((audio_file_path.as_posix(), text)) @@ -78,12 +83,12 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB") raw_arrow_path = out_dir / "raw.arrow" with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer: - for line in tqdm(result, desc=f"Writing to raw.arrow ..."): + for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) # dup a json separately saving duration in case for DynamicBatchSampler ease dur_json_path = out_dir / "duration.json" - with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f: + with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) # vocab map, i.e. tokenizer @@ -120,13 +125,14 @@ def cli(): # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain parser = argparse.ArgumentParser(description="Prepare and save dataset.") - parser.add_argument('inp_dir', type=str, help="Input directory containing the data.") - parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.") - parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune") + parser.add_argument("inp_dir", type=str, help="Input directory containing the data.") + parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.") + parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune") args = parser.parse_args() prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain) + if __name__ == "__main__": cli() diff --git a/scripts/prepare_emilia.py b/scripts/prepare_emilia.py index f268e7243..6461f30ac 100644 --- a/scripts/prepare_emilia.py +++ b/scripts/prepare_emilia.py @@ -4,7 +4,9 @@ # generate audio text map for Emilia ZH & EN # evaluate for vocab size -import sys, os +import sys +import os + sys.path.append(os.getcwd()) from pathlib import Path @@ -12,7 +14,6 @@ from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor -from datasets import Dataset from datasets.arrow_writer import ArrowWriter from model.utils import ( @@ -21,13 +22,89 @@ ) -out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"} +out_zh = { + "ZH_B00041_S06226", + "ZH_B00042_S09204", + "ZH_B00065_S09430", + "ZH_B00065_S09431", + "ZH_B00066_S09327", + "ZH_B00066_S09328", +} zh_filters = ["い", "て"] # seems synthesized audios, or heavily code-switched out_en = { - "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375", - - "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995", + "EN_B00013_S00913", + "EN_B00042_S00120", + "EN_B00055_S04111", + "EN_B00061_S00693", + "EN_B00061_S01494", + "EN_B00061_S03375", + "EN_B00059_S00092", + "EN_B00111_S04300", + "EN_B00100_S03759", + "EN_B00087_S03811", + "EN_B00059_S00950", + "EN_B00089_S00946", + "EN_B00078_S05127", + "EN_B00070_S04089", + "EN_B00074_S09659", + "EN_B00061_S06983", + "EN_B00061_S07060", + "EN_B00059_S08397", + "EN_B00082_S06192", + "EN_B00091_S01238", + "EN_B00089_S07349", + "EN_B00070_S04343", + "EN_B00061_S02400", + "EN_B00076_S01262", + "EN_B00068_S06467", + "EN_B00076_S02943", + "EN_B00064_S05954", + "EN_B00061_S05386", + "EN_B00066_S06544", + "EN_B00076_S06944", + "EN_B00072_S08620", + "EN_B00076_S07135", + "EN_B00076_S09127", + "EN_B00065_S00497", + "EN_B00059_S06227", + "EN_B00063_S02859", + "EN_B00075_S01547", + "EN_B00061_S08286", + "EN_B00079_S02901", + "EN_B00092_S03643", + "EN_B00096_S08653", + "EN_B00063_S04297", + "EN_B00063_S04614", + "EN_B00079_S04698", + "EN_B00104_S01666", + "EN_B00061_S09504", + "EN_B00061_S09694", + "EN_B00065_S05444", + "EN_B00063_S06860", + "EN_B00065_S05725", + "EN_B00069_S07628", + "EN_B00083_S03875", + "EN_B00071_S07665", + "EN_B00071_S07665", + "EN_B00062_S04187", + "EN_B00065_S09873", + "EN_B00065_S09922", + "EN_B00084_S02463", + "EN_B00067_S05066", + "EN_B00106_S08060", + "EN_B00073_S06399", + "EN_B00073_S09236", + "EN_B00087_S00432", + "EN_B00085_S05618", + "EN_B00064_S01262", + "EN_B00072_S01739", + "EN_B00059_S03913", + "EN_B00069_S04036", + "EN_B00067_S05623", + "EN_B00060_S05389", + "EN_B00060_S07290", + "EN_B00062_S08995", } en_filters = ["ا", "い", "て"] @@ -43,18 +120,24 @@ def deal_with_audio_dir(audio_dir): for line in tqdm(lines, desc=f"{audio_jsonl.stem}"): obj = json.loads(line) text = obj["text"] - if obj['language'] == "zh": + if obj["language"] == "zh": if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text): bad_case_zh += 1 continue else: - text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched - if obj['language'] == "en": - if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4): + text = text.translate( + str.maketrans({",": ",", "!": "!", "?": "?"}) + ) # not "。" cuz much code-switched + if obj["language"] == "en": + if ( + obj["wav"].split("/")[1] in out_en + or any(f in text for f in en_filters) + or repetition_found(text, length=4) + ): bad_case_en += 1 continue if tokenizer == "pinyin": - text = convert_char_to_pinyin([text], polyphone = polyphone)[0] + text = convert_char_to_pinyin([text], polyphone=polyphone)[0] duration = obj["duration"] sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration}) durations.append(duration) @@ -96,11 +179,11 @@ def main(): # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB") with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer: - for line in tqdm(result, desc=f"Writing to raw.arrow ..."): + for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) # dup a json separately saving duration in case for DynamicBatchSampler ease - with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f: + with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) # vocab map, i.e. tokenizer @@ -114,12 +197,13 @@ def main(): print(f"\nFor {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") - if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}") - if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n") + if "ZH" in langs: + print(f"Bad zh transcription case: {total_bad_case_zh}") + if "EN" in langs: + print(f"Bad en transcription case: {total_bad_case_en}\n") if __name__ == "__main__": - max_workers = 32 tokenizer = "pinyin" # "pinyin" | "char" diff --git a/scripts/prepare_wenetspeech4tts.py b/scripts/prepare_wenetspeech4tts.py index 0403ad7f0..2763fda98 100644 --- a/scripts/prepare_wenetspeech4tts.py +++ b/scripts/prepare_wenetspeech4tts.py @@ -1,7 +1,9 @@ # generate audio text map for WenetSpeech4TTS # evaluate for vocab size -import sys, os +import sys +import os + sys.path.append(os.getcwd()) import json @@ -23,7 +25,7 @@ def deal_with_sub_path_files(dataset_path, sub_path): audio_paths, texts, durations = [], [], [] for text_file in tqdm(text_files): - with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file: + with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file: first_line = file.readline().split("\t") audio_nm = first_line[0] audio_path = os.path.join(audio_dir, audio_nm + ".wav") @@ -32,7 +34,7 @@ def deal_with_sub_path_files(dataset_path, sub_path): audio_paths.append(audio_path) if tokenizer == "pinyin": - texts.extend(convert_char_to_pinyin([text], polyphone = polyphone)) + texts.extend(convert_char_to_pinyin([text], polyphone=polyphone)) elif tokenizer == "char": texts.append(text) @@ -46,7 +48,7 @@ def main(): assert tokenizer in ["pinyin", "char"] audio_path_list, text_list, duration_list = [], [], [] - + executor = ProcessPoolExecutor(max_workers=max_workers) futures = [] for dataset_path in dataset_paths: @@ -68,8 +70,10 @@ def main(): dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format - with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f: - json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease + with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f: + json.dump( + {"duration": duration_list}, f, ensure_ascii=False + ) # dup a json separately saving duration in case for DynamicBatchSampler ease print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...") text_vocab_set = set() @@ -85,22 +89,21 @@ def main(): f.write(vocab + "\n") print(f"\nFor {dataset_name}, sample count: {len(text_list)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n") - -if __name__ == "__main__": +if __name__ == "__main__": max_workers = 32 tokenizer = "pinyin" # "pinyin" | "char" polyphone = True dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic - dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1] + dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1] dataset_paths = [ "/WenetSpeech4TTS/Basic", "/WenetSpeech4TTS/Standard", "/WenetSpeech4TTS/Premium", - ][-dataset_choice:] + ][-dataset_choice:] print(f"\nChoose Dataset: {dataset_name}\n") main() @@ -109,8 +112,8 @@ def main(): # WenetSpeech4TTS Basic Standard Premium # samples count 3932473 1941220 407494 # pinyin vocab size 1349 1348 1344 (no polyphone) - # - - 1459 (polyphone) + # - - 1459 (polyphone) # char vocab size 5264 5219 5042 - + # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme) # please be careful if using pretrained model, make sure the vocab.txt is same diff --git a/speech_edit.py b/speech_edit.py index 0b13a79d8..82f7cc9b2 100644 --- a/speech_edit.py +++ b/speech_edit.py @@ -5,11 +5,11 @@ import torchaudio from vocos import Vocos -from model import CFM, UNetT, DiT, MMDiT +from model import CFM, UNetT, DiT from model.utils import ( load_checkpoint, - get_tokenizer, - convert_char_to_pinyin, + get_tokenizer, + convert_char_to_pinyin, save_spectrogram, ) @@ -35,18 +35,18 @@ ckpt_step = 1200000 nfe_step = 32 # 16, 32 -cfg_strength = 2. -ode_method = 'euler' # euler | midpoint -sway_sampling_coef = -1. -speed = 1. +cfg_strength = 2.0 +ode_method = "euler" # euler | midpoint +sway_sampling_coef = -1.0 +speed = 1.0 if exp_name == "F5TTS_Base": model_cls = DiT - model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) elif exp_name == "E2TTS_Base": model_cls = UNetT - model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors" output_dir = "tests" @@ -62,8 +62,14 @@ audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav" origin_text = "Some call me nature, others call me mother nature." target_text = "Some call me optimist, others call me realist." -parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds -fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds +parts_to_edit = [ + [1.42, 2.44], + [4.04, 4.9], +] # stard_ends of "nature" & "mother nature", in seconds +fix_duration = [ + 1.2, + 1, +] # fix duration for "optimist" & "realist", in seconds # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav" # origin_text = "对,这就是我,万人敬仰的太乙真人。" @@ -86,7 +92,7 @@ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device) vocos.load_state_dict(state_dict) - + vocos.eval() else: vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") @@ -96,23 +102,19 @@ # Model model = CFM( - transformer = model_cls( - **model_cfg, - text_num_embeds = vocab_size, - mel_dim = n_mel_channels - ), - mel_spec_kwargs = dict( - target_sample_rate = target_sample_rate, - n_mel_channels = n_mel_channels, - hop_length = hop_length, + transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + mel_spec_kwargs=dict( + target_sample_rate=target_sample_rate, + n_mel_channels=n_mel_channels, + hop_length=hop_length, ), - odeint_kwargs = dict( - method = ode_method, + odeint_kwargs=dict( + method=ode_method, ), - vocab_char_map = vocab_char_map, + vocab_char_map=vocab_char_map, ).to(device) -model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema) +model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema) # Audio audio, sr = torchaudio.load(audio_to_edit) @@ -132,14 +134,18 @@ part_dur = end - start if fix_duration is None else fix_duration.pop(0) part_dur = part_dur * target_sample_rate start = start * target_sample_rate - audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1) - edit_mask = torch.cat((edit_mask, - torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool), - torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool) - ), dim = -1) + audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1) + edit_mask = torch.cat( + ( + edit_mask, + torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool), + torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool), + ), + dim=-1, + ) offset = end * target_sample_rate # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1) -edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True) +edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True) audio = audio.to(device) edit_mask = edit_mask.to(device) @@ -159,14 +165,14 @@ # Inference with torch.inference_mode(): generated, trajectory = model.sample( - cond = audio, - text = final_text_list, - duration = duration, - steps = nfe_step, - cfg_strength = cfg_strength, - sway_sampling_coef = sway_sampling_coef, - seed = seed, - edit_mask = edit_mask, + cond=audio, + text=final_text_list, + duration=duration, + steps=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + seed=seed, + edit_mask=edit_mask, ) print(f"Generated mel: {generated.shape}") diff --git a/train.py b/train.py index 3da6a717c..b48b0f916 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -from model import CFM, UNetT, DiT, MMDiT, Trainer +from model import CFM, UNetT, DiT, Trainer from model.utils import get_tokenizer from model.dataset import load_dataset @@ -9,8 +9,8 @@ n_mel_channels = 100 hop_length = 256 -tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' -tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) +tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' +tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" # -------------------------- Training Settings -------------------------- # @@ -23,7 +23,7 @@ batch_size_type = "frame" # "frame" or "sample" max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps -max_grad_norm = 1. +max_grad_norm = 1.0 epochs = 11 # use linear decay, thus epochs control the slope num_warmup_updates = 20000 # warmup steps @@ -34,15 +34,16 @@ if exp_name == "F5TTS_Base": wandb_resume_id = None model_cls = DiT - model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) elif exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT - model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) # ----------------------------------------------------------------------- # + def main(): if tokenizer == "custom": tokenizer_path = tokenizer_path @@ -51,44 +52,41 @@ def main(): vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) mel_spec_kwargs = dict( - target_sample_rate = target_sample_rate, - n_mel_channels = n_mel_channels, - hop_length = hop_length, - ) - + target_sample_rate=target_sample_rate, + n_mel_channels=n_mel_channels, + hop_length=hop_length, + ) + model = CFM( - transformer = model_cls( - **model_cfg, - text_num_embeds = vocab_size, - mel_dim = n_mel_channels - ), - mel_spec_kwargs = mel_spec_kwargs, - vocab_char_map = vocab_char_map, + transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + mel_spec_kwargs=mel_spec_kwargs, + vocab_char_map=vocab_char_map, ) trainer = Trainer( model, - epochs, + epochs, learning_rate, - num_warmup_updates = num_warmup_updates, - save_per_updates = save_per_updates, - checkpoint_path = f'ckpts/{exp_name}', - batch_size = batch_size_per_gpu, - batch_size_type = batch_size_type, - max_samples = max_samples, - grad_accumulation_steps = grad_accumulation_steps, - max_grad_norm = max_grad_norm, - wandb_project = "CFM-TTS", - wandb_run_name = exp_name, - wandb_resume_id = wandb_resume_id, - last_per_steps = last_per_steps, + num_warmup_updates=num_warmup_updates, + save_per_updates=save_per_updates, + checkpoint_path=f"ckpts/{exp_name}", + batch_size=batch_size_per_gpu, + batch_size_type=batch_size_type, + max_samples=max_samples, + grad_accumulation_steps=grad_accumulation_steps, + max_grad_norm=max_grad_norm, + wandb_project="CFM-TTS", + wandb_run_name=exp_name, + wandb_resume_id=wandb_resume_id, + last_per_steps=last_per_steps, ) train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) - trainer.train(train_dataset, - resumable_with_seed = 666 # seed for shuffling dataset - ) + trainer.train( + train_dataset, + resumable_with_seed=666, # seed for shuffling dataset + ) -if __name__ == '__main__': +if __name__ == "__main__": main()