@@ -619,14 +619,32 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
619
619
// used to determine end of generation
620
620
bool has_eos = false ;
621
621
622
+ // grammar stuff
623
+ struct llama_grammar * grammar_dft = NULL ;
624
+ struct llama_grammar * grammar_tgt = NULL ;
625
+
626
+ grammar_parser::parse_state parsed_grammar;
627
+
628
+ // if requested - load the grammar, error checking is omitted for brevity
629
+ if (!params.grammar .empty ()) {
630
+ parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
631
+ // will be empty (default) if there are parse errors
632
+ if (parsed_grammar.rules .empty ()) {
633
+ return 1 ;
634
+ }
635
+
636
+ std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
637
+ grammar_tgt = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
638
+ }
639
+
622
640
const auto t_dec_start = ggml_time_us ();
623
641
624
642
while (true ) {
625
- // sample from the drafted tokens if any
626
643
int i_dft = 0 ;
627
644
while (true ) {
628
- const llama_token id = llama_sample_token (ctx_tgt, NULL , NULL , params, last_tokens, candidates, i_dft);
629
-
645
+ // sample from the target model
646
+ const llama_token id = llama_sample_token (ctx_tgt, NULL , grammar_tgt, params, last_tokens, candidates, i_dft);
647
+ // remember which tokens were sampled - used for repetition penalties during sampling
630
648
last_tokens.erase (last_tokens.begin ());
631
649
last_tokens.push_back (id);
632
650
@@ -644,6 +662,7 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
644
662
645
663
++n_predict;
646
664
665
+ // check if the draft matches the target
647
666
if (i_dft < (int ) drafted.size () && id == drafted[i_dft]) {
648
667
LOG (" drafted token %d accepted\n " , id);
649
668
++n_accept;
@@ -654,6 +673,13 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
654
673
continue ;
655
674
}
656
675
676
+ if (i_dft < (int ) drafted.size ()) {
677
+ LOG (" the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n " ,
678
+ i_dft, drafted[i_dft], llama_token_to_piece (ctx_dft, drafted[i_dft]).c_str (), id, token_str.c_str ());
679
+ } else {
680
+ LOG (" out of drafted tokens\n " );
681
+ }
682
+
657
683
// the drafted token was rejected or we are out of drafted tokens
658
684
llama_eval (ctx_dft, &id, 1 , n_past_dft, params.n_threads );
659
685
++n_past_dft;
@@ -668,7 +694,16 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
668
694
break ;
669
695
}
670
696
671
- // sample n_draft tokens from the draft model picking the best token
697
+ if (grammar_tgt) {
698
+ if (grammar_dft) {
699
+ llama_grammar_free (grammar_dft);
700
+ }
701
+ grammar_dft = llama_grammar_copy (grammar_tgt);
702
+
703
+ LOG (" copied target grammar to draft grammar\n " );
704
+ }
705
+
706
+ // sample n_draft tokens from the draft model using greedy decoding
672
707
int n_past_cur = n_past_dft;
673
708
for (int i = 0 ; i < n_draft; ++i) {
674
709
float * logits = llama_get_logits (ctx_dft);
@@ -680,32 +715,48 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
680
715
681
716
llama_token_data_array cur_p = { candidates.data (), candidates.size (), false };
682
717
718
+ if (grammar_dft != NULL ) {
719
+ llama_sample_grammar (ctx_dft, &cur_p, grammar_dft);
720
+ }
721
+
683
722
// computes softmax and sorts the candidates
684
723
llama_sample_softmax (ctx_dft, &cur_p);
685
724
686
725
for (int i = 0 ; i < 3 ; ++i) {
687
726
LOG (" - draft candidate %d: %d (%.3f)\n " , i, cur_p.data [i].id , cur_p.data [i].p );
688
727
}
689
728
690
- // too low probability, stop drafting
729
+ // TODO: better logic?
691
730
if (cur_p.data [0 ].p < 2 *cur_p.data [1 ].p ) {
731
+ LOG (" stopping drafting, probability too low: %.3f < 2*%.3f\n " , cur_p.data [0 ].p , cur_p.data [1 ].p );
692
732
break ;
693
733
}
694
734
695
- drafted.push_back (cur_p.data [0 ].id );
735
+ // drafted token
736
+ const llama_token id = cur_p.data [0 ].id ;
737
+
738
+ drafted.push_back (id);
696
739
++n_drafted;
697
740
698
- if (i < n_draft - 1 ) {
699
- // evaluate the drafted token on the draft model
700
- llama_eval (ctx_dft, &drafted.back (), 1 , n_past_cur, params.n_threads );
701
- ++n_past_cur;
741
+ // no need to evaluate the last drafted token, since we won't use the result
742
+ if (i == n_draft - 1 ) {
743
+ break ;
744
+ }
745
+
746
+ // evaluate the drafted token on the draft model
747
+ llama_eval (ctx_dft, &drafted.back (), 1 , n_past_cur, params.n_threads );
748
+ ++n_past_cur;
749
+
750
+ if (grammar_dft != NULL ) {
751
+ llama_grammar_accept_token (ctx_dft, grammar_dft, id);
702
752
}
703
753
}
704
754
705
755
// evaluate the target model on the drafted tokens
706
756
llama_eval (ctx_tgt, drafted.data (), drafted.size (), n_past_tgt, params.n_threads );
707
757
++n_past_tgt;
708
-
758
+
759
+ // the first token is always proposed by the traget model before the speculation loop
709
760
drafted.erase (drafted.begin ());
710
761
}
711
762
if (debug) {
@@ -732,7 +783,10 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
732
783
733
784
fprintf (stderr, " \n\n " );
734
785
}
735
-
786
+ if (grammar_dft != NULL ) {
787
+ llama_grammar_free (grammar_dft);
788
+ llama_grammar_free (grammar_tgt);
789
+ }
736
790
strcpy (result, res.c_str ());
737
791
return 0 ;
738
792
}
0 commit comments