From 6dd39a291a0c6d73b04be3ab5675b339c6c6a909 Mon Sep 17 00:00:00 2001 From: liyulingyue <852433440@qq.com> Date: Sat, 23 Nov 2024 01:44:03 +0800 Subject: [PATCH] fix Tacotron2 with CSMSC --- examples/csmsc/tts0/README.md | 8 ++++++++ paddlespeech/t2s/exps/tacotron2/preprocess.py | 6 +++--- paddlespeech/t2s/modules/nets_utils.py | 3 +++ paddlespeech/t2s/modules/tacotron2/attentions.py | 3 ++- paddlespeech/t2s/modules/tacotron2/encoder.py | 3 +++ 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/csmsc/tts0/README.md b/examples/csmsc/tts0/README.md index ce682495e97..c554524a772 100644 --- a/examples/csmsc/tts0/README.md +++ b/examples/csmsc/tts0/README.md @@ -5,6 +5,14 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171 ### Download and Extract Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. +The structure of the folder is listed below. + +```text +datasets/BZNSYP +└── Wave + └── .wav files +``` + ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here. You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. diff --git a/paddlespeech/t2s/exps/tacotron2/preprocess.py b/paddlespeech/t2s/exps/tacotron2/preprocess.py index 46b72591693..96eb64616e6 100644 --- a/paddlespeech/t2s/exps/tacotron2/preprocess.py +++ b/paddlespeech/t2s/exps/tacotron2/preprocess.py @@ -228,9 +228,9 @@ def main(): if args.dataset == "baker": wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) - # split data into 3 sections - num_train = 9800 - num_dev = 100 + # split data into 3 sections, the max number of dev/test is 10% or 100 + num_dev = min(int(len(wav_files) * 0.1), 100) + num_train = len(wav_files) - num_dev * 2 train_wav_files = wav_files[:num_train] dev_wav_files = wav_files[num_train:num_train + num_dev] test_wav_files = wav_files[num_train + num_dev:] diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 57c46e3a859..cc3df8a1921 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -181,6 +181,9 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): if length_dim == 0: raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + # check if ilens is 0-dim tensor, if so, add a dimension + if lengths.ndim == 0: + lengths = lengths.unsqueeze(0) bs = paddle.shape(lengths) if xs is None: maxlen = paddle.cast(lengths.max(), dtype=bs.dtype) diff --git a/paddlespeech/t2s/modules/tacotron2/attentions.py b/paddlespeech/t2s/modules/tacotron2/attentions.py index 5d1a2484536..86407e7786e 100644 --- a/paddlespeech/t2s/modules/tacotron2/attentions.py +++ b/paddlespeech/t2s/modules/tacotron2/attentions.py @@ -171,7 +171,8 @@ def forward( if paddle.sum(att_prev) == 0: # if no bias, 0 0-pad goes 0 att_prev = 1.0 - make_pad_mask(enc_hs_len) - att_prev = att_prev / enc_hs_len.unsqueeze(-1) + att_prev = att_prev / enc_hs_len.unsqueeze(-1).astype( + att_prev.dtype) # att_prev: (utt, frame) -> (utt, 1, 1, frame) # -> (utt, att_conv_chans, 1, frame) diff --git a/paddlespeech/t2s/modules/tacotron2/encoder.py b/paddlespeech/t2s/modules/tacotron2/encoder.py index 224c82400d2..ac942be0f13 100644 --- a/paddlespeech/t2s/modules/tacotron2/encoder.py +++ b/paddlespeech/t2s/modules/tacotron2/encoder.py @@ -162,6 +162,9 @@ def forward(self, xs, ilens=None): return xs.transpose([0, 2, 1]) if not isinstance(ilens, paddle.Tensor): ilens = paddle.to_tensor(ilens) + # check if ilens is 0-dim tensor, if so, add a dimension + if ilens.ndim == 0: + ilens = ilens.unsqueeze(0) xs = xs.transpose([0, 2, 1]) # for dygraph to static graph # self.blstm.flatten_parameters()