Skip to content

Commit

Permalink
add ability to use guide tokens for TTS, ref: ggerganov#11186
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Jan 11, 2025
1 parent bd38665 commit 07173e8
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 3 deletions.
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.vocoder.model = value;
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--tts-use-guide-tokens"},
"Use guide tokens to improve TTS word recall",
[](common_params & params) {
params.vocoder.use_guide_tokens = true;
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));

// model-specific
add_opt(common_arg(
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ struct common_params_vocoder {

std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT

bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};

struct common_params {
Expand Down
43 changes: 42 additions & 1 deletion examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,29 @@ static void prompt_init(llama_tokens & prompt, const llama_model * model) {
prompt_add(prompt, model, "<|im_start|>\n", true, true);
}

static std::vector<llama_token> prepare_guide_tokens(const llama_model * model, const std::string& str)
{
const std::string& delimiter = "<|text_sep|>";

std::vector<llama_token> result;
size_t start = 0;
size_t end = str.find(delimiter);

while (end != std::string::npos) {
std::string current_word = str.substr(start, end - start);
auto tmp = common_tokenize(model, current_word, false, true);
result.push_back(tmp[0]);
start = end + delimiter.length();
end = str.find(delimiter, start);
}

// Add the last part
std::string current_word = str.substr(start);
auto tmp = common_tokenize(model, current_word, false, true);
result.push_back(tmp[0]);
return result;
}

int main(int argc, char ** argv) {
common_params params;

Expand Down Expand Up @@ -492,6 +515,7 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us();

std::vector<llama_token> codes;
std::vector<llama_token> guide_tokens;

// process prompt and generate voice codes
{
Expand All @@ -506,6 +530,10 @@ int main(int argc, char ** argv) {
// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt);
if(params.vocoder.use_guide_tokens)
{
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
}

LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());

Expand Down Expand Up @@ -715,6 +743,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
int n_past = batch.n_tokens;
int n_decode = 0;

bool next_token_uses_guide_token = true;

while (n_decode <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
Expand All @@ -726,7 +756,18 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
continue;
}

const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);

//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
{
llama_token guide_token = guide_tokens[0];
guide_tokens.erase(guide_tokens.begin());
new_token_id = guide_token; //ensure correct word fragment is used
}

//this is the token id that always precedes a new word
next_token_uses_guide_token = (new_token_id == 198);

common_sampler_accept(smpl[i], new_token_id, true);

Expand Down
4 changes: 2 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11122,9 +11122,9 @@ static int llama_decode_impl(
}
}

GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT_CONTINUE(n_tokens_all <= cparams.n_batch);

GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
GGML_ASSERT_CONTINUE((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");

if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
Expand Down

0 comments on commit 07173e8

Please sign in to comment.