From 8cd53a4cc970324686c87c85cc4775e42b7207bc Mon Sep 17 00:00:00 2001 From: Jhen Date: Mon, 13 Nov 2023 07:13:26 +0800 Subject: [PATCH] feat: add min_p completion param --- android/src/main/java/com/rnllama/LlamaContext.java | 3 +++ android/src/main/jni.cpp | 2 ++ ios/RNLlamaContext.mm | 1 + src/NativeRNLlama.ts | 1 + 4 files changed, 7 insertions(+) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 03cb24b..bab6803 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -167,6 +167,8 @@ public WritableMap completion(ReadableMap params) { params.hasKey("top_k") ? params.getInt("top_k") : 40, // float top_p, params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f, + // float min_p, + params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f, // float tfs_z, params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f, // float typical_p, @@ -316,6 +318,7 @@ protected static native WritableMap doCompletion( float mirostat_eta, int top_k, float top_p, + float min_p, float tfs_z, float typical_p, String[] stop, diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index e1c20df..eb1a7fc 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -286,6 +286,7 @@ Java_com_rnllama_LlamaContext_doCompletion( jfloat mirostat_eta, jint top_k, jfloat top_p, + jfloat min_p, jfloat tfs_z, jfloat typical_p, jobjectArray stop, @@ -321,6 +322,7 @@ Java_com_rnllama_LlamaContext_doCompletion( sparams.mirostat_eta = mirostat_eta; sparams.top_k = top_k; sparams.top_p = top_p; + sparams.min_p = min_p; sparams.tfs_z = tfs_z; sparams.typical_p = typical_p; sparams.n_probs = n_probs; diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index ab8ea2c..f9c6035 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -159,6 +159,7 @@ - (NSDictionary *)completion:(NSDictionary *)params if (params[@"top_k"]) sparams.top_k = [params[@"top_k"] intValue]; if (params[@"top_p"]) sparams.top_p = [params[@"top_p"] doubleValue]; + if (params[@"min_p"]) sparams.min_p = [params[@"min_p"] doubleValue]; if (params[@"tfs_z"]) sparams.tfs_z = [params[@"tfs_z"] doubleValue]; if (params[@"typical_p"]) sparams.typical_p = [params[@"typical_p"] doubleValue]; diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 896a37a..6cb88e6 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -46,6 +46,7 @@ export type NativeCompletionParams = { mirostat_eta?: number top_k?: number top_p?: number + min_p?: number tfs_z?: number typical_p?: number