diff --git a/docs/Parameters.rst b/docs/Parameters.rst index ce98f3d6296b..b69142d6839d 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -1075,6 +1075,17 @@ Objective Parameters - separate by ``,`` +- ``lambdarank_unbiased`` :raw-html:`🔗︎`, default = ``false``, type = bool + + - used only in ``lambdarank`` application + + - set this to ``true`` to use the position bias correction of `Unbiased LambdaMART `__ + +- ``lambdarank_bias_p_norm`` :raw-html:`🔗︎`, default = ``0.5``, type = double, constraints: ``lambdarank_bias_p_norm >= 0.0`` + + - used only in ``lambdarank`` application where ``lambdarank_unbiased = true`` + + Metric Parameters ----------------- diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index c045970a8f1f..21e0fa185f6c 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -900,6 +900,14 @@ struct Config { // desc = separate by ``,`` std::vector label_gain; + // desc = used only in ``lambdarank`` application + // desc = set this to ``true`` to use the position bias correction of `Unbiased LambdaMART `__ + bool lambdarank_unbiased = false; + + // check = >=0.0 + // desc = used only in ``lambdarank`` application where ``lambdarank_unbiased = true`` + double lambdarank_bias_p_norm = 0.5; + #pragma endregion #pragma region Metric Parameters diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 9f3dd7a188f1..3ee06e9625c2 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -299,6 +299,8 @@ const std::unordered_set& Config::parameter_set() { "lambdarank_truncation_level", "lambdarank_norm", "label_gain", + "lambdarank_unbiased", + "lambdarank_bias_p_norm", "metric", "metric_freq", "is_provide_training_metric", @@ -606,6 +608,11 @@ void Config::GetMembersFromString(const std::unordered_map(tmp_str, ','); } + GetBool(params, "lambdarank_unbiased", &lambdarank_unbiased); + + GetDouble(params, "lambdarank_bias_p_norm", &lambdarank_bias_p_norm); + CHECK_GE(lambdarank_bias_p_norm, 0.0); + GetInt(params, "metric_freq", &metric_freq); CHECK_GT(metric_freq, 0); @@ -741,6 +748,8 @@ std::string Config::SaveMembersToString() const { str_buf << "[lambdarank_truncation_level: " << lambdarank_truncation_level << "]\n"; str_buf << "[lambdarank_norm: " << lambdarank_norm << "]\n"; str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n"; + str_buf << "[lambdarank_unbiased: " << lambdarank_unbiased << "]\n"; + str_buf << "[lambdarank_bias_p_norm: " << lambdarank_bias_p_norm << "]\n"; str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n"; str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n"; str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n"; diff --git a/src/objective/rank_objective.hpp b/src/objective/rank_objective.hpp index 239bb3651f53..b99e4b974578 100644 --- a/src/objective/rank_objective.hpp +++ b/src/objective/rank_objective.hpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -101,7 +102,9 @@ class LambdarankNDCG : public RankingObjective { : RankingObjective(config), sigmoid_(config.sigmoid), norm_(config.lambdarank_norm), - truncation_level_(config.lambdarank_truncation_level) { + truncation_level_(config.lambdarank_truncation_level), + unbiased_(config.lambdarank_unbiased), + bias_p_norm_(config.lambdarank_bias_p_norm) { label_gain_ = config.label_gain; // initialize DCG calculator DCGCalculator::DefaultLabelGain(&label_gain_); @@ -111,6 +114,14 @@ class LambdarankNDCG : public RankingObjective { if (sigmoid_ <= 0.0) { Log::Fatal("Sigmoid param %f should be greater than zero", sigmoid_); } + + #pragma omp parallel + #pragma omp master + { + num_threads_ = omp_get_num_threads(); + } + + position_bias_regularizer = 1.0f / (1.0f + bias_p_norm_); } explicit LambdarankNDCG(const std::vector& strs) @@ -135,12 +146,24 @@ class LambdarankNDCG : public RankingObjective { } // construct Sigmoid table to speed up Sigmoid transform ConstructSigmoidTable(); + + // initialize position bias vectors + InitPositionBiasesAndGradients(); + } + + void GetGradients(const double* score, score_t* gradients, + score_t* hessians) const override { + RankingObjective::GetGradients(score, gradients, hessians); + + if (unbiased_) { UpdatePositionBiasesAndGradients(); } } inline void GetGradientsForOneQuery(data_size_t query_id, data_size_t cnt, const label_t* label, const double* score, score_t* lambdas, score_t* hessians) const override { + const int tid = omp_get_thread_num(); // get thread id + // get max DCG on current query const double inverse_max_dcg = inverse_max_dcgs_[query_id]; // initialize with zero @@ -199,15 +222,26 @@ class LambdarankNDCG : public RankingObjective { // get delta NDCG double delta_pair_NDCG = dcg_gap * paired_discount * inverse_max_dcg; // regular the delta_pair_NDCG by score distance - if (norm_ && best_score != worst_score) { + if ((norm_ || unbiased_) && best_score != worst_score) { delta_pair_NDCG /= (0.01f + fabs(delta_score)); } // calculate lambda for this pair double p_lambda = GetSigmoid(delta_score); double p_hessian = p_lambda * (1.0f - p_lambda); + + int debias_high_rank = static_cast(std::min(high, truncation_level_ - 1)); + int debias_low_rank = static_cast(std::min(low, truncation_level_ - 1)); + + if (unbiased_) { + double p_cost = log(1.0f / (1.0f - p_lambda)) * delta_pair_NDCG; + + // more relevant (clicked) gets debiased by less relevant (unclicked) + i_costs_buffer_[tid][debias_high_rank] += p_cost / j_biases_pow_[debias_low_rank]; + j_costs_buffer_[tid][debias_low_rank] += p_cost / i_biases_pow_[debias_high_rank]; // and vice versa + } // update - p_lambda *= -sigmoid_ * delta_pair_NDCG; - p_hessian *= sigmoid_ * sigmoid_ * delta_pair_NDCG; + p_lambda *= -sigmoid_ * delta_pair_NDCG / i_biases_pow_[debias_high_rank] / j_biases_pow_[debias_low_rank]; + p_hessian *= sigmoid_ * sigmoid_ * delta_pair_NDCG / i_biases_pow_[debias_high_rank] / j_biases_pow_[debias_low_rank]; lambdas[low] -= static_cast(p_lambda); hessians[low] += static_cast(p_hessian); lambdas[high] += static_cast(p_lambda); @@ -253,9 +287,86 @@ class LambdarankNDCG : public RankingObjective { } } + void InitPositionBiasesAndGradients() { + i_biases_pow_.resize(truncation_level_); + j_biases_pow_.resize(truncation_level_); + i_costs_.resize(truncation_level_); + j_costs_.resize(truncation_level_); + + for (int i = 0; i < truncation_level_; ++i) { + // init position biases + i_biases_pow_[i] = 1.0f; + j_biases_pow_[i] = 1.0f; + + // init position gradients + i_costs_[i] = 0.0f; + j_costs_[i] = 0.0f; + } + + // init gradient buffers for gathering results across threads + for (int i = 0; i < num_threads_; i++) { + i_costs_buffer_.emplace_back(truncation_level_, 0.0f); + j_costs_buffer_.emplace_back(truncation_level_, 0.0f); + } + } + + void UpdatePositionBiasesAndGradients() const { + // accumulate the parallel results + for (int i = 0; i < num_threads_; i++) { + for (int j = 0; j < truncation_level_; j++) { + i_costs_[j] += i_costs_buffer_[i][j]; + j_costs_[j] += j_costs_buffer_[i][j]; + } + } + + for (int i = 0; i < num_threads_; i++) { + for (int j = 0; j < truncation_level_; j++) { + // clear buffer for next run + i_costs_buffer_[i][j] = 0.0f; + j_costs_buffer_[i][j] = 0.0f; + } + } + + for (int i = 0; i < truncation_level_; i++) { + // Update bias + i_biases_pow_[i] = pow(i_costs_[i] / i_costs_[0], position_bias_regularizer); + j_biases_pow_[i] = pow(j_costs_[i] / j_costs_[0], position_bias_regularizer); + } + + LogDebugPositionBiases(); + + for (int i = 0; i < truncation_level_; i++) { + // Clear position info + i_costs_[i] = 0.0f; + j_costs_[i] = 0.0f; + } + } + const char* GetName() const override { return "lambdarank"; } private: + void LogDebugPositionBiases() const { + std::stringstream message_stream; + message_stream << std::setw(10) << "position" + << std::setw(15) << "bias_i" + << std::setw(15) << "bias_j" + << std::setw(15) << "i_cost" + << std::setw(15) << "j_cost" + << std::endl; + Log::Debug(message_stream.str().c_str()); + message_stream.str(""); + + for (int i = 0; i < truncation_level_; ++i) { + message_stream << std::setw(10) << i + << std::setw(15) << i_biases_pow_[i] + << std::setw(15) << j_biases_pow_[i] + << std::setw(15) << i_costs_[i] + << std::setw(15) << j_costs_[i]; + Log::Debug(message_stream.str().c_str()); + message_stream.str(""); + } + } + /*! \brief Sigmoid param */ double sigmoid_; /*! \brief Normalize the lambdas or not */ @@ -276,6 +387,35 @@ class LambdarankNDCG : public RankingObjective { double max_sigmoid_input_ = 50; /*! \brief Factor that covert score to bin in Sigmoid table */ double sigmoid_table_idx_factor_; + + // bias correction variables + /*! \brief power of (click) position biases */ + mutable std::vector i_biases_pow_; + + /*! \brief power of (unclick) position biases */ + mutable std::vector j_biases_pow_; + + // mutable double position cost; + mutable std::vector i_costs_; + mutable std::vector> i_costs_buffer_; + + mutable std::vector j_costs_; + mutable std::vector> j_costs_buffer_; + + /*! + * \brief Should use lambdarank with position bias correction + * [arxiv.org/pdf/1809.05818.pdf] + */ + bool unbiased_; + + /*! \brief Position bias regularizer norm */ + double bias_p_norm_; + + /*! \brief Position bias regularizer exponent, 1 / (1 + bias_p_norm_) */ + double position_bias_regularizer; + + /*! \brief Number of threads */ + int num_threads_; }; /*! diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index d4112078f39e..d996808a271a 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -145,6 +145,21 @@ def test_lambdarank(): assert gbm.best_score_['valid_0']['ndcg@3'] > 0.578 +def test_lambdarank_unbiased(): + rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank' + X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train')) + X_test, y_test = load_svmlight_file(str(rank_example_dir / 'rank.test')) + q_train = np.loadtxt(str(rank_example_dir / 'rank.train.query')) + q_test = np.loadtxt(str(rank_example_dir / 'rank.test.query')) + gbm = lgb.LGBMRanker(n_estimators=50, lambdarank_unbiased=True, sigmoid=2) + gbm.fit(X_train, y_train, group=q_train, eval_set=[(X_test, y_test)], + eval_group=[q_test], eval_at=[1, 3], early_stopping_rounds=10, verbose=False, + callbacks=[lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))]) + assert gbm.best_iteration_ <= 24 + assert gbm.best_score_['valid_0']['ndcg@1'] > 0.569 + assert gbm.best_score_['valid_0']['ndcg@3'] > 0.62 + + def test_xendcg(): xendcg_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'xendcg' X_train, y_train = load_svmlight_file(str(xendcg_example_dir / 'rank.train'))