Skip to content

Commit

Permalink
feat: add reasoning_format param & reasoning_content in completion re…
Browse files Browse the repository at this point in the history
…sult
  • Loading branch information
jhen0409 committed Feb 18, 2025
1 parent e475f84 commit 812b91e
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 3 deletions.
3 changes: 3 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.getString("model"),
// String chat_template,
params.hasKey("chat_template") ? params.getString("chat_template") : "",
// String reasoning_format,
params.hasKey("reasoning_format") ? params.getString("reasoning_format") : "none",
// boolean embedding,
params.hasKey("embedding") ? params.getBoolean("embedding") : false,
// int embd_normalize,
Expand Down Expand Up @@ -470,6 +472,7 @@ protected static native WritableMap modelInfo(
protected static native long initContext(
String model,
String chat_template,
String reasoning_format,
boolean embedding,
int embd_normalize,
int n_ctx,
Expand Down
21 changes: 21 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ Java_com_rnllama_LlamaContext_initContext(
jobject thiz,
jstring model_path_str,
jstring chat_template,
jstring reasoning_format,
jboolean embedding,
jint embd_normalize,
jint n_ctx,
Expand Down Expand Up @@ -259,6 +260,13 @@ Java_com_rnllama_LlamaContext_initContext(
const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
defaultParams.chat_template = chat_template_chars;

const char *reasoning_format_chars = env->GetStringUTFChars(reasoning_format, nullptr);
if (reasoning_format_chars == "deepseek") {
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
} else {
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
}

defaultParams.n_ctx = n_ctx;
defaultParams.n_batch = n_batch;
defaultParams.n_ubatch = n_ubatch;
Expand Down Expand Up @@ -326,6 +334,7 @@ Java_com_rnllama_LlamaContext_initContext(

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(chat_template, chat_template_chars);
env->ReleaseStringUTFChars(reasoning_format, reasoning_format_chars);
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);

Expand Down Expand Up @@ -884,10 +893,16 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama->is_predicting = false;

auto toolCalls = createWritableArray(env);
std::string reasoningContent = "";
std::string *content = nullptr;
auto toolCallsSize = 0;
if (!llama->is_interrupted) {
try {
common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
if (!message.reasoning_content.empty()) {
reasoningContent = message.reasoning_content;
}
content = &message.content;
for (const auto &tc : message.tool_calls) {
auto toolCall = createWriteableMap(env);
putString(env, toolCall, "type", "function");
Expand All @@ -908,6 +923,12 @@ Java_com_rnllama_LlamaContext_doCompletion(

auto result = createWriteableMap(env);
putString(env, result, "text", llama->generated_text.c_str());
if (content) {
putString(env, result, "content", content->c_str());
}
if (!reasoningContent.empty()) {
putString(env, result, "reasoning_content", reasoningContent.c_str());
}
if (toolCallsSize > 0) {
putArray(env, result, "tool_calls", toolCalls);
}
Expand Down
1 change: 1 addition & 0 deletions cpp/rn-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ common_chat_params llama_rn_context::getFormattedChatWithJinja(
if (!json_schema.empty()) {
inputs.json_schema = json::parse(json_schema);
}
inputs.extract_reasoning = params.reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.stream = true;

// If chat_template is provided, create new one and use it (probably slow)
Expand Down
6 changes: 4 additions & 2 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ export default function App() {
initLlama(
{
model: file.uri,
n_ctx: 200,
use_mlock: true,
lora_list: loraFile ? [{ path: loraFile.uri, scaled: 1.0 }] : undefined, // Or lora: loraFile?.uri,

// If use deepseek r1 distill
reasoning_format: 'deepseek',

// Currently only for iOS
n_gpu_layers: Platform.OS === 'ios' ? 99 : 0,
// no_gpu_devices: true, // (iOS only)
Expand Down Expand Up @@ -474,7 +476,7 @@ export default function App() {
],
}
// Comment to test:
jinjaParams = undefined
jinjaParams = { jinja: true }
}

// Test area
Expand Down
17 changes: 16 additions & 1 deletion ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig
NSLog(@"chatTemplate: %@", chatTemplate);
}

NSString *reasoningFormat = params[@"reasoning_format"];
if (reasoningFormat && [reasoningFormat isEqualToString:@"deepseek"]) {
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
} else {
defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE;
}

if (params[@"n_ctx"]) defaultParams.n_ctx = [params[@"n_ctx"] intValue];
if (params[@"use_mlock"]) defaultParams.use_mlock = [params[@"use_mlock"]boolValue];

Expand Down Expand Up @@ -610,10 +617,16 @@ - (NSDictionary *)completion:(NSDictionary *)params
const auto timings = llama_perf_context(llama->ctx);

NSMutableArray *toolCalls = nil;
NSString *reasoningContent = nil;
NSString *content = nil;
if (!llama->is_interrupted) {
try {
auto chat_format = params[@"chat_format"] ? [params[@"chat_format"] intValue] : COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_chat_msg message = common_chat_parse(llama->generated_text, static_cast<common_chat_format>(chat_format));
if (!message.reasoning_content.empty()) {
reasoningContent = [NSString stringWithUTF8String:message.reasoning_content.c_str()];
}
content = [NSString stringWithUTF8String:message.content.c_str()];
toolCalls = [[NSMutableArray alloc] init];
for (const auto &tc : message.tool_calls) {
[toolCalls addObject:@{
Expand All @@ -631,7 +644,9 @@ - (NSDictionary *)completion:(NSDictionary *)params
}

NSMutableDictionary *result = [[NSMutableDictionary alloc] init];
result[@"text"] = [NSString stringWithUTF8String:llama->generated_text.c_str()];
result[@"text"] = [NSString stringWithUTF8String:llama->generated_text.c_str()]; // Original text
if (content) result[@"content"] = content;
if (reasoningContent) result[@"reasoning_content"] = reasoningContent;
if (toolCalls && toolCalls.count > 0) result[@"tool_calls"] = toolCalls;
result[@"completion_probabilities"] = [self tokenProbsToDict:llama->generated_token_probs];
result[@"tokens_predicted"] = @(llama->num_tokens_predicted);
Expand Down
17 changes: 17 additions & 0 deletions src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ export type NativeContextParams = {
*/
chat_template?: string

reasoning_format?: string

is_model_asset?: boolean
use_progress_callback?: boolean

Expand Down Expand Up @@ -236,7 +238,18 @@ export type NativeCompletionResultTimings = {
}

export type NativeCompletionResult = {
/**
* Original text (Ignored reasoning_content / tool_calls)
*/
text: string

/**
* Reasoning content (parsed for reasoning model)
*/
reasoning_content: string
/**
* Tool calls
*/
tool_calls: Array<{
type: 'function'
function: {
Expand All @@ -245,6 +258,10 @@ export type NativeCompletionResult = {
}
id?: string
}>
/**
* Content text (Filtered text by reasoning_content / tool_calls)
*/
content: string

tokens_predicted: number
tokens_evaluated: number
Expand Down

0 comments on commit 812b91e

Please sign in to comment.