Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sync llama.cpp #122

Merged
merged 10 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ runs:
- name: Setup Node.js
uses: actions/setup-node@v3
with:
node-version: 18.x
node-version: 20.x

- name: Cache dependencies
id: yarn-cache
Expand Down
6 changes: 6 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.getString("model"),
// String chat_template,
params.hasKey("chat_template") ? params.getString("chat_template") : "",
// String reasoning_format,
params.hasKey("reasoning_format") ? params.getString("reasoning_format") : "none",
// boolean embedding,
params.hasKey("embedding") ? params.getBoolean("embedding") : false,
// int embd_normalize,
Expand Down Expand Up @@ -301,6 +303,8 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2,
// int dry_penalty_last_n,
params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1,
// float top_n_sigma,
params.hasKey("top_n_sigma") ? (float) params.getDouble("top_n_sigma") : -1.0f,
// String[] dry_sequence_breakers, when undef, we use the default definition from common.h
params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"},
// PartialCompletionCallback partial_completion_callback
Expand Down Expand Up @@ -468,6 +472,7 @@ protected static native WritableMap modelInfo(
protected static native long initContext(
String model,
String chat_template,
String reasoning_format,
boolean embedding,
int embd_normalize,
int n_ctx,
Expand Down Expand Up @@ -550,6 +555,7 @@ protected static native WritableMap doCompletion(
float dry_base,
int dry_allowed_length,
int dry_penalty_last_n,
float top_n_sigma,
String[] dry_sequence_breakers,
PartialCompletionCallback partial_completion_callback
);
Expand Down
23 changes: 23 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ Java_com_rnllama_LlamaContext_initContext(
jobject thiz,
jstring model_path_str,
jstring chat_template,
jstring reasoning_format,
jboolean embedding,
jint embd_normalize,
jint n_ctx,
Expand Down Expand Up @@ -259,6 +260,13 @@ Java_com_rnllama_LlamaContext_initContext(
const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
defaultParams.chat_template = chat_template_chars;

const char *reasoning_format_chars = env->GetStringUTFChars(reasoning_format, nullptr);
if (strcmp(reasoning_format_chars, "deepseek") == 0) {
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
} else {
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
}

defaultParams.n_ctx = n_ctx;
defaultParams.n_batch = n_batch;
defaultParams.n_ubatch = n_ubatch;
Expand Down Expand Up @@ -326,6 +334,7 @@ Java_com_rnllama_LlamaContext_initContext(

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(chat_template, chat_template_chars);
env->ReleaseStringUTFChars(reasoning_format, reasoning_format_chars);
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);

Expand Down Expand Up @@ -664,6 +673,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
jfloat dry_base,
jint dry_allowed_length,
jint dry_penalty_last_n,
jfloat top_n_sigma,
jobjectArray dry_sequence_breakers,
jobject partial_completion_callback
) {
Expand Down Expand Up @@ -706,6 +716,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.dry_base = dry_base;
sparams.dry_allowed_length = dry_allowed_length;
sparams.dry_penalty_last_n = dry_penalty_last_n;
sparams.top_n_sigma = top_n_sigma;

// grammar
auto grammar_chars = env->GetStringUTFChars(grammar, nullptr);
Expand Down Expand Up @@ -882,10 +893,16 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama->is_predicting = false;

auto toolCalls = createWritableArray(env);
std::string reasoningContent = "";
std::string *content = nullptr;
auto toolCallsSize = 0;
if (!llama->is_interrupted) {
try {
common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
if (!message.reasoning_content.empty()) {
reasoningContent = message.reasoning_content;
}
content = &message.content;
for (const auto &tc : message.tool_calls) {
auto toolCall = createWriteableMap(env);
putString(env, toolCall, "type", "function");
Expand All @@ -906,6 +923,12 @@ Java_com_rnllama_LlamaContext_doCompletion(

auto result = createWriteableMap(env);
putString(env, result, "text", llama->generated_text.c_str());
if (content) {
putString(env, result, "content", content->c_str());
}
if (!reasoningContent.empty()) {
putString(env, result, "reasoning_content", reasoningContent.c_str());
}
if (toolCallsSize > 0) {
putArray(env, result, "tool_calls", toolCalls);
}
Expand Down
28 changes: 21 additions & 7 deletions cpp/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,30 @@ class chat_template {
inputs.add_generation_prompt = false;
full = apply(inputs);
}

if (full.find(prefix) != 0) {
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
auto eos_pos_last = full.rfind(eos_token_);
if (eos_pos_last == prefix.size() - eos_token_.size() ||
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
full = full.substr(0, eos_pos_last);
}
size_t common_prefix_length = 0;
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
if (prefix[i] != full[i]) {
break;
}
if (prefix[i] == '<') {
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
// but it removes thinking tags for past messages.
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
continue;
}
common_prefix_length = i + 1;
}
if (full.find(prefix) != 0) {
auto example = full.substr(common_prefix_length);
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
} else {
tool_call_example_ = example;
}
tool_call_example_ = full.substr(prefix.size());
}
} catch (const std::exception & e) {
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
Expand Down Expand Up @@ -363,7 +377,7 @@ class chat_template {
if (polyfill_tools) {
adjusted_messages = add_system(inputs.messages,
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
} else {
adjusted_messages = inputs.messages;
}
Expand Down
Loading