@@ -158,7 +158,7 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
158
158
}
159
159
160
160
std::vector<llama_token> embd_inp;
161
- if (session_tokens.empty ()) {
161
+ if ( !params. prompt . empty () || session_tokens.empty () ) {
162
162
// Add a space in front of the first character to match OG llama tokenizer behavior
163
163
params.prompt .insert (0 , 1 , ' ' );
164
164
@@ -190,7 +190,12 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
190
190
}
191
191
}
192
192
}
193
-
193
+ // if we will use the cache for the full prompt without reaching the end of the cache, force
194
+ // reevaluation of the last token token to recalculate the cached logits
195
+ if (!embd_inp.empty () && n_matching_session_tokens == embd_inp.size () &&
196
+ session_tokens.size () > embd_inp.size ()) {
197
+ session_tokens.resize (embd_inp.size () - 1 );
198
+ }
194
199
// number of tokens to keep when resetting context
195
200
if (params.n_keep < 0 || params.n_keep > (int )embd_inp.size () || params.instruct ) {
196
201
params.n_keep = (int )embd_inp.size ();
@@ -258,12 +263,6 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
258
263
}
259
264
}
260
265
if (i > 0 ) {
261
- // check if we've used up all the prompt but not all cached tokens
262
- if (embd.size () == i && n_session_consumed < (int ) session_tokens.size ()) {
263
- // force revaluation of the last token to recalculate logits
264
- i--;
265
- n_past--;
266
- }
267
266
embd.erase (embd.begin (), embd.begin () + i);
268
267
}
269
268
}
0 commit comments