diff --git a/paddlespeech/t2s/models/cosyvoice/flow/flow.py b/paddlespeech/t2s/models/cosyvoice/flow/flow.py new file mode 100644 index 00000000000..2e6f841bc45 --- /dev/null +++ b/paddlespeech/t2s/models/cosyvoice/flow/flow.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import random +from typing import Dict, Optional +import paddle + +from omegaconf import DictConfig +from paddlespeech.t2s.utils.common import make_pad_mask + +class CausalMaskedDiffWithXvec(paddle.nn.Layer): + def __init__(self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + token_mel_ratio: int = 2, + pre_lookahead_len: int = 3, + encoder: paddle.nn.Layer = None, + decoder: paddle.nn.Layer = None, + decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, + 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', + 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), + 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, + 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, + mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, + 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = paddle.nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = paddle.nn.Linear(spk_embed_dim, output_size) + self.encoder = encoder + self.encoder_proj = paddle.nn.Linear(self.encoder.output_size(), output_size) + self.decoder = decoder + self.only_mask_loss = only_mask_loss + self.token_mel_ratio = token_mel_ratio + self.pre_lookahead_len = pre_lookahead_len + + def inference(self, + token, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + embedding, + finalize): + if self.fp16 is True: + prompt_feat = prompt_feat.half() + embedding = embedding.half() + + assert token.shape[0] == 1 + # xvec projection + embedding = paddle.nn.functional.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + token, token_len = paddle.concat([prompt_token, token], dim=1), prompt_token_len + token_len + mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) + token = self.input_embedding(paddle.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len) + if finalize is False: + h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] + mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] + h = self.encoder_proj(h) + + # get conditions + conds = paddle.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) + conds[:, :mel_len1] = prompt_feat + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(paddle.tensor([mel_len1 + mel_len2]))).to(h) + feat, _ = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=10 + ) + feat = feat[:, :, mel_len1:] + assert feat.shape[2] == mel_len2 + return feat.float(), None \ No newline at end of file diff --git a/paddlespeech/t2s/models/cosyvoice/llm/llm.py b/paddlespeech/t2s/models/cosyvoice/llm/llm.py new file mode 100644 index 00000000000..ef3f0865375 --- /dev/null +++ b/paddlespeech/t2s/models/cosyvoice/llm/llm.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Callable, List, Generator + +import paddle +from paddlenlp.transformers import Qwen2ForCausalLM +from paddle.nn import Pad1D + +class Qwen2Encoder(paddle.nn.Layer): + def __init__(self, pretrain_path): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + + def forward_one_step(self, xs, masks, cache=None): + input_masks = masks[:, -1, :] + outs = self.model( + inputs_embeds=xs, + attention_mask=input_masks, + output_hidden_states=True, + return_dict=True, + use_cache=True, + past_key_values=cache, + ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + return xs, new_cache + + +class Qwen2LM(paddle.nn.Layer): + def __init__( + self, + llm_input_size: int, + llm_output_size: int, + speech_token_size: int, + llm: paddle.nn.Layer, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.llm_output_size = llm_output_size + self.speech_token_size = speech_token_size + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.fill_token = 2 + + self.llm_embedding = paddle.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = paddle.nn.Linear(llm_output_size, speech_token_size + 3) + + # 3. [Optional] build speech token related modules + self.speech_embedding = paddle.nn.Embedding(speech_token_size + 3, llm_input_size) + + # 4. sampling method + self.sampling = sampling + + def sampling_ids( + self, + weighted_scores: paddle.Tensor, + decoded_tokens: List, + sampling: int, + ignore_eos: bool = True, + ): + num_trials, max_trials = 0, 100 + while True: + top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) + if (not ignore_eos) or (self.speech_token_size not in top_ids): + break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) + return top_ids + + def inference( + self, + text: paddle.Tensor, + text_len: paddle.Tensor, + prompt_text: paddle.Tensor, + prompt_text_len: paddle.Tensor, + prompt_speech_token: paddle.Tensor, + prompt_speech_token_len: paddle.Tensor, + embedding: paddle.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[paddle.Tensor, None, None]: + device = text.device + text = paddle.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.llm.model.model.embed_tokens(text) + + # 2. encode embedding + embedding = paddle.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = paddle.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = paddle.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + out_tokens = [] + cache = None + for i in range(max_len): + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=paddle.tril(paddle.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(paddle.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + if top_ids > self.speech_token_size: + continue + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) \ No newline at end of file diff --git a/paddlespeech/t2s/utils/common.py b/paddlespeech/t2s/utils/common.py new file mode 100644 index 00000000000..edb7ca3a25d --- /dev/null +++ b/paddlespeech/t2s/utils/common.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +def make_pad_mask(lengths: paddle.Tensor, max_len: int = 0) -> paddle.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = paddle.arange(0, + max_len, + dtype=paddle.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask \ No newline at end of file