From 878d7b8d8c58c55a2b15a80255a89aa9b305ce1c Mon Sep 17 00:00:00 2001 From: Djallal Fekirine <45321671+djallalzoldik@users.noreply.github.com> Date: Tue, 10 Dec 2024 16:51:25 +0100 Subject: [PATCH] Update infer_cli.py --- src/f5_tts/infer/infer_cli.py | 129 +++++++++++++++++++--------------- 1 file changed, 71 insertions(+), 58 deletions(-) diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index abf6ecef..8babb204 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -87,6 +87,12 @@ default=1.0, help="Adjust the speed of the audio generation (default: 1.0)", ) +parser.add_argument( + "--cross_fade_duration", + type=float, + default=0.15, + help="Duration of cross-fade between audio segments in seconds (default: 0.15)", +) args = parser.parse_args() config = tomli.load(open(args.config, "rb")) @@ -95,7 +101,7 @@ ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"] gen_text = args.gen_text if args.gen_text else config["gen_text"] gen_file = args.gen_file if args.gen_file else config["gen_file"] - +cross_fade_duration = args.cross_fade_duration if hasattr(args, 'cross_fade_duration') else config.get('cross_fade_duration', 0.15) # patches for pip pkg user if "infer/examples/" in ref_audio: ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}")) @@ -163,63 +169,70 @@ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file) -def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed): - main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} - if "voices" not in config: - voices = {"main": main_voice} - else: - voices = config["voices"] - voices["main"] = main_voice - for voice in voices: - voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( - voices[voice]["ref_audio"], voices[voice]["ref_text"] - ) - print("Voice:", voice) - print("Ref_audio:", voices[voice]["ref_audio"]) - print("Ref_text:", voices[voice]["ref_text"]) - - generated_audio_segments = [] - reg1 = r"(?=\[\w+\])" - chunks = re.split(reg1, text_gen) - reg2 = r"\[(\w+)\]" - for text in chunks: - if not text.strip(): - continue - match = re.match(reg2, text) - if match: - voice = match[1] - else: - print("No voice tag found, using main.") - voice = "main" - if voice not in voices: - print(f"Voice {voice} not found, using main.") - voice = "main" - text = re.sub(reg2, "", text) - gen_text = text.strip() - ref_audio = voices[voice]["ref_audio"] - ref_text = voices[voice]["ref_text"] - print(f"Voice: {voice}") - audio, final_sample_rate, spectragram = infer_process( - ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed - ) - generated_audio_segments.append(audio) - - if generated_audio_segments: - final_wave = np.concatenate(generated_audio_segments) - - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - with open(wave_path, "wb") as f: - sf.write(f.name, final_wave, final_sample_rate) - # Remove silence - if remove_silence: - remove_silence_for_generated_wav(f.name) - print(f.name) - - -def main(): - main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed) +def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed, cross_fade_duration): + main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} + if "voices" not in config: + voices = {"main": main_voice} + else: + voices = config["voices"] + voices["main"] = main_voice + for voice in voices: + voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( + voices[voice]["ref_audio"], voices[voice]["ref_text"] + ) + print("Voice:", voice) + print("Ref_audio:", voices[voice]["ref_audio"]) + print("Ref_text:", voices[voice]["ref_text"]) + + generated_audio_segments = [] + reg1 = r"(?=$\w+$)" + chunks = re.split(reg1, text_gen) + reg2 = r"$(\w+)$" + for text in chunks: + if not text.strip(): + continue + match = re.match(reg2, text) + if match: + voice = match[1] + else: + print("No voice tag found, using main.") + voice = "main" + if voice not in voices: + print(f"Voice {voice} not found, using main.") + voice = "main" + text = re.sub(reg2, "", text) + gen_text = text.strip() + ref_audio = voices[voice]["ref_audio"] + ref_text = voices[voice]["ref_text"] + print(f"Voice: {voice}") + audio, final_sample_rate, spectragram = infer_process( + ref_audio, + ref_text, + gen_text, + model_obj, + vocoder, + mel_spec_type=mel_spec_type, + speed=speed, + cross_fade_duration=cross_fade_duration + ) + generated_audio_segments.append(audio) + + if generated_audio_segments: + final_wave = np.concatenate(generated_audio_segments) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(wave_path, "wb") as f: + sf.write(f.name, final_wave, final_sample_rate) + # Remove silence + if remove_silence: + remove_silence_for_generated_wav(f.name) + print(f.name) + + +def main(): + main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed, cross_fade_duration) if __name__ == "__main__":