Skip to content

Commit 3f10005

Browse files
authored
Fix session prompt, align to upstream changes (#79)
1 parent 10caf37 commit 3f10005

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

binding.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
158158
}
159159

160160
std::vector<llama_token> embd_inp;
161-
if (session_tokens.empty()) {
161+
if ( !params.prompt.empty() || session_tokens.empty() ) {
162162
// Add a space in front of the first character to match OG llama tokenizer behavior
163163
params.prompt.insert(0, 1, ' ');
164164

@@ -190,7 +190,12 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
190190
}
191191
}
192192
}
193-
193+
// if we will use the cache for the full prompt without reaching the end of the cache, force
194+
// reevaluation of the last token token to recalculate the cached logits
195+
if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() &&
196+
session_tokens.size() > embd_inp.size()) {
197+
session_tokens.resize(embd_inp.size() - 1);
198+
}
194199
// number of tokens to keep when resetting context
195200
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
196201
params.n_keep = (int)embd_inp.size();
@@ -258,12 +263,6 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
258263
}
259264
}
260265
if (i > 0) {
261-
// check if we've used up all the prompt but not all cached tokens
262-
if (embd.size() == i && n_session_consumed < (int) session_tokens.size()) {
263-
// force revaluation of the last token to recalculate logits
264-
i--;
265-
n_past--;
266-
}
267266
embd.erase(embd.begin(), embd.begin() + i);
268267
}
269268
}

0 commit comments

Comments
 (0)