diff --git a/model/modules.py b/model/modules.py index 9933725ffc..b9000bf25a 100644 --- a/model/modules.py +++ b/model/modules.py @@ -11,7 +11,8 @@ from utils.tools import get_mask_from_lengths, pad -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") class VarianceAdaptor(nn.Module): @@ -77,6 +78,14 @@ def __init__(self, preprocess_config, model_config): n_bins, model_config["transformer"]["encoder_hidden"] ) + def manual_bucketize(self, input_tensor, boundaries): + boundaries = torch.tensor(boundaries).float() + expanded_input = input_tensor.unsqueeze(-1) + comparison = (expanded_input > boundaries).float() + bucket_indices = comparison.sum(-1).long() + return bucket_indices + + def get_pitch_embedding(self, x, target, mask, control): prediction = self.pitch_predictor(x, mask) if target is not None: @@ -84,7 +93,7 @@ def get_pitch_embedding(self, x, target, mask, control): else: prediction = prediction * control embedding = self.pitch_embedding( - torch.bucketize(prediction, self.pitch_bins) + self.manual_bucketize(prediction, self.pitch_bins) ) return prediction, embedding @@ -95,7 +104,7 @@ def get_energy_embedding(self, x, target, mask, control): else: prediction = prediction * control embedding = self.energy_embedding( - torch.bucketize(prediction, self.energy_bins) + self.manual_bucketize(prediction, self.energy_bins) ) return prediction, embedding diff --git a/requirements.txt b/requirements.txt index c2bbeed7f5..2da7474b69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,15 +3,16 @@ inflect == 4.1.0 librosa == 0.7.2 matplotlib == 3.2.2 numba == 0.48 -numpy == 1.19.0 +# numpy == 1.19.0 pypinyin==0.39.0 -pyworld == 0.2.10 -PyYAML==5.4.1 -scikit-learn==0.23.2 -scipy == 1.5.0 -soundfile==0.10.3.post1 -tensorboard == 2.2.2 -tgt == 1.4.4 -torch == 1.7.0 -tqdm==4.46.1 -unidecode == 1.1.1 \ No newline at end of file +# pyworld == 0.2.10 +pyworld==0.3.4 +# PyYAML==5.4.1 +scikit-learn +scipy +soundfile +tensorboard +tgt +torch +tqdm +unidecode \ No newline at end of file diff --git a/synthesize.py b/synthesize.py index 59a682aa7d..e346c55691 100644 --- a/synthesize.py +++ b/synthesize.py @@ -14,8 +14,9 @@ from dataset import TextDataset from text import text_to_sequence -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") +print('device:', device) def read_lexicon(lex_path): lexicon = {} diff --git a/utils/model.py b/utils/model.py index 45e1f41431..5357b3c90a 100644 --- a/utils/model.py +++ b/utils/model.py @@ -17,7 +17,7 @@ def get_model(args, configs, device, train=False): train_config["path"]["ckpt_path"], "{}.pth.tar".format(args.restore_step), ) - ckpt = torch.load(ckpt_path) + ckpt = torch.load(ckpt_path, map_location=torch.device('mps')) model.load_state_dict(ckpt["model"]) if train: @@ -55,12 +55,14 @@ def get_vocoder(config, device): vocoder.mel2wav.eval() vocoder.mel2wav.to(device) elif name == "HiFi-GAN": + print('device:', device) with open("hifigan/config.json", "r") as f: config = json.load(f) config = hifigan.AttrDict(config) vocoder = hifigan.Generator(config) if speaker == "LJSpeech": - ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar") + ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar", + map_location=torch.device('mps')) elif speaker == "universal": ckpt = torch.load("hifigan/generator_universal.pth.tar") vocoder.load_state_dict(ckpt["generator"]) diff --git a/utils/tools.py b/utils/tools.py index f897430e69..c1713aee93 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -12,7 +12,8 @@ matplotlib.use("Agg") -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") def to_device(data, device):