Skip to content

Commit bd27600

Browse files
authored
feat(speculative-sampling): add grammar support (#203)
Signed-off-by: mudler <[email protected]>
1 parent d143221 commit bd27600

File tree

3 files changed

+68
-14
lines changed

3 files changed

+68
-14
lines changed

.github/workflows/test.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
run: go version
4747
- name: Test
4848
run: |
49-
CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" make test
49+
CMAKE_ARGS="-DLLAMA_METAL=OFF -DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" make test
5050
5151
macOS-metal-latest:
5252
runs-on: macOS-latest

binding.cpp

+66-12
Original file line numberDiff line numberDiff line change
@@ -619,14 +619,32 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
619619
// used to determine end of generation
620620
bool has_eos = false;
621621

622+
// grammar stuff
623+
struct llama_grammar * grammar_dft = NULL;
624+
struct llama_grammar * grammar_tgt = NULL;
625+
626+
grammar_parser::parse_state parsed_grammar;
627+
628+
// if requested - load the grammar, error checking is omitted for brevity
629+
if (!params.grammar.empty()) {
630+
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
631+
// will be empty (default) if there are parse errors
632+
if (parsed_grammar.rules.empty()) {
633+
return 1;
634+
}
635+
636+
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
637+
grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
638+
}
639+
622640
const auto t_dec_start = ggml_time_us();
623641

624642
while (true) {
625-
// sample from the drafted tokens if any
626643
int i_dft = 0;
627644
while (true) {
628-
const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft);
629-
645+
// sample from the target model
646+
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
647+
// remember which tokens were sampled - used for repetition penalties during sampling
630648
last_tokens.erase(last_tokens.begin());
631649
last_tokens.push_back(id);
632650

@@ -644,6 +662,7 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
644662

645663
++n_predict;
646664

665+
// check if the draft matches the target
647666
if (i_dft < (int) drafted.size() && id == drafted[i_dft]) {
648667
LOG("drafted token %d accepted\n", id);
649668
++n_accept;
@@ -654,6 +673,13 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
654673
continue;
655674
}
656675

676+
if (i_dft < (int) drafted.size()) {
677+
LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n",
678+
i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str());
679+
} else {
680+
LOG("out of drafted tokens\n");
681+
}
682+
657683
// the drafted token was rejected or we are out of drafted tokens
658684
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
659685
++n_past_dft;
@@ -668,7 +694,16 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
668694
break;
669695
}
670696

671-
// sample n_draft tokens from the draft model picking the best token
697+
if (grammar_tgt) {
698+
if (grammar_dft) {
699+
llama_grammar_free(grammar_dft);
700+
}
701+
grammar_dft = llama_grammar_copy(grammar_tgt);
702+
703+
LOG("copied target grammar to draft grammar\n");
704+
}
705+
706+
// sample n_draft tokens from the draft model using greedy decoding
672707
int n_past_cur = n_past_dft;
673708
for (int i = 0; i < n_draft; ++i) {
674709
float * logits = llama_get_logits(ctx_dft);
@@ -680,32 +715,48 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
680715

681716
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
682717

718+
if (grammar_dft != NULL) {
719+
llama_sample_grammar(ctx_dft, &cur_p, grammar_dft);
720+
}
721+
683722
// computes softmax and sorts the candidates
684723
llama_sample_softmax(ctx_dft, &cur_p);
685724

686725
for (int i = 0; i < 3; ++i) {
687726
LOG(" - draft candidate %d: %d (%.3f)\n", i, cur_p.data[i].id, cur_p.data[i].p);
688727
}
689728

690-
// too low probability, stop drafting
729+
// TODO: better logic?
691730
if (cur_p.data[0].p < 2*cur_p.data[1].p) {
731+
LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p);
692732
break;
693733
}
694734

695-
drafted.push_back(cur_p.data[0].id);
735+
// drafted token
736+
const llama_token id = cur_p.data[0].id;
737+
738+
drafted.push_back(id);
696739
++n_drafted;
697740

698-
if (i < n_draft - 1) {
699-
// evaluate the drafted token on the draft model
700-
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
701-
++n_past_cur;
741+
// no need to evaluate the last drafted token, since we won't use the result
742+
if (i == n_draft - 1) {
743+
break;
744+
}
745+
746+
// evaluate the drafted token on the draft model
747+
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
748+
++n_past_cur;
749+
750+
if (grammar_dft != NULL) {
751+
llama_grammar_accept_token(ctx_dft, grammar_dft, id);
702752
}
703753
}
704754

705755
// evaluate the target model on the drafted tokens
706756
llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads);
707757
++n_past_tgt;
708-
758+
759+
// the first token is always proposed by the traget model before the speculation loop
709760
drafted.erase(drafted.begin());
710761
}
711762
if (debug) {
@@ -732,7 +783,10 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
732783

733784
fprintf(stderr, "\n\n");
734785
}
735-
786+
if (grammar_dft != NULL) {
787+
llama_grammar_free(grammar_dft);
788+
llama_grammar_free(grammar_tgt);
789+
}
736790
strcpy(result, res.c_str());
737791
return 0;
738792
}

llama.cpp

0 commit comments

Comments
 (0)