Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update infer_cli.py #611

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 71 additions & 58 deletions src/f5_tts/infer/infer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@
default=1.0,
help="Adjust the speed of the audio generation (default: 1.0)",
)
parser.add_argument(
"--cross_fade_duration",
type=float,
default=0.15,
help="Duration of cross-fade between audio segments in seconds (default: 0.15)",
)
args = parser.parse_args()

config = tomli.load(open(args.config, "rb"))
Expand All @@ -95,7 +101,7 @@
ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
gen_text = args.gen_text if args.gen_text else config["gen_text"]
gen_file = args.gen_file if args.gen_file else config["gen_file"]

cross_fade_duration = args.cross_fade_duration if hasattr(args, 'cross_fade_duration') else config.get('cross_fade_duration', 0.15)
# patches for pip pkg user
if "infer/examples/" in ref_audio:
ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
Expand Down Expand Up @@ -163,63 +169,70 @@
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)


def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
else:
voices = config["voices"]
voices["main"] = main_voice
for voice in voices:
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("Voice:", voice)
print("Ref_audio:", voices[voice]["ref_audio"])
print("Ref_text:", voices[voice]["ref_text"])

generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, text_gen)
reg2 = r"\[(\w+)\]"
for text in chunks:
if not text.strip():
continue
match = re.match(reg2, text)
if match:
voice = match[1]
else:
print("No voice tag found, using main.")
voice = "main"
if voice not in voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(
ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
)
generated_audio_segments.append(audio)

if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)

if not os.path.exists(output_dir):
os.makedirs(output_dir)

with open(wave_path, "wb") as f:
sf.write(f.name, final_wave, final_sample_rate)
# Remove silence
if remove_silence:
remove_silence_for_generated_wav(f.name)
print(f.name)


def main():
main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed, cross_fade_duration):
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
else:
voices = config["voices"]
voices["main"] = main_voice
for voice in voices:
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("Voice:", voice)
print("Ref_audio:", voices[voice]["ref_audio"])
print("Ref_text:", voices[voice]["ref_text"])

generated_audio_segments = []
reg1 = r"(?=$\w+$)"
chunks = re.split(reg1, text_gen)
reg2 = r"$(\w+)$"
for text in chunks:
if not text.strip():
continue
match = re.match(reg2, text)
if match:
voice = match[1]
else:
print("No voice tag found, using main.")
voice = "main"
if voice not in voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(
ref_audio,
ref_text,
gen_text,
model_obj,
vocoder,
mel_spec_type=mel_spec_type,
speed=speed,
cross_fade_duration=cross_fade_duration
)
generated_audio_segments.append(audio)

if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)

if not os.path.exists(output_dir):
os.makedirs(output_dir)

with open(wave_path, "wb") as f:
sf.write(f.name, final_wave, final_sample_rate)
# Remove silence
if remove_silence:
remove_silence_for_generated_wav(f.name)
print(f.name)


def main():
main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed, cross_fade_duration)


if __name__ == "__main__":
Expand Down