Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add map_location for torch.load #172

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,7 @@ output/
*.npy
TextGrid/
hifigan/*.pth.tar

# some test script
look_function.py
run.sh
9 changes: 5 additions & 4 deletions audio/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, filter_length, hop_length, win_length, window="hann"):
assert filter_length >= win_length
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
fft_window = pad_center(fft_window, size=filter_length)
fft_window = torch.from_numpy(fft_window).float()

# window the bases
Expand All @@ -65,8 +65,9 @@ def transform(self, input_data):
input_data = input_data.squeeze(1)

forward_transform = F.conv1d(
input_data.cuda(),
torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(),
input_data.cuda() if torch.cuda.is_available() else input_data.cpu(),
torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda()
if torch.cuda.is_available() else torch.autograd.Variable(self.forward_basis, requires_grad=False).cpu(),
stride=self.hop_length,
padding=0,
).cpu()
Expand Down Expand Up @@ -143,7 +144,7 @@ def __init__(
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length)
mel_basis = librosa_mel_fn(
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
Expand Down
3 changes: 2 additions & 1 deletion config/LJSpeech/preprocess.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
dataset: "LJSpeech"

path:
corpus_path: "/home/ming/Data/LJSpeech-1.1"
# corpus_path: "/home/ming/Data/LJSpeech-1.1" # modify the path for your own data
corpus_path: "/home/wangyuancheng/Data/LJSpeech-1.1"
lexicon_path: "lexicon/librispeech-lexicon.txt"
raw_path: "./raw_data/LJSpeech"
preprocessed_path: "./preprocessed_data/LJSpeech"
Expand Down
2 changes: 1 addition & 1 deletion config/LJSpeech/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ path:
log_path: "./output/log/LJSpeech"
result_path: "./output/result/LJSpeech"
optimizer:
batch_size: 16
batch_size: 48
betas: [0.9, 0.98]
eps: 0.000000001
weight_decay: 0.0
Expand Down
2 changes: 1 addition & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def collate_fn(self, data):
import torch
import yaml
from torch.utils.data import DataLoader
from utils.utils import to_device
from utils.tools import to_device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
preprocess_config = yaml.load(
Expand Down
197 changes: 197 additions & 0 deletions get_attn_map.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"import os\n",
"\n",
"import torch\n",
"import yaml\n",
"import numpy as np\n",
"\n",
"from utils.model import get_model\n",
"from utils.tools import to_device, get_mask_from_lengths\n",
"from synthesize import preprocess_english, preprocess_mandarin\n",
"\n",
"from matplotlib import pyplot as plt\n",
"from text import _id_to_symbol\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"parser = argparse.ArgumentParser()\n",
"parser.add_argument(\"--restore_step\", type=int, required=True)\n",
"parser.add_argument(\n",
" \"--mode\",\n",
" type=str,\n",
" choices=[\"single\"], # only support single mode\n",
" required=True,\n",
" help=\"Synthesize a whole dataset or a single sentence\",\n",
")\n",
"parser.add_argument(\n",
" \"--source\",\n",
" type=str,\n",
" default=None,\n",
" help=\"path to a source file with format like train.txt and val.txt, for batch mode only\",\n",
")\n",
"parser.add_argument(\n",
" \"--text\",\n",
" type=str,\n",
" default=None,\n",
" help=\"raw text to synthesize, for single-sentence mode only\",\n",
")\n",
"parser.add_argument(\n",
" \"--speaker_id\",\n",
" type=int,\n",
" default=0,\n",
" help=\"speaker ID for multi-speaker synthesis, for single-sentence mode only\",\n",
")\n",
"parser.add_argument(\n",
" \"-p\",\n",
" \"--preprocess_config\",\n",
" type=str,\n",
" required=True,\n",
" help=\"path to preprocess.yaml\",\n",
")\n",
"parser.add_argument(\n",
" \"-m\", \"--model_config\", type=str, required=True, help=\"path to model.yaml\"\n",
")\n",
"parser.add_argument(\n",
" \"-t\", \"--train_config\", type=str, required=True, help=\"path to train.yaml\"\n",
")\n",
"parser.add_argument(\n",
" \"--pitch_control\",\n",
" type=float,\n",
" default=1.0,\n",
" help=\"control the pitch of the whole utterance, larger value for higher pitch\",\n",
")\n",
"parser.add_argument(\n",
" \"--energy_control\",\n",
" type=float,\n",
" default=1.0,\n",
" help=\"control the energy of the whole utterance, larger value for larger volume\",\n",
")\n",
"parser.add_argument(\n",
" \"--duration_control\",\n",
" type=float,\n",
" default=1.0,\n",
" help=\"control the speed of the whole utterance, larger value for slower speaking rate\",\n",
")\n",
" \n",
"args = parser.parse_args(args=[\"--text\", \"This is a simple long sentence test, Hello world\",\n",
" \"--restore_step\", \"135000\", \"--mode\", \"single\",\n",
" \"-p\", \"config/LJSpeech/preprocess.yaml\", \n",
" \"-m\", \"config/LJSpeech/model.yaml\", \n",
" \"-t\", \"config/LJSpeech/train.yaml\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def get_attention_map(ids, phoneme_seq, attns_list, attn_map_path):\n",
" os.makedirs(attn_map_path, exist_ok=True)\n",
" for layer_num in range(len(attns_list)):\n",
" for head_num in range(attns_list[layer_num].shape[0]):\n",
" attn_matrix = attns_list[layer_num][head_num].numpy()\n",
" plt.figure(figsize=(10, 10))\n",
" im = plt.imshow(attn_matrix, interpolation='none')\n",
" im.axes.set_title('layer {}, head {}'.format(layer_num+1, head_num+1))\n",
" im.axes.set_xticks(range(len(phoneme_seq)))\n",
" im.axes.set_xticklabels(phoneme_seq, fontsize=200/len(phoneme_seq))\n",
" im.axes.set_yticks(range(len(phoneme_seq)))\n",
" im.axes.set_yticklabels(phoneme_seq, fontsize=200/len(phoneme_seq))\n",
" im_cb = plt.colorbar(im)\n",
" plt.savefig(os.path.join(attn_map_path, \"layer{}_head{}_{}.png\".format(layer_num+1, head_num+1, ids[0])))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Read Config\n",
"preprocess_config = yaml.load(\n",
" open(args.preprocess_config, \"r\"), Loader=yaml.FullLoader\n",
")\n",
"model_config = yaml.load(open(args.model_config, \"r\"), Loader=yaml.FullLoader)\n",
"train_config = yaml.load(open(args.train_config, \"r\"), Loader=yaml.FullLoader)\n",
"configs = (preprocess_config, model_config, train_config)\n",
"\n",
"# Get model\n",
"model = get_model(args, configs, device, train=False)\n",
"\n",
"if args.mode == \"single\":\n",
" ids = raw_texts = [args.text[:100]]\n",
" speakers = np.array([args.speaker_id])\n",
" if preprocess_config[\"preprocessing\"][\"text\"][\"language\"] == \"en\":\n",
" texts = np.array([preprocess_english(args.text, preprocess_config)])\n",
" elif preprocess_config[\"preprocessing\"][\"text\"][\"language\"] == \"zh\":\n",
" texts = np.array([preprocess_mandarin(args.text, preprocess_config)])\n",
" text_lens = np.array([len(texts[0])])\n",
" batchs = [(ids, raw_texts, speakers, texts, text_lens, max(text_lens))]\n",
"\n",
"batch = batchs[0]\n",
"\n",
"batch = to_device(batch, device)\n",
"with torch.no_grad():\n",
" # Forward\n",
" ids = batch[0]\n",
" texts = batch[3]\n",
" phoneme_seq = [_id_to_symbol[s].replace('@', '') for s in texts[0].numpy()]\n",
"\n",
" src_lens = batch[4]\n",
" max_src_len = batch[5]\n",
" src_masks = get_mask_from_lengths(src_lens, max_src_len)\n",
"\n",
" encode, attns_list = model.encoder.forward(texts, src_masks, return_attns=True)\n",
"\n",
" attn_map_path = train_config['path']['log_path'].replace('log', 'attention')\n",
" # print(attn_map_path)\n",
"\n",
" get_attention_map(phoneme_seq, attns_list, attn_map_path)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.0 ('audio')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "aab49b93b15df8204604f499ca4edf92b00ea5ce48c75a41557219b18aaec957"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading