From 6d3797f4b400820c4dc9085c62f80126234d896b Mon Sep 17 00:00:00 2001 From: Rusyaidi Date: Sat, 20 Jul 2024 21:36:42 +0800 Subject: [PATCH 1/4] fix: loadSession not taking paths with file:// --- src/index.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/index.ts b/src/index.ts index e074bbd6..eabdebbc 100644 --- a/src/index.ts +++ b/src/index.ts @@ -75,7 +75,9 @@ export class LlamaContext { * Load cached prompt & completion state from a file. */ async loadSession(filepath: string): Promise { - return RNLlama.loadSession(this.id, filepath) + let path = filepath + if (path.startsWith('file://')) path = path.slice(7) + return RNLlama.loadSession(this.id, path) } /** From 3f71538d54057b1ce79eb68699f1bd53a7295117 Mon Sep 17 00:00:00 2001 From: Vali98 Date: Sun, 17 Nov 2024 14:59:42 +0800 Subject: [PATCH 2/4] feat: exposed dry sampling --- .../main/java/com/rnllama/LlamaContext.java | 15 +++++++++++ android/src/main/jni.cpp | 25 +++++++++++++++++++ example/src/App.tsx | 5 ++++ src/NativeRNLlama.ts | 6 +++++ 4 files changed, 51 insertions(+) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index ee1d5717..b369bb1a 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -237,6 +237,16 @@ public WritableMap completion(ReadableMap params) { params.hasKey("ignore_eos") ? params.getBoolean("ignore_eos") : false, // double[][] logit_bias, logit_bias, + // float dry_multiplier, + params.hasKey("dry_multiplier") ? (float) params.getDouble("dry_multiplier") : 0.00f, + // float dry_base, + params.hasKey("dry_base") ? (float) params.getDouble("dry_base") : 1.75f, + // int dry_allowed_length, + params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2, + // int dry_penalty_last_n, + params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1, + // String[] dry_sequence_breakers, when undef, we use the default definition from common.h + params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"}, // PartialCompletionCallback partial_completion_callback new PartialCompletionCallback( this, @@ -445,6 +455,11 @@ protected static native WritableMap doCompletion( String[] stop, boolean ignore_eos, double[][] logit_bias, + float dry_multiplier, + float dry_base, + int dry_allowed_length, + int dry_penalty_last_n, + String[] dry_sequence_breakers, PartialCompletionCallback partial_completion_callback ); protected static native void stopCompletion(long contextPtr); diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index db496fd3..1ec3d19f 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -535,6 +535,11 @@ Java_com_rnllama_LlamaContext_doCompletion( jobjectArray stop, jboolean ignore_eos, jobjectArray logit_bias, + jfloat dry_multiplier, + jfloat dry_base, + jint dry_allowed_length, + jint dry_penalty_last_n, + jobjectArray dry_sequence_breakers, jobject partial_completion_callback ) { UNUSED(thiz); @@ -573,12 +578,32 @@ Java_com_rnllama_LlamaContext_doCompletion( sparams.grammar = env->GetStringUTFChars(grammar, nullptr); sparams.xtc_threshold = xtc_threshold; sparams.xtc_probability = xtc_probability; + sparams.dry_multiplier = dry_multiplier; + sparams.dry_base = dry_base; + sparams.dry_allowed_length = dry_allowed_length; + sparams.dry_penalty_last_n = dry_penalty_last_n; sparams.logit_bias.clear(); if (ignore_eos) { sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY; } + // dry break seq + + jint size = env->GetArrayLength(dry_sequence_breakers); + std::vector dry_sequence_breakers_vector; + + for (jint i = 0; i < size; i++) { + jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i); + const char *nativeString = env->GetStringUTFChars(javaString, 0); + dry_sequence_breakers_vector.push_back(std::string(nativeString)); + env->ReleaseStringUTFChars(javaString, nativeString); + env->DeleteLocalRef(javaString); + } + + sparams.dry_sequence_breakers = dry_sequence_breakers_vector; + + // logit bias const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx)); jsize logit_bias_len = env->GetArrayLength(logit_bias); diff --git a/example/src/App.tsx b/example/src/App.tsx index 06572115..4d44a757 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -410,6 +410,11 @@ export default function App() { '<|end_of_turn|>', '<|endoftext|>', ], + dry_multiplier: 1, + dry_base: 1.75, + dry_allowed_length: 200, + dry_penalty_last_n: -1, + dry_sequence_breakers: ["\n", ":", "\"", "*"], grammar, // n_threads: 4, // logit_bias: [[15043,1.0]], diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 400e84bb..04e06e86 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -72,6 +72,12 @@ export type NativeCompletionParams = { penalize_nl?: boolean seed?: number + dry_multiplier?: number + dry_base?: number + dry_allowed_length?: number + dry_penalty_last_n?: number + dry_sequence_breakers?: Array + ignore_eos?: boolean logit_bias?: Array> From 0c4675d26bc4907fa6fea4762a95bed8da84200c Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Mon, 18 Nov 2024 11:02:06 +0800 Subject: [PATCH 3/4] feat(ios): expose DRY sampler params --- ios/RNLlamaContext.mm | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index ac4e6014..c2cb5932 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -292,6 +292,19 @@ - (NSDictionary *)completion:(NSDictionary *)params if (params[@"xtc_probability"]) sparams.xtc_probability = [params[@"xtc_probability"] doubleValue]; if (params[@"typical_p"]) sparams.typ_p = [params[@"typical_p"] doubleValue]; + if (params[@"dry_multiplier"]) sparams.dry_multiplier = [params[@"dry_multiplier"] doubleValue]; + if (params[@"dry_base"]) sparams.dry_base = [params[@"dry_base"] doubleValue]; + if (params[@"dry_allowed_length"]) sparams.dry_allowed_length = [params[@"dry_allowed_length"] intValue]; + if (params[@"dry_penalty_last_n"]) sparams.dry_penalty_last_n = [params[@"dry_penalty_last_n"] intValue]; + + // dry break seq + if (params[@"dry_sequence_breakers"] && [params[@"dry_sequence_breakers"] isKindOfClass:[NSArray class]]) { + NSArray *dry_sequence_breakers = params[@"dry_sequence_breakers"]; + for (NSString *s in dry_sequence_breakers) { + sparams.dry_sequence_breakers.push_back([s UTF8String]); + } + } + if (params[@"grammar"]) { sparams.grammar = [params[@"grammar"] UTF8String]; } From d84cf18b5872e6af062d6d616e923d5565179798 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Mon, 18 Nov 2024 11:24:42 +0800 Subject: [PATCH 4/4] feat(example): use default values in completion & remove comments --- example/src/App.tsx | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/example/src/App.tsx b/example/src/App.tsx index 4d44a757..fceb129b 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -383,22 +383,32 @@ export default function App() { { messages: msgs, n_predict: 100, + grammar, + seed: -1, + n_probs: 0, + + // Sampling params + top_k: 40, + top_p: 0.5, + min_p: 0.05, xtc_probability: 0.5, xtc_threshold: 0.1, + typical_p: 1.0, temperature: 0.7, - top_k: 40, // <= 0 to use vocab size - top_p: 0.5, // 1.0 = disabled - typical_p: 1.0, // 1.0 = disabled - penalty_last_n: 256, // 0 = disable penalty, -1 = context size - penalty_repeat: 1.18, // 1.0 = disabled - penalty_freq: 0.0, // 0.0 = disabled - penalty_present: 0.0, // 0.0 = disabled - mirostat: 0, // 0/1/2 - mirostat_tau: 5, // target entropy - mirostat_eta: 0.1, // learning rate - penalize_nl: false, // penalize newlines - seed: -1, // random seed - n_probs: 0, // Show probabilities + penalty_last_n: 64, + penalty_repeat: 1.0, + penalty_freq: 0.0, + penalty_present: 0.0, + dry_multiplier: 0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + dry_sequence_breakers: ["\n", ":", "\"", "*"], + mirostat: 0, + mirostat_tau: 5, + mirostat_eta: 0.1, + penalize_nl: false, + ignore_eos: false, stop: [ '', '<|end|>', @@ -410,12 +420,6 @@ export default function App() { '<|end_of_turn|>', '<|endoftext|>', ], - dry_multiplier: 1, - dry_base: 1.75, - dry_allowed_length: 200, - dry_penalty_last_n: -1, - dry_sequence_breakers: ["\n", ":", "\"", "*"], - grammar, // n_threads: 4, // logit_bias: [[15043,1.0]], },