From 41b779fe2dcb69bdc8b6cc0962798551df98cf66 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Mon, 4 Nov 2024 14:39:21 +0800 Subject: [PATCH] feat: add progress callback in initLlama (#82) * feat(ios): add progress callback in initLlama * feat(android): add progress callback in initLlama * fix(ts): skip random context id on testing --- .../main/java/com/rnllama/LlamaContext.java | 33 +++++++++++-- .../src/main/java/com/rnllama/RNLlama.java | 14 ++++-- android/src/main/jni.cpp | 48 ++++++++++++++++++- .../java/com/rnllama/RNLlamaModule.java | 4 +- .../java/com/rnllama/RNLlamaModule.java | 4 +- cpp/common.cpp | 3 ++ cpp/common.h | 3 ++ cpp/rn-llama.hpp | 19 ++++---- example/ios/.xcode.env.local | 2 +- example/src/App.tsx | 19 +++++++- ios/RNLlama.mm | 23 ++++++--- ios/RNLlamaContext.h | 7 ++- ios/RNLlamaContext.mm | 25 ++++++++-- jest/mock.js | 1 - scripts/common.cpp.patch | 22 ++++++--- scripts/common.h.patch | 22 ++++++--- src/NativeRNLlama.ts | 3 +- src/__tests__/index.test.ts | 2 +- src/index.ts | 40 ++++++++++++---- 19 files changed, 236 insertions(+), 58 deletions(-) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 18cd87ee..337ed04e 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -37,6 +37,7 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa } Log.d(NAME, "Setting log callback"); logToAndroid(); + eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class); this.id = id; this.context = initContext( // String model, @@ -64,11 +65,16 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa // float rope_freq_base, params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f, // float rope_freq_scale - params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f + params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f, + // LoadProgressCallback load_progress_callback + params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null ); this.modelDetails = loadModelDetails(this.context); this.reactContext = reactContext; - eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class); + } + + public void interruptLoad() { + interruptLoad(this.context); } public long getContext() { @@ -87,6 +93,25 @@ public String getFormattedChat(ReadableArray messages, String chatTemplate) { return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate); } + private void emitLoadProgress(int progress) { + WritableMap event = Arguments.createMap(); + event.putInt("contextId", LlamaContext.this.id); + event.putInt("progress", progress); + eventEmitter.emit("@RNLlama_onInitContextProgress", event); + } + + private static class LoadProgressCallback { + LlamaContext context; + + public LoadProgressCallback(LlamaContext context) { + this.context = context; + } + + void onLoadProgress(int progress) { + context.emitLoadProgress(progress); + } + } + private void emitPartialCompletion(WritableMap tokenResult) { WritableMap event = Arguments.createMap(); event.putInt("contextId", LlamaContext.this.id); @@ -346,8 +371,10 @@ protected static native long initContext( String lora, float lora_scaled, float rope_freq_base, - float rope_freq_scale + float rope_freq_scale, + LoadProgressCallback load_progress_callback ); + protected static native void interruptLoad(long contextPtr); protected static native WritableMap loadModelDetails( long contextPtr ); diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index ac96eb25..eb027554 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -42,21 +42,24 @@ public void setContextLimit(double limit, Promise promise) { promise.resolve(null); } - public void initContext(final ReadableMap params, final Promise promise) { + public void initContext(double id, final ReadableMap params, final Promise promise) { + final int contextId = (int) id; AsyncTask task = new AsyncTask() { private Exception exception; @Override protected WritableMap doInBackground(Void... voids) { try { - int id = Math.abs(new Random().nextInt()); - LlamaContext llamaContext = new LlamaContext(id, reactContext, params); + LlamaContext context = contexts.get(contextId); + if (context != null) { + throw new Exception("Context already exists"); + } + LlamaContext llamaContext = new LlamaContext(contextId, reactContext, params); if (llamaContext.getContext() == 0) { throw new Exception("Failed to initialize context"); } - contexts.put(id, llamaContext); + contexts.put(contextId, llamaContext); WritableMap result = Arguments.createMap(); - result.putInt("contextId", id); result.putBoolean("gpu", false); result.putString("reasonNoGPU", "Currently not supported"); result.putMap("model", llamaContext.getModelDetails()); @@ -393,6 +396,7 @@ protected Void doInBackground(Void... voids) { if (context == null) { throw new Exception("Context " + id + " not found"); } + context.interruptLoad(); context.stopCompletion(); AsyncTask completionTask = null; for (AsyncTask task : tasks.keySet()) { diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 4ad058b5..128aa3fe 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -132,6 +132,11 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v env->CallVoidMethod(map, putArrayMethod, jKey, value); } +struct callback_context { + JNIEnv *env; + rnllama::llama_rn_context *llama; + jobject callback; +}; std::unordered_map context_map; @@ -151,7 +156,8 @@ Java_com_rnllama_LlamaContext_initContext( jstring lora_str, jfloat lora_scaled, jfloat rope_freq_base, - jfloat rope_freq_scale + jfloat rope_freq_scale, + jobject load_progress_callback ) { UNUSED(thiz); @@ -190,6 +196,32 @@ Java_com_rnllama_LlamaContext_initContext( defaultParams.rope_freq_scale = rope_freq_scale; auto llama = new rnllama::llama_rn_context(); + llama->is_load_interrupted = false; + llama->loading_progress = 0; + + if (load_progress_callback != nullptr) { + defaultParams.progress_callback = [](float progress, void * user_data) { + callback_context *cb_ctx = (callback_context *)user_data; + JNIEnv *env = cb_ctx->env; + auto llama = cb_ctx->llama; + jobject callback = cb_ctx->callback; + int percentage = (int) (100 * progress); + if (percentage > llama->loading_progress) { + llama->loading_progress = percentage; + jclass callback_class = env->GetObjectClass(callback); + jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V"); + env->CallVoidMethod(callback, onLoadProgress, percentage); + } + return !llama->is_load_interrupted; + }; + + callback_context *cb_ctx = new callback_context; + cb_ctx->env = env; + cb_ctx->llama = llama; + cb_ctx->callback = env->NewGlobalRef(load_progress_callback); + defaultParams.progress_callback_user_data = cb_ctx; + } + bool is_model_loaded = llama->loadModel(defaultParams); LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false")); @@ -205,6 +237,20 @@ Java_com_rnllama_LlamaContext_initContext( return reinterpret_cast(llama->ctx); } + +JNIEXPORT void JNICALL +Java_com_rnllama_LlamaContext_interruptLoad( + JNIEnv *env, + jobject thiz, + jlong context_ptr +) { + UNUSED(thiz); + auto llama = context_map[(long) context_ptr]; + if (llama) { + llama->is_load_interrupted = true; + } +} + JNIEXPORT jobject JNICALL Java_com_rnllama_LlamaContext_loadModelDetails( JNIEnv *env, diff --git a/android/src/newarch/java/com/rnllama/RNLlamaModule.java b/android/src/newarch/java/com/rnllama/RNLlamaModule.java index 7527c0f5..5bab9b16 100644 --- a/android/src/newarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/newarch/java/com/rnllama/RNLlamaModule.java @@ -38,8 +38,8 @@ public void setContextLimit(double limit, Promise promise) { } @ReactMethod - public void initContext(final ReadableMap params, final Promise promise) { - rnllama.initContext(params, promise); + public void initContext(double id, final ReadableMap params, final Promise promise) { + rnllama.initContext(id, params, promise); } @ReactMethod diff --git a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java index 4e6cc6fb..2719515b 100644 --- a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java @@ -39,8 +39,8 @@ public void setContextLimit(double limit, Promise promise) { } @ReactMethod - public void initContext(final ReadableMap params, final Promise promise) { - rnllama.initContext(params, promise); + public void initContext(double id, final ReadableMap params, final Promise promise) { + rnllama.initContext(id, params, promise); } @ReactMethod diff --git a/cpp/common.cpp b/cpp/common.cpp index a7e4a467..dfaa0378 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -1001,6 +1001,9 @@ struct llama_model_params common_model_params_to_llama(const common_params & par mparams.kv_overrides = params.kv_overrides.data(); } + mparams.progress_callback = params.progress_callback; + mparams.progress_callback_user_data = params.progress_callback_user_data; + return mparams; } diff --git a/cpp/common.h b/cpp/common.h index 5a79c8c5..c60e5bef 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -283,6 +283,9 @@ struct common_params { bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data + llama_progress_callback progress_callback; + void * progress_callback_user_data; + std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index c8ca2ce7..a0b6c9a2 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -158,9 +158,12 @@ struct llama_rn_context common_params params; llama_model *model = nullptr; + float loading_progress = 0; + bool is_load_interrupted = false; + llama_context *ctx = nullptr; common_sampler *ctx_sampling = nullptr; - + int n_ctx; bool truncated = false; @@ -367,7 +370,7 @@ struct llama_rn_context n_eval = params.n_batch; } if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval))) - { + { LOG_ERROR("failed to eval, n_eval: %d, n_past: %d, n_threads: %d, embd: %s", n_eval, n_past, @@ -378,7 +381,7 @@ struct llama_rn_context return result; } n_past += n_eval; - + if(is_interrupted) { LOG_INFO("Decoding Interrupted"); embd.resize(n_past); @@ -400,11 +403,11 @@ struct llama_rn_context candidates.reserve(llama_n_vocab(model)); result.tok = common_sampler_sample(ctx_sampling, ctx, -1); - + llama_token_data_array cur_p = *common_sampler_get_candidates(ctx_sampling); const int32_t n_probs = params.sparams.n_probs; - + // deprecated /*if (params.sparams.temp <= 0 && n_probs > 0) { @@ -412,7 +415,7 @@ struct llama_rn_context llama_sampler_init_softmax(); }*/ - + for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) { @@ -542,14 +545,14 @@ struct llama_rn_context return std::vector(n_embd, 0.0f); } float *data; - + if(params.pooling_type == 0){ data = llama_get_embeddings(ctx); } else { data = llama_get_embeddings_seq(ctx, 0); } - + if(!data) { return std::vector(n_embd, 0.0f); } diff --git a/example/ios/.xcode.env.local b/example/ios/.xcode.env.local index 289c3d07..347de307 100644 --- a/example/ios/.xcode.env.local +++ b/example/ios/.xcode.env.local @@ -1 +1 @@ -export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1730514789911-0.16979892623603998/node +export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1730697817603-0.6786179339916347/node diff --git a/example/src/App.tsx b/example/src/App.tsx index 7f5959a8..4a8e9d24 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -64,6 +64,7 @@ export default function App() { metadata: { system: true, ...metadata }, } addMessage(textMessage) + return textMessage.id } const handleReleaseContext = async () => { @@ -82,12 +83,28 @@ export default function App() { const handleInitContext = async (file: DocumentPickerResponse) => { await handleReleaseContext() - addSystemMessage('Initializing context...') + const msgId = addSystemMessage('Initializing context...') initLlama({ model: file.uri, use_mlock: true, n_gpu_layers: Platform.OS === 'ios' ? 0 : 0, // > 0: enable GPU // embedding: true, + }, (progress) => { + setMessages((msgs) => { + const index = msgs.findIndex((msg) => msg.id === msgId) + if (index >= 0) { + return msgs.map((msg, i) => { + if (msg.type == 'text' && i === index) { + return { + ...msg, + text: `Initializing context... ${progress}%`, + } + } + return msg + }) + } + return msgs + }) }) .then((ctx) => { setContext(ctx) diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index 89b37c06..cef66304 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -21,10 +21,17 @@ @implementation RNLlama resolve(nil); } -RCT_EXPORT_METHOD(initContext:(NSDictionary *)contextParams +RCT_EXPORT_METHOD(initContext:(double)contextId + withContextParams:(NSDictionary *)contextParams withResolver:(RCTPromiseResolveBlock)resolve withRejecter:(RCTPromiseRejectBlock)reject) { + NSNumber *contextIdNumber = [NSNumber numberWithDouble:contextId]; + if (llamaContexts[contextIdNumber] != nil) { + reject(@"llama_error", @"Context already exists", nil); + return; + } + if (llamaDQueue == nil) { llamaDQueue = dispatch_queue_create("com.rnllama", DISPATCH_QUEUE_SERIAL); } @@ -38,19 +45,19 @@ @implementation RNLlama return; } - RNLlamaContext *context = [RNLlamaContext initWithParams:contextParams]; + RNLlamaContext *context = [RNLlamaContext initWithParams:contextParams onProgress:^(unsigned int progress) { + dispatch_async(dispatch_get_main_queue(), ^{ + [self sendEventWithName:@"@RNLlama_onInitContextProgress" body:@{ @"contextId": @(contextId), @"progress": @(progress) }]; + }); + }]; if (![context isModelLoaded]) { reject(@"llama_cpp_error", @"Failed to load the model", nil); return; } - double contextId = (double) arc4random_uniform(1000000); - - NSNumber *contextIdNumber = [NSNumber numberWithDouble:contextId]; [llamaContexts setObject:context forKey:contextIdNumber]; resolve(@{ - @"contextId": contextIdNumber, @"gpu": @([context isMetalEnabled]), @"reasonNoGPU": [context reasonNoMetal], @"model": [context modelInfo], @@ -125,6 +132,7 @@ @implementation RNLlama - (NSArray *)supportedEvents { return@[ + @"@RNLlama_onInitContextProgress", @"@RNLlama_onToken", ]; } @@ -260,6 +268,9 @@ - (NSArray *)supportedEvents { reject(@"llama_error", @"Context not found", nil); return; } + if (![context isModelLoaded]) { + [context interruptLoad]; + } [context stopCompletion]; dispatch_barrier_sync(llamaDQueue, ^{}); [context invalidate]; diff --git a/ios/RNLlamaContext.h b/ios/RNLlamaContext.h index 37a34bb8..3d70a540 100644 --- a/ios/RNLlamaContext.h +++ b/ios/RNLlamaContext.h @@ -6,13 +6,16 @@ @interface RNLlamaContext : NSObject { bool is_metal_enabled; - NSString * reason_no_metal; bool is_model_loaded; + NSString * reason_no_metal; + + void (^onProgress)(unsigned int progress); rnllama::llama_rn_context * llama; } -+ (instancetype)initWithParams:(NSDictionary *)params; ++ (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsigned int progress))onProgress; +- (void)interruptLoad; - (bool)isMetalEnabled; - (NSString *)reasonNoMetal; - (NSDictionary *)modelInfo; diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index b2f959d3..1e4ad35d 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -3,7 +3,7 @@ @implementation RNLlamaContext -+ (instancetype)initWithParams:(NSDictionary *)params { ++ (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsigned int progress))onProgress { // llama_backend_init(false); common_params defaultParams; @@ -78,9 +78,24 @@ + (instancetype)initWithParams:(NSDictionary *)params { defaultParams.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads; RNLlamaContext *context = [[RNLlamaContext alloc] init]; - if (context->llama == nullptr) { - context->llama = new rnllama::llama_rn_context(); + context->llama = new rnllama::llama_rn_context(); + context->llama->is_load_interrupted = false; + context->llama->loading_progress = 0; + context->onProgress = onProgress; + + if (params[@"use_progress_callback"] && [params[@"use_progress_callback"] boolValue]) { + defaultParams.progress_callback = [](float progress, void * user_data) { + RNLlamaContext *context = (__bridge RNLlamaContext *)(user_data); + unsigned percentage = (unsigned) (100 * progress); + if (percentage > context->llama->loading_progress) { + context->llama->loading_progress = percentage; + context->onProgress(percentage); + } + return !context->llama->is_load_interrupted; + }; + defaultParams.progress_callback_user_data = context; } + context->is_model_loaded = context->llama->loadModel(defaultParams); context->is_metal_enabled = isMetalEnabled; context->reason_no_metal = reasonNoMetal; @@ -88,6 +103,10 @@ + (instancetype)initWithParams:(NSDictionary *)params { return context; } +- (void)interruptLoad { + llama->is_load_interrupted = true; +} + - (bool)isMetalEnabled { return is_metal_enabled; } diff --git a/jest/mock.js b/jest/mock.js index 48914857..ace189e0 100644 --- a/jest/mock.js +++ b/jest/mock.js @@ -4,7 +4,6 @@ if (!NativeModules.RNLlama) { NativeModules.RNLlama = { initContext: jest.fn(() => Promise.resolve({ - contextId: 1, gpu: false, reasonNoGPU: 'Test', }), diff --git a/scripts/common.cpp.patch b/scripts/common.cpp.patch index ae22c096..49c1622e 100644 --- a/scripts/common.cpp.patch +++ b/scripts/common.cpp.patch @@ -1,18 +1,18 @@ ---- common.cpp.orig 2024-11-02 10:33:10 -+++ common.cpp 2024-11-02 10:33:11 -@@ -53,6 +53,12 @@ - #include +--- common.cpp.orig 2024-11-04 12:59:08 ++++ common.cpp 2024-11-04 12:58:17 +@@ -54,6 +54,12 @@ #include #endif -+ + +// build info +int LLAMA_BUILD_NUMBER = 0; +char const *LLAMA_COMMIT = "unknown"; +char const *LLAMA_COMPILER = "unknown"; +char const *LLAMA_BUILD_TARGET = "unknown"; - ++ #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data + #endif @@ -979,6 +985,8 @@ if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; @@ -22,3 +22,13 @@ mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; +@@ -993,6 +1001,9 @@ + mparams.kv_overrides = params.kv_overrides.data(); + } + ++ mparams.progress_callback = params.progress_callback; ++ mparams.progress_callback_user_data = params.progress_callback_user_data; ++ + return mparams; + } + diff --git a/scripts/common.h.patch b/scripts/common.h.patch index 354d31df..87b54b8e 100644 --- a/scripts/common.h.patch +++ b/scripts/common.h.patch @@ -1,10 +1,9 @@ ---- common.h.orig 2024-11-02 10:33:10 -+++ common.h 2024-11-02 10:33:11 -@@ -40,6 +40,17 @@ - extern char const * LLAMA_BUILD_TARGET; +--- common.h.orig 2024-11-04 12:59:08 ++++ common.h 2024-11-04 12:58:24 +@@ -41,6 +41,17 @@ struct common_control_vector_load_info; -+ + +#define print_build_info() do { \ + fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ + fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ @@ -15,9 +14,10 @@ +extern char const *LLAMA_COMMIT; +extern char const *LLAMA_COMPILER; +extern char const *LLAMA_BUILD_TARGET; - ++ // // CPU utils + // @@ -154,6 +165,7 @@ }; @@ -26,3 +26,13 @@ int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 0; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) +@@ -271,6 +283,9 @@ + bool warmup = true; // warmup run + bool check_tensors = false; // validate tensor data + ++ llama_progress_callback progress_callback; ++ void * progress_callback_user_data; ++ + std::string cache_type_k = "f16"; // KV cache data type for the K + std::string cache_type_v = "f16"; // KV cache data type for the V + diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index d42d3fbc..c0eda5dd 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -4,6 +4,7 @@ import { TurboModuleRegistry } from 'react-native' export type NativeContextParams = { model: string is_model_asset?: boolean + use_progress_callback?: boolean embedding?: boolean @@ -119,7 +120,7 @@ export type NativeLlamaChatMessage = { export interface Spec extends TurboModule { setContextLimit(limit: number): Promise - initContext(params: NativeContextParams): Promise + initContext(contextId: number, params: NativeContextParams): Promise getFormattedChat( contextId: number, diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index e980961f..aa99b8ec 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -9,7 +9,7 @@ test('Mock', async () => { const context = await initLlama({ model: 'test.bin', }) - expect(context.id).toBe(1) + expect(context.id).toBe(0) const events: TokenData[] = [] const completionResult = await context.completion({ prompt: 'Test', diff --git a/src/index.ts b/src/index.ts index 151dd18b..ccce03b8 100644 --- a/src/index.ts +++ b/src/index.ts @@ -17,6 +17,7 @@ import { formatChat } from './chat' export { SchemaGrammarConverter, convertJsonSchemaToGrammar } +const EVENT_ON_INIT_CONTEXT_PROGRESS = '@RNLlama_onInitContextProgress' const EVENT_ON_TOKEN = '@RNLlama_onToken' let EventEmitter: NativeEventEmitter | DeviceEventEmitterStatic @@ -110,9 +111,9 @@ export class LlamaContext { params: CompletionParams, callback?: (data: TokenData) => void, ): Promise { - let finalPrompt = params.prompt - if (params.messages) { // messages always win + if (params.messages) { + // messages always win finalPrompt = await this.getFormattedChat(params.messages) } @@ -188,23 +189,44 @@ export async function setContextLimit(limit: number): Promise { return RNLlama.setContextLimit(limit) } -export async function initLlama({ - model, - is_model_asset: isModelAsset, - ...rest -}: ContextParams): Promise { +let contextIdCounter = 0 +const contextIdRandom = () => + process.env.NODE_ENV === 'test' ? 0 : Math.floor(Math.random() * 100000) + +export async function initLlama( + { model, is_model_asset: isModelAsset, ...rest }: ContextParams, + onProgress?: (progress: number) => void, +): Promise { let path = model if (path.startsWith('file://')) path = path.slice(7) + const contextId = contextIdCounter + contextIdRandom() + contextIdCounter += 1 + + let removeProgressListener: any = null + if (onProgress) { + removeProgressListener = EventEmitter.addListener( + EVENT_ON_INIT_CONTEXT_PROGRESS, + (evt: { contextId: number; progress: number }) => { + if (evt.contextId !== contextId) return + onProgress(evt.progress) + }, + ) + } + const { - contextId, gpu, reasonNoGPU, model: modelDetails, - } = await RNLlama.initContext({ + } = await RNLlama.initContext(contextId, { model: path, is_model_asset: !!isModelAsset, + use_progress_callback: !!onProgress, ...rest, + }).catch((err: any) => { + removeProgressListener?.remove() + throw err }) + removeProgressListener?.remove() return new LlamaContext({ contextId, gpu, reasonNoGPU, model: modelDetails }) }