Skip to content

Commit 621a001

Browse files
authored
fix(cuda): pass pointer instead of copy-by-value in llama_sample_token (#228)
Signed-off-by: mudler <[email protected]>
1 parent 40c0d3d commit 621a001

File tree

2 files changed

+272
-7
lines changed

2 files changed

+272
-7
lines changed

binding.cpp

+137-2
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
446446
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
447447
}
448448

449-
const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
449+
const llama_token id = llama_sample_token_binding(ctx, ctx_guidance, grammar, params_p, last_tokens, candidates);
450+
//const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
450451

451452
last_tokens.erase(last_tokens.begin());
452453
last_tokens.push_back(id);
@@ -645,7 +646,9 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
645646
int i_dft = 0;
646647
while (true) {
647648
// sample from the target model
648-
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
649+
650+
// const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
651+
const llama_token id = llama_sample_token_binding(ctx_tgt, NULL, grammar_tgt, params_p, last_tokens, candidates, i_dft);
649652
// remember which tokens were sampled - used for repetition penalties during sampling
650653
last_tokens.erase(last_tokens.begin());
651654
last_tokens.push_back(id);
@@ -965,6 +968,15 @@ struct llama_binding_state {
965968
966969
void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity);
967970
971+
llama_token llama_sample_token_binding(
972+
struct llama_context * ctx,
973+
struct llama_context * ctx_guidance,
974+
struct llama_grammar * grammar,
975+
const struct gpt_params * g_params,
976+
const std::vector<llama_token> & last_tokens,
977+
std::vector<llama_token_data> & candidates,
978+
int idx = 0);
979+
968980
common.cpp:
969981
970982
gpt_params* create_gpt_params(const std::string& fname,const std::string& lora,const std::string& lora_base) {
@@ -1060,4 +1072,127 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f
10601072
state->model= model;
10611073
return state;
10621074
}
1075+
1076+
// Note: the only difference here is passing params as a pointer and avoid copy-by-value
1077+
// We stick to another function to avoid patching all the llama.cpp code
1078+
// We need the function to be in the common.o object, as using it in the binding does not make effect.
1079+
llama_token llama_sample_token_binding(
1080+
struct llama_context * ctx,
1081+
struct llama_context * ctx_guidance,
1082+
struct llama_grammar * grammar,
1083+
const struct gpt_params * g_params, // NOTE: this is our patch
1084+
const std::vector<llama_token> & last_tokens,
1085+
std::vector<llama_token_data> & candidates,
1086+
int idx) {
1087+
1088+
1089+
struct gpt_params params = *g_params; // NOTE: this is our patch
1090+
const int n_ctx = llama_n_ctx(ctx);
1091+
const int n_vocab = llama_n_vocab(ctx);
1092+
1093+
const float temp = params.temp;
1094+
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
1095+
const float top_p = params.top_p;
1096+
const float tfs_z = params.tfs_z;
1097+
const float typical_p = params.typical_p;
1098+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
1099+
const float repeat_penalty = params.repeat_penalty;
1100+
const float alpha_presence = params.presence_penalty;
1101+
const float alpha_frequency = params.frequency_penalty;
1102+
const int mirostat = params.mirostat;
1103+
const float mirostat_tau = params.mirostat_tau;
1104+
const float mirostat_eta = params.mirostat_eta;
1105+
const bool penalize_nl = params.penalize_nl;
1106+
1107+
llama_token id = 0;
1108+
1109+
float * logits = llama_get_logits(ctx) + idx * n_vocab;
1110+
1111+
// Apply params.logit_bias map
1112+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
1113+
logits[it->first] += it->second;
1114+
}
1115+
1116+
candidates.clear();
1117+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
1118+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
1119+
}
1120+
1121+
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
1122+
1123+
if (ctx_guidance) {
1124+
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
1125+
}
1126+
1127+
// apply penalties
1128+
if (!last_tokens.empty()) {
1129+
const float nl_logit = logits[llama_token_nl(ctx)];
1130+
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
1131+
1132+
llama_sample_repetition_penalty(ctx, &cur_p,
1133+
last_tokens.data() + last_tokens.size() - last_n_repeat,
1134+
last_n_repeat, repeat_penalty);
1135+
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
1136+
last_tokens.data() + last_tokens.size() - last_n_repeat,
1137+
last_n_repeat, alpha_frequency, alpha_presence);
1138+
1139+
if (!penalize_nl) {
1140+
for (size_t idx = 0; idx < cur_p.size; idx++) {
1141+
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
1142+
cur_p.data[idx].logit = nl_logit;
1143+
break;
1144+
}
1145+
}
1146+
}
1147+
}
1148+
1149+
if (grammar != NULL) {
1150+
llama_sample_grammar(ctx, &cur_p, grammar);
1151+
}
1152+
1153+
if (temp <= 0) {
1154+
// Greedy sampling
1155+
id = llama_sample_token_greedy(ctx, &cur_p);
1156+
} else {
1157+
if (mirostat == 1) {
1158+
static float mirostat_mu = 2.0f * mirostat_tau;
1159+
const int mirostat_m = 100;
1160+
llama_sample_temperature(ctx, &cur_p, temp);
1161+
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
1162+
} else if (mirostat == 2) {
1163+
static float mirostat_mu = 2.0f * mirostat_tau;
1164+
llama_sample_temperature(ctx, &cur_p, temp);
1165+
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
1166+
} else {
1167+
// Temperature sampling
1168+
llama_sample_top_k (ctx, &cur_p, top_k, 1);
1169+
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
1170+
llama_sample_typical (ctx, &cur_p, typical_p, 1);
1171+
llama_sample_top_p (ctx, &cur_p, top_p, 1);
1172+
llama_sample_temperature(ctx, &cur_p, temp);
1173+
1174+
{
1175+
const int n_top = 10;
1176+
LOG("top %d candidates:\n", n_top);
1177+
1178+
for (int i = 0; i < n_top; i++) {
1179+
const llama_token id = cur_p.data[i].id;
1180+
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
1181+
}
1182+
}
1183+
1184+
id = llama_sample_token(ctx, &cur_p);
1185+
1186+
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
1187+
}
1188+
}
1189+
// printf("`%d`", candidates_p.size);
1190+
1191+
if (grammar != NULL) {
1192+
llama_grammar_accept_token(ctx, grammar, id);
1193+
}
1194+
1195+
return id;
1196+
}
1197+
10631198
*/

patches/1902-cuda.patch

+135-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
diff --git a/common/common.cpp b/common/common.cpp
2-
index d4f9dbf..9a01627 100644
2+
index 2597ba0..e42ae73 100644
33
--- a/common/common.cpp
44
+++ b/common/common.cpp
5-
@@ -1259,3 +1259,97 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
5+
@@ -1268,3 +1268,218 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
66
fprintf(stream, "typical_p: %f # default: 1.0\n", params.typical_p);
77
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
88
}
@@ -69,7 +69,7 @@ index d4f9dbf..9a01627 100644
6969
+ if (maingpu[0] != '\0') {
7070
+ lparams->main_gpu = std::stoi(maingpu);
7171
+ }
72-
+
72+
+
7373
+ if (tensorsplit[0] != '\0') {
7474
+ std::string arg_next = tensorsplit;
7575
+ // split string by , and /
@@ -100,12 +100,133 @@ index d4f9dbf..9a01627 100644
100100
+ state->model= model;
101101
+ return state;
102102
+}
103+
+
104+
+// Note: the only difference here is passing params as a pointer and avoid copy-by-value
105+
+// We stick to another function to avoid patching all the llama.cpp code
106+
+// We need the function to be in the common.o object, as using it in the binding does not make effect.
107+
+llama_token llama_sample_token_binding(
108+
+ struct llama_context * ctx,
109+
+ struct llama_context * ctx_guidance,
110+
+ struct llama_grammar * grammar,
111+
+ const struct gpt_params * g_params,
112+
+ const std::vector<llama_token> & last_tokens,
113+
+ std::vector<llama_token_data> & candidates,
114+
+ int idx) {
115+
+
116+
+ struct gpt_params params = *g_params;
117+
+ const int n_ctx = llama_n_ctx(ctx);
118+
+ const int n_vocab = llama_n_vocab(ctx);
119+
+
120+
+ const float temp = params.temp;
121+
+ const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
122+
+ const float top_p = params.top_p;
123+
+ const float tfs_z = params.tfs_z;
124+
+ const float typical_p = params.typical_p;
125+
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
126+
+ const float repeat_penalty = params.repeat_penalty;
127+
+ const float alpha_presence = params.presence_penalty;
128+
+ const float alpha_frequency = params.frequency_penalty;
129+
+ const int mirostat = params.mirostat;
130+
+ const float mirostat_tau = params.mirostat_tau;
131+
+ const float mirostat_eta = params.mirostat_eta;
132+
+ const bool penalize_nl = params.penalize_nl;
133+
+
134+
+ llama_token id = 0;
135+
+
136+
+ float * logits = llama_get_logits(ctx) + idx * n_vocab;
137+
+
138+
+ // Apply params.logit_bias map
139+
+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
140+
+ logits[it->first] += it->second;
141+
+ }
142+
+
143+
+ candidates.clear();
144+
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
145+
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
146+
+ }
147+
+
148+
+ llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
149+
+
150+
+ if (ctx_guidance) {
151+
+ llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
152+
+ }
153+
+
154+
+ // apply penalties
155+
+ if (!last_tokens.empty()) {
156+
+ const float nl_logit = logits[llama_token_nl(ctx)];
157+
+ const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
158+
+
159+
+ llama_sample_repetition_penalty(ctx, &cur_p,
160+
+ last_tokens.data() + last_tokens.size() - last_n_repeat,
161+
+ last_n_repeat, repeat_penalty);
162+
+ llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
163+
+ last_tokens.data() + last_tokens.size() - last_n_repeat,
164+
+ last_n_repeat, alpha_frequency, alpha_presence);
165+
+
166+
+ if (!penalize_nl) {
167+
+ for (size_t idx = 0; idx < cur_p.size; idx++) {
168+
+ if (cur_p.data[idx].id == llama_token_nl(ctx)) {
169+
+ cur_p.data[idx].logit = nl_logit;
170+
+ break;
171+
+ }
172+
+ }
173+
+ }
174+
+ }
175+
+
176+
+ if (grammar != NULL) {
177+
+ llama_sample_grammar(ctx, &cur_p, grammar);
178+
+ }
179+
+
180+
+ if (temp <= 0) {
181+
+ // Greedy sampling
182+
+ id = llama_sample_token_greedy(ctx, &cur_p);
183+
+ } else {
184+
+ if (mirostat == 1) {
185+
+ static float mirostat_mu = 2.0f * mirostat_tau;
186+
+ const int mirostat_m = 100;
187+
+ llama_sample_temperature(ctx, &cur_p, temp);
188+
+ id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
189+
+ } else if (mirostat == 2) {
190+
+ static float mirostat_mu = 2.0f * mirostat_tau;
191+
+ llama_sample_temperature(ctx, &cur_p, temp);
192+
+ id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
193+
+ } else {
194+
+ // Temperature sampling
195+
+ llama_sample_top_k (ctx, &cur_p, top_k, 1);
196+
+ llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
197+
+ llama_sample_typical (ctx, &cur_p, typical_p, 1);
198+
+ llama_sample_top_p (ctx, &cur_p, top_p, 1);
199+
+ llama_sample_temperature(ctx, &cur_p, temp);
200+
+
201+
+ {
202+
+ const int n_top = 10;
203+
+ LOG("top %d candidates:\n", n_top);
204+
+
205+
+ for (int i = 0; i < n_top; i++) {
206+
+ const llama_token id = cur_p.data[i].id;
207+
+ LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
208+
+ }
209+
+ }
210+
+
211+
+ id = llama_sample_token(ctx, &cur_p);
212+
+
213+
+ LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
214+
+ }
215+
+ }
216+
+ // printf("`%d`", candidates_p.size);
217+
+
218+
+ if (grammar != NULL) {
219+
+ llama_grammar_accept_token(ctx, grammar, id);
220+
+ }
221+
+
222+
+ return id;
223+
+}
103224
\ No newline at end of file
104225
diff --git a/common/common.h b/common/common.h
105-
index 85ac0df..eb9d24b 100644
226+
index 18aea38..ca7a168 100644
106227
--- a/common/common.h
107228
+++ b/common/common.h
108-
@@ -201,3 +201,10 @@ std::string get_sortable_timestamp();
229+
@@ -209,3 +209,19 @@ std::string get_sortable_timestamp();
109230
void dump_non_result_info_yaml(
110231
FILE * stream, const gpt_params & params, const llama_context * lctx,
111232
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
@@ -116,3 +237,12 @@ index 85ac0df..eb9d24b 100644
116237
+};
117238
+
118239
+void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity);
240+
+
241+
+llama_token llama_sample_token_binding(
242+
+ struct llama_context * ctx,
243+
+ struct llama_context * ctx_guidance,
244+
+ struct llama_grammar * grammar,
245+
+ const struct gpt_params * g_params,
246+
+ const std::vector<llama_token> & last_tokens,
247+
+ std::vector<llama_token_data> & candidates,
248+
+ int idx = 0);

0 commit comments

Comments
 (0)