Skip to content

Commit

Permalink
ENH: Update fish audio (#2555)
Browse files Browse the repository at this point in the history
Co-authored-by: qinxuye <[email protected]>
  • Loading branch information
codingl2k1 and qinxuye authored Nov 15, 2024
1 parent 7a0bb60 commit 4c96475
Show file tree
Hide file tree
Showing 40 changed files with 2,505 additions and 275 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y "faster_whisper"
${{ env.SELF_HOST_PYTHON }} -m pip install -U accelerate
${{ env.SELF_HOST_PYTHON }} -m pip install -U verovio
${{ env.SELF_HOST_PYTHON }} -m pip install -U cachetools
${{ env.SELF_HOST_PYTHON }} -m pip install -U silero-vad
${{ env.SELF_HOST_PYTHON }} -m pip install -U pydantic
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
--disable-warnings \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/core/tests/test_continuous_batching.py && \
Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ all =
natsort # For Fish Speech
loralib # For Fish Speech
ormsgpack # For Fish Speech
cachetools # For Fish Speech
silero-vad # For Fish Speech
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
Expand Down Expand Up @@ -210,6 +212,8 @@ audio =
natsort # For Fish Speech
loralib # For Fish Speech
ormsgpack # For Fish Speech
cachetools # For Fish Speech
silero-vad # For Fish Speech
doc =
ipython>=6.5.0
sphinx>=3.0.0
Expand Down
4 changes: 3 additions & 1 deletion xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ click
tqdm>=4.27
tabulate
requests
pydantic
pydantic>2
fastapi>=0.110.3
uvicorn
huggingface-hub>=0.19.4
Expand Down Expand Up @@ -72,6 +72,8 @@ loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
ormsgpack # For Fish Speech
cachetools # For Fish Speech
silero-vad # For Fish Speech
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
Expand Down
4 changes: 3 additions & 1 deletion xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ click
tqdm>=4.27
tabulate
requests
pydantic
pydantic>2
fastapi>=0.110.3
uvicorn
huggingface-hub>=0.19.4
Expand Down Expand Up @@ -67,6 +67,8 @@ loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
ormsgpack # For Fish Speech
cachetools # For Fish Speech
silero-vad # For Fish Speech
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
Expand Down
2 changes: 1 addition & 1 deletion xinference/model/audio/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
"model_name": "FishSpeech-1.4",
"model_family": "FishAudio",
"model_id": "fishaudio/fish-speech-1.4",
"model_revision": "3c49651b8e583b6b13f55e375432e0d57e1aa84d",
"model_revision": "069c573759936b35191d3380deb89183c0656f59",
"model_ability": "text-to-audio",
"multilingual": true
}
Expand Down
Empty file.
Empty file.
254 changes: 254 additions & 0 deletions xinference/thirdparty/fish_speech/fish_speech/conversation.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,256 @@
from dataclasses import dataclass, field
from typing import Literal

import torch
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast

IM_START_TOKEN = "<|im_start|>"
IM_END_TOKEN = "<|im_end|>"
SEMANTIC_TOKEN = "<|semantic|>"
MEL_TOKEN = "<|mel|>"
PHONEME_START_TOKEN = "<|phoneme_start|>"
PHONEME_END_TOKEN = "<|phoneme_end|>"
ALL_SPECIAL_TOKENS = [
IM_START_TOKEN,
IM_END_TOKEN,
SEMANTIC_TOKEN,
MEL_TOKEN,
PHONEME_START_TOKEN,
PHONEME_END_TOKEN,
]

CODEBOOK_PAD_TOKEN_ID = 0


class FishTokenizerConfig(PretrainedConfig):
share_codebook_embeddings: bool = True
codebook_size: int = 1024
num_codebooks: int = 8


class FishTokenizerFast(PreTrainedTokenizerFast):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
self.codebook_size = kwargs.pop("codebook_size", 1024)
self.num_codebooks = kwargs.pop("num_codebooks", 8)


AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)


@dataclass(kw_only=True)
class BasePart:
pass


@dataclass(kw_only=True)
class VQPart(BasePart):
codes: torch.Tensor


@dataclass(kw_only=True)
class TextPart(BasePart):
text: str


@dataclass(kw_only=True)
class MelPart(BasePart):
mels: torch.Tensor


@dataclass(kw_only=True)
class EncodedMessage:
tokens: torch.Tensor
labels: torch.Tensor
vq_parts: list[torch.Tensor]
mel_parts: list[torch.Tensor]
vq_require_losses: torch.Tensor | None = None


@dataclass(kw_only=True)
class Message:
role: Literal["system", "user", "assistant"]
parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
add_im_start: bool = True
add_im_end: bool = True
cal_loss: bool = False

# By default, ignore the loss of the auto-generated im_start token
ignore_im_start_loss: bool = True

def encode(
self: "Message",
tokenizer: AutoTokenizer,
) -> EncodedMessage:
all_tokens = []
all_labels = []

# Multi-modal tokens
vq_parts = []
mel_parts = []

semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
[SEMANTIC_TOKEN, MEL_TOKEN]
)

parts = self.parts.copy()
if self.add_im_start:
parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))

if self.add_im_end:
parts.append(TextPart(text="<|im_end|>"))

for part in parts:
if isinstance(part, TextPart):
tokens = tokenizer.encode(
part.text,
add_special_tokens=False,
truncation=False,
return_tensors="pt",
).int()[0]
elif isinstance(part, VQPart):
tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
codes = part.codes.clone() + 1

if getattr(tokenizer, "share_codebook_embeddings", True) is False:
for i in range(len(codes)):
codes[i] += tokenizer.codebook_size * i

vq_parts.append(codes)
elif isinstance(part, MelPart):
tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
mel_parts.append(part.mels)
else:
raise ValueError(f"Unsupported part type: {type(part)}")

all_tokens.append(tokens)
if self.cal_loss:
all_labels.append(tokens.clone())
else:
all_labels.append(torch.full_like(tokens, -100))

tokens = torch.cat(all_tokens, dim=0)
labels = torch.cat(all_labels, dim=0)
assert tokens.shape == labels.shape

if self.ignore_im_start_loss and self.add_im_start:
labels[: len(all_tokens[0])] = -100

return EncodedMessage(
tokens=tokens,
labels=labels,
vq_parts=vq_parts,
mel_parts=mel_parts,
)


@dataclass
class Conversation:
messages: list[Message]

def encode(
self: "Conversation",
tokenizer: AutoTokenizer,
add_shift: bool = True,
) -> EncodedMessage:
# Build the input_ids and labels
tokens = []
labels = []
vq_parts = []
mel_parts = []
vq_require_losses = []

for message in self.messages:
encoded = message.encode(
tokenizer,
)
tokens.append(encoded.tokens)
labels.append(encoded.labels)
vq_parts.extend(encoded.vq_parts)
mel_parts.extend(encoded.mel_parts)
vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))

tokens = torch.cat(tokens, dim=0)
labels = torch.cat(labels, dim=0)
vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)

if add_shift:
tokens = tokens[:-1]
labels = labels[1:]

assert tokens.dtype in [
torch.int,
torch.long,
], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"

return EncodedMessage(
tokens=tokens,
labels=labels,
vq_parts=vq_parts,
mel_parts=mel_parts,
vq_require_losses=vq_require_losses,
)

def encode_for_inference(
self: "Conversation",
tokenizer: AutoTokenizer,
num_codebooks: int,
) -> EncodedMessage:
encoded = self.encode(tokenizer, add_shift=False)
tokens = encoded.tokens
values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
values[0] = tokens

if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
return values

semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
[SEMANTIC_TOKEN, MEL_TOKEN]
)
vq_parts = encoded.vq_parts
vq_parts = torch.cat(vq_parts, dim=1)
values[1:, tokens == semantic_id] = vq_parts
return values

def visualize(self: "Conversation", tokenizer: AutoTokenizer):
encoded = self.encode(tokenizer, add_shift=False)

print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")

for tok, lab in zip(encoded.tokens, encoded.labels):
val = tokenizer.decode(tok, skip_special_tokens=False)
if val == "\n":
val = "\\n\n"

if lab == -100:
print_in_green(val)
else:
print_in_blue(val)

print()


if __name__ == "__main__":
message0 = Message(
role="user",
parts=[
TextPart(text="Hello, how are you?"),
VQPart(codes=torch.zeros((4, 10))),
],
cal_loss=False,
)

message1 = Message(
role="assistant",
parts=[TextPart(text="I'm fine, thank you.")],
cal_loss=True,
)
conversation = Conversation([message0, message1])
tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
conversation.visualize(tokenizer)

encoded = conversation.encode(tokenizer)
print(encoded)
print(tokenizer.batch_decode(encoded.tokens))
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,6 @@
"new": "new",
"Realtime Transform Text": "Realtime Transform Text",
"Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
"Text Normalization": "Text Normalization"
"Text Normalization": "Text Normalization",
"Select Example Audio": "Select Example Audio"
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,6 @@
"new": "nuevo",
"Realtime Transform Text": "Transformación de Texto en Tiempo Real",
"Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
"Text Normalization": "Normalización de Texto"
"Text Normalization": "Normalización de Texto",
"Select Example Audio": "Selecionar áudio de exemplo"
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,6 @@
"new": "新規",
"Realtime Transform Text": "リアルタイム変換テキスト",
"Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
"Text Normalization": "テキスト正規化"

"Text Normalization": "テキスト正規化",
"Select Example Audio": "サンプル音声を選択"
}
Loading

0 comments on commit 4c96475

Please sign in to comment.