diff --git a/src/c_api.cpp b/src/c_api.cpp index fa42adab4072..1da7d40cb557 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -121,7 +121,7 @@ struct SingleRowPredictor { int predict_type, Boosting *boosting, int start_iter, - int num_iter) : booster_shared_lock(booster_mutex), config(Config::Str2Map(parameters)), data_type(data_type), num_cols(num_cols), single_row_predictor_inner(predict_type, boosting, config, start_iter, num_iter) { + int num_iter) : config(Config::Str2Map(parameters)), data_type(data_type), num_cols(num_cols), single_row_predictor_inner(predict_type, boosting, config, start_iter, num_iter), booster_mutex(booster_mutex) { if (!config.predict_disable_shape_check && num_cols != boosting->MaxFeatureIdx() + 1) { Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\ "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", num_cols, boosting->MaxFeatureIdx() + 1); @@ -131,6 +131,7 @@ struct SingleRowPredictor { void Predict(std::function>(int row_idx)> get_row_fun, double* out_result, int64_t* out_len) const { UNIQUE_LOCK(single_row_predictor_mutex) + yamc::shared_lock booster_shared_lock(booster_mutex); auto one_row = get_row_fun(0); single_row_predictor_inner.predict_function(one_row, out_result); @@ -138,10 +139,6 @@ struct SingleRowPredictor { *out_len = single_row_predictor_inner.num_pred_in_one_row; } - private: - // Prevent the booster from being modified while we have a predictor relying on it - yamc::shared_lock booster_shared_lock; - public: Config config; const int data_type; @@ -150,6 +147,9 @@ struct SingleRowPredictor { private: SingleRowPredictorInner single_row_predictor_inner; + // Prevent the booster from being modified while we have a predictor relying on it during prediction + yamc::alternate::shared_mutex *booster_mutex; + // If several threads try to predict at the same time using the same SingleRowPredictor // we want them to still provide correct values, so the mutex is necessary due to the shared // resources in the predictor.