-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] feat: support cosyvoice2 inference #3966
Open
yinfan98
wants to merge
4
commits into
PaddlePaddle:develop
Choose a base branch
from
yinfan98:add_cosyvoice_inference
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import random | ||
from typing import Dict, Optional | ||
import paddle | ||
|
||
from omegaconf import DictConfig | ||
from paddlespeech.t2s.utils.common import make_pad_mask | ||
|
||
class CausalMaskedDiffWithXvec(paddle.nn.Layer): | ||
def __init__(self, | ||
input_size: int = 512, | ||
output_size: int = 80, | ||
spk_embed_dim: int = 192, | ||
output_type: str = "mel", | ||
vocab_size: int = 4096, | ||
input_frame_rate: int = 50, | ||
only_mask_loss: bool = True, | ||
token_mel_ratio: int = 2, | ||
pre_lookahead_len: int = 3, | ||
encoder: paddle.nn.Layer = None, | ||
decoder: paddle.nn.Layer = None, | ||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, | ||
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', | ||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), | ||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, | ||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, | ||
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, | ||
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): | ||
super().__init__() | ||
self.input_size = input_size | ||
self.output_size = output_size | ||
self.decoder_conf = decoder_conf | ||
self.mel_feat_conf = mel_feat_conf | ||
self.vocab_size = vocab_size | ||
self.output_type = output_type | ||
self.input_frame_rate = input_frame_rate | ||
logging.info(f"input frame rate={self.input_frame_rate}") | ||
self.input_embedding = paddle.nn.Embedding(vocab_size, input_size) | ||
self.spk_embed_affine_layer = paddle.nn.Linear(spk_embed_dim, output_size) | ||
self.encoder = encoder | ||
self.encoder_proj = paddle.nn.Linear(self.encoder.output_size(), output_size) | ||
self.decoder = decoder | ||
self.only_mask_loss = only_mask_loss | ||
self.token_mel_ratio = token_mel_ratio | ||
self.pre_lookahead_len = pre_lookahead_len | ||
|
||
def inference(self, | ||
token, | ||
token_len, | ||
prompt_token, | ||
prompt_token_len, | ||
prompt_feat, | ||
prompt_feat_len, | ||
embedding, | ||
finalize): | ||
if self.fp16 is True: | ||
prompt_feat = prompt_feat.half() | ||
embedding = embedding.half() | ||
|
||
assert token.shape[0] == 1 | ||
# xvec projection | ||
embedding = paddle.nn.functional.normalize(embedding, dim=1) | ||
embedding = self.spk_embed_affine_layer(embedding) | ||
|
||
# concat text and prompt_text | ||
token, token_len = paddle.concat([prompt_token, token], dim=1), prompt_token_len + token_len | ||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) | ||
token = self.input_embedding(paddle.clamp(token, min=0)) * mask | ||
|
||
# text encode | ||
h, h_lengths = self.encoder(token, token_len) | ||
if finalize is False: | ||
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] | ||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] | ||
h = self.encoder_proj(h) | ||
|
||
# get conditions | ||
conds = paddle.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) | ||
conds[:, :mel_len1] = prompt_feat | ||
conds = conds.transpose(1, 2) | ||
|
||
mask = (~make_pad_mask(paddle.tensor([mel_len1 + mel_len2]))).to(h) | ||
feat, _ = self.decoder( | ||
mu=h.transpose(1, 2).contiguous(), | ||
mask=mask.unsqueeze(1), | ||
spks=embedding, | ||
cond=conds, | ||
n_timesteps=10 | ||
) | ||
feat = feat[:, :, mel_len1:] | ||
assert feat.shape[2] == mel_len2 | ||
return feat.float(), None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, Optional, Callable, List, Generator | ||
|
||
import paddle | ||
from paddlenlp.transformers import Qwen2ForCausalLM | ||
from paddle.nn import Pad1D | ||
|
||
class Qwen2Encoder(paddle.nn.Layer): | ||
def __init__(self, pretrain_path): | ||
super().__init__() | ||
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) | ||
|
||
def forward_one_step(self, xs, masks, cache=None): | ||
input_masks = masks[:, -1, :] | ||
outs = self.model( | ||
inputs_embeds=xs, | ||
attention_mask=input_masks, | ||
output_hidden_states=True, | ||
return_dict=True, | ||
use_cache=True, | ||
past_key_values=cache, | ||
) | ||
xs = outs.hidden_states[-1] | ||
new_cache = outs.past_key_values | ||
return xs, new_cache | ||
|
||
|
||
class Qwen2LM(paddle.nn.Layer): | ||
def __init__( | ||
self, | ||
llm_input_size: int, | ||
llm_output_size: int, | ||
speech_token_size: int, | ||
llm: paddle.nn.Layer, | ||
sampling: Callable, | ||
length_normalized_loss: bool = True, | ||
lsm_weight: float = 0.0, | ||
): | ||
super().__init__() | ||
self.llm_input_size = llm_input_size | ||
self.llm_output_size = llm_output_size | ||
self.speech_token_size = speech_token_size | ||
|
||
# 2. build speech token language model related modules | ||
self.sos_eos = 0 | ||
self.task_id = 1 | ||
self.fill_token = 2 | ||
|
||
self.llm_embedding = paddle.nn.Embedding(2, llm_input_size) | ||
self.llm = llm | ||
self.llm_decoder = paddle.nn.Linear(llm_output_size, speech_token_size + 3) | ||
|
||
# 3. [Optional] build speech token related modules | ||
self.speech_embedding = paddle.nn.Embedding(speech_token_size + 3, llm_input_size) | ||
|
||
# 4. sampling method | ||
self.sampling = sampling | ||
|
||
def sampling_ids( | ||
self, | ||
weighted_scores: paddle.Tensor, | ||
decoded_tokens: List, | ||
sampling: int, | ||
ignore_eos: bool = True, | ||
): | ||
num_trials, max_trials = 0, 100 | ||
while True: | ||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) | ||
if (not ignore_eos) or (self.speech_token_size not in top_ids): | ||
break | ||
num_trials += 1 | ||
if num_trials > max_trials: | ||
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) | ||
return top_ids | ||
|
||
def inference( | ||
self, | ||
text: paddle.Tensor, | ||
text_len: paddle.Tensor, | ||
prompt_text: paddle.Tensor, | ||
prompt_text_len: paddle.Tensor, | ||
prompt_speech_token: paddle.Tensor, | ||
prompt_speech_token_len: paddle.Tensor, | ||
embedding: paddle.Tensor, | ||
sampling: int = 25, | ||
max_token_text_ratio: float = 20, | ||
min_token_text_ratio: float = 2, | ||
) -> Generator[paddle.Tensor, None, None]: | ||
device = text.device | ||
text = paddle.concat([prompt_text, text], dim=1) | ||
text_len += prompt_text_len | ||
text = self.llm.model.model.embed_tokens(text) | ||
|
||
# 2. encode embedding | ||
embedding = paddle.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) | ||
|
||
# 3. concat llm_input | ||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) | ||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) | ||
if prompt_speech_token_len != 0: | ||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) | ||
else: | ||
prompt_speech_token_emb = paddle.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) | ||
lm_input = paddle.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) | ||
|
||
# 4. cal min/max_length | ||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio) | ||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio) | ||
|
||
# 5. step by step decode | ||
out_tokens = [] | ||
cache = None | ||
for i in range(max_len): | ||
y_pred, cache = self.llm.forward_one_step(lm_input, | ||
masks=paddle.tril(paddle.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(paddle.bool), | ||
cache=cache) | ||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) | ||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() | ||
if top_ids == self.speech_token_size: | ||
break | ||
if top_ids > self.speech_token_size: | ||
continue | ||
# in stream mode, yield token one by one | ||
yield top_ids | ||
out_tokens.append(top_ids) | ||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
|
||
def make_pad_mask(lengths: paddle.Tensor, max_len: int = 0) -> paddle.Tensor: | ||
"""Make mask tensor containing indices of padded part. | ||
|
||
See description of make_non_pad_mask. | ||
|
||
Args: | ||
lengths (torch.Tensor): Batch of lengths (B,). | ||
Returns: | ||
torch.Tensor: Mask tensor containing indices of padded part. | ||
|
||
Examples: | ||
>>> lengths = [5, 3, 2] | ||
>>> make_pad_mask(lengths) | ||
masks = [[0, 0, 0, 0 ,0], | ||
[0, 0, 0, 1, 1], | ||
[0, 0, 1, 1, 1]] | ||
""" | ||
batch_size = lengths.size(0) | ||
max_len = max_len if max_len > 0 else lengths.max().item() | ||
seq_range = paddle.arange(0, | ||
max_len, | ||
dtype=paddle.int64, | ||
device=lengths.device) | ||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | ||
seq_length_expand = lengths.unsqueeze(-1) | ||
mask = seq_range_expand >= seq_length_expand | ||
return mask |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
年记得更新一下