diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 63b8d4b3..9272c874 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -70,6 +70,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa params.getString("model"), // String chat_template, params.hasKey("chat_template") ? params.getString("chat_template") : "", + // String reasoning_format, + params.hasKey("reasoning_format") ? params.getString("reasoning_format") : "none", // boolean embedding, params.hasKey("embedding") ? params.getBoolean("embedding") : false, // int embd_normalize, @@ -470,6 +472,7 @@ protected static native WritableMap modelInfo( protected static native long initContext( String model, String chat_template, + String reasoning_format, boolean embedding, int embd_normalize, int n_ctx, diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 935c59db..3fb76e74 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -223,6 +223,7 @@ Java_com_rnllama_LlamaContext_initContext( jobject thiz, jstring model_path_str, jstring chat_template, + jstring reasoning_format, jboolean embedding, jint embd_normalize, jint n_ctx, @@ -259,6 +260,13 @@ Java_com_rnllama_LlamaContext_initContext( const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr); defaultParams.chat_template = chat_template_chars; + const char *reasoning_format_chars = env->GetStringUTFChars(reasoning_format, nullptr); + if (reasoning_format_chars == "deepseek") { + defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + } else { + defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE; + } + defaultParams.n_ctx = n_ctx; defaultParams.n_batch = n_batch; defaultParams.n_ubatch = n_ubatch; @@ -326,6 +334,7 @@ Java_com_rnllama_LlamaContext_initContext( env->ReleaseStringUTFChars(model_path_str, model_path_chars); env->ReleaseStringUTFChars(chat_template, chat_template_chars); + env->ReleaseStringUTFChars(reasoning_format, reasoning_format_chars); env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars); env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars); @@ -884,10 +893,16 @@ Java_com_rnllama_LlamaContext_doCompletion( llama->is_predicting = false; auto toolCalls = createWritableArray(env); + std::string reasoningContent = ""; + std::string *content = nullptr; auto toolCallsSize = 0; if (!llama->is_interrupted) { try { common_chat_msg message = common_chat_parse(llama->generated_text, static_cast(chat_format)); + if (!message.reasoning_content.empty()) { + reasoningContent = message.reasoning_content; + } + content = &message.content; for (const auto &tc : message.tool_calls) { auto toolCall = createWriteableMap(env); putString(env, toolCall, "type", "function"); @@ -908,6 +923,12 @@ Java_com_rnllama_LlamaContext_doCompletion( auto result = createWriteableMap(env); putString(env, result, "text", llama->generated_text.c_str()); + if (content) { + putString(env, result, "content", content->c_str()); + } + if (!reasoningContent.empty()) { + putString(env, result, "reasoning_content", reasoningContent.c_str()); + } if (toolCallsSize > 0) { putArray(env, result, "tool_calls", toolCalls); } diff --git a/cpp/rn-llama.cpp b/cpp/rn-llama.cpp index d969ebb3..0859d110 100644 --- a/cpp/rn-llama.cpp +++ b/cpp/rn-llama.cpp @@ -232,6 +232,7 @@ common_chat_params llama_rn_context::getFormattedChatWithJinja( if (!json_schema.empty()) { inputs.json_schema = json::parse(json_schema); } + inputs.extract_reasoning = params.reasoning_format != COMMON_REASONING_FORMAT_NONE; inputs.stream = true; // If chat_template is provided, create new one and use it (probably slow) diff --git a/example/src/App.tsx b/example/src/App.tsx index 41332ab7..c29477e2 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -144,10 +144,12 @@ export default function App() { initLlama( { model: file.uri, - n_ctx: 200, use_mlock: true, lora_list: loraFile ? [{ path: loraFile.uri, scaled: 1.0 }] : undefined, // Or lora: loraFile?.uri, + // If use deepseek r1 distill + reasoning_format: 'deepseek', + // Currently only for iOS n_gpu_layers: Platform.OS === 'ios' ? 99 : 0, // no_gpu_devices: true, // (iOS only) @@ -474,7 +476,7 @@ export default function App() { ], } // Comment to test: - jinjaParams = undefined + jinjaParams = { jinja: true } } // Test area diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index c6302a31..d7d0f3db 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -90,6 +90,13 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig NSLog(@"chatTemplate: %@", chatTemplate); } + NSString *reasoningFormat = params[@"reasoning_format"]; + if (reasoningFormat && [reasoningFormat isEqualToString:@"deepseek"]) { + defaultParams.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + } else { + defaultParams.reasoning_format = COMMON_REASONING_FORMAT_NONE; + } + if (params[@"n_ctx"]) defaultParams.n_ctx = [params[@"n_ctx"] intValue]; if (params[@"use_mlock"]) defaultParams.use_mlock = [params[@"use_mlock"]boolValue]; @@ -610,10 +617,16 @@ - (NSDictionary *)completion:(NSDictionary *)params const auto timings = llama_perf_context(llama->ctx); NSMutableArray *toolCalls = nil; + NSString *reasoningContent = nil; + NSString *content = nil; if (!llama->is_interrupted) { try { auto chat_format = params[@"chat_format"] ? [params[@"chat_format"] intValue] : COMMON_CHAT_FORMAT_CONTENT_ONLY; common_chat_msg message = common_chat_parse(llama->generated_text, static_cast(chat_format)); + if (!message.reasoning_content.empty()) { + reasoningContent = [NSString stringWithUTF8String:message.reasoning_content.c_str()]; + } + content = [NSString stringWithUTF8String:message.content.c_str()]; toolCalls = [[NSMutableArray alloc] init]; for (const auto &tc : message.tool_calls) { [toolCalls addObject:@{ @@ -631,7 +644,9 @@ - (NSDictionary *)completion:(NSDictionary *)params } NSMutableDictionary *result = [[NSMutableDictionary alloc] init]; - result[@"text"] = [NSString stringWithUTF8String:llama->generated_text.c_str()]; + result[@"text"] = [NSString stringWithUTF8String:llama->generated_text.c_str()]; // Original text + if (content) result[@"content"] = content; + if (reasoningContent) result[@"reasoning_content"] = reasoningContent; if (toolCalls && toolCalls.count > 0) result[@"tool_calls"] = toolCalls; result[@"completion_probabilities"] = [self tokenProbsToDict:llama->generated_token_probs]; result[@"tokens_predicted"] = @(llama->num_tokens_predicted); diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index c8221206..910b6fc7 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -12,6 +12,8 @@ export type NativeContextParams = { */ chat_template?: string + reasoning_format?: string + is_model_asset?: boolean use_progress_callback?: boolean @@ -236,7 +238,18 @@ export type NativeCompletionResultTimings = { } export type NativeCompletionResult = { + /** + * Original text (Ignored reasoning_content / tool_calls) + */ text: string + + /** + * Reasoning content (parsed for reasoning model) + */ + reasoning_content: string + /** + * Tool calls + */ tool_calls: Array<{ type: 'function' function: { @@ -245,6 +258,10 @@ export type NativeCompletionResult = { } id?: string }> + /** + * Content text (Filtered text by reasoning_content / tool_calls) + */ + content: string tokens_predicted: number tokens_evaluated: number