Skip to content

Commit

Permalink
feat: add min_p completion param
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 12, 2023
1 parent 1f20cef commit 8cd53a4
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 0 deletions.
3 changes: 3 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("top_k") ? params.getInt("top_k") : 40,
// float top_p,
params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f,
// float min_p,
params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f,
// float tfs_z,
params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f,
// float typical_p,
Expand Down Expand Up @@ -316,6 +318,7 @@ protected static native WritableMap doCompletion(
float mirostat_eta,
int top_k,
float top_p,
float min_p,
float tfs_z,
float typical_p,
String[] stop,
Expand Down
2 changes: 2 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
jfloat mirostat_eta,
jint top_k,
jfloat top_p,
jfloat min_p,
jfloat tfs_z,
jfloat typical_p,
jobjectArray stop,
Expand Down Expand Up @@ -321,6 +322,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.mirostat_eta = mirostat_eta;
sparams.top_k = top_k;
sparams.top_p = top_p;
sparams.min_p = min_p;
sparams.tfs_z = tfs_z;
sparams.typical_p = typical_p;
sparams.n_probs = n_probs;
Expand Down
1 change: 1 addition & 0 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ - (NSDictionary *)completion:(NSDictionary *)params

if (params[@"top_k"]) sparams.top_k = [params[@"top_k"] intValue];
if (params[@"top_p"]) sparams.top_p = [params[@"top_p"] doubleValue];
if (params[@"min_p"]) sparams.min_p = [params[@"min_p"] doubleValue];
if (params[@"tfs_z"]) sparams.tfs_z = [params[@"tfs_z"] doubleValue];

if (params[@"typical_p"]) sparams.typical_p = [params[@"typical_p"] doubleValue];
Expand Down
1 change: 1 addition & 0 deletions src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export type NativeCompletionParams = {
mirostat_eta?: number
top_k?: number
top_p?: number
min_p?: number
tfs_z?: number
typical_p?: number

Expand Down

0 comments on commit 8cd53a4

Please sign in to comment.