From 96c701070e777ed8b3d94014f684cf2f242b64b2 Mon Sep 17 00:00:00 2001 From: Daniel Chin Date: Wed, 18 Oct 2023 14:56:48 +0400 Subject: [PATCH 1/3] support loading models to CPU --- utils/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/utils/model.py b/utils/model.py index 45e1f41431..9efea47ad7 100644 --- a/utils/model.py +++ b/utils/model.py @@ -7,6 +7,7 @@ import hifigan from model import FastSpeech2, ScheduledOptim +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_model(args, configs, device, train=False): (preprocess_config, model_config, train_config) = configs @@ -17,7 +18,7 @@ def get_model(args, configs, device, train=False): train_config["path"]["ckpt_path"], "{}.pth.tar".format(args.restore_step), ) - ckpt = torch.load(ckpt_path) + ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt["model"]) if train: @@ -60,7 +61,7 @@ def get_vocoder(config, device): config = hifigan.AttrDict(config) vocoder = hifigan.Generator(config) if speaker == "LJSpeech": - ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar") + ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar", map_location=device) elif speaker == "universal": ckpt = torch.load("hifigan/generator_universal.pth.tar") vocoder.load_state_dict(ckpt["generator"]) From d57e142e53f8b4e75a3d05b3e9e964c428f70b9f Mon Sep 17 00:00:00 2001 From: Daniel Chin Date: Wed, 18 Oct 2023 14:59:16 +0400 Subject: [PATCH 2/3] add option to turn off plotting spectrogram --- synthesize.py | 6 +++++- utils/tools.py | 25 +++++++++++++++---------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/synthesize.py b/synthesize.py index 59a682aa7d..d68cf8559c 100644 --- a/synthesize.py +++ b/synthesize.py @@ -84,7 +84,10 @@ def preprocess_mandarin(text, preprocess_config): return np.array(sequence) -def synthesize(model, step, configs, vocoder, batchs, control_values): +def synthesize( + model, step, configs, vocoder, batchs, control_values, + do_plot_spectrogram=True, +): preprocess_config, model_config, train_config = configs pitch_control, energy_control, duration_control = control_values @@ -105,6 +108,7 @@ def synthesize(model, step, configs, vocoder, batchs, control_values): model_config, preprocess_config, train_config["path"]["result_path"], + do_plot_spectrogram, ) diff --git a/utils/tools.py b/utils/tools.py index f897430e69..3d7a5c93dd 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -161,7 +161,11 @@ def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_con return fig, wav_reconstruction, wav_prediction, basename -def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): +def synth_samples( + targets, predictions, vocoder, model_config, + preprocess_config, path, + do_plot_spectrogram=True, +): basenames = targets[0] for i in range(len(predictions[0])): @@ -187,15 +191,16 @@ def synth_samples(targets, predictions, vocoder, model_config, preprocess_config stats = json.load(f) stats = stats["pitch"] + stats["energy"][:2] - fig = plot_mel( - [ - (mel_prediction.cpu().numpy(), pitch, energy), - ], - stats, - ["Synthetized Spectrogram"], - ) - plt.savefig(os.path.join(path, "{}.png".format(basename))) - plt.close() + if do_plot_spectrogram: + fig = plot_mel( + [ + (mel_prediction.cpu().numpy(), pitch, energy), + ], + stats, + ["Synthetized Spectrogram"], + ) + plt.savefig(os.path.join(path, "{}.png".format(basename))) + plt.close() from .model import vocoder_infer From 7220cb9a9073f407fffe7a53602cb58d3a496d68 Mon Sep 17 00:00:00 2001 From: Daniel Chin Date: Wed, 18 Oct 2023 15:09:30 +0400 Subject: [PATCH 3/3] legalize output filename in case text contains punctuations --- utils/tools.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/utils/tools.py b/utils/tools.py index 3d7a5c93dd..3d0d7eb33c 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -160,6 +160,8 @@ def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_con return fig, wav_reconstruction, wav_prediction, basename +def legalize_filename(filename: str): + return ''.join([x if x.isalnum() else '' for x in filename])[:30] def synth_samples( targets, predictions, vocoder, model_config, @@ -199,7 +201,7 @@ def synth_samples( stats, ["Synthetized Spectrogram"], ) - plt.savefig(os.path.join(path, "{}.png".format(basename))) + plt.savefig(os.path.join(path, "{}.png".format(legalize_filename(basename)))) plt.close() from .model import vocoder_infer @@ -212,7 +214,8 @@ def synth_samples( sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] for wav, basename in zip(wav_predictions, basenames): - wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav) + wavfile.write(os.path.join( + path, "{}.wav".format(legalize_filename(basename))), sampling_rate, wav) def plot_mel(data, stats, titles):