diff --git a/java/client/build.gradle b/java/client/build.gradle index c35794d882..3fa5072199 100644 --- a/java/client/build.gradle +++ b/java/client/build.gradle @@ -219,6 +219,8 @@ tasks.withType(Test) { events "started", "skipped", "passed", "failed" showStandardStreams true } + // This is needed for the FFI tests + jvmArgs "-Djava.library.path=${projectDir}/../target/debug" } jar { diff --git a/java/client/src/test/java/glide/ffi/FfiTest.java b/java/client/src/test/java/glide/ffi/FfiTest.java index b0c72f77c9..af70730e3d 100644 --- a/java/client/src/test/java/glide/ffi/FfiTest.java +++ b/java/client/src/test/java/glide/ffi/FfiTest.java @@ -5,6 +5,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import glide.ffi.resolvers.RedisValueResolver; @@ -42,6 +43,18 @@ public class FfiTest { public static native long createLeakedLongSet(long[] value); + // This tests that panics do not cross the FFI boundary and an exception is thrown if a panic is + // caught + public static native long handlePanics( + boolean shouldPanic, boolean errorPresent, long value, long defaultValue); + + // This tests that Rust errors are properly converted into Java exceptions and thrown + public static native long handleErrors(boolean isSuccess, long value, long defaultValue); + + // This tests that a Java exception is properly thrown across the FFI boundary + public static native void throwException( + boolean throwTwice, boolean isRuntimeException, String message); + @Test public void redisValueToJavaValue_Nil() { long ptr = FfiTest.createLeakedNil(); @@ -141,4 +154,52 @@ public void redisValueToJavaValue_Set() { () -> assertTrue(result.contains(2L)), () -> assertEquals(result.size(), 2)); } + + @Test + public void handlePanics_panic() { + long expectedValue = 0L; + long value = FfiTest.handlePanics(true, false, 1L, expectedValue); + assertEquals(expectedValue, value); + } + + @Test + public void handlePanics_returnError() { + long expectedValue = 0L; + long value = FfiTest.handlePanics(false, true, 1L, expectedValue); + assertEquals(expectedValue, value); + } + + @Test + public void handlePanics_returnValue() { + long expectedValue = 2L; + long value = FfiTest.handlePanics(false, false, expectedValue, 0L); + assertEquals(expectedValue, value); + } + + @Test + public void handleErrors_success() { + long expectedValue = 0L; + long value = FfiTest.handleErrors(true, expectedValue, 1L); + assertEquals(expectedValue, value); + } + + @Test + public void handleErrors_error() { + assertThrows(Exception.class, () -> FfiTest.handleErrors(false, 0L, 1L)); + } + + @Test + public void throwException() { + assertThrows(Exception.class, () -> FfiTest.throwException(false, false, "My message")); + } + + @Test + public void throwException_throwTwice() { + assertThrows(Exception.class, () -> FfiTest.throwException(true, false, "My message")); + } + + @Test + public void throwException_throwRuntimeException() { + assertThrows(RuntimeException.class, () -> FfiTest.throwException(false, true, "My message")); + } } diff --git a/java/src/errors.rs b/java/src/errors.rs new file mode 100644 index 0000000000..031a898551 --- /dev/null +++ b/java/src/errors.rs @@ -0,0 +1,108 @@ +/** + * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 + */ +use jni::{errors::Error as JNIError, JNIEnv}; +use log::error; +use std::string::FromUtf8Error; + +pub enum FFIError { + Jni(JNIError), + Uds(String), + Utf8(FromUtf8Error), +} + +impl From for FFIError { + fn from(value: jni::errors::Error) -> Self { + FFIError::Jni(value) + } +} + +impl From for FFIError { + fn from(value: FromUtf8Error) -> Self { + FFIError::Utf8(value) + } +} + +impl std::fmt::Display for FFIError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FFIError::Jni(err) => write!(f, "{}", err), + FFIError::Uds(err) => write!(f, "{}", err), + FFIError::Utf8(err) => write!(f, "{}", err), + } + } +} + +#[derive(Copy, Clone)] +pub enum ExceptionType { + Exception, + RuntimeException, +} + +impl std::fmt::Display for ExceptionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ExceptionType::Exception => write!(f, "java/lang/Exception"), + ExceptionType::RuntimeException => write!(f, "java/lang/RuntimeException"), + } + } +} + +// This handles `FFIError`s by converting them to Java exceptions and throwing them. +pub fn handle_errors(env: &mut JNIEnv, result: Result) -> Option { + match result { + Ok(value) => Some(value), + Err(err) => { + match err { + FFIError::Utf8(utf8_error) => throw_java_exception( + env, + ExceptionType::RuntimeException, + &utf8_error.to_string(), + ), + error => throw_java_exception(env, ExceptionType::Exception, &error.to_string()), + }; + // Return `None` because we need to still return a value after throwing. + // This signals to the caller that we need to return the default value. + None + } + } +} + +// This function handles Rust panics by converting them into Java exceptions and throwing them. +// `func` returns an `Option` because this is intended to wrap the output of `handle_errors`. +pub fn handle_panics Option>( + func: F, + ffi_func_name: &str, +) -> Option { + match std::panic::catch_unwind(func) { + Ok(value) => value, + Err(_err) => { + // Following https://github.com/jni-rs/jni-rs/issues/76#issuecomment-363523906 + // and throwing a runtime exception is not feasible here because of https://github.com/jni-rs/jni-rs/issues/432 + error!("Native function {} panicked.", ffi_func_name); + None + } + } +} + +pub fn throw_java_exception(env: &mut JNIEnv, exception_type: ExceptionType, message: &str) { + match env.exception_check() { + Ok(true) => (), + Ok(false) => { + env.throw_new(exception_type.to_string(), message) + .unwrap_or_else(|err| { + error!( + "Failed to create exception with string {}: {}", + message, + err.to_string() + ); + }); + } + Err(err) => { + error!( + "Failed to check if an exception is currently being thrown: {}", + err.to_string() + ); + } + } +} diff --git a/java/src/ffi_test.rs b/java/src/ffi_test.rs index 5cebaf2fd3..03b4e52726 100644 --- a/java/src/ffi_test.rs +++ b/java/src/ffi_test.rs @@ -1,9 +1,10 @@ /** * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +use crate::errors::{handle_errors, handle_panics, throw_java_exception, ExceptionType, FFIError}; use jni::{ - objects::{JClass, JLongArray}, - sys::jlong, + objects::{JByteArray, JClass, JLongArray, JString}, + sys::{jboolean, jdouble, jlong}, JNIEnv, }; use redis::Value; @@ -21,7 +22,7 @@ pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedNil<'local>( pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedSimpleString<'local>( mut env: JNIEnv<'local>, _class: JClass<'local>, - value: jni::objects::JString<'local>, + value: JString<'local>, ) -> jlong { let value: String = env.get_string(&value).unwrap().into(); let redis_value = Value::SimpleString(value); @@ -51,7 +52,7 @@ pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedInt<'local>( pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedBulkString<'local>( env: JNIEnv<'local>, _class: JClass<'local>, - value: jni::objects::JByteArray<'local>, + value: JByteArray<'local>, ) -> jlong { let value = env.convert_byte_array(&value).unwrap(); let value = value.into_iter().collect::>(); @@ -88,7 +89,7 @@ pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedMap<'local>( pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedDouble<'local>( _env: JNIEnv<'local>, _class: JClass<'local>, - value: jni::sys::jdouble, + value: jdouble, ) -> jlong { let redis_value = Value::Double(value.into()); Box::leak(Box::new(redis_value)) as *mut Value as jlong @@ -98,7 +99,7 @@ pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedDouble<'local>( pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedBoolean<'local>( _env: JNIEnv<'local>, _class: JClass<'local>, - value: jni::sys::jboolean, + value: jboolean, ) -> jlong { let redis_value = Value::Boolean(value != 0); Box::leak(Box::new(redis_value)) as *mut Value as jlong @@ -108,7 +109,7 @@ pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedBoolean<'local>( pub extern "system" fn Java_glide_ffi_FfiTest_createLeakedVerbatimString<'local>( mut env: JNIEnv<'local>, _class: JClass<'local>, - value: jni::objects::JString<'local>, + value: JString<'local>, ) -> jlong { use redis::VerbatimFormat; let value: String = env.get_string(&value).unwrap().into(); @@ -144,3 +145,68 @@ fn java_long_array_to_value<'local>( .map(|value| Value::Int(*value)) .collect::>() } + +#[no_mangle] +pub extern "system" fn Java_glide_ffi_FfiTest_handlePanics<'local>( + _env: JNIEnv<'local>, + _class: JClass<'local>, + should_panic: jboolean, + error_present: jboolean, + value: jlong, + default_value: jlong, +) -> jlong { + let should_panic = should_panic != 0; + let error_present = error_present != 0; + handle_panics( + || { + if should_panic { + panic!("Panicking") + } else if error_present { + None + } else { + Some(value) + } + }, + "handlePanics", + ) + .unwrap_or(default_value) +} + +#[no_mangle] +pub extern "system" fn Java_glide_ffi_FfiTest_handleErrors<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + is_success: jboolean, + value: jlong, + default_value: jlong, +) -> jlong { + let is_success = is_success != 0; + let error = FFIError::Uds("Error starting socket listener".to_string()); + let result = if is_success { Ok(value) } else { Err(error) }; + handle_errors(&mut env, result).unwrap_or(default_value) +} + +#[no_mangle] +pub extern "system" fn Java_glide_ffi_FfiTest_throwException<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + throw_twice: jboolean, + is_runtime_exception: jboolean, + message: JString<'local>, +) { + let throw_twice = throw_twice != 0; + let is_runtime_exception = is_runtime_exception != 0; + + let exception_type = if is_runtime_exception { + ExceptionType::RuntimeException + } else { + ExceptionType::Exception + }; + + let message: String = env.get_string(&message).unwrap().into(); + throw_java_exception(&mut env, exception_type, &message); + + if throw_twice { + throw_java_exception(&mut env, exception_type, &message); + } +} diff --git a/java/src/lib.rs b/java/src/lib.rs index a6154d023c..a7dade4afd 100644 --- a/java/src/lib.rs +++ b/java/src/lib.rs @@ -1,15 +1,18 @@ /** * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ -use glide_core::start_socket_listener; +use glide_core::start_socket_listener as start_socket_listener_core; -use jni::objects::{JClass, JObject, JObjectArray, JString, JThrowable}; +use jni::objects::{JClass, JObject, JObjectArray, JString}; use jni::sys::jlong; use jni::JNIEnv; -use log::error; use redis::Value; use std::sync::mpsc; +mod errors; + +use errors::{handle_errors, handle_panics, FFIError}; + #[cfg(ffi_test)] mod ffi_test; #[cfg(ffi_test)] @@ -20,93 +23,65 @@ fn redis_value_to_java<'local>( env: &mut JNIEnv<'local>, val: Value, encoding_utf8: bool, -) -> JObject<'local> { +) -> Result, FFIError> { match val { - Value::Nil => JObject::null(), - Value::SimpleString(str) => JObject::from(env.new_string(str).unwrap()), - Value::Okay => JObject::from(env.new_string("OK").unwrap()), - Value::Int(num) => env - .new_object("java/lang/Long", "(J)V", &[num.into()]) - .unwrap(), + Value::Nil => Ok(JObject::null()), + Value::SimpleString(str) => Ok(JObject::from(env.new_string(str)?)), + Value::Okay => Ok(JObject::from(env.new_string("OK")?)), + Value::Int(num) => Ok(env.new_object("java/lang/Long", "(J)V", &[num.into()])?), Value::BulkString(data) => { if encoding_utf8 { - let Ok(utf8_str) = String::from_utf8(data) else { - let _ = env.throw("Failed to construct UTF-8 string"); - return JObject::null(); - }; - match env.new_string(utf8_str) { - Ok(string) => JObject::from(string), - Err(e) => { - let _ = env.throw(format!( - "Failed to construct Java UTF-8 string from Rust UTF-8 string. {:?}", - e - )); - JObject::null() - } - } + let utf8_str = String::from_utf8(data)?; + Ok(JObject::from(env.new_string(utf8_str)?)) } else { - let Ok(bytearr) = env.byte_array_from_slice(&data) else { - let _ = env.throw("Failed to allocate byte array"); - return JObject::null(); - }; - bytearr.into() + Ok(JObject::from(env.byte_array_from_slice(&data)?)) } } Value::Array(array) => { - let items: JObjectArray = env - .new_object_array(array.len() as i32, "java/lang/Object", JObject::null()) - .unwrap(); + let items: JObjectArray = + env.new_object_array(array.len() as i32, "java/lang/Object", JObject::null())?; for (i, item) in array.into_iter().enumerate() { - let java_value = redis_value_to_java(env, item, encoding_utf8); - env.set_object_array_element(&items, i as i32, java_value) - .unwrap(); + let java_value = redis_value_to_java(env, item, encoding_utf8)?; + env.set_object_array_element(&items, i as i32, java_value)?; } - items.into() + Ok(items.into()) } Value::Map(map) => { - let linked_hash_map = env - .new_object("java/util/LinkedHashMap", "()V", &[]) - .unwrap(); + let linked_hash_map = env.new_object("java/util/LinkedHashMap", "()V", &[])?; for (key, value) in map { - let java_key = redis_value_to_java(env, key, encoding_utf8); - let java_value = redis_value_to_java(env, value, encoding_utf8); + let java_key = redis_value_to_java(env, key, encoding_utf8)?; + let java_value = redis_value_to_java(env, value, encoding_utf8)?; env.call_method( &linked_hash_map, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", &[(&java_key).into(), (&java_value).into()], - ) - .unwrap(); + )?; } - linked_hash_map + Ok(linked_hash_map) } - Value::Double(float) => env - .new_object("java/lang/Double", "(D)V", &[float.into()]) - .unwrap(), - Value::Boolean(bool) => env - .new_object("java/lang/Boolean", "(Z)V", &[bool.into()]) - .unwrap(), - Value::VerbatimString { format: _, text } => JObject::from(env.new_string(text).unwrap()), + Value::Double(float) => Ok(env.new_object("java/lang/Double", "(D)V", &[float.into()])?), + Value::Boolean(bool) => Ok(env.new_object("java/lang/Boolean", "(Z)V", &[bool.into()])?), + Value::VerbatimString { format: _, text } => Ok(JObject::from(env.new_string(text)?)), Value::BigNumber(_num) => todo!(), Value::Set(array) => { - let set = env.new_object("java/util/HashSet", "()V", &[]).unwrap(); + let set = env.new_object("java/util/HashSet", "()V", &[])?; for elem in array { - let java_value = redis_value_to_java(env, elem, encoding_utf8); + let java_value = redis_value_to_java(env, elem, encoding_utf8)?; env.call_method( &set, "add", "(Ljava/lang/Object;)Z", &[(&java_value).into()], - ) - .unwrap(); + )?; } - set + Ok(set) } Value::Attribute { data: _, @@ -122,8 +97,21 @@ pub extern "system" fn Java_glide_ffi_resolvers_RedisValueResolver_valueFromPoin _class: JClass<'local>, pointer: jlong, ) -> JObject<'local> { - let value = unsafe { Box::from_raw(pointer as *mut Value) }; - redis_value_to_java(&mut env, *value, true) + handle_panics( + move || { + fn value_from_pointer<'a>( + env: &mut JNIEnv<'a>, + pointer: jlong, + ) -> Result, FFIError> { + let value = unsafe { Box::from_raw(pointer as *mut Value) }; + redis_value_to_java(env, *value, true) + } + let result = value_from_pointer(&mut env, pointer); + handle_errors(&mut env, result) + }, + "valueFromPointer", + ) + .unwrap_or(JObject::null()) } #[no_mangle] @@ -134,59 +122,58 @@ pub extern "system" fn Java_glide_ffi_resolvers_RedisValueResolver_valueFromPoin _class: JClass<'local>, pointer: jlong, ) -> JObject<'local> { - let value = unsafe { Box::from_raw(pointer as *mut Value) }; - redis_value_to_java(&mut env, *value, false) + handle_panics( + move || { + fn value_from_pointer_binary<'a>( + env: &mut JNIEnv<'a>, + pointer: jlong, + ) -> Result, FFIError> { + let value = unsafe { Box::from_raw(pointer as *mut Value) }; + redis_value_to_java(env, *value, false) + } + let result = value_from_pointer_binary(&mut env, pointer); + handle_errors(&mut env, result) + }, + "valueFromPointerBinary", + ) + .unwrap_or(JObject::null()) } #[no_mangle] pub extern "system" fn Java_glide_ffi_resolvers_SocketListenerResolver_startSocketListener< 'local, >( - env: JNIEnv<'local>, + mut env: JNIEnv<'local>, _class: JClass<'local>, ) -> JObject<'local> { - let (tx, rx) = mpsc::channel::>(); - - start_socket_listener(move |socket_path: Result| { - // Signals that thread has started - let _ = tx.send(socket_path); - }); - - // Wait until the thread has started - let socket_path = rx.recv(); - - match socket_path { - Ok(Ok(path)) => env.new_string(path).unwrap().into(), - Ok(Err(error_message)) => { - throw_java_exception(env, error_message); - JObject::null() - } - Err(error) => { - throw_java_exception(env, error.to_string()); - JObject::null() - } - } -} - -fn throw_java_exception(mut env: JNIEnv, message: String) { - let res = env.new_object( - "java/lang/Exception", - "(Ljava/lang/String;)V", - &[(&env.new_string(message.clone()).unwrap()).into()], - ); - - match res { - Ok(res) => { - let _ = env.throw(JThrowable::from(res)); - } - Err(err) => { - error!( - "Failed to create exception with string {}: {}", - message, - err.to_string() - ); - } - }; + handle_panics( + move || { + fn start_socket_listener<'a>(env: &mut JNIEnv<'a>) -> Result, FFIError> { + let (tx, rx) = mpsc::channel::>(); + + start_socket_listener_core(move |socket_path: Result| { + // Signals that thread has started + let _ = tx.send(socket_path); + }); + + // Wait until the thread has started + let socket_path = rx.recv(); + + match socket_path { + Ok(Ok(path)) => env + .new_string(path) + .map(|p| p.into()) + .map_err(|err| FFIError::Uds(err.to_string())), + Ok(Err(error_message)) => Err(FFIError::Uds(error_message)), + Err(error) => Err(FFIError::Uds(error.to_string())), + } + } + let result = start_socket_listener(&mut env); + handle_errors(&mut env, result) + }, + "startSocketListener", + ) + .unwrap_or(JObject::null()) } #[no_mangle] @@ -195,9 +182,22 @@ pub extern "system" fn Java_glide_ffi_resolvers_ScriptResolver_storeScript<'loca _class: JClass<'local>, code: JString, ) -> JObject<'local> { - let code_str: String = env.get_string(&code).unwrap().into(); - let hash = glide_core::scripts_container::add_script(&code_str); - JObject::from(env.new_string(hash).unwrap()) + handle_panics( + move || { + fn store_script<'a>( + env: &mut JNIEnv<'a>, + code: JString, + ) -> Result, FFIError> { + let code_str: String = env.get_string(&code)?.into(); + let hash = glide_core::scripts_container::add_script(&code_str); + Ok(JObject::from(env.new_string(hash)?)) + } + let result = store_script(&mut env, code); + handle_errors(&mut env, result) + }, + "storeScript", + ) + .unwrap_or(JObject::null()) } #[no_mangle] @@ -206,6 +206,17 @@ pub extern "system" fn Java_glide_ffi_resolvers_ScriptResolver_dropScript<'local _class: JClass<'local>, hash: JString, ) { - let hash_str: String = env.get_string(&hash).unwrap().into(); - glide_core::scripts_container::remove_script(&hash_str); + handle_panics( + move || { + fn drop_script(env: &mut JNIEnv<'_>, hash: JString) -> Result<(), FFIError> { + let hash_str: String = env.get_string(&hash)?.into(); + glide_core::scripts_container::remove_script(&hash_str); + Ok(()) + } + let result = drop_script(&mut env, hash); + handle_errors(&mut env, result) + }, + "dropScript", + ) + .unwrap_or(()) }