diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f42fad8..bbee08aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -439,7 +439,7 @@ set(DUCKDB_SRC_FILES src/duckdb/extension/json/json_serializer.cpp src/duckdb/ub_extension_json_json_functions.cpp) -set(JEMACLLOC_SRC_FILES +set(JEMALLOC_SRC_FILES src/duckdb/extension/jemalloc/jemalloc_extension.cpp src/duckdb/extension/jemalloc/jemalloc/src/jemalloc.c src/duckdb/extension/jemalloc/jemalloc/src/arena.c @@ -553,13 +553,14 @@ add_jar(duckdb_jdbc_tests ${JAVA_TEST_FILES} INCLUDE_JARS duckdb_jdbc) if(MSVC) list(APPEND DUCKDB_SRC_FILES duckdb_java.def) else() - list(APPEND DUCKDB_SRC_FILES ${JEMACLLOC_SRC_FILES}) + list(APPEND DUCKDB_SRC_FILES ${JEMALLOC_SRC_FILES}) endif() add_library(duckdb_java SHARED src/jni/config.cpp src/jni/duckdb_java.cpp src/jni/functions.cpp + src/jni/holders.cpp src/jni/refs.cpp src/jni/types.cpp src/jni/util.cpp diff --git a/CMakeLists.txt.in b/CMakeLists.txt.in index 0b45d9f4..992658a5 100644 --- a/CMakeLists.txt.in +++ b/CMakeLists.txt.in @@ -46,7 +46,7 @@ set(DUCKDB_DEFINITIONS set(DUCKDB_SRC_FILES ${SOURCES}) -set(JEMACLLOC_SRC_FILES +set(JEMALLOC_SRC_FILES ${JEMALLOC_SOURCES}) @@ -95,13 +95,14 @@ add_jar(duckdb_jdbc_tests ${JAVA_TEST_FILES} INCLUDE_JARS duckdb_jdbc) if(MSVC) list(APPEND DUCKDB_SRC_FILES duckdb_java.def) else() - list(APPEND DUCKDB_SRC_FILES ${JEMACLLOC_SRC_FILES}) + list(APPEND DUCKDB_SRC_FILES ${JEMALLOC_SRC_FILES}) endif() add_library(duckdb_java SHARED src/jni/config.cpp src/jni/duckdb_java.cpp src/jni/functions.cpp + src/jni/holders.cpp src/jni/refs.cpp src/jni/types.cpp src/jni/util.cpp diff --git a/duckdb_java.def b/duckdb_java.def index 29ccd472..450a2cb2 100644 --- a/duckdb_java.def +++ b/duckdb_java.def @@ -34,6 +34,7 @@ Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1auto_1commit Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1catalog Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1schema Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1interrupt +Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1is_1result_1open Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepare Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepared_1statement_1meta Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1result_1meta diff --git a/duckdb_java.exp b/duckdb_java.exp index 4bf7a9d4..36db0eb1 100644 --- a/duckdb_java.exp +++ b/duckdb_java.exp @@ -31,6 +31,7 @@ _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1auto_1commit _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1catalog _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1schema _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1interrupt +_Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1is_1result_1open _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepare _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepared_1statement_1meta _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1result_1meta diff --git a/duckdb_java.map b/duckdb_java.map index 7991cd98..58484b7d 100644 --- a/duckdb_java.map +++ b/duckdb_java.map @@ -33,6 +33,7 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1catalog; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1schema; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1interrupt; + Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1is_1result_1open; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepare; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepared_1statement_1meta; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1result_1meta; diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index de09536b..d3c55984 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -13,12 +13,14 @@ #include "duckdb/main/extension_util.hpp" #include "duckdb/parser/parsed_data/create_type_info.hpp" #include "functions.hpp" +#include "holders.hpp" #include "refs.hpp" #include "types.hpp" #include "util.hpp" #include #include +#include using namespace duckdb; using namespace std; @@ -37,6 +39,9 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { } try { + ConnectionHolder::init_statics(); + StatementHolder::init_statics(); + ResultHolder::init_statics(); create_refs(env); } catch (const std::exception &e) { if (!env->ExceptionCheck()) { @@ -59,40 +64,6 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { delete_global_refs(env); } -/** - * Associates a duckdb::Connection with a duckdb::DuckDB. The DB may be shared amongst many ConnectionHolders, but the - * Connection is unique to this holder. Every Java DuckDBConnection has exactly 1 of these holders, and they are never - * shared. The holder is freed when the DuckDBConnection is closed. When the last holder sharing a DuckDB is freed, the - * DuckDB is released as well. - */ -struct ConnectionHolder { - const duckdb::shared_ptr db; - const duckdb::unique_ptr connection; - - ConnectionHolder(duckdb::shared_ptr _db) - : db(_db), connection(make_uniq(*_db)) { - } -}; - -/** - * Throws a SQLException and returns nullptr if a valid Connection can't be retrieved from the buffer. - */ -static Connection *get_connection(JNIEnv *env, jobject conn_ref_buf) { - if (!conn_ref_buf) { - throw ConnectionException("Invalid connection"); - } - auto conn_holder = (ConnectionHolder *)env->GetDirectBufferAddress(conn_ref_buf); - if (!conn_holder) { - throw ConnectionException("Invalid connection"); - } - auto conn_ref = conn_holder->connection.get(); - if (!conn_ref || !conn_ref->context) { - throw ConnectionException("Invalid connection"); - } - - return conn_ref; -} - //! The database instance cache, used so that multiple connections to the same file point to the same database object duckdb::DBInstanceCache instance_cache; @@ -100,40 +71,45 @@ jobject _duckdb_jdbc_startup(JNIEnv *env, jclass, jbyteArray database_j, jboolea auto database = byte_array_to_string(env, database_j); std::unique_ptr config = create_db_config(env, read_only, props); bool cache_instance = database != ":memory:" && !database.empty(); + auto shared_db = instance_cache.GetOrCreateInstance(database, *config, cache_instance); - auto conn_holder = new ConnectionHolder(shared_db); + auto conn_ref_ptr = make_uniq(shared_db); + ConnectionHolder::track(conn_ref_ptr.get()); - return env->NewDirectByteBuffer(conn_holder, 0); + return env->NewDirectByteBuffer(conn_ref_ptr.release(), 0); } jobject _duckdb_jdbc_connect(JNIEnv *env, jclass, jobject conn_ref_buf) { - auto conn_ref = (ConnectionHolder *)env->GetDirectBufferAddress(conn_ref_buf); + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); + auto config = ClientConfig::GetConfig(*conn_ref->connection->context); - auto conn = new ConnectionHolder(conn_ref->db); - conn->connection->context->config = config; - return env->NewDirectByteBuffer(conn, 0); + auto conn_ref_dup = make_uniq(conn_ref->db, std::move(config)); + ConnectionHolder::track(conn_ref_dup.get()); + + return env->NewDirectByteBuffer(conn_ref_dup.release(), 0); } jstring _duckdb_jdbc_get_schema(JNIEnv *env, jclass, jobject conn_ref_buf) { - auto conn_ref = get_connection(env, conn_ref_buf); - if (!conn_ref) { - return nullptr; - } - - auto entry = ClientData::Get(*conn_ref->context).catalog_search_path->GetDefault(); + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); + auto entry = conn_ref->client_data().catalog_search_path->GetDefault(); return env->NewStringUTF(entry.schema.c_str()); } static void set_catalog_search_path(JNIEnv *env, jobject conn_ref_buf, CatalogSearchEntry search_entry) { - auto conn_ref = get_connection(env, conn_ref_buf); - if (!conn_ref) { - return; - } + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); - conn_ref->context->RunFunctionInTransaction([&]() { - ClientData::Get(*conn_ref->context).catalog_search_path->Set(search_entry, CatalogSetPathType::SET_SCHEMA); - }); + conn_ref->conn().context->RunFunctionInTransaction( + [&]() { conn_ref->client_data().catalog_search_path->Set(search_entry, CatalogSetPathType::SET_SCHEMA); }); } void _duckdb_jdbc_set_schema(JNIEnv *env, jclass, jobject conn_ref_buf, jstring schema) { @@ -145,65 +121,122 @@ void _duckdb_jdbc_set_catalog(JNIEnv *env, jclass, jobject conn_ref_buf, jstring } jstring _duckdb_jdbc_get_catalog(JNIEnv *env, jclass, jobject conn_ref_buf) { - auto conn_ref = get_connection(env, conn_ref_buf); - if (!conn_ref) { - return nullptr; - } + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); - auto entry = ClientData::Get(*conn_ref->context).catalog_search_path->GetDefault(); + auto entry = conn_ref->client_data().catalog_search_path->GetDefault(); if (entry.catalog == INVALID_CATALOG) { - entry.catalog = DatabaseManager::GetDefaultDatabase(*conn_ref->context); + entry.catalog = DatabaseManager::GetDefaultDatabase(*conn_ref->conn().context); } - return env->NewStringUTF(entry.catalog.c_str()); } void _duckdb_jdbc_set_auto_commit(JNIEnv *env, jclass, jobject conn_ref_buf, jboolean auto_commit) { - auto conn_ref = get_connection(env, conn_ref_buf); - if (!conn_ref) { - return; - } - conn_ref->context->RunFunctionInTransaction([&]() { conn_ref->SetAutoCommit(auto_commit); }); + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); + + conn_ref->conn().context->RunFunctionInTransaction([&]() { conn_ref->conn().SetAutoCommit(auto_commit); }); } jboolean _duckdb_jdbc_get_auto_commit(JNIEnv *env, jclass, jobject conn_ref_buf) { - auto conn_ref = get_connection(env, conn_ref_buf); - if (!conn_ref) { - return false; - } - return conn_ref->IsAutoCommit(); + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); + + return conn_ref->conn().IsAutoCommit(); } void _duckdb_jdbc_interrupt(JNIEnv *env, jclass, jobject conn_ref_buf) { - auto conn_ref = get_connection(env, conn_ref_buf); - if (!conn_ref) { - return; + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); + + conn_ref->conn().Interrupt(); +} + +// Statement lock must be held when calling this function, +// see details in holders.hpp. +static void close_stmt_tracked_results(StatementHolder *stmt_ref) { + // last created - first deleted + for (auto it = stmt_ref->res_list.rbegin(); it != stmt_ref->res_list.rend(); it++) { + ResultHolder *res_ref = *it; + + auto res_mtx = ResultHolder::mutex_no_throw(res_ref); + if (res_mtx.get() == nullptr) { + continue; + } + std::lock_guard res_guard(*res_mtx); + bool res_was_tracked = ResultHolder::untrack(res_ref); + if (!res_was_tracked) { + continue; + } + delete res_ref; } - conn_ref->Interrupt(); + + stmt_ref->res_list.clear(); + stmt_ref->res_set.clear(); } void _duckdb_jdbc_disconnect(JNIEnv *env, jclass, jobject conn_ref_buf) { - auto conn_ref = (ConnectionHolder *)env->GetDirectBufferAddress(conn_ref_buf); - if (conn_ref) { - delete conn_ref; + if (conn_ref_buf == nullptr) { + return; + } + auto conn_ref = ConnectionHolder::unwrap_ref_buf_no_throw(env, conn_ref_buf); + if (conn_ref == nullptr) { + return; + } + auto mtx = ConnectionHolder::mutex_no_throw(conn_ref); + if (mtx.get() == nullptr) { + return; + } + std::lock_guard guard(*mtx); + bool was_tracked = ConnectionHolder::untrack(conn_ref); + if (!was_tracked) { + return; } -} -struct StatementHolder { - duckdb::unique_ptr stmt; -}; + // cancel active query if any + conn_ref->conn().Interrupt(); + + // clean up statements, last created - first deleted + for (auto it = conn_ref->stmt_list.rbegin(); it != conn_ref->stmt_list.rend(); it++) { + StatementHolder *stmt_ref = *it; + + auto stmt_mtx = StatementHolder::mutex_no_throw(stmt_ref); + if (stmt_mtx.get() == nullptr) { + continue; + } + std::lock_guard stmt_guard(*stmt_mtx); + bool stmt_was_tracked = StatementHolder::untrack(stmt_ref); + if (!stmt_was_tracked) { + continue; + } + close_stmt_tracked_results(stmt_ref); + delete stmt_ref; + } + + conn_ref->stmt_list.clear(); + conn_ref->stmt_set.clear(); + delete conn_ref; +} #include "utf8proc_wrapper.hpp" jobject _duckdb_jdbc_prepare(JNIEnv *env, jclass, jobject conn_ref_buf, jbyteArray query_j) { - auto conn_ref = get_connection(env, conn_ref_buf); - if (!conn_ref) { - return nullptr; - } + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); auto query = byte_array_to_string(env, query_j); - auto statements = conn_ref->ExtractStatements(query.c_str()); + auto statements = conn_ref->conn().ExtractStatements(query.c_str()); if (statements.empty()) { throw InvalidInputException("No statements to execute."); } @@ -211,32 +244,25 @@ jobject _duckdb_jdbc_prepare(JNIEnv *env, jclass, jobject conn_ref_buf, jbyteArr // if there are multiple statements, we directly execute the statements besides the last one // we only return the result of the last statement to the user, unless one of the previous statements fails for (idx_t i = 0; i + 1 < statements.size(); i++) { - auto res = conn_ref->Query(std::move(statements[i])); + auto res = conn_ref->conn().Query(std::move(statements[i])); if (res->HasError()) { res->ThrowError(); } } - auto stmt_ref = new StatementHolder(); - stmt_ref->stmt = conn_ref->Prepare(std::move(statements.back())); - if (stmt_ref->stmt->HasError()) { - string error_msg = string(stmt_ref->stmt->GetError()); - stmt_ref->stmt = nullptr; - - // No success, so it must be deleted - delete stmt_ref; + auto stmt = conn_ref->conn().Prepare(std::move(statements.back())); + if (stmt->HasError()) { + string error_msg = string(stmt->GetError()); ThrowJNI(env, error_msg.c_str()); - // Just return control flow back to JVM, as an Exception is pending anyway return nullptr; } - return env->NewDirectByteBuffer(stmt_ref, 0); -} + auto stmt_ref_ptr = make_uniq(conn_ref, std::move(stmt)); + StatementHolder::track(stmt_ref_ptr.get()); + conn_ref->track_stmt(stmt_ref_ptr.get()); -struct ResultHolder { - duckdb::unique_ptr res; - duckdb::unique_ptr chunk; -}; + return env->NewDirectByteBuffer(stmt_ref_ptr.release(), 0); +} Value ToValue(JNIEnv *env, jobject param, duckdb::shared_ptr context) { param = env->CallStaticObjectMethod(J_Timestamp, J_Timestamp_valueOf, param); @@ -348,12 +374,11 @@ Value ToValue(JNIEnv *env, jobject param, duckdb::shared_ptr cont } jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectArray params) { - auto stmt_ref = (StatementHolder *)env->GetDirectBufferAddress(stmt_ref_buf); - if (!stmt_ref) { - throw InvalidInputException("Invalid statement"); - } + auto stmt_ref = StatementHolder::unwrap_ref_buf(env, stmt_ref_buf); + auto mtx = StatementHolder::mutex(stmt_ref); + std::lock_guard guard(*mtx); + StatementHolder::check_tracked(stmt_ref); - auto res_ref = make_uniq(); duckdb::vector duckdb_params; idx_t param_len = env->GetArrayLength(params); @@ -375,28 +400,63 @@ jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectA bool stream_results = stmt_ref->stmt->context->TryGetCurrentSetting("jdbc_stream_results", result) ? result.GetValue() : false; - res_ref->res = stmt_ref->stmt->Execute(duckdb_params, stream_results); - if (res_ref->res->HasError()) { - string error_msg = string(res_ref->res->GetError()); - res_ref->res = nullptr; + auto res = stmt_ref->stmt->Execute(duckdb_params, stream_results); + if (res->HasError()) { + string error_msg = string(res->GetError()); ThrowJNI(env, error_msg.c_str()); return nullptr; } + auto res_ref = make_uniq(stmt_ref, std::move(res)); + stmt_ref->track_result(res_ref.get()); + ResultHolder::track(res_ref.get()); return env->NewDirectByteBuffer(res_ref.release(), 0); } -void _duckdb_jdbc_release(JNIEnv *env, jclass, jobject stmt_ref_buf) { - auto stmt_ref = (StatementHolder *)env->GetDirectBufferAddress(stmt_ref_buf); - if (stmt_ref) { - delete stmt_ref; - } +static jobject build_meta(JNIEnv *env, size_t column_count, size_t n_param, const duckdb::vector &names, + const duckdb::vector &types, StatementProperties properties); + +jobject _duckdb_jdbc_prepared_statement_meta(JNIEnv *env, jclass, jobject stmt_ref_buf) { + auto stmt_ref = StatementHolder::unwrap_ref_buf(env, stmt_ref_buf); + auto mtx = StatementHolder::mutex(stmt_ref); + std::lock_guard guard(*mtx); + StatementHolder::check_tracked(stmt_ref); + + auto &stmt = stmt_ref->stmt; + auto n_param = stmt->named_param_map.size(); + + return build_meta(env, stmt->ColumnCount(), n_param, stmt->GetNames(), stmt->GetTypes(), + stmt->GetStatementProperties()); } -void _duckdb_jdbc_free_result(JNIEnv *env, jclass, jobject res_ref_buf) { - auto res_ref = (ResultHolder *)env->GetDirectBufferAddress(res_ref_buf); - if (res_ref) { - delete res_ref; +void _duckdb_jdbc_release(JNIEnv *env, jclass, jobject stmt_ref_buf) { + if (stmt_ref_buf == nullptr) { + return; + } + auto stmt_ref = StatementHolder::unwrap_ref_buf_no_throw(env, stmt_ref_buf); + if (stmt_ref == nullptr) { + return; + } + auto mtx = StatementHolder::mutex_no_throw(stmt_ref); + if (mtx.get() == nullptr) { + return; + } + std::lock_guard guard(*mtx); + bool was_tracked = StatementHolder::untrack(stmt_ref); + if (!was_tracked) { + return; } + { + auto conn_mtx = ConnectionHolder::mutex_no_throw(stmt_ref->conn_ref); + if (conn_mtx.get() != nullptr) { + std::lock_guard conn_guard(*conn_mtx); + bool conn_tracked = ConnectionHolder::is_tracked(stmt_ref->conn_ref); + if (conn_tracked) { + stmt_ref->conn_ref->untrack_stmt(stmt_ref); + } + } + } + close_stmt_tracked_results(stmt_ref); + delete stmt_ref; } static jobject build_meta(JNIEnv *env, size_t column_count, size_t n_param, const duckdb::vector &names, @@ -427,43 +487,34 @@ static jobject build_meta(JNIEnv *env, size_t column_count, size_t n_param, cons } jobject _duckdb_jdbc_query_result_meta(JNIEnv *env, jclass, jobject res_ref_buf) { - auto res_ref = (ResultHolder *)env->GetDirectBufferAddress(res_ref_buf); - if (!res_ref || !res_ref->res || res_ref->res->HasError()) { + auto res_ref = ResultHolder::unwrap_ref_buf(env, res_ref_buf); + auto mtx = ResultHolder::mutex(res_ref); + std::lock_guard guard(*mtx); + ResultHolder::check_tracked(res_ref); + if (!res_ref->res || res_ref->res->HasError()) { throw InvalidInputException("Invalid result set"); } - auto &result = res_ref->res; + auto &result = res_ref->res; auto n_param = -1; // no params now - return build_meta(env, result->ColumnCount(), n_param, result->names, result->types, result->properties); } -jobject _duckdb_jdbc_prepared_statement_meta(JNIEnv *env, jclass, jobject stmt_ref_buf) { - - auto stmt_ref = (StatementHolder *)env->GetDirectBufferAddress(stmt_ref_buf); - if (!stmt_ref || !stmt_ref->stmt || stmt_ref->stmt->HasError()) { - throw InvalidInputException("Invalid statement"); - } - - auto &stmt = stmt_ref->stmt; - auto n_param = stmt->named_param_map.size(); - - return build_meta(env, stmt->ColumnCount(), n_param, stmt->GetNames(), stmt->GetTypes(), - stmt->GetStatementProperties()); -} - jobject ProcessVector(JNIEnv *env, Connection *conn_ref, Vector &vec, idx_t row_count); jobjectArray _duckdb_jdbc_fetch(JNIEnv *env, jclass, jobject res_ref_buf, jobject conn_ref_buf) { - auto res_ref = (ResultHolder *)env->GetDirectBufferAddress(res_ref_buf); - if (!res_ref || !res_ref->res || res_ref->res->HasError()) { + auto res_ref = ResultHolder::unwrap_ref_buf(env, res_ref_buf); + auto mtx = ResultHolder::mutex(res_ref); + std::lock_guard guard(*mtx); + ResultHolder::check_tracked(res_ref); + if (!res_ref->res || res_ref->res->HasError()) { throw InvalidInputException("Invalid result set"); } - auto conn_ref = get_connection(env, conn_ref_buf); - if (conn_ref == nullptr) { - return nullptr; - } + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto conn_mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard conn_guard(*conn_mtx); + ConnectionHolder::check_tracked(conn_ref); res_ref->chunk = res_ref->res->Fetch(); if (!res_ref->chunk) { @@ -475,7 +526,7 @@ jobjectArray _duckdb_jdbc_fetch(JNIEnv *env, jclass, jobject res_ref_buf, jobjec for (idx_t col_idx = 0; col_idx < res_ref->chunk->ColumnCount(); col_idx++) { auto &vec = res_ref->chunk->data[col_idx]; - auto jvec = ProcessVector(env, conn_ref, vec, row_count); + auto jvec = ProcessVector(env, &conn_ref->conn(), vec, row_count); env->SetObjectArrayElement(vec_array, col_idx, jvec); } @@ -483,6 +534,47 @@ jobjectArray _duckdb_jdbc_fetch(JNIEnv *env, jclass, jobject res_ref_buf, jobjec return vec_array; } +jboolean _duckdb_jdbc_is_result_open(JNIEnv *env, jclass, jobject res_ref_buf) { + if (res_ref_buf == nullptr) { + return false; + } + auto res_ref = ResultHolder::unwrap_ref_buf_no_throw(env, res_ref_buf); + if (res_ref == nullptr) { + return false; + } + return ResultHolder::is_tracked(res_ref); +} + +void _duckdb_jdbc_free_result(JNIEnv *env, jclass, jobject res_ref_buf) { + if (res_ref_buf == nullptr) { + return; + } + auto res_ref = ResultHolder::unwrap_ref_buf_no_throw(env, res_ref_buf); + if (res_ref == nullptr) { + return; + } + auto mtx = ResultHolder::mutex_no_throw(res_ref); + if (mtx.get() == nullptr) { + return; + } + std::lock_guard guard(*mtx); + bool was_tracked = ResultHolder::untrack(res_ref); + if (!was_tracked) { + return; + } + { + auto stmt_mtx = StatementHolder::mutex_no_throw(res_ref->stmt_ref); + if (stmt_mtx.get() != nullptr) { + std::lock_guard stmt_guard(*stmt_mtx); + bool stmt_tracked = StatementHolder::is_tracked(res_ref->stmt_ref); + if (stmt_tracked) { + res_ref->stmt_ref->untrack_result(res_ref); + } + } + } + delete res_ref; +} + jobject ProcessVector(JNIEnv *env, Connection *conn_ref, Vector &vec, idx_t row_count) { auto type_str = env->NewStringUTF(type_to_jduckdb_type(vec.GetType()).c_str()); // construct nullmask @@ -704,14 +796,14 @@ jint _duckdb_jdbc_fetch_size(JNIEnv *, jclass) { jobject _duckdb_jdbc_create_appender(JNIEnv *env, jclass, jobject conn_ref_buf, jbyteArray schema_name_j, jbyteArray table_name_j) { + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); - auto conn_ref = get_connection(env, conn_ref_buf); - if (!conn_ref) { - return nullptr; - } auto schema_name = byte_array_to_string(env, schema_name_j); auto table_name = byte_array_to_string(env, table_name_j); - auto appender = new Appender(*conn_ref, schema_name, table_name); + auto appender = new Appender(conn_ref->conn(), schema_name, table_name); return env->NewDirectByteBuffer(appender, 0); } @@ -850,11 +942,11 @@ class JavaArrowTabularStreamFactory { void _duckdb_jdbc_arrow_register(JNIEnv *env, jclass, jobject conn_ref_buf, jlong arrow_array_stream_pointer, jbyteArray name_j) { + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); - auto conn = get_connection(env, conn_ref_buf); - if (conn == nullptr) { - return; - } auto name = byte_array_to_string(env, name_j); auto arrow_array_stream = (ArrowArrayStream *)(uintptr_t)arrow_array_stream_pointer; @@ -864,17 +956,16 @@ void _duckdb_jdbc_arrow_register(JNIEnv *env, jclass, jobject conn_ref_buf, jlon parameters.push_back(Value::POINTER((uintptr_t)factory)); parameters.push_back(Value::POINTER((uintptr_t)JavaArrowTabularStreamFactory::Produce)); parameters.push_back(Value::POINTER((uintptr_t)JavaArrowTabularStreamFactory::GetSchema)); - conn->TableFunction("arrow_scan_dumb", parameters)->CreateView(name, true, true); + conn_ref->conn().TableFunction("arrow_scan_dumb", parameters)->CreateView(name, true, true); } -void _duckdb_jdbc_create_extension_type(JNIEnv *env, jclass, jobject conn_buf) { - - auto connection = get_connection(env, conn_buf); - if (!connection) { - return; - } +void _duckdb_jdbc_create_extension_type(JNIEnv *env, jclass, jobject conn_ref_buf) { + auto conn_ref = ConnectionHolder::unwrap_ref_buf(env, conn_ref_buf); + auto mtx = ConnectionHolder::mutex(conn_ref); + std::lock_guard guard(*mtx); + ConnectionHolder::check_tracked(conn_ref); - auto &db_instance = DatabaseInstance::GetDatabase(*connection->context); + auto &db_instance = DatabaseInstance::GetDatabase(*conn_ref->conn().context); child_list_t children = {{"hello", LogicalType::VARCHAR}, {"world", LogicalType::VARCHAR}}; auto hello_world_type = LogicalType::STRUCT(children); hello_world_type.SetAlias("test_type"); diff --git a/src/jni/functions.cpp b/src/jni/functions.cpp index e4f0826b..2d54c07e 100644 --- a/src/jni/functions.cpp +++ b/src/jni/functions.cpp @@ -174,6 +174,17 @@ JNIEXPORT jobjectArray JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch( } } +JNIEXPORT jboolean JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1is_1result_1open(JNIEnv * env, jclass param0, jobject param1) { + try { + return _duckdb_jdbc_is_result_open(env, param0, param1); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + + return false; + } +} + JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch_1size(JNIEnv * env, jclass param0) { try { return _duckdb_jdbc_fetch_size(env, param0); diff --git a/src/jni/functions.hpp b/src/jni/functions.hpp index 1267b3f7..1637a298 100644 --- a/src/jni/functions.hpp +++ b/src/jni/functions.hpp @@ -73,6 +73,10 @@ jobjectArray _duckdb_jdbc_fetch(JNIEnv * env, jclass param0, jobject param1, job JNIEXPORT jobjectArray JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch(JNIEnv * env, jclass param0, jobject param1, jobject param2); +jboolean _duckdb_jdbc_is_result_open(JNIEnv * env, jclass param0, jobject param1); + +JNIEXPORT jboolean JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1is_1result_1open(JNIEnv * env, jclass param0, jobject param1); + jint _duckdb_jdbc_fetch_size(JNIEnv * env, jclass param0); JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch_1size(JNIEnv * env, jclass param0); diff --git a/src/jni/holders.cpp b/src/jni/holders.cpp new file mode 100644 index 00000000..c465da9d --- /dev/null +++ b/src/jni/holders.cpp @@ -0,0 +1,393 @@ +#include "holders.hpp" + +#include "duckdb/main/client_data.hpp" + +#include +#include + +// All following statics are initialized from JNI_OnLoad + +static std::shared_ptr conn_map_mutex() { + static auto mutex = std::make_shared(); + return mutex; +} + +static std::shared_ptr>> conn_map() { + static auto map = std::make_shared>>(); + return map; +} + +static std::shared_ptr stmt_map_mutex() { + static auto mutex = std::make_shared(); + return mutex; +} + +static std::shared_ptr>> stmt_map() { + static auto map = std::make_shared>>(); + return map; +} + +static std::shared_ptr res_map_mutex() { + static auto mutex = std::make_shared(); + return mutex; +} + +static std::shared_ptr>> res_map() { + static auto map = std::make_shared>>(); + return map; +} + +// Connection + +ConnectionHolder::ConnectionHolder(duckdb::shared_ptr db_in) : db(std::move(db_in)) { + auto conn_ptr = duckdb::make_uniq(*this->db); + this->connection = std::move(conn_ptr); +} + +ConnectionHolder::ConnectionHolder(duckdb::shared_ptr db_in, duckdb::ClientConfig config) + : ConnectionHolder(std::move(db_in)) { + this->connection->context->config = config; +} + +duckdb::Connection &ConnectionHolder::conn() { + return *this->connection; +} + +duckdb::ClientData &ConnectionHolder::client_data() { + return duckdb::ClientData::Get(*this->conn().context); +} + +void ConnectionHolder::track_stmt(StatementHolder *stmt_ref) { + auto it = this->stmt_set.emplace(stmt_ref); + if (!it.second) { + throw std::runtime_error("Statement is already registered with the connection"); + } + this->stmt_list.emplace_back(stmt_ref); +} + +void ConnectionHolder::untrack_stmt(StatementHolder *stmt_ref) { + auto num_removed = this->stmt_set.erase(stmt_ref); + if (1 != num_removed) { + return; + } + stmt_list.remove(stmt_ref); +} + +void ConnectionHolder::init_statics() { + conn_map_mutex(); + conn_map(); +} + +void ConnectionHolder::track(ConnectionHolder *conn_ref) { + if (conn_ref == nullptr) { + throw std::runtime_error("Invalid connection ref"); + } + auto mtx = conn_map_mutex(); + std::lock_guard guard(*mtx); + auto map = conn_map(); + auto pair = map->emplace(conn_ref, std::make_shared()); + if (!pair.second) { + throw std::runtime_error("Connection is already registered"); + } +} + +void ConnectionHolder::check_tracked(ConnectionHolder *conn_ref) { + bool tracked = ConnectionHolder::is_tracked(conn_ref); + if (!tracked) { + throw std::runtime_error("Connection is closed"); + } +} + +bool ConnectionHolder::is_tracked(ConnectionHolder *conn_ref) { + auto mtx = conn_map_mutex(); + std::lock_guard guard(*mtx); + auto map = conn_map(); + auto count = map->count(conn_ref); + return count == 1; +} + +bool ConnectionHolder::untrack(ConnectionHolder *conn_ref) { + if (conn_ref == nullptr) { + return false; + } + auto mtx = conn_map_mutex(); + std::lock_guard guard(*mtx); + auto map = conn_map(); + auto num_removed = map->erase(conn_ref); + return num_removed == 1; +} + +static std::shared_ptr lookup_conn_mutex(ConnectionHolder *conn_ref, bool throw_on_not_found) { + if (conn_ref == nullptr) { + if (throw_on_not_found) { + throw std::runtime_error("Invalid connection ref"); + } else { + return std::shared_ptr(); + } + } + auto mtx = conn_map_mutex(); + std::lock_guard guard(*mtx); + auto map = conn_map(); + auto it = map->find(conn_ref); + if (it == map->end()) { + if (throw_on_not_found) { + throw std::runtime_error("Connection is closed"); + } else { + return std::shared_ptr(); + } + } + return it->second; +} + +std::shared_ptr ConnectionHolder::mutex(ConnectionHolder *conn_ref) { + return lookup_conn_mutex(conn_ref, true); +} + +std::shared_ptr ConnectionHolder::mutex_no_throw(ConnectionHolder *conn_ref) { + return lookup_conn_mutex(conn_ref, false); +} + +static ConnectionHolder *get_conn_from_ref_buf(JNIEnv *env, jobject conn_ref_buf, bool throw_on_not_found) { + if (conn_ref_buf == nullptr) { + if (throw_on_not_found) { + throw std::runtime_error("Invalid connection ref buffer"); + } else { + return nullptr; + } + } + auto conn_holder = reinterpret_cast(env->GetDirectBufferAddress(conn_ref_buf)); + if (conn_holder == nullptr && throw_on_not_found) { + throw std::runtime_error("Invalid connection"); + } + return conn_holder; +} + +ConnectionHolder *ConnectionHolder::unwrap_ref_buf(JNIEnv *env, jobject conn_ref_buf) { + return get_conn_from_ref_buf(env, conn_ref_buf, true); +} + +ConnectionHolder *ConnectionHolder::unwrap_ref_buf_no_throw(JNIEnv *env, jobject conn_ref_buf) { + return get_conn_from_ref_buf(env, conn_ref_buf, false); +} + +// Statement + +StatementHolder::StatementHolder(ConnectionHolder *conn_ref_in, duckdb::unique_ptr stmt_in) + : conn_ref(conn_ref_in), stmt(std::move(stmt_in)) { +} + +void StatementHolder::track_result(ResultHolder *res_ref) { + auto it = this->res_set.emplace(res_ref); + if (!it.second) { + throw std::runtime_error("Result is already registered with the connection"); + } + this->res_list.emplace_back(res_ref); +} + +void StatementHolder::untrack_result(ResultHolder *res_ref) { + auto num_removed = this->res_set.erase(res_ref); + if (1 != num_removed) { + return; + } + res_list.remove(res_ref); +} + +void StatementHolder::init_statics() { + stmt_map_mutex(); + stmt_map(); +} + +void StatementHolder::track(StatementHolder *stmt_ref) { + if (stmt_ref == nullptr) { + throw std::runtime_error("Invalid statement ref"); + } + auto mtx = stmt_map_mutex(); + std::lock_guard guard(*mtx); + auto map = stmt_map(); + auto pair = map->emplace(stmt_ref, std::make_shared()); + if (!pair.second) { + throw std::runtime_error("Statement is already registered"); + } +} + +void StatementHolder::check_tracked(StatementHolder *stmt_ref) { + bool tracked = StatementHolder::is_tracked(stmt_ref); + if (!tracked) { + throw std::runtime_error("Statement is closed"); + } +} + +bool StatementHolder::is_tracked(StatementHolder *stmt_ref) { + auto mtx = stmt_map_mutex(); + std::lock_guard guard(*mtx); + auto map = stmt_map(); + auto count = map->count(stmt_ref); + return count == 1; +} + +bool StatementHolder::untrack(StatementHolder *stmt_ref) { + if (stmt_ref == nullptr) { + return false; + } + auto mtx = stmt_map_mutex(); + std::lock_guard guard(*mtx); + auto map = stmt_map(); + auto num_removed = map->erase(stmt_ref); + return num_removed == 1; +} + +static std::shared_ptr lookup_stmt_mutex(StatementHolder *stmt_ref, bool throw_on_not_found) { + if (stmt_ref == nullptr) { + if (throw_on_not_found) { + throw std::runtime_error("Invalid statement ref"); + } else { + return std::shared_ptr(); + } + } + auto mtx = stmt_map_mutex(); + std::lock_guard guard(*mtx); + auto map = stmt_map(); + auto it = map->find(stmt_ref); + if (it == map->end()) { + if (throw_on_not_found) { + throw std::runtime_error("Statement is closed"); + } else { + return std::shared_ptr(); + } + } + return it->second; +} + +std::shared_ptr StatementHolder::mutex(StatementHolder *stmt_ref) { + return lookup_stmt_mutex(stmt_ref, true); +} + +std::shared_ptr StatementHolder::mutex_no_throw(StatementHolder *stmt_ref) { + return lookup_stmt_mutex(stmt_ref, false); +} + +static StatementHolder *get_stmt_from_ref_buf(JNIEnv *env, jobject stmt_ref_buf, bool throw_on_not_found) { + if (stmt_ref_buf == nullptr) { + if (throw_on_not_found) { + throw std::runtime_error("Invalid statement ref buffer"); + } else { + return nullptr; + } + } + auto stmt_ref = reinterpret_cast(env->GetDirectBufferAddress(stmt_ref_buf)); + if (stmt_ref == nullptr && throw_on_not_found) { + throw std::runtime_error("Invalid statement"); + } + return stmt_ref; +} + +StatementHolder *StatementHolder::unwrap_ref_buf(JNIEnv *env, jobject stmt_ref_buf) { + return get_stmt_from_ref_buf(env, stmt_ref_buf, true); +} + +StatementHolder *StatementHolder::unwrap_ref_buf_no_throw(JNIEnv *env, jobject stmt_ref_buf) { + return get_stmt_from_ref_buf(env, stmt_ref_buf, false); +} + +// Result + +ResultHolder::ResultHolder(StatementHolder *stmt_ref_in, duckdb::unique_ptr res_in) + : stmt_ref(stmt_ref_in), res(std::move(res_in)) { +} + +void ResultHolder::init_statics() { + res_map_mutex(); + res_map(); +} + +void ResultHolder::track(ResultHolder *res_ref) { + if (res_ref == nullptr) { + throw std::runtime_error("Invalid result ref"); + } + auto mtx = res_map_mutex(); + std::lock_guard guard(*mtx); + auto map = res_map(); + auto pair = map->emplace(res_ref, std::make_shared()); + if (!pair.second) { + throw std::runtime_error("Result is already registered"); + } +} + +void ResultHolder::check_tracked(ResultHolder *res_ref) { + bool tracked = ResultHolder::is_tracked(res_ref); + if (!tracked) { + throw std::runtime_error("Result is closed"); + } +} + +bool ResultHolder::is_tracked(ResultHolder *res_ref) { + auto mtx = res_map_mutex(); + std::lock_guard guard(*mtx); + auto map = res_map(); + auto count = map->count(res_ref); + return count == 1; +} + +bool ResultHolder::untrack(ResultHolder *res_ref) { + if (res_ref == nullptr) { + return false; + } + auto mtx = res_map_mutex(); + std::lock_guard guard(*mtx); + auto map = res_map(); + auto num_removed = map->erase(res_ref); + return num_removed == 1; +} + +std::shared_ptr lookup_res_mutex(ResultHolder *res_ref, bool throw_on_not_found) { + if (res_ref == nullptr) { + if (throw_on_not_found) { + throw std::runtime_error("Invalid result ref"); + } else { + return std::shared_ptr(); + } + } + auto mtx = res_map_mutex(); + std::lock_guard guard(*mtx); + auto map = res_map(); + auto it = map->find(res_ref); + if (it == map->end()) { + if (throw_on_not_found) { + throw std::runtime_error("Result is closed"); + } else { + return std::shared_ptr(); + } + } + return it->second; +} + +std::shared_ptr ResultHolder::mutex(ResultHolder *res_ref) { + return lookup_res_mutex(res_ref, true); +} + +std::shared_ptr ResultHolder::mutex_no_throw(ResultHolder *res_ref) { + return lookup_res_mutex(res_ref, false); +} + +static ResultHolder *get_res_from_ref_buf(JNIEnv *env, jobject res_ref_buf, bool throw_on_not_found) { + if (res_ref_buf == nullptr) { + if (throw_on_not_found) { + throw std::runtime_error("Invalid result set ref buffer"); + } else { + return nullptr; + } + } + auto res_ref = reinterpret_cast(env->GetDirectBufferAddress(res_ref_buf)); + if (res_ref == nullptr && throw_on_not_found) { + throw std::runtime_error("Invalid result set"); + } + return res_ref; +} + +ResultHolder *ResultHolder::unwrap_ref_buf(JNIEnv *env, jobject res_ref_buf) { + return get_res_from_ref_buf(env, res_ref_buf, true); +} + +ResultHolder *ResultHolder::unwrap_ref_buf_no_throw(JNIEnv *env, jobject res_ref_buf) { + return get_res_from_ref_buf(env, res_ref_buf, false); +} diff --git a/src/jni/holders.hpp b/src/jni/holders.hpp new file mode 100644 index 00000000..a0076bc3 --- /dev/null +++ b/src/jni/holders.hpp @@ -0,0 +1,157 @@ +#pragma once + +#include "duckdb.hpp" + +#include +#include +#include +#include +#include + +/** + * This header contains holders for Connection, Statement and Result. + * Instances of these objects are shared with Java part as bare poiters + * inside empty ByteBuffers. When instance is closed - the underlying + * object is destroyed. Concurrent closure can happen any time, thus + * the pointer coming from Java cannot be dereferenced without holding + * the same lock that is held when deleting this object. + * + * Locks registry is a global synchonized unordered_map for every type + * of the object, it is managed using `track/untrack/check_tracked` + * static methods. Mutex for the corresponding object is created when + * the object is registered and deleted when the object is deleted. + * Object, shared with Java, is considered alive (and thus a pointer + * to it can be dereferenced) only while it has a mutex in the + * registry. 'init_statics' methods are used to initialize the + * registries. They are called from JNI_OnLoad to not rely on + * thread-safety of a static var initialization, that, while required + * by the standard, is not thread-safe in older versions of MSVC and + * can be disabled manually in GCC. + * + * To dereferenced a pointer coming from Java it is necessary first to + * lock the instance taking a mutex from registry and then, while + * holding the lock, re-check that the instance is still tracked. + * These steps are required in the beginning of every JNI call, they + * are the same for all object types, example for a Statement: + * + * > jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, ...) { + * > auto stmt_ref = StatementHolder::unwrap_ref_buf(env, stmt_ref_buf); + * > auto mtx = StatementHolder::mutex(stmt_ref); + * > std::lock_guard guard(*mtx); + * > StatementHolder::check_tracked(stmt_ref); + * > ... + * + * Note, the 'mtx' above is a shared_ptr and must be kept in (at least) + * the same scope as the lock_guard created on it. + * + * Connection maintains an ordered set of Statements opened on it, and + * Statement maintains an ordered set of Results opened on it (set here + * is used for consistency, only one active result is expected). When + * the parent object is closed - these child objects are closed too. + * + * Close methods should not throw, so they use *_no_throw versions of + * the steps above. + */ + +struct StatementHolder; +struct ResultHolder; + +/** + * Associates a duckdb::Connection with a duckdb::DuckDB. The DB may be shared amongst many ConnectionHolders, but the + * Connection is unique to this holder. Every Java DuckDBConnection has exactly 1 of these holders, and they are never + * shared. The holder is freed when the DuckDBConnection is closed. When the last holder sharing a DuckDB is freed, the + * DuckDB is released as well. + */ +struct ConnectionHolder { + duckdb::shared_ptr db; + duckdb::unique_ptr connection; + std::unordered_set stmt_set; + std::list stmt_list; + + explicit ConnectionHolder(duckdb::shared_ptr db_in); + + explicit ConnectionHolder(duckdb::shared_ptr db_in, duckdb::ClientConfig config); + + duckdb::Connection &conn(); + + duckdb::ClientData &client_data(); + + void track_stmt(StatementHolder *stmt_ref); + + void untrack_stmt(StatementHolder *stmt_ref); + + static void init_statics(); + + static void track(ConnectionHolder *conn_ref); + + static void check_tracked(ConnectionHolder *conn_ref); + + static bool is_tracked(ConnectionHolder *conn_ref); + + static bool untrack(ConnectionHolder *conn_ref); + + static std::shared_ptr mutex(ConnectionHolder *conn_ref); + + static std::shared_ptr mutex_no_throw(ConnectionHolder *conn_ref); + + static ConnectionHolder *unwrap_ref_buf(JNIEnv *env, jobject conn_ref_buf); + + static ConnectionHolder *unwrap_ref_buf_no_throw(JNIEnv *env, jobject conn_ref_buf); +}; + +struct StatementHolder { + ConnectionHolder *conn_ref; + duckdb::unique_ptr stmt; + std::unordered_set res_set; + std::list res_list; + + explicit StatementHolder(ConnectionHolder *conn_ref_in, duckdb::unique_ptr stmt_in); + + void track_result(ResultHolder *res_ref); + + void untrack_result(ResultHolder *res_ref); + + static void init_statics(); + + static void track(StatementHolder *stmt_ref); + + static void check_tracked(StatementHolder *conn_ref); + + static bool is_tracked(StatementHolder *conn_ref); + + static bool untrack(StatementHolder *stmt_ref); + + static std::shared_ptr mutex(StatementHolder *stmt_ref); + + static std::shared_ptr mutex_no_throw(StatementHolder *stmt_ref); + + static StatementHolder *unwrap_ref_buf(JNIEnv *env, jobject stmt_ref_buf); + + static StatementHolder *unwrap_ref_buf_no_throw(JNIEnv *env, jobject stmt_ref_buf); +}; + +struct ResultHolder { + StatementHolder *stmt_ref; + duckdb::unique_ptr res; + duckdb::unique_ptr chunk; + + ResultHolder(StatementHolder *stmt_ref_in, duckdb::unique_ptr res_in); + + static void init_statics(); + + static void track(ResultHolder *res_ref); + + static void check_tracked(ResultHolder *conn_ref); + + static bool is_tracked(ResultHolder *conn_ref); + + static bool untrack(ResultHolder *res_ref); + + static std::shared_ptr mutex(ResultHolder *res_ref); + + static std::shared_ptr mutex_no_throw(ResultHolder *res_ref); + + static ResultHolder *unwrap_ref_buf(JNIEnv *env, jobject res_ref_buf); + + static ResultHolder *unwrap_ref_buf_no_throw(JNIEnv *env, jobject res_ref_buf); +}; diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index 798c7292..453df5b3 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -32,9 +32,10 @@ public final class DuckDBConnection implements java.sql.Connection { /** Name of the DuckDB default schema. */ public static final String DEFAULT_SCHEMA = "main"; - ByteBuffer conn_ref; - boolean autoCommit = true; - boolean transactionRunning; + volatile ByteBuffer conn_ref; + volatile boolean autoCommit = true; + volatile boolean transactionRunning; + final String url; private final boolean readOnly; @@ -110,7 +111,7 @@ protected void finalize() throws Throwable { close(); } - public synchronized void close() throws SQLException { + public void close() throws SQLException { if (conn_ref != null) { DuckDBNative.duckdb_jdbc_disconnect(conn_ref); conn_ref = null; diff --git a/src/main/java/org/duckdb/DuckDBDatabaseMetaData.java b/src/main/java/org/duckdb/DuckDBDatabaseMetaData.java index ed326931..b62dd205 100644 --- a/src/main/java/org/duckdb/DuckDBDatabaseMetaData.java +++ b/src/main/java/org/duckdb/DuckDBDatabaseMetaData.java @@ -173,74 +173,74 @@ public String getIdentifierQuoteString() throws SQLException { @Override public String getSQLKeywords() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery("SELECT keyword_name FROM duckdb_keywords()"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery("SELECT keyword_name FROM duckdb_keywords()")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override public String getNumericFunctions() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery("SELECT DISTINCT function_name FROM duckdb_functions() " - + "WHERE parameter_types[1] ='DECIMAL'" - + "OR parameter_types[1] ='DOUBLE'" - + "OR parameter_types[1] ='SMALLINT'" - + "OR parameter_types[1] = 'BIGINT'"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery("SELECT DISTINCT function_name FROM duckdb_functions() " + + "WHERE parameter_types[1] ='DECIMAL'" + + "OR parameter_types[1] ='DOUBLE'" + + "OR parameter_types[1] ='SMALLINT'" + + "OR parameter_types[1] = 'BIGINT'")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override public String getStringFunctions() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery( - "SELECT DISTINCT function_name FROM duckdb_functions() WHERE parameter_types[1] = 'VARCHAR'"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery( + "SELECT DISTINCT function_name FROM duckdb_functions() WHERE parameter_types[1] = 'VARCHAR'")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override public String getSystemFunctions() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery( - "SELECT DISTINCT function_name FROM duckdb_functions() WHERE length(parameter_types) = 0"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery( + "SELECT DISTINCT function_name FROM duckdb_functions() WHERE length(parameter_types) = 0")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override public String getTimeDateFunctions() throws SQLException { - Statement statement = conn.createStatement(); - statement.closeOnCompletion(); - ResultSet rs = statement.executeQuery( - "SELECT DISTINCT function_name FROM duckdb_functions() WHERE parameter_types[1] LIKE 'TIME%'"); - StringBuilder sb = new StringBuilder(); - while (rs.next()) { - sb.append(rs.getString(1)); - sb.append(','); + try (Statement statement = conn.createStatement(); + ResultSet rs = statement.executeQuery( + "SELECT DISTINCT function_name FROM duckdb_functions() WHERE parameter_types[1] LIKE 'TIME%'")) { + StringBuilder sb = new StringBuilder(); + while (rs.next()) { + sb.append(rs.getString(1)); + sb.append(','); + } + return sb.toString(); } - return sb.toString(); } @Override diff --git a/src/main/java/org/duckdb/DuckDBNative.java b/src/main/java/org/duckdb/DuckDBNative.java index d03eff58..3886d7d0 100644 --- a/src/main/java/org/duckdb/DuckDBNative.java +++ b/src/main/java/org/duckdb/DuckDBNative.java @@ -110,6 +110,8 @@ protected static native DuckDBResultSetMetaData duckdb_jdbc_prepared_statement_m protected static native DuckDBVector[] duckdb_jdbc_fetch(ByteBuffer res_ref, ByteBuffer conn_ref) throws SQLException; + protected static native boolean duckdb_jdbc_is_result_open(ByteBuffer res_ref) throws SQLException; + protected static native int duckdb_jdbc_fetch_size(); protected static native long duckdb_jdbc_arrow_stream(ByteBuffer res_ref, long batch_size); diff --git a/src/main/java/org/duckdb/DuckDBPreparedStatement.java b/src/main/java/org/duckdb/DuckDBPreparedStatement.java index 1b078f03..476ec5c4 100644 --- a/src/main/java/org/duckdb/DuckDBPreparedStatement.java +++ b/src/main/java/org/duckdb/DuckDBPreparedStatement.java @@ -1,6 +1,5 @@ package org.duckdb; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.Reader; @@ -31,7 +30,6 @@ import java.time.LocalDateTime; import java.time.OffsetDateTime; import java.util.ArrayList; -import java.util.Arrays; import java.util.Calendar; import java.util.List; import java.util.logging.Level; @@ -40,15 +38,15 @@ public class DuckDBPreparedStatement implements PreparedStatement { private static Logger logger = Logger.getLogger(DuckDBPreparedStatement.class.getName()); - private DuckDBConnection conn; + private volatile DuckDBConnection conn; + private volatile ByteBuffer stmt_ref = null; + private volatile DuckDBResultSet select_result = null; + volatile boolean closeOnCompletion = false; - private ByteBuffer stmt_ref = null; - private DuckDBResultSet select_result = null; private int update_result = 0; private boolean returnsChangedRows = false; private boolean returnsNothing = false; private boolean returnsResultSet = false; - boolean closeOnCompletion = false; private Object[] params = new Object[0]; private DuckDBResultSetMetaData meta = null; private final List batchedParams = new ArrayList<>(); @@ -76,15 +74,22 @@ public DuckDBPreparedStatement(DuckDBConnection conn, String sql) throws SQLExce } private void startTransaction() throws SQLException { - if (this.conn.autoCommit || this.conn.transactionRunning) { - return; + if (isClosed()) { + throw new SQLException("Statement was closed"); } + try { + if (this.conn.autoCommit || this.conn.transactionRunning) { + return; + } - this.conn.transactionRunning = true; + this.conn.transactionRunning = true; - // Start transaction via Statement - try (Statement s = conn.createStatement()) { - s.execute("BEGIN TRANSACTION;"); + // Start transaction via Statement + try (Statement s = conn.createStatement()) { + s.execute("BEGIN TRANSACTION;"); + } + } catch (NullPointerException e) { + throw new SQLException(e); } } @@ -96,28 +101,32 @@ private void prepare(String sql) throws SQLException { throw new SQLException("sql query parameter cannot be null"); } - // In case the statement is reused, release old one first - if (stmt_ref != null) { - DuckDBNative.duckdb_jdbc_release(stmt_ref); - stmt_ref = null; - } + try { + // In case the statement is reused, release old one first + if (stmt_ref != null) { + DuckDBNative.duckdb_jdbc_release(stmt_ref); + stmt_ref = null; + } - meta = null; - params = null; + meta = null; + params = null; - if (select_result != null) { - select_result.close(); - } - select_result = null; - update_result = 0; - - try { - stmt_ref = DuckDBNative.duckdb_jdbc_prepare(conn.conn_ref, sql.getBytes(StandardCharsets.UTF_8)); - meta = DuckDBNative.duckdb_jdbc_prepared_statement_meta(stmt_ref); - params = new Object[0]; - } catch (SQLException e) { - // Delete stmt_ref as it might already be allocated - close(); + if (select_result != null) { + select_result.close(); + } + select_result = null; + update_result = 0; + + try { + stmt_ref = DuckDBNative.duckdb_jdbc_prepare(conn.conn_ref, sql.getBytes(StandardCharsets.UTF_8)); + meta = DuckDBNative.duckdb_jdbc_prepared_statement_meta(stmt_ref); + params = new Object[0]; + } catch (SQLException e) { + // Delete stmt_ref as it might already be allocated + close(); + throw new SQLException(e); + } + } catch (NullPointerException e) { throw new SQLException(e); } } @@ -151,7 +160,7 @@ private boolean execute(boolean startTransaction) throws SQLException { returnsResultSet = result_meta.return_type.equals(StatementReturnType.QUERY_RESULT); returnsChangedRows = result_meta.return_type.equals(StatementReturnType.CHANGED_ROWS); returnsNothing = result_meta.return_type.equals(StatementReturnType.NOTHING); - } catch (SQLException e) { + } catch (NullPointerException | SQLException e) { // Delete stmt_ref as it cannot be used anymore and // result_ref as it might be allocated if (select_result != null) { @@ -161,7 +170,10 @@ private boolean execute(boolean startTransaction) throws SQLException { result_ref = null; } close(); - throw e; + if (e instanceof SQLException) { + throw e; + } + throw new SQLException(e); } if (returnsChangedRows) { @@ -349,10 +361,9 @@ public void setQueryTimeout(int seconds) throws SQLException { /** * This function calls the underlying C++ interrupt function which aborts the query running on that connection. - * It is not safe to call this function when the connection is already closed. */ @Override - public synchronized void cancel() throws SQLException { + public void cancel() throws SQLException { if (conn.conn_ref != null) { DuckDBNative.duckdb_jdbc_interrupt(conn.conn_ref); } @@ -559,7 +570,9 @@ public int getResultSetHoldability() throws SQLException { @Override public boolean isClosed() throws SQLException { - return conn == null; + // Cannot check native stmt here because it is created only + // when prepare() is called. + return conn == null || conn.conn_ref == null; } @Override diff --git a/src/main/java/org/duckdb/DuckDBResultSet.java b/src/main/java/org/duckdb/DuckDBResultSet.java index 5b6e0b62..96ada130 100644 --- a/src/main/java/org/duckdb/DuckDBResultSet.java +++ b/src/main/java/org/duckdb/DuckDBResultSet.java @@ -40,7 +40,8 @@ public class DuckDBResultSet implements ResultSet { /** * {@code null} if this result set is closed. */ - private ByteBuffer result_ref; + private volatile ByteBuffer result_ref; + private DuckDBVector[] current_chunk = {}; private int chunk_idx = 0; private boolean finished = false; @@ -49,10 +50,14 @@ public class DuckDBResultSet implements ResultSet { public DuckDBResultSet(DuckDBPreparedStatement stmt, DuckDBResultSetMetaData meta, ByteBuffer result_ref, ByteBuffer conn_ref) throws SQLException { - this.stmt = Objects.requireNonNull(stmt); - this.result_ref = Objects.requireNonNull(result_ref); - this.meta = Objects.requireNonNull(meta); - this.conn_ref = Objects.requireNonNull(conn_ref); + try { + this.stmt = Objects.requireNonNull(stmt); + this.result_ref = Objects.requireNonNull(result_ref); + this.meta = Objects.requireNonNull(meta); + this.conn_ref = Objects.requireNonNull(conn_ref); + } catch (NullPointerException e) { + throw new SQLException(e); + } } public Statement getStatement() throws SQLException { @@ -69,7 +74,7 @@ public ResultSetMetaData getMetaData() throws SQLException { return meta; } - public synchronized boolean next() throws SQLException { + public boolean next() throws SQLException { if (isClosed()) { throw new SQLException("ResultSet was closed"); } @@ -88,7 +93,7 @@ public synchronized boolean next() throws SQLException { return true; } - public synchronized void close() throws SQLException { + public void close() throws SQLException { if (result_ref != null) { DuckDBNative.duckdb_jdbc_free_result(result_ref); // Nullness is used to determine whether we're closed @@ -106,8 +111,15 @@ protected void finalize() throws Throwable { close(); } - public synchronized boolean isClosed() throws SQLException { - return result_ref == null; + public boolean isClosed() throws SQLException { + if (this.result_ref == null) { + return true; + } + boolean open = DuckDBNative.duckdb_jdbc_is_result_open(this.result_ref); + if (!open) { + this.result_ref = null; + } + return !open; } private void check(int columnIndex) throws SQLException { diff --git a/src/test/java/org/duckdb/TestClosure.java b/src/test/java/org/duckdb/TestClosure.java new file mode 100644 index 00000000..1baaf52b --- /dev/null +++ b/src/test/java/org/duckdb/TestClosure.java @@ -0,0 +1,223 @@ +package org.duckdb; + +import static org.duckdb.TestDuckDBJDBC.JDBC_URL; +import static org.duckdb.test.Assertions.*; + +import java.io.File; +import java.sql.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class TestClosure { + + // https://github.com/duckdb/duckdb-java/issues/101 + public static void test_unclosed_statement_does_not_hang() throws Exception { + String dbName = "test_issue_101.db"; + String url = JDBC_URL + dbName; + Connection conn = DriverManager.getConnection(url); + Statement stmt = conn.createStatement(); + stmt.execute("select 42"); + // statement not closed explicitly + conn.close(); + assertTrue(stmt.isClosed()); + Connection connOther = DriverManager.getConnection(url); + connOther.close(); + assertTrue(new File(dbName).delete()); + } + + public static void test_result_set_auto_closed() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + Statement stmt = conn.createStatement(); + ResultSet rs1 = stmt.executeQuery("select 42"); + ResultSet rs2 = stmt.executeQuery("select 43"); + assertTrue(rs1.isClosed()); + stmt.close(); + assertTrue(rs2.isClosed()); + } + } + + public static void test_statements_auto_closed_on_conn_close() throws Exception { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt1 = conn.createStatement(); + stmt1.execute("select 42"); + PreparedStatement stmt2 = conn.prepareStatement("select 43"); + stmt2.execute(); + Statement stmt3 = conn.createStatement(); + stmt3.execute("select 44"); + stmt3.close(); + conn.close(); + assertTrue(stmt1.isClosed()); + assertTrue(stmt2.isClosed()); + } + + public static void test_results_auto_closed_on_conn_close() throws Exception { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("select 42"); + rs.next(); + conn.close(); + assertTrue(rs.isClosed()); + assertTrue(stmt.isClosed()); + } + + public static void test_statement_auto_closed_on_completion() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + Statement stmt = conn.createStatement(); + stmt.closeOnCompletion(); + assertTrue(stmt.isCloseOnCompletion()); + try (ResultSet rs = stmt.executeQuery("select 42")) { + rs.next(); + } + assertTrue(stmt.isClosed()); + } + } + + public static void test_long_query_conn_close() throws Exception { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt = conn.createStatement(); + stmt.execute("DROP TABLE IF EXISTS test_fib1"); + stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); + stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); + long start = System.currentTimeMillis(); + Thread th = new Thread(() -> { + try { + Thread.sleep(1000); + conn.close(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + th.start(); + assertThrows( + () + -> stmt.executeQuery( + "WITH RECURSIVE cte AS (" + + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 150000) " + + "SELECT avg(f) FROM cte"), + SQLException.class); + th.join(); + long elapsed = System.currentTimeMillis() - start; + assertTrue(elapsed < 2000); + assertTrue(stmt.isClosed()); + assertTrue(conn.isClosed()); + } + + public static void test_long_query_stmt_close() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + Statement stmt = conn.createStatement(); + stmt.execute("DROP TABLE IF EXISTS test_fib1"); + stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); + stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); + long start = System.currentTimeMillis(); + Thread th = new Thread(() -> { + try { + Thread.sleep(1000); + stmt.cancel(); + stmt.close(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + th.start(); + assertThrows( + () + -> stmt.executeQuery( + "WITH RECURSIVE cte AS (" + + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 150000) " + + "SELECT avg(f) FROM cte"), + SQLException.class); + th.join(); + long elapsed = System.currentTimeMillis() - start; + assertTrue(elapsed < 2000); + assertTrue(stmt.isClosed()); + assertFalse(conn.isClosed()); + } + } + + public static void test_conn_close_no_crash() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + for (int i = 0; i < 1 << 7; i++) { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt = conn.createStatement(); + Future future = executor.submit(() -> { + try { + conn.close(); + } catch (SQLException e) { + fail(); + } + }); + try { + stmt.executeQuery("select 42"); + } catch (SQLException e) { + } + future.get(); + } + } + + public static void test_stmt_close_no_crash() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + for (int i = 0; i < 1 << 10; i++) { + Statement stmt = conn.createStatement(); + Future future = executor.submit(() -> { + try { + stmt.close(); + } catch (SQLException e) { + fail(); + } + }); + try { + stmt.executeQuery("select 42"); + } catch (SQLException e) { + } + future.get(); + } + } + } + + public static void test_results_close_no_crash() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + for (int i = 0; i < 1 << 12; i++) { + ResultSet rs = stmt.executeQuery("select 42"); + Future future = executor.submit(() -> { + try { + rs.close(); + } catch (SQLException e) { + fail(); + } + }); + try { + rs.next(); + } catch (SQLException e) { + } + future.get(); + } + } + } + + public static void test_results_close_prepared_stmt_no_crash() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try (Connection conn = DriverManager.getConnection(JDBC_URL); + PreparedStatement stmt = conn.prepareStatement("select 42")) { + for (int i = 0; i < 1 << 12; i++) { + ResultSet rs = stmt.executeQuery(); + Future future = executor.submit(() -> { + try { + rs.close(); + } catch (SQLException e) { + fail(); + } + }); + try { + rs.next(); + } catch (SQLException e) { + } + future.get(); + } + } + } +} diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index 7b415200..bff41d5d 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -3085,7 +3085,7 @@ public static void test_get_schema() throws Exception { conn.getSchema(); fail(); } catch (SQLException e) { - assertEquals(e.getMessage(), "Connection Error: Invalid connection"); + assertEquals(e.getMessage(), "Invalid Error: Invalid connection ref buffer"); } } @@ -4801,7 +4801,8 @@ public static void main(String[] args) throws Exception { statusCode = runTests(new String[0], clazz); } else { // extension installation fails on CI, Spatial test is temporary disabled - statusCode = runTests(args, TestDuckDBJDBC.class, TestExtensionTypes.class /*, TestSpatial.class */); + statusCode = runTests(args, TestDuckDBJDBC.class, TestExtensionTypes.class /*, TestSpatial.class */, + TestClosure.class); } System.exit(statusCode); } diff --git a/src/test/java/org/duckdb/test/Assertions.java b/src/test/java/org/duckdb/test/Assertions.java index d72114dc..9a036e70 100644 --- a/src/test/java/org/duckdb/test/Assertions.java +++ b/src/test/java/org/duckdb/test/Assertions.java @@ -71,12 +71,12 @@ public static void assertEquals(double a, double b, double epsilon) throws Excep assertTrue(Math.abs(a - b) < epsilon); } - public static void fail() throws Exception { + public static void fail() { fail(null); } - public static void fail(String s) throws Exception { - throw new Exception(s); + public static void fail(String s) { + throw new RuntimeException(s); } public static String assertThrows(Thrower thrower, Class exception) throws Exception {