From 6190f5746fdea73e282dabcf2c49075c097f1a18 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Sat, 16 Nov 2024 18:44:48 +0800 Subject: [PATCH] feat: update embedding method (#88) * feat: update embedding method * docs(api): build * fix(example): revert unnecessary change --- .../main/java/com/rnllama/LlamaContext.java | 24 +++++++- .../src/main/java/com/rnllama/RNLlama.java | 4 +- android/src/main/jni.cpp | 58 ++++++++++++++++--- .../java/com/rnllama/RNLlamaModule.java | 4 +- .../java/com/rnllama/RNLlamaModule.java | 4 +- cpp/rn-llama.hpp | 17 +++--- docs/API/README.md | 33 +++++++---- docs/API/classes/LlamaContext.md | 33 ++++++----- docs/API/classes/SchemaGrammarConverter.md | 32 +++++----- ios/RNLlama.mm | 42 ++++++++------ ios/RNLlamaContext.h | 2 +- ios/RNLlamaContext.mm | 48 ++++++++++++--- src/NativeRNLlama.ts | 18 +++++- src/index.ts | 34 +++++++++-- 14 files changed, 251 insertions(+), 102 deletions(-) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 36b1a1c..c2e06cf 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -44,6 +44,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa params.getString("model"), // boolean embedding, params.hasKey("embedding") ? params.getBoolean("embedding") : false, + // int embd_normalize, + params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1, // int n_ctx, params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512, // int n_batch, @@ -66,9 +68,14 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f, // float rope_freq_scale params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f, + // int pooling_type, + params.hasKey("pooling_type") ? params.getInt("pooling_type") : -1, // LoadProgressCallback load_progress_callback params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null ); + if (this.context == -1) { + throw new IllegalStateException("Failed to initialize context"); + } this.modelDetails = loadModelDetails(this.context); this.reactContext = reactContext; } @@ -258,11 +265,16 @@ public String detokenize(ReadableArray tokens) { return detokenize(this.context, toks); } - public WritableMap getEmbedding(String text) { + public WritableMap getEmbedding(String text, ReadableMap params) { if (isEmbeddingEnabled(this.context) == false) { throw new IllegalStateException("Embedding is not enabled"); } - WritableMap result = embedding(this.context, text); + WritableMap result = embedding( + this.context, + text, + // int embd_normalize, + params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1 + ); if (result.hasKey("error")) { throw new IllegalStateException(result.getString("error")); } @@ -365,6 +377,7 @@ protected static native WritableMap modelInfo( protected static native long initContext( String model, boolean embedding, + int embd_normalize, int n_ctx, int n_batch, int n_threads, @@ -376,6 +389,7 @@ protected static native long initContext( float lora_scaled, float rope_freq_base, float rope_freq_scale, + int pooling_type, LoadProgressCallback load_progress_callback ); protected static native void interruptLoad(long contextPtr); @@ -429,7 +443,11 @@ protected static native WritableMap doCompletion( protected static native WritableArray tokenize(long contextPtr, String text); protected static native String detokenize(long contextPtr, int[] tokens); protected static native boolean isEmbeddingEnabled(long contextPtr); - protected static native WritableMap embedding(long contextPtr, String text); + protected static native WritableMap embedding( + long contextPtr, + String text, + int embd_normalize + ); protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr); protected static native void freeContext(long contextPtr); protected static native void logToAndroid(); diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index 1f02f2d..aa19731 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -349,7 +349,7 @@ protected void onPostExecute(String result) { tasks.put(task, "detokenize-" + contextId); } - public void embedding(double id, final String text, final Promise promise) { + public void embedding(double id, final String text, final ReadableMap params, final Promise promise) { final int contextId = (int) id; AsyncTask task = new AsyncTask() { private Exception exception; @@ -361,7 +361,7 @@ protected WritableMap doInBackground(Void... voids) { if (context == null) { throw new Exception("Context not found"); } - return context.getEmbedding(text); + return context.getEmbedding(text, params); } catch (Exception e) { exception = e; } diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 3beb203..4475c95 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -115,6 +115,15 @@ static inline void pushDouble(JNIEnv *env, jobject arr, double value) { env->CallVoidMethod(arr, pushDoubleMethod, value); } +// Method to push string into WritableArray +static inline void pushString(JNIEnv *env, jobject arr, const char *value) { + jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray"); + jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V"); + + jstring jValue = env->NewStringUTF(value); + env->CallVoidMethod(arr, pushStringMethod, jValue); +} + // Method to push WritableMap into WritableArray static inline void pushMap(JNIEnv *env, jobject arr, jobject value) { jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray"); @@ -213,6 +222,7 @@ Java_com_rnllama_LlamaContext_initContext( jobject thiz, jstring model_path_str, jboolean embedding, + jint embd_normalize, jint n_ctx, jint n_batch, jint n_threads, @@ -224,6 +234,7 @@ Java_com_rnllama_LlamaContext_initContext( jfloat lora_scaled, jfloat rope_freq_base, jfloat rope_freq_scale, + jint pooling_type, jobject load_progress_callback ) { UNUSED(thiz); @@ -238,11 +249,22 @@ Java_com_rnllama_LlamaContext_initContext( const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr); defaultParams.model = model_path_chars; - defaultParams.embedding = embedding; - defaultParams.n_ctx = n_ctx; defaultParams.n_batch = n_batch; + if (pooling_type != -1) { + defaultParams.pooling_type = static_cast(pooling_type); + } + + defaultParams.embedding = embedding; + if (embd_normalize != -1) { + defaultParams.embd_normalize = embd_normalize; + } + if (embedding) { + // For non-causal models, batch size must be equal to ubatch size + defaultParams.n_ubatch = defaultParams.n_batch; + } + int max_threads = std::thread::hardware_concurrency(); // Use 2 threads by default on 4-core devices, 4 threads on more cores int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads); @@ -291,16 +313,21 @@ Java_com_rnllama_LlamaContext_initContext( bool is_model_loaded = llama->loadModel(defaultParams); + env->ReleaseStringUTFChars(model_path_str, model_path_chars); + env->ReleaseStringUTFChars(lora_str, lora_chars); + LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false")); if (is_model_loaded) { + if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) { + LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported"); + llama_free(llama->ctx); + return -1; + } context_map[(long) llama->ctx] = llama; } else { llama_free(llama->ctx); } - env->ReleaseStringUTFChars(model_path_str, model_path_chars); - env->ReleaseStringUTFChars(lora_str, lora_chars); - return reinterpret_cast(llama->ctx); } @@ -745,10 +772,21 @@ Java_com_rnllama_LlamaContext_isEmbeddingEnabled( JNIEXPORT jobject JNICALL Java_com_rnllama_LlamaContext_embedding( - JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) { + JNIEnv *env, jobject thiz, + jlong context_ptr, + jstring text, + jint embd_normalize +) { UNUSED(thiz); auto llama = context_map[(long) context_ptr]; + common_params embdParams; + embdParams.embedding = true; + embdParams.embd_normalize = llama->params.embd_normalize; + if (embd_normalize != -1) { + embdParams.embd_normalize = embd_normalize; + } + const char *text_chars = env->GetStringUTFChars(text, nullptr); llama->rewind(); @@ -769,7 +807,7 @@ Java_com_rnllama_LlamaContext_embedding( llama->loadPrompt(); llama->doCompletion(); - std::vector embedding = llama->getEmbedding(); + std::vector embedding = llama->getEmbedding(embdParams); auto embeddings = createWritableArray(env); for (const auto &val : embedding) { @@ -777,6 +815,12 @@ Java_com_rnllama_LlamaContext_embedding( } putArray(env, result, "embedding", embeddings); + auto promptTokens = createWritableArray(env); + for (const auto &tok : llama->embd) { + pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str()); + } + putArray(env, result, "prompt_tokens", promptTokens); + env->ReleaseStringUTFChars(text, text_chars); return result; } diff --git a/android/src/newarch/java/com/rnllama/RNLlamaModule.java b/android/src/newarch/java/com/rnllama/RNLlamaModule.java index a41aa05..19077c8 100644 --- a/android/src/newarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/newarch/java/com/rnllama/RNLlamaModule.java @@ -83,8 +83,8 @@ public void detokenize(double id, final ReadableArray tokens, final Promise prom } @ReactMethod - public void embedding(double id, final String text, final Promise promise) { - rnllama.embedding(id, text, promise); + public void embedding(double id, final String text, final ReadableMap params, final Promise promise) { + rnllama.embedding(id, text, params, promise); } @ReactMethod diff --git a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java index 4f01542..a96bf3a 100644 --- a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java @@ -84,8 +84,8 @@ public void detokenize(double id, final ReadableArray tokens, final Promise prom } @ReactMethod - public void embedding(double id, final String text, final Promise promise) { - rnllama.embedding(id, text, promise); + public void embedding(double id, final String text, final ReadableMap params, final Promise promise) { + rnllama.embedding(id, text, params, promise); } @ReactMethod diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index 1da49e6..86f702a 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -595,28 +595,29 @@ struct llama_rn_context return token_with_probs; } - std::vector getEmbedding() + std::vector getEmbedding(common_params &embd_params) { static const int n_embd = llama_n_embd(llama_get_model(ctx)); - if (!params.embedding) + if (!embd_params.embedding) { - LOG_WARNING("embedding disabled, embedding: %s", params.embedding); + LOG_WARNING("embedding disabled, embedding: %s", embd_params.embedding); return std::vector(n_embd, 0.0f); } float *data; - if(params.pooling_type == 0){ + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + printf("pooling_type: %d\n", pooling_type); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { data = llama_get_embeddings(ctx); - } - else { + } else { data = llama_get_embeddings_seq(ctx, 0); } - if(!data) { + if (!data) { return std::vector(n_embd, 0.0f); } std::vector embedding(data, data + n_embd), out(data, data + n_embd); - common_embd_normalize(embedding.data(), out.data(), n_embd, params.embd_normalize); + common_embd_normalize(embedding.data(), out.data(), n_embd, embd_params.embd_normalize); return out; } diff --git a/docs/API/README.md b/docs/API/README.md index 2dc0cee..6419ef5 100644 --- a/docs/API/README.md +++ b/docs/API/README.md @@ -14,6 +14,7 @@ llama.rn - [BenchResult](README.md#benchresult) - [CompletionParams](README.md#completionparams) - [ContextParams](README.md#contextparams) +- [EmbeddingParams](README.md#embeddingparams) - [TokenData](README.md#tokendata) ### Functions @@ -44,7 +45,7 @@ llama.rn #### Defined in -[index.ts:52](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L52) +[index.ts:57](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L57) ___ @@ -54,17 +55,27 @@ ___ #### Defined in -[index.ts:44](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L44) +[index.ts:49](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L49) ___ ### ContextParams -Ƭ **ContextParams**: `NativeContextParams` +Ƭ **ContextParams**: `Omit`<`NativeContextParams`, ``"pooling_type"``\> & { `pooling_type?`: ``"none"`` \| ``"mean"`` \| ``"cls"`` \| ``"last"`` \| ``"rank"`` } #### Defined in -[index.ts:42](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L42) +[index.ts:43](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L43) + +___ + +### EmbeddingParams + +Ƭ **EmbeddingParams**: `NativeEmbeddingParams` + +#### Defined in + +[index.ts:47](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L47) ___ @@ -81,7 +92,7 @@ ___ #### Defined in -[index.ts:32](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L32) +[index.ts:33](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L33) ## Functions @@ -105,7 +116,7 @@ ___ #### Defined in -[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L824) +[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L824) ___ @@ -117,7 +128,7 @@ ___ | Name | Type | | :------ | :------ | -| `«destructured»` | `NativeContextParams` | +| `«destructured»` | [`ContextParams`](README.md#contextparams) | | `onProgress?` | (`progress`: `number`) => `void` | #### Returns @@ -126,7 +137,7 @@ ___ #### Defined in -[index.ts:208](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L208) +[index.ts:225](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L225) ___ @@ -146,7 +157,7 @@ ___ #### Defined in -[index.ts:202](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L202) +[index.ts:210](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L210) ___ @@ -160,7 +171,7 @@ ___ #### Defined in -[index.ts:245](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L245) +[index.ts:269](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L269) ___ @@ -180,4 +191,4 @@ ___ #### Defined in -[index.ts:188](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L188) +[index.ts:196](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L196) diff --git a/docs/API/classes/LlamaContext.md b/docs/API/classes/LlamaContext.md index e6dd8a3..0240db7 100644 --- a/docs/API/classes/LlamaContext.md +++ b/docs/API/classes/LlamaContext.md @@ -42,7 +42,7 @@ #### Defined in -[index.ts:73](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L73) +[index.ts:78](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L78) ## Properties @@ -52,7 +52,7 @@ #### Defined in -[index.ts:65](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L65) +[index.ts:70](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L70) ___ @@ -62,7 +62,7 @@ ___ #### Defined in -[index.ts:63](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L63) +[index.ts:68](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L68) ___ @@ -78,7 +78,7 @@ ___ #### Defined in -[index.ts:69](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L69) +[index.ts:74](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L74) ___ @@ -88,7 +88,7 @@ ___ #### Defined in -[index.ts:67](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L67) +[index.ts:72](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L72) ## Methods @@ -111,7 +111,7 @@ ___ #### Defined in -[index.ts:163](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L163) +[index.ts:171](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L171) ___ @@ -132,7 +132,7 @@ ___ #### Defined in -[index.ts:110](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L110) +[index.ts:115](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L115) ___ @@ -152,19 +152,20 @@ ___ #### Defined in -[index.ts:155](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L155) +[index.ts:160](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L160) ___ ### embedding -▸ **embedding**(`text`): `Promise`<`NativeEmbeddingResult`\> +▸ **embedding**(`text`, `params?`): `Promise`<`NativeEmbeddingResult`\> #### Parameters | Name | Type | | :------ | :------ | | `text` | `string` | +| `params?` | `NativeEmbeddingParams` | #### Returns @@ -172,7 +173,7 @@ ___ #### Defined in -[index.ts:159](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L159) +[index.ts:164](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L164) ___ @@ -192,7 +193,7 @@ ___ #### Defined in -[index.ts:99](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L99) +[index.ts:104](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L104) ___ @@ -214,7 +215,7 @@ Load cached prompt & completion state from a file. #### Defined in -[index.ts:83](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L83) +[index.ts:88](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L88) ___ @@ -228,7 +229,7 @@ ___ #### Defined in -[index.ts:183](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L183) +[index.ts:191](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L191) ___ @@ -252,7 +253,7 @@ Save current cached prompt & completion state to a file. #### Defined in -[index.ts:92](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L92) +[index.ts:97](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L97) ___ @@ -266,7 +267,7 @@ ___ #### Defined in -[index.ts:147](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L147) +[index.ts:152](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L152) ___ @@ -286,4 +287,4 @@ ___ #### Defined in -[index.ts:151](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L151) +[index.ts:156](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L156) diff --git a/docs/API/classes/SchemaGrammarConverter.md b/docs/API/classes/SchemaGrammarConverter.md index 15778b4..a29b886 100644 --- a/docs/API/classes/SchemaGrammarConverter.md +++ b/docs/API/classes/SchemaGrammarConverter.md @@ -46,7 +46,7 @@ #### Defined in -[grammar.ts:211](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L211) +[grammar.ts:211](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L211) ## Properties @@ -56,7 +56,7 @@ #### Defined in -[grammar.ts:201](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L201) +[grammar.ts:201](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L201) ___ @@ -66,7 +66,7 @@ ___ #### Defined in -[grammar.ts:203](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L203) +[grammar.ts:203](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L203) ___ @@ -76,7 +76,7 @@ ___ #### Defined in -[grammar.ts:199](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L199) +[grammar.ts:199](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L199) ___ @@ -90,7 +90,7 @@ ___ #### Defined in -[grammar.ts:207](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L207) +[grammar.ts:207](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L207) ___ @@ -100,7 +100,7 @@ ___ #### Defined in -[grammar.ts:209](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L209) +[grammar.ts:209](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L209) ___ @@ -114,7 +114,7 @@ ___ #### Defined in -[grammar.ts:205](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L205) +[grammar.ts:205](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L205) ## Methods @@ -135,7 +135,7 @@ ___ #### Defined in -[grammar.ts:693](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L693) +[grammar.ts:693](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L693) ___ @@ -156,7 +156,7 @@ ___ #### Defined in -[grammar.ts:224](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L224) +[grammar.ts:224](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L224) ___ @@ -179,7 +179,7 @@ ___ #### Defined in -[grammar.ts:710](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L710) +[grammar.ts:710](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L710) ___ @@ -200,7 +200,7 @@ ___ #### Defined in -[grammar.ts:312](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L312) +[grammar.ts:312](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L312) ___ @@ -220,7 +220,7 @@ ___ #### Defined in -[grammar.ts:518](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L518) +[grammar.ts:518](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L518) ___ @@ -241,7 +241,7 @@ ___ #### Defined in -[grammar.ts:323](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L323) +[grammar.ts:323](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L323) ___ @@ -255,7 +255,7 @@ ___ #### Defined in -[grammar.ts:813](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L813) +[grammar.ts:813](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L813) ___ @@ -276,7 +276,7 @@ ___ #### Defined in -[grammar.ts:247](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L247) +[grammar.ts:247](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L247) ___ @@ -297,4 +297,4 @@ ___ #### Defined in -[grammar.ts:529](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L529) +[grammar.ts:529](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L529) diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index d36dff7..0f61f67 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -53,23 +53,27 @@ @implementation RNLlama return; } - RNLlamaContext *context = [RNLlamaContext initWithParams:contextParams onProgress:^(unsigned int progress) { - dispatch_async(dispatch_get_main_queue(), ^{ - [self sendEventWithName:@"@RNLlama_onInitContextProgress" body:@{ @"contextId": @(contextId), @"progress": @(progress) }]; - }); - }]; - if (![context isModelLoaded]) { - reject(@"llama_cpp_error", @"Failed to load the model", nil); - return; + @try { + RNLlamaContext *context = [RNLlamaContext initWithParams:contextParams onProgress:^(unsigned int progress) { + dispatch_async(dispatch_get_main_queue(), ^{ + [self sendEventWithName:@"@RNLlama_onInitContextProgress" body:@{ @"contextId": @(contextId), @"progress": @(progress) }]; + }); + }]; + if (![context isModelLoaded]) { + reject(@"llama_cpp_error", @"Failed to load the model", nil); + return; + } + + [llamaContexts setObject:context forKey:contextIdNumber]; + + resolve(@{ + @"gpu": @([context isMetalEnabled]), + @"reasonNoGPU": [context reasonNoMetal], + @"model": [context modelInfo], + }); + } @catch (NSException *exception) { + reject(@"llama_cpp_error", exception.reason, nil); } - - [llamaContexts setObject:context forKey:contextIdNumber]; - - resolve(@{ - @"gpu": @([context isMetalEnabled]), - @"reasonNoGPU": [context reasonNoMetal], - @"model": [context modelInfo], - }); } RCT_EXPORT_METHOD(getFormattedChat:(double)contextId @@ -229,6 +233,7 @@ - (NSArray *)supportedEvents { RCT_EXPORT_METHOD(embedding:(double)contextId text:(NSString *)text + params:(NSDictionary *)params withResolver:(RCTPromiseResolveBlock)resolve withRejecter:(RCTPromiseRejectBlock)reject) { @@ -238,9 +243,8 @@ - (NSArray *)supportedEvents { return; } @try { - NSMutableArray *embedding = [context embedding:text]; - resolve(@{ @"embedding": embedding }); - [embedding release]; + NSDictionary *embedding = [context embedding:text params:params]; + resolve(embedding); } @catch (NSException *exception) { reject(@"llama_cpp_error", exception.reason, nil); } diff --git a/ios/RNLlamaContext.h b/ios/RNLlamaContext.h index 922f8d1..52c4e92 100644 --- a/ios/RNLlamaContext.h +++ b/ios/RNLlamaContext.h @@ -28,7 +28,7 @@ - (void)stopCompletion; - (NSArray *)tokenize:(NSString *)text; - (NSString *)detokenize:(NSArray *)tokens; -- (NSArray *)embedding:(NSString *)text; +- (NSDictionary *)embedding:(NSString *)text params:(NSDictionary *)params; - (NSString *)getFormattedChat:(NSArray *)messages withTemplate:(NSString *)chatTemplate; - (NSDictionary *)loadSession:(NSString *)path; - (int)saveSession:(NSString *)path size:(int)size; diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 36249ef..e4d22ec 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -57,10 +57,6 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig if (isAsset) path = [[NSBundle mainBundle] pathForResource:modelPath ofType:nil]; defaultParams.model = [path UTF8String]; - if (params[@"embedding"] && [params[@"embedding"] boolValue]) { - defaultParams.embedding = true; - } - if (params[@"n_ctx"]) defaultParams.n_ctx = [params[@"n_ctx"] intValue]; if (params[@"use_mlock"]) defaultParams.use_mlock = [params[@"use_mlock"]boolValue]; @@ -100,6 +96,20 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig if (params[@"n_batch"]) defaultParams.n_batch = [params[@"n_batch"] intValue]; if (params[@"use_mmap"]) defaultParams.use_mmap = [params[@"use_mmap"] boolValue]; + if (params[@"pooling_type"] && [params[@"pooling_type"] isKindOfClass:[NSNumber class]]) { + defaultParams.pooling_type = static_cast([params[@"pooling_type"] intValue]); + } + + if (params[@"embedding"] && [params[@"embedding"] boolValue]) { + defaultParams.embedding = true; + // For non-causal models, batch size must be equal to ubatch size + defaultParams.n_ubatch = defaultParams.n_batch; + + if (params[@"embd_normalize"] && [params[@"embd_normalize"] isKindOfClass:[NSNumber class]]) { + defaultParams.embd_normalize = [params[@"embd_normalize"] intValue]; + } + } + if (params[@"lora"]) { float lora_scaled = 1.0f; if (params[@"lora_scaled"]) lora_scaled = [params[@"lora_scaled"] floatValue]; @@ -136,6 +146,15 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig } context->is_model_loaded = context->llama->loadModel(defaultParams); + + if ( + params[@"embedding"] && [params[@"embedding"] boolValue] && + llama_model_has_encoder(context->llama->model) && llama_model_has_decoder(context->llama->model) + ) { + delete context->llama; + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Embedding is not supported in encoder-decoder models" userInfo:nil]; + } + context->is_metal_enabled = isMetalEnabled; context->reason_no_metal = reasonNoMetal; @@ -418,11 +437,19 @@ - (NSString *)detokenize:(NSArray *)tokens { return [NSString stringWithUTF8String:text.c_str()]; } -- (NSArray *)embedding:(NSString *)text { +- (NSDictionary *)embedding:(NSString *)text params:(NSDictionary *)params { if (llama->params.embedding != true) { @throw [NSException exceptionWithName:@"LlamaException" reason:@"Embedding is not enabled" userInfo:nil]; } + common_params embdParams; + embdParams.embedding = true; + embdParams.embd_normalize = llama->params.embd_normalize; + + if (params[@"embd_normalize"] && [params[@"embd_normalize"] isKindOfClass:[NSNumber class]]) { + embdParams.embd_normalize = [params[@"embd_normalize"] intValue]; + } + llama->rewind(); llama_perf_context_reset(llama->ctx); @@ -438,15 +465,22 @@ - (NSArray *)embedding:(NSString *)text { llama->loadPrompt(); llama->doCompletion(); - std::vector result = llama->getEmbedding(); + std::vector result = llama->getEmbedding(embdParams); + NSMutableDictionary *resultDict = [[NSMutableDictionary alloc] init]; NSMutableArray *embeddingResult = [[NSMutableArray alloc] init]; for (float f : result) { [embeddingResult addObject:@(f)]; } + resultDict[@"embedding"] = embeddingResult; + NSMutableArray *promptTokens = [[NSMutableArray alloc] init]; + for (llama_token tok : llama->embd) { + [promptTokens addObject:[NSString stringWithUTF8String:common_token_to_piece(llama->ctx, tok).c_str()]]; + } + resultDict[@"prompt_tokens"] = promptTokens; llama->is_predicting = false; - return embeddingResult; + return resultDict; } - (NSDictionary *)loadSession:(NSString *)path { diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 5427d9c..cf4d49e 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -1,13 +1,15 @@ import type { TurboModule } from 'react-native' import { TurboModuleRegistry } from 'react-native' +export type NativeEmbeddingParams = { + embd_normalize?: number +} + export type NativeContextParams = { model: string is_model_asset?: boolean use_progress_callback?: boolean - embedding?: boolean - n_ctx?: number n_batch?: number @@ -23,6 +25,12 @@ export type NativeContextParams = { rope_freq_base?: number rope_freq_scale?: number + + pooling_type?: number + + // Embedding params + embedding?: boolean + embd_normalize?: number } export type NativeCompletionParams = { @@ -145,7 +153,11 @@ export interface Spec extends TurboModule { stopCompletion(contextId: number): Promise tokenize(contextId: number, text: string): Promise detokenize(contextId: number, tokens: number[]): Promise - embedding(contextId: number, text: string): Promise + embedding( + contextId: number, + text: string, + params: NativeEmbeddingParams, + ): Promise bench( contextId: number, pp: number, diff --git a/src/index.ts b/src/index.ts index a03c2b8..3e9d46a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -10,6 +10,7 @@ import type { NativeTokenizeResult, NativeEmbeddingResult, NativeSessionLoadResult, + NativeEmbeddingParams, } from './NativeRNLlama' import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar' import type { RNLlamaOAICompatibleMessage } from './chat' @@ -39,7 +40,11 @@ type TokenNativeEvent = { tokenResult: TokenData } -export type ContextParams = NativeContextParams +export type ContextParams = Omit & { + pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank' +} + +export type EmbeddingParams = NativeEmbeddingParams export type CompletionParams = Omit< NativeCompletionParams, @@ -156,8 +161,11 @@ export class LlamaContext { return RNLlama.detokenize(this.id, tokens) } - embedding(text: string): Promise { - return RNLlama.embedding(this.id, text) + embedding( + text: string, + params?: EmbeddingParams, + ): Promise { + return RNLlama.embedding(this.id, text, params || {}) } async bench( @@ -197,7 +205,7 @@ const modelInfoSkip = [ // Large fields 'tokenizer.ggml.tokens', 'tokenizer.ggml.token_type', - 'tokenizer.ggml.merges' + 'tokenizer.ggml.merges', ] export async function loadLlamaModelInfo(model: string): Promise { let path = model @@ -205,8 +213,22 @@ export async function loadLlamaModelInfo(model: string): Promise { return RNLlama.modelInfo(path, modelInfoSkip) } +const poolTypeMap = { + // -1 is unspecified as undefined + none: 0, + mean: 1, + cls: 2, + last: 3, + rank: 4, +} + export async function initLlama( - { model, is_model_asset: isModelAsset, ...rest }: ContextParams, + { + model, + is_model_asset: isModelAsset, + pooling_type: poolingType, + ...rest + }: ContextParams, onProgress?: (progress: number) => void, ): Promise { let path = model @@ -225,6 +247,7 @@ export async function initLlama( ) } + const poolType = poolTypeMap[poolingType as keyof typeof poolTypeMap] const { gpu, reasonNoGPU, @@ -233,6 +256,7 @@ export async function initLlama( model: path, is_model_asset: !!isModelAsset, use_progress_callback: !!onProgress, + pooling_type: poolType, ...rest, }).catch((err: any) => { removeProgressListener?.remove()