diff --git a/include/LightGBM/dataset_loader.h b/include/LightGBM/dataset_loader.h index 27bea113b052..e72dd4910804 100644 --- a/include/LightGBM/dataset_loader.h +++ b/include/LightGBM/dataset_loader.h @@ -67,7 +67,7 @@ class DatasetLoader { /*! \brief Random generator*/ Random random_; /*! \brief prediction function for initial model */ - const PredictFunction predict_fun_; + const PredictFunction& predict_fun_; /*! \brief number of classes */ int num_class_; /*! \brief index of label column */ diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index 51ee7e4e540a..2d2c4d622b1c 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -1143,7 +1143,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat double tmp_label = 0.0f; auto& ref_text_data = *text_data; std::vector feature_row(dataset->num_features_); - if (!predict_fun_) { + if (predict_fun_ == nullptr) { OMP_INIT_EX(); // if doesn't need to prediction with initial model #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row) @@ -1262,7 +1262,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser, const std::vector& used_data_indices, Dataset* dataset) { std::vector init_score; - if (predict_fun_) { + if (predict_fun_ != nullptr) { init_score = std::vector(dataset->num_data_ * num_class_); } std::function&)> process_fun =