diff --git a/include/LightGBM/bin.h b/include/LightGBM/bin.h index ec330bd94f8c..71b60c493504 100644 --- a/include/LightGBM/bin.h +++ b/include/LightGBM/bin.h @@ -261,8 +261,9 @@ class Bin { /*! * \brief Initialize for pushing. By default, no action needed. * \param num_thread The number of external threads that will be calling the push APIs + * \param omp_max_threads The maximum number of OpenMP threads to allocate for */ - virtual void InitStreaming(uint32_t /*num_thread*/) { } + virtual void InitStreaming(uint32_t /*num_thread*/, int32_t /*omp_max_threads*/) { } /*! * \brief Push one record * \param tid Thread id diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 7bde951b65e4..27fb0b620a07 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -153,6 +153,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc * \param has_queries Whether the dataset has Metadata queries/groups * \param nclasses Number of initial score classes * \param nthreads Number of external threads that will use the PushRows APIs + * \param omp_max_threads Maximum number of OpenMP threads (-1 for default) * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset, @@ -160,7 +161,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset, int32_t has_init_scores, int32_t has_queries, int32_t nclasses, - int32_t nthreads); + int32_t nthreads, + int32_t omp_max_threads); /*! * \brief Push data to existing dataset, if ``nrow + start_row == num_total_row``, will call ``dataset->FinishLoad``. diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index c60aaf037c71..74e3e9c1dad4 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -458,10 +458,18 @@ class Dataset { int32_t has_init_scores, int32_t has_queries, int32_t nclasses, - int32_t nthreads) { + int32_t nthreads, + int32_t omp_max_threads) { + // Initialize optional max thread count with either parameter or OMP setting + if (omp_max_threads > 0) { + omp_max_threads_ = omp_max_threads; + } else if (omp_max_threads_ <= 0) { + omp_max_threads_ = OMP_NUM_THREADS(); + } + metadata_.Init(num_data, has_weights, has_init_scores, has_queries, nclasses); for (int i = 0; i < num_groups_; ++i) { - feature_groups_[i]->InitStreaming(nthreads); + feature_groups_[i]->InitStreaming(nthreads, omp_max_threads_); } } @@ -846,6 +854,9 @@ class Dataset { /*! \brief Get whether FinishLoad is automatically called when pushing last row. */ inline bool wait_for_manual_finish() const { return wait_for_manual_finish_; } + /*! \brief Get the maximum number of OpenMP threads to allocate for. */ + inline int omp_max_threads() const { return omp_max_threads_; } + /*! \brief Set whether the Dataset is finished automatically when last row is pushed or with a manual * MarkFinished API call. Set to true for thread-safe streaming and/or if will be coalesced later. * FinishLoad should not be called on any Dataset that will be coalesced. @@ -947,6 +958,7 @@ class Dataset { std::vector feature_need_push_zeros_; std::vector> raw_data_; bool wait_for_manual_finish_; + int omp_max_threads_ = -1; bool has_raw_; /*! map feature (inner index) to its index in the list of numeric (non-categorical) features */ std::vector numeric_feature_map_; diff --git a/include/LightGBM/feature_group.h b/include/LightGBM/feature_group.h index 72d9fcac08dc..0ddfd857bce1 100644 --- a/include/LightGBM/feature_group.h +++ b/include/LightGBM/feature_group.h @@ -192,14 +192,15 @@ class FeatureGroup { /*! * \brief Initialize for pushing in a streaming fashion. By default, no action needed. * \param num_thread The number of external threads that will be calling the push APIs + * \param omp_max_threads The maximum number of OpenMP threads to allocate for */ - void InitStreaming(int32_t num_thread) { + void InitStreaming(int32_t num_thread, int32_t omp_max_threads) { if (is_multi_val_) { for (int i = 0; i < num_feature_; ++i) { - multi_bin_data_[i]->InitStreaming(num_thread); + multi_bin_data_[i]->InitStreaming(num_thread, omp_max_threads); } } else { - bin_data_->InitStreaming(num_thread); + bin_data_->InitStreaming(num_thread, omp_max_threads); } } diff --git a/src/c_api.cpp b/src/c_api.cpp index 20633273134e..004a1f230c74 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -1018,11 +1018,12 @@ int LGBM_DatasetInitStreaming(DatasetHandle dataset, int32_t has_init_scores, int32_t has_queries, int32_t nclasses, - int32_t nthreads) { + int32_t nthreads, + int32_t omp_max_threads) { API_BEGIN(); auto p_dataset = reinterpret_cast(dataset); auto num_data = p_dataset->num_data(); - p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads); + p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads, omp_max_threads); p_dataset->set_wait_for_manual_finish(true); API_END(); } @@ -1073,19 +1074,20 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset, if (!data) { Log::Fatal("data cannot be null."); } - const int num_omp_threads = OMP_NUM_THREADS(); auto p_dataset = reinterpret_cast(dataset); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1); if (p_dataset->has_raw()) { p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow); } + const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS(); + OMP_INIT_EX(); #pragma omp parallel for schedule(static) for (int i = 0; i < nrow; ++i) { OMP_LOOP_EX_BEGIN(); // convert internal thread id to be unique based on external thread id - const int internal_tid = omp_get_thread_num() + (num_omp_threads * tid); + const int internal_tid = omp_get_thread_num() + (max_omp_threads * tid); auto one_row = get_row_fun(i); p_dataset->PushOneRow(internal_tid, start_row + i, one_row); OMP_LOOP_EX_END(); @@ -1154,19 +1156,21 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset, if (!data) { Log::Fatal("data cannot be null."); } - const int num_omp_threads = OMP_NUM_THREADS(); auto p_dataset = reinterpret_cast(dataset); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int32_t nrow = static_cast(nindptr - 1); if (p_dataset->has_raw()) { p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow); } + + const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS(); + OMP_INIT_EX(); #pragma omp parallel for schedule(static) for (int i = 0; i < nrow; ++i) { OMP_LOOP_EX_BEGIN(); // convert internal thread id to be unique based on external thread id - const int internal_tid = omp_get_thread_num() + (num_omp_threads * tid); + const int internal_tid = omp_get_thread_num() + (max_omp_threads * tid); auto one_row = get_row_fun(i); p_dataset->PushOneRow(internal_tid, static_cast(start_row + i), one_row); OMP_LOOP_EX_END(); diff --git a/src/io/sparse_bin.hpp b/src/io/sparse_bin.hpp index 24931fd4eff0..79ebb25d08dd 100644 --- a/src/io/sparse_bin.hpp +++ b/src/io/sparse_bin.hpp @@ -81,10 +81,10 @@ class SparseBin : public Bin { ~SparseBin() {} - void InitStreaming(uint32_t num_thread) override { - // Each thread needs its own push buffer, so allocate external num_thread times the number of OMP threads - int num_omp_threads = OMP_NUM_THREADS(); - push_buffers_.resize(num_omp_threads * num_thread); + void InitStreaming(uint32_t num_thread, int32_t omp_max_threads) override { + // Each external thread needs its own set of OpenMP push buffers, + // so allocate num_thread times the maximum number of OMP threads per external thread + push_buffers_.resize(omp_max_threads * num_thread); }; void ReSize(data_size_t num_data) override { num_data_ = num_data; } diff --git a/tests/cpp_tests/test_stream.cpp b/tests/cpp_tests/test_stream.cpp index e8bcc7a76a22..bc5f73b0a3ee 100644 --- a/tests/cpp_tests/test_stream.cpp +++ b/tests/cpp_tests/test_stream.cpp @@ -79,7 +79,7 @@ void test_stream_dense( &dataset_handle); EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result; - result = LGBM_DatasetInitStreaming(dataset_handle, has_weights, has_init_scores, has_queries, nclasses, 1); + result = LGBM_DatasetInitStreaming(dataset_handle, has_weights, has_init_scores, has_queries, nclasses, 1, -1); EXPECT_EQ(0, result) << "LGBM_DatasetInitStreaming result code: " << result; break; } @@ -197,7 +197,7 @@ void test_stream_sparse( EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result; dataset = static_cast(dataset_handle); - dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2); + dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2, -1); break; }