Skip to content

Commit

Permalink
Merge branch 'main' into sim-metal
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Feb 6, 2025
2 parents fc339a1 + 12abb68 commit 93a1a7c
Show file tree
Hide file tree
Showing 18 changed files with 326 additions and 83 deletions.
28 changes: 25 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,29 @@ public class LlamaContext {

private static String loadedLibrary = "";

private static class NativeLogCallback {
DeviceEventManagerModule.RCTDeviceEventEmitter eventEmitter;

public NativeLogCallback(ReactApplicationContext reactContext) {
this.eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
}

void emitNativeLog(String level, String text) {
WritableMap event = Arguments.createMap();
event.putString("level", level);
event.putString("text", text);
eventEmitter.emit("@RNLlama_onNativeLog", event);
}
}

static void toggleNativeLog(ReactApplicationContext reactContext, boolean enabled) {
if (enabled) {
setupLog(new NativeLogCallback(reactContext));
} else {
unsetLog();
}
}

private int id;
private ReactApplicationContext reactContext;
private long context;
Expand All @@ -37,8 +60,6 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
if (!params.hasKey("model")) {
throw new IllegalArgumentException("Missing required parameter: model");
}
Log.d(NAME, "Setting log callback");
logToAndroid();
eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
this.id = id;
this.context = initContext(
Expand Down Expand Up @@ -539,5 +560,6 @@ protected static native WritableMap embedding(
protected static native void removeLoraAdapters(long contextPtr);
protected static native WritableArray getLoadedLoraAdapters(long contextPtr);
protected static native void freeContext(long contextPtr);
protected static native void logToAndroid();
protected static native void setupLog(NativeLogCallback logCallback);
protected static native void unsetLog();
}
26 changes: 26 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,32 @@ public RNLlama(ReactApplicationContext reactContext) {

private HashMap<Integer, LlamaContext> contexts = new HashMap<>();

public void toggleNativeLog(boolean enabled, Promise promise) {
new AsyncTask<Void, Void, Boolean>() {
private Exception exception;

@Override
protected Boolean doInBackground(Void... voids) {
try {
LlamaContext.toggleNativeLog(reactContext, enabled);
return true;
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(Boolean result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
}

private int llamaContextLimit = -1;

public void setContextLimit(double limit, Promise promise) {
Expand Down
71 changes: 68 additions & 3 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ static inline int min(int a, int b) {
return (a < b) ? a : b;
}

static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
static void rnllama_log_callback_default(lm_ggml_log_level level, const char * fmt, void * data) {
if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
Expand Down Expand Up @@ -1110,11 +1110,76 @@ Java_com_rnllama_LlamaContext_freeContext(
delete llama;
}

struct log_callback_context {
JavaVM *jvm;
jobject callback;
};

static void rnllama_log_callback_to_j(lm_ggml_log_level level, const char * text, void * data) {
auto level_c = "";
if (level == LM_GGML_LOG_LEVEL_ERROR) {
__android_log_print(ANDROID_LOG_ERROR, TAG, text, nullptr);
level_c = "error";
} else if (level == LM_GGML_LOG_LEVEL_INFO) {
__android_log_print(ANDROID_LOG_INFO, TAG, text, nullptr);
level_c = "info";
} else if (level == LM_GGML_LOG_LEVEL_WARN) {
__android_log_print(ANDROID_LOG_WARN, TAG, text, nullptr);
level_c = "warn";
} else {
__android_log_print(ANDROID_LOG_DEFAULT, TAG, text, nullptr);
}

log_callback_context *cb_ctx = (log_callback_context *) data;

JNIEnv *env;
bool need_detach = false;
int getEnvResult = cb_ctx->jvm->GetEnv((void**)&env, JNI_VERSION_1_6);

if (getEnvResult == JNI_EDETACHED) {
if (cb_ctx->jvm->AttachCurrentThread(&env, nullptr) == JNI_OK) {
need_detach = true;
} else {
return;
}
} else if (getEnvResult != JNI_OK) {
return;
}

jobject callback = cb_ctx->callback;
jclass cb_class = env->GetObjectClass(callback);
jmethodID emitNativeLog = env->GetMethodID(cb_class, "emitNativeLog", "(Ljava/lang/String;Ljava/lang/String;)V");

jstring level_str = env->NewStringUTF(level_c);
jstring text_str = env->NewStringUTF(text);
env->CallVoidMethod(callback, emitNativeLog, level_str, text_str);
env->DeleteLocalRef(level_str);
env->DeleteLocalRef(text_str);

if (need_detach) {
cb_ctx->jvm->DetachCurrentThread();
}
}

JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_setupLog(JNIEnv *env, jobject thiz, jobject logCallback) {
UNUSED(thiz);

log_callback_context *cb_ctx = new log_callback_context;

JavaVM *jvm;
env->GetJavaVM(&jvm);
cb_ctx->jvm = jvm;
cb_ctx->callback = env->NewGlobalRef(logCallback);

llama_log_set(rnllama_log_callback_to_j, cb_ctx);
}

JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
Java_com_rnllama_LlamaContext_unsetLog(JNIEnv *env, jobject thiz) {
UNUSED(env);
UNUSED(thiz);
llama_log_set(log_callback, NULL);
llama_log_set(rnllama_log_callback_default, NULL);
}

} // extern "C"
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public String getName() {
return NAME;
}

@ReactMethod
public void toggleNativeLog(boolean enabled, Promise promise) {
rnllama.toggleNativeLog(enabled, promise);
}

@ReactMethod
public void setContextLimit(double limit, Promise promise) {
rnllama.setContextLimit(limit, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ public String getName() {
return NAME;
}

@ReactMethod
public void toggleNativeLog(boolean enabled, Promise promise) {
rnllama.toggleNativeLog(enabled, promise);
}

@ReactMethod
public void setContextLimit(double limit, Promise promise) {
rnllama.setContextLimit(limit, promise);
Expand Down
Loading

0 comments on commit 93a1a7c

Please sign in to comment.