Skip to content

Commit

Permalink
check whether freeze is due to booster shared lock being held by the …
Browse files Browse the repository at this point in the history
…SingleRowPredictor
  • Loading branch information
Ten0 committed Aug 7, 2023
1 parent feaf3dc commit f41755b
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct SingleRowPredictor {
int predict_type,
Boosting *boosting,
int start_iter,
int num_iter) : booster_shared_lock(booster_mutex), config(Config::Str2Map(parameters)), data_type(data_type), num_cols(num_cols), single_row_predictor_inner(predict_type, boosting, config, start_iter, num_iter) {
int num_iter) : config(Config::Str2Map(parameters)), data_type(data_type), num_cols(num_cols), single_row_predictor_inner(predict_type, boosting, config, start_iter, num_iter), booster_mutex(booster_mutex) {
if (!config.predict_disable_shape_check && num_cols != boosting->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", num_cols, boosting->MaxFeatureIdx() + 1);
Expand All @@ -131,17 +131,14 @@ struct SingleRowPredictor {
void Predict(std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
double* out_result, int64_t* out_len) const {
UNIQUE_LOCK(single_row_predictor_mutex)
yamc::shared_lock<yamc::alternate::shared_mutex> booster_shared_lock(booster_mutex);

auto one_row = get_row_fun(0);
single_row_predictor_inner.predict_function(one_row, out_result);

*out_len = single_row_predictor_inner.num_pred_in_one_row;
}

private:
// Prevent the booster from being modified while we have a predictor relying on it
yamc::shared_lock<yamc::alternate::shared_mutex> booster_shared_lock;

public:
Config config;
const int data_type;
Expand All @@ -150,6 +147,9 @@ struct SingleRowPredictor {
private:
SingleRowPredictorInner single_row_predictor_inner;

// Prevent the booster from being modified while we have a predictor relying on it during prediction
yamc::alternate::shared_mutex *booster_mutex;

// If several threads try to predict at the same time using the same SingleRowPredictor
// we want them to still provide correct values, so the mutex is necessary due to the shared
// resources in the predictor.
Expand Down

0 comments on commit f41755b

Please sign in to comment.