Skip to content

Commit

Permalink
update export_codec_vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
lyblsgo committed Feb 26, 2025
1 parent f280558 commit 54e9384
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 4 additions & 0 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def export_codec_vllm(self, model_path):
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to

dtype = torch.bfloat16
# lm_head
new_lm_head = nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
with torch.no_grad():
new_lm_head.weight[:vocab_size] = self.llm.llm_decoder.weight
Expand All @@ -339,6 +340,8 @@ def export_codec_vllm(self, model_path):
new_lm_head.bias[vocab_size:] = 0
self.llm.llm.model.lm_head = new_lm_head
new_codec_embed = nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
# embed_tokens
embed_tokens = self.llm.llm.model.model.embed_tokens
with torch.no_grad():
new_codec_embed.weight[:vocab_size] = self.llm.speech_embedding.weight
new_codec_embed.weight[vocab_size:] = 0
Expand All @@ -356,6 +359,7 @@ def export_codec_vllm(self, model_path):
self.llm.llm.model.save_pretrained(model_path)
self.llm.llm.model.config.vocab_size = tmp_vocab_size
self.llm.llm.model.config.tie_word_embeddings = tmp_tie_embedding
self.llm.llm.model.set_input_embeddings(embed_tokens)

def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
Expand Down
4 changes: 3 additions & 1 deletion cosyvoice/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,14 @@ def inference(
if str(request_output.request_id) != str(request_id):
continue
if not request_output.finished:
print(f"Partial request output: {request_output}")
# print(f"Partial request output: {request_output}")
out_token = list(request_output.outputs[0].token_ids)[-1]
yield out_token
out_token_ids.append(out_token)
else:
break
if not vllm_codec_engine.has_unfinished_requests():
break

@torch.inference_mode()
def inference_bistream(
Expand Down

0 comments on commit 54e9384

Please sign in to comment.