Skip to content

Commit

Permalink
feat: add chat_template for override default on init
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Feb 5, 2025
1 parent 91913f1 commit 6a6a9bd
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 3 deletions.
3 changes: 3 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
this.context = initContext(
// String model,
params.getString("model"),
// String chat_template,
params.hasKey("chat_template") ? params.getString("chat_template") : "",
// boolean embedding,
params.hasKey("embedding") ? params.getBoolean("embedding") : false,
// int embd_normalize,
Expand Down Expand Up @@ -437,6 +439,7 @@ protected static native WritableMap modelInfo(
);
protected static native long initContext(
String model,
String chat_template,
boolean embedding,
int embd_normalize,
int n_ctx,
Expand Down
5 changes: 5 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ Java_com_rnllama_LlamaContext_initContext(
JNIEnv *env,
jobject thiz,
jstring model_path_str,
jstring chat_template,
jboolean embedding,
jint embd_normalize,
jint n_ctx,
Expand Down Expand Up @@ -255,6 +256,9 @@ Java_com_rnllama_LlamaContext_initContext(
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
defaultParams.model = model_path_chars;

const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
defaultParams.chat_template = chat_template_chars;

defaultParams.n_ctx = n_ctx;
defaultParams.n_batch = n_batch;
defaultParams.n_ubatch = n_ubatch;
Expand Down Expand Up @@ -321,6 +325,7 @@ Java_com_rnllama_LlamaContext_initContext(
bool is_model_loaded = llama->loadModel(defaultParams);

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(chat_template, chat_template_chars);
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);

Expand Down
2 changes: 1 addition & 1 deletion cpp/rn-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ bool llama_rn_context::loadModel(common_params &params_)
LOG_ERROR("unable to load model: %s", params_.model.c_str());
return false;
}
templates = common_chat_templates_from_model(model, "");
templates = common_chat_templates_from_model(model, params.chat_template);
n_ctx = llama_n_ctx(ctx);

// We can uncomment for debugging or after this fix: https://github.com/ggerganov/llama.cpp/pull/11101
Expand Down
6 changes: 6 additions & 0 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig
if (isAsset) path = [[NSBundle mainBundle] pathForResource:modelPath ofType:nil];
defaultParams.model = [path UTF8String];

NSString *chatTemplate = params[@"chat_template"];
if (chatTemplate) {
defaultParams.chat_template = [chatTemplate UTF8String];
NSLog(@"chatTemplate: %@", chatTemplate);
}

if (params[@"n_ctx"]) defaultParams.n_ctx = [params[@"n_ctx"] intValue];
if (params[@"use_mlock"]) defaultParams.use_mlock = [params[@"use_mlock"]boolValue];

Expand Down
5 changes: 5 additions & 0 deletions src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ export type NativeEmbeddingParams = {

export type NativeContextParams = {
model: string
/**
* Chat template to override the default one from the model.
*/
chat_template?: string

is_model_asset?: boolean
use_progress_callback?: boolean

Expand Down
5 changes: 3 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ export type CompletionResponseFormat = {
export type CompletionBaseParams = {
prompt?: string
messages?: RNLlamaOAICompatibleMessage[]
chatTemplate?: string
chatTemplate?: string // deprecated
chat_template?: string
jinja?: boolean
tools?: object
parallel_tool_calls?: object
Expand Down Expand Up @@ -232,7 +233,7 @@ export class LlamaContext {
// messages always win
const formattedResult = await this.getFormattedChat(
params.messages,
params.chatTemplate,
params.chat_template || params.chatTemplate,
{
jinja: params.jinja,
tools: params.tools,
Expand Down

0 comments on commit 6a6a9bd

Please sign in to comment.