Skip to content

Commit

Permalink
feat: sync llama.cpp (#122)
Browse files Browse the repository at this point in the history
* feat: sync llama.cpp

* fix(ts): exclude llama.cpp/ for tsc

* feat: sync llama.cpp

* feat: expose top_n_sigma completion param

* fix(ts): config

* chore: bump react-native-builder-bob

* ci: use node 20

* feat: add reasoning_format param & reasoning_content in completion result

* fix(android): reasoning_format
  • Loading branch information
jhen0409 authored Feb 18, 2025
1 parent bb9d6e8 commit 2b1a8bd
Show file tree
Hide file tree
Showing 44 changed files with 3,207 additions and 1,190 deletions.
2 changes: 1 addition & 1 deletion .github/actions/setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ runs:
- name: Setup Node.js
uses: actions/setup-node@v3
with:
node-version: 18.x
node-version: 20.x

- name: Cache dependencies
id: yarn-cache
Expand Down
6 changes: 6 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.getString("model"),
// String chat_template,
params.hasKey("chat_template") ? params.getString("chat_template") : "",
// String reasoning_format,
params.hasKey("reasoning_format") ? params.getString("reasoning_format") : "none",
// boolean embedding,
params.hasKey("embedding") ? params.getBoolean("embedding") : false,
// int embd_normalize,
Expand Down Expand Up @@ -301,6 +303,8 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2,
// int dry_penalty_last_n,
params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1,
// float top_n_sigma,
params.hasKey("top_n_sigma") ? (float) params.getDouble("top_n_sigma") : -1.0f,
// String[] dry_sequence_breakers, when undef, we use the default definition from common.h
params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"},
// PartialCompletionCallback partial_completion_callback
Expand Down Expand Up @@ -468,6 +472,7 @@ protected static native WritableMap modelInfo(
protected static native long initContext(
String model,
String chat_template,
String reasoning_format,
boolean embedding,
int embd_normalize,
int n_ctx,
Expand Down Expand Up @@ -550,6 +555,7 @@ protected static native WritableMap doCompletion(
float dry_base,
int dry_allowed_length,
int dry_penalty_last_n,
float top_n_sigma,
String[] dry_sequence_breakers,
PartialCompletionCallback partial_completion_callback
);
Expand Down
23 changes: 23 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ Java_com_rnllama_LlamaContext_initContext(
jobject thiz,
jstring model_path_str,
jstring chat_template,
jstring reasoning_format,
jboolean embedding,
jint embd_normalize,
jint n_ctx,
Expand Down Expand Up @@ -259,6 +260,13 @@ Java_com_rnllama_LlamaContext_initContext(
const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
defaultParams.chat_template = chat_template_chars;

const char *reasoning_format_chars = env->GetStringUTFChars(reasoning_format, nullptr);
if (strcmp(reasoning_format_chars, "deepseek") == 0) {
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
} else {
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
}

defaultParams.n_ctx = n_ctx;
defaultParams.n_batch = n_batch;
defaultParams.n_ubatch = n_ubatch;
Expand Down Expand Up @@ -326,6 +334,7 @@ Java_com_rnllama_LlamaContext_initContext(

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(chat_template, chat_template_chars);
env->ReleaseStringUTFChars(reasoning_format, reasoning_format_chars);
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);

Expand Down Expand Up @@ -664,6 +673,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
jfloat dry_base,
jint dry_allowed_length,
jint dry_penalty_last_n,
jfloat top_n_sigma,
jobjectArray dry_sequence_breakers,
jobject partial_completion_callback
) {
Expand Down Expand Up @@ -706,6 +716,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.dry_base = dry_base;
sparams.dry_allowed_length = dry_allowed_length;
sparams.dry_penalty_last_n = dry_penalty_last_n;
sparams.top_n_sigma = top_n_sigma;

// grammar
auto grammar_chars = env->GetStringUTFChars(grammar, nullptr);
Expand Down Expand Up @@ -882,10 +893,16 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama->is_predicting = false;

auto toolCalls = createWritableArray(env);
std::string reasoningContent = "";
std::string *content = nullptr;
auto toolCallsSize = 0;
if (!llama->is_interrupted) {
try {
common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
if (!message.reasoning_content.empty()) {
reasoningContent = message.reasoning_content;
}
content = &message.content;
for (const auto &tc : message.tool_calls) {
auto toolCall = createWriteableMap(env);
putString(env, toolCall, "type", "function");
Expand All @@ -906,6 +923,12 @@ Java_com_rnllama_LlamaContext_doCompletion(

auto result = createWriteableMap(env);
putString(env, result, "text", llama->generated_text.c_str());
if (content) {
putString(env, result, "content", content->c_str());
}
if (!reasoningContent.empty()) {
putString(env, result, "reasoning_content", reasoningContent.c_str());
}
if (toolCallsSize > 0) {
putArray(env, result, "tool_calls", toolCalls);
}
Expand Down
28 changes: 21 additions & 7 deletions cpp/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,30 @@ class chat_template {
inputs.add_generation_prompt = false;
full = apply(inputs);
}

if (full.find(prefix) != 0) {
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
auto eos_pos_last = full.rfind(eos_token_);
if (eos_pos_last == prefix.size() - eos_token_.size() ||
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
full = full.substr(0, eos_pos_last);
}
size_t common_prefix_length = 0;
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
if (prefix[i] != full[i]) {
break;
}
if (prefix[i] == '<') {
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
// but it removes thinking tags for past messages.
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
continue;
}
common_prefix_length = i + 1;
}
if (full.find(prefix) != 0) {
auto example = full.substr(common_prefix_length);
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
} else {
tool_call_example_ = example;
}
tool_call_example_ = full.substr(prefix.size());
}
} catch (const std::exception & e) {
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
Expand Down Expand Up @@ -363,7 +377,7 @@ class chat_template {
if (polyfill_tools) {
adjusted_messages = add_system(inputs.messages,
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
} else {
adjusted_messages = inputs.messages;
}
Expand Down
Loading

0 comments on commit 2b1a8bd

Please sign in to comment.