@@ -446,7 +446,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
446
446
llama_save_session_file (ctx, path_session.c_str (), session_tokens.data (), session_tokens.size ());
447
447
}
448
448
449
- const llama_token id = llama_sample_token (ctx, ctx_guidance, grammar, params, last_tokens, candidates);
449
+ const llama_token id = llama_sample_token_binding (ctx, ctx_guidance, grammar, params_p, last_tokens, candidates);
450
+ // const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
450
451
451
452
last_tokens.erase (last_tokens.begin ());
452
453
last_tokens.push_back (id);
@@ -645,7 +646,9 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
645
646
int i_dft = 0 ;
646
647
while (true ) {
647
648
// sample from the target model
648
- const llama_token id = llama_sample_token (ctx_tgt, NULL , grammar_tgt, params, last_tokens, candidates, i_dft);
649
+
650
+ // const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
651
+ const llama_token id = llama_sample_token_binding (ctx_tgt, NULL , grammar_tgt, params_p, last_tokens, candidates, i_dft);
649
652
// remember which tokens were sampled - used for repetition penalties during sampling
650
653
last_tokens.erase (last_tokens.begin ());
651
654
last_tokens.push_back (id);
@@ -965,6 +968,15 @@ struct llama_binding_state {
965
968
966
969
void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity);
967
970
971
+ llama_token llama_sample_token_binding(
972
+ struct llama_context * ctx,
973
+ struct llama_context * ctx_guidance,
974
+ struct llama_grammar * grammar,
975
+ const struct gpt_params * g_params,
976
+ const std::vector<llama_token> & last_tokens,
977
+ std::vector<llama_token_data> & candidates,
978
+ int idx = 0);
979
+
968
980
common.cpp:
969
981
970
982
gpt_params* create_gpt_params(const std::string& fname,const std::string& lora,const std::string& lora_base) {
@@ -1060,4 +1072,127 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f
1060
1072
state->model= model;
1061
1073
return state;
1062
1074
}
1075
+
1076
+ // Note: the only difference here is passing params as a pointer and avoid copy-by-value
1077
+ // We stick to another function to avoid patching all the llama.cpp code
1078
+ // We need the function to be in the common.o object, as using it in the binding does not make effect.
1079
+ llama_token llama_sample_token_binding(
1080
+ struct llama_context * ctx,
1081
+ struct llama_context * ctx_guidance,
1082
+ struct llama_grammar * grammar,
1083
+ const struct gpt_params * g_params, // NOTE: this is our patch
1084
+ const std::vector<llama_token> & last_tokens,
1085
+ std::vector<llama_token_data> & candidates,
1086
+ int idx) {
1087
+
1088
+
1089
+ struct gpt_params params = *g_params; // NOTE: this is our patch
1090
+ const int n_ctx = llama_n_ctx(ctx);
1091
+ const int n_vocab = llama_n_vocab(ctx);
1092
+
1093
+ const float temp = params.temp;
1094
+ const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
1095
+ const float top_p = params.top_p;
1096
+ const float tfs_z = params.tfs_z;
1097
+ const float typical_p = params.typical_p;
1098
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
1099
+ const float repeat_penalty = params.repeat_penalty;
1100
+ const float alpha_presence = params.presence_penalty;
1101
+ const float alpha_frequency = params.frequency_penalty;
1102
+ const int mirostat = params.mirostat;
1103
+ const float mirostat_tau = params.mirostat_tau;
1104
+ const float mirostat_eta = params.mirostat_eta;
1105
+ const bool penalize_nl = params.penalize_nl;
1106
+
1107
+ llama_token id = 0;
1108
+
1109
+ float * logits = llama_get_logits(ctx) + idx * n_vocab;
1110
+
1111
+ // Apply params.logit_bias map
1112
+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
1113
+ logits[it->first] += it->second;
1114
+ }
1115
+
1116
+ candidates.clear();
1117
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
1118
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
1119
+ }
1120
+
1121
+ llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
1122
+
1123
+ if (ctx_guidance) {
1124
+ llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
1125
+ }
1126
+
1127
+ // apply penalties
1128
+ if (!last_tokens.empty()) {
1129
+ const float nl_logit = logits[llama_token_nl(ctx)];
1130
+ const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
1131
+
1132
+ llama_sample_repetition_penalty(ctx, &cur_p,
1133
+ last_tokens.data() + last_tokens.size() - last_n_repeat,
1134
+ last_n_repeat, repeat_penalty);
1135
+ llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
1136
+ last_tokens.data() + last_tokens.size() - last_n_repeat,
1137
+ last_n_repeat, alpha_frequency, alpha_presence);
1138
+
1139
+ if (!penalize_nl) {
1140
+ for (size_t idx = 0; idx < cur_p.size; idx++) {
1141
+ if (cur_p.data[idx].id == llama_token_nl(ctx)) {
1142
+ cur_p.data[idx].logit = nl_logit;
1143
+ break;
1144
+ }
1145
+ }
1146
+ }
1147
+ }
1148
+
1149
+ if (grammar != NULL) {
1150
+ llama_sample_grammar(ctx, &cur_p, grammar);
1151
+ }
1152
+
1153
+ if (temp <= 0) {
1154
+ // Greedy sampling
1155
+ id = llama_sample_token_greedy(ctx, &cur_p);
1156
+ } else {
1157
+ if (mirostat == 1) {
1158
+ static float mirostat_mu = 2.0f * mirostat_tau;
1159
+ const int mirostat_m = 100;
1160
+ llama_sample_temperature(ctx, &cur_p, temp);
1161
+ id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
1162
+ } else if (mirostat == 2) {
1163
+ static float mirostat_mu = 2.0f * mirostat_tau;
1164
+ llama_sample_temperature(ctx, &cur_p, temp);
1165
+ id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
1166
+ } else {
1167
+ // Temperature sampling
1168
+ llama_sample_top_k (ctx, &cur_p, top_k, 1);
1169
+ llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
1170
+ llama_sample_typical (ctx, &cur_p, typical_p, 1);
1171
+ llama_sample_top_p (ctx, &cur_p, top_p, 1);
1172
+ llama_sample_temperature(ctx, &cur_p, temp);
1173
+
1174
+ {
1175
+ const int n_top = 10;
1176
+ LOG("top %d candidates:\n", n_top);
1177
+
1178
+ for (int i = 0; i < n_top; i++) {
1179
+ const llama_token id = cur_p.data[i].id;
1180
+ LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
1181
+ }
1182
+ }
1183
+
1184
+ id = llama_sample_token(ctx, &cur_p);
1185
+
1186
+ LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
1187
+ }
1188
+ }
1189
+ // printf("`%d`", candidates_p.size);
1190
+
1191
+ if (grammar != NULL) {
1192
+ llama_grammar_accept_token(ctx, grammar, id);
1193
+ }
1194
+
1195
+ return id;
1196
+ }
1197
+
1063
1198
*/
0 commit comments