Skip to content

Commit

Permalink
Merge branch 'SWivid:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr authored Nov 1, 2024
2 parents 199c56c + 2a3deaa commit 9984a48
Show file tree
Hide file tree
Showing 15 changed files with 394 additions and 189 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "src/third_party/BigVGAN"]
path = src/third_party/BigVGAN
url = https://github.com/NVIDIA/BigVGAN.git
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,15 @@ pip install git+https://github.com/SWivid/F5-TTS.git
```bash
git clone https://github.com/SWivid/F5-TTS.git
cd F5-TTS
# git submodule update --init --recursive # (optional, if need bigvgan)
pip install -e .
```
If initialize submodule, you should add the following code at the beginning of `src/third_party/BigVGAN/bigvgan.py`.
```python
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
```

### 3. Docker usage
```bash
Expand Down
36 changes: 22 additions & 14 deletions src/f5_tts/api.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import random
import sys
import tqdm
from importlib.resources import files

import soundfile as sf
import torch
import tqdm
from cached_path import cached_path

from f5_tts.model import DiT, UNetT
from f5_tts.model.utils import seed_everything
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
hop_length,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
preprocess_ref_audio_text,
target_sample_rate,
hop_length,
)
from f5_tts.model import DiT, UNetT
from f5_tts.model.utils import seed_everything


class F5TTS:
Expand All @@ -29,6 +29,7 @@ def __init__(
vocab_file="",
ode_method="euler",
use_ema=True,
vocoder_name="vocos",
local_path=None,
device=None,
):
Expand All @@ -37,23 +38,27 @@ def __init__(
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.seed = -1
self.mel_spec_type = vocoder_name

# Set device
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Load models
self.load_vocoder_model(local_path)
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
self.load_vocoder_model(vocoder_name, local_path)
self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)

def load_vocoder_model(self, local_path):
self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
def load_vocoder_model(self, vocoder_name, local_path):
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)

def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
if model_type == "F5-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
if mel_spec_type == "vocos":
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
elif mel_spec_type == "bigvgan":
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
Expand All @@ -64,7 +69,9 @@ def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema)
else:
raise ValueError(f"Unknown model type: {model_type}")

self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
self.ema_model = load_model(
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
)

def export_wav(self, wav, file_wave, remove_silence=False):
sf.write(file_wave, wav, self.target_sample_rate)
Expand Down Expand Up @@ -107,6 +114,7 @@ def infer(
gen_text,
self.ema_model,
self.vocoder,
self.mel_spec_type,
show_info=show_info,
progress=progress,
target_rms=target_rms,
Expand Down
68 changes: 39 additions & 29 deletions src/f5_tts/eval/eval_infer_batch.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import sys
import os
import sys

sys.path.append(os.getcwd())

import time
from tqdm import tqdm
import argparse
import time
from importlib.resources import files

import torch
import torchaudio
from accelerate import Accelerator
from vocos import Vocos
from tqdm import tqdm

from f5_tts.model import CFM, UNetT, DiT
from f5_tts.model.utils import get_tokenizer
from f5_tts.infer.utils_infer import load_checkpoint
from f5_tts.eval.utils_eval import (
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_inference_prompt,
get_librispeech_test_clean_metainfo,
get_seedtts_testset_metainfo,
)
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM, DiT, UNetT
from f5_tts.model.utils import get_tokenizer

accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
Expand All @@ -31,8 +30,11 @@
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
target_rms = 0.1


tokenizer = "pinyin"
rel_path = str(files("f5_tts").joinpath("../../"))

Expand All @@ -46,6 +48,7 @@ def main():
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])

parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
Expand All @@ -60,6 +63,7 @@ def main():
exp_name = args.expname
ckpt_step = args.ckptstep
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
mel_spec_type = args.mel_spec_type

nfe_step = args.nfestep
ode_method = args.odemethod
Expand Down Expand Up @@ -98,7 +102,7 @@ def main():
output_dir = (
f"{rel_path}/"
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}"
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
Expand All @@ -116,21 +120,19 @@ def main():
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
mel_spec_type=mel_spec_type,
target_rms=target_rms,
use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size,
)

# Vocoder model
local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
if mel_spec_type == "vocos":
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif mel_spec_type == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)

# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
Expand All @@ -139,17 +141,21 @@ def main():
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)

model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
Expand Down Expand Up @@ -178,14 +184,18 @@ def main():
no_ref_audio=no_ref_audio,
seed=seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1)
generated_wave = vocos.decode(gen_mel_spec.cpu())
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec)
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(gen_mel_spec)

if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)

accelerator.wait_for_everyone()
if accelerator.is_main_process:
Expand Down
14 changes: 11 additions & 3 deletions src/f5_tts/eval/utils_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import os
import random
import string
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm

from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import convert_char_to_pinyin
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL


# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
Expand Down Expand Up @@ -74,8 +74,11 @@ def get_inference_prompt(
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="vocos",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
Expand All @@ -94,7 +97,12 @@ def get_inference_prompt(
)

mel_spectrogram = MelSpec(
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)

for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
Expand Down
4 changes: 4 additions & 0 deletions src/f5_tts/infer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ f5-tts_infer-cli \
--ref_audio "ref_audio.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."

# Choose Vocoder
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
```

And a `.toml` file would help with more flexible usage.
Expand Down
Loading

0 comments on commit 9984a48

Please sign in to comment.