Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr committed Oct 21, 2024
1 parent 99eb50d commit 6a333fa
Showing 1 changed file with 65 additions and 47 deletions.
112 changes: 65 additions & 47 deletions api.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,25 @@
import argparse
import codecs
import re
import tempfile
from pathlib import Path

import numpy as np
import soundfile as sf
import tomli
import torch
import torchaudio
import tqdm
from cached_path import cached_path
from einops import rearrange
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos

from model import CFM, DiT, MMDiT, UNetT
from model.utils import (convert_char_to_pinyin, get_tokenizer,
load_checkpoint, save_spectrogram)

from model.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
chunk_text
)

from model import DiT, UNetT
from model.utils import save_spectrogram

from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav


class F5TTS:
def __init__(self, model_type="F5-TTS", ckpt_file="", vocab_file="", ode_method="euler", use_ema=True, local_path=None, device=None):
def __init__(
self,
model_type="F5-TTS",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
local_path=None,
device=None,
):
# Initialize parameters
self.final_wave = None
self.target_sample_rate = 24000
Expand All @@ -39,17 +28,18 @@ def __init__(self, model_type="F5-TTS", ckpt_file="", vocab_file="", ode_method=
self.target_rms = 0.1

# Set device
self.device = device or ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Load models
self.load_vecoder_model(local_path)
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)

def load_vecoder_model(self, local_path):
self.vocos = load_vocoder(local_path is not None,local_path,self.device)
self.vocos = load_vocoder(local_path is not None, local_path, self.device)

def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):

if model_type == "F5-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
Expand All @@ -66,37 +56,65 @@ def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema)
self.ema_model = self.load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema)

def load_model(self, model_cls, model_cfg, ckpt_path, file_vocab, ode_method, use_ema):
return load_model(model_cls, model_cfg, ckpt_path, file_vocab,ode_method,use_ema,self.device)
return load_model(model_cls, model_cfg, ckpt_path, file_vocab, ode_method, use_ema, self.device)

def export_wav(self,wav, file_wave, remove_silence=False):
def export_wav(self, wav, file_wave, remove_silence=False):
if remove_silence:
remove_silence_for_generated_wav(file_wave)

sf.write(file_wave, wav, self.target_sample_rate)
def export_spectrogram(self,spect, file_spect):

def export_spectrogram(self, spect, file_spect):
save_spectrogram(spect, file_spect)

def infer(self, ref_file, ref_text, gen_text, sway_sampling_coef=-1, cfg_strength=2, nfe_step=32, speed=1.0, fix_duration=None, remove_silence=False, file_wave=None, file_spect=None, cross_fade_duration=0.15, show_info=print, progress=tqdm):

wav,sr,spect = infer_process(ref_file, ref_text, gen_text, self.ema_model, cross_fade_duration, speed, show_info, progress,nfe_step,cfg_strength,sway_sampling_coef,fix_duration)
def infer(
self,
ref_file,
ref_text,
gen_text,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
cross_fade_duration=0.15,
show_info=print,
progress=tqdm,
):
wav, sr, spect = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
cross_fade_duration,
speed,
show_info,
progress,
nfe_step,
cfg_strength,
sway_sampling_coef,
fix_duration,
)

if file_wave is not None:
self.export_wav(wav,file_wave,remove_silence)
self.export_wav(wav, file_wave, remove_silence)

if file_spect is not None:
self.export_spectrogram(spect,file_spect)
self.export_spectrogram(spect, file_spect)

return wav,sr,spect
return wav, sr, spect

if __name__ == "__main__":

f5tts=F5TTS()
if __name__ == "__main__":
f5tts = F5TTS()

wav,sr,spect=f5tts.infer(
wav, sr, spect = f5tts.infer(
ref_file="tests/ref_audio/test_en_1_ref_short.wav",
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave="tests/out.wav",
file_spect="tests/out.png"
)
file_spect="tests/out.png",
)

0 comments on commit 6a333fa

Please sign in to comment.