Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat position bias via GAM in LambdaMART #5929

Merged
merged 103 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
4fea7aa
Update dataset.h
metpavel Jun 15, 2023
e1bbbba
Update metadata.cpp
metpavel Jun 15, 2023
6027716
Update rank_objective.hpp
metpavel Jun 15, 2023
199d412
Update metadata.cpp
metpavel Jun 15, 2023
1a56b5a
Update rank_objective.hpp
metpavel Jun 17, 2023
93f67d1
Update metadata.cpp
metpavel Jun 17, 2023
c615b7d
Update dataset.h
metpavel Jun 17, 2023
e374e9a
Update rank_objective.hpp
metpavel Jun 20, 2023
6c7b86f
Update metadata.cpp
metpavel Jun 20, 2023
365ca75
Update test_engine.py
metpavel Jun 20, 2023
9f033ed
Update test_engine.py
metpavel Jun 20, 2023
50659d7
Add files via upload
metpavel Jun 20, 2023
45fbe8b
Update test_engine.py
metpavel Jun 20, 2023
7579ce1
Update test_engine.py
metpavel Jun 20, 2023
e74f75e
Update test_engine.py
metpavel Jun 20, 2023
a1985e1
Update test_engine.py
metpavel Jun 20, 2023
01189b0
Update test_engine.py
metpavel Jun 21, 2023
065269b
Update _rank.train.position
metpavel Jun 21, 2023
b1d5529
Update test_engine.py
metpavel Jun 21, 2023
adcd240
Update test_engine.py
metpavel Jun 21, 2023
943e1a3
Update test_engine.py
metpavel Jun 21, 2023
6333c77
Update test_engine.py
metpavel Jun 21, 2023
4fc927e
Update _rank.train.position
metpavel Jun 21, 2023
f083d62
Update _rank.train.position
metpavel Jun 21, 2023
93238ac
Update test_engine.py
metpavel Jun 21, 2023
ae647e4
Update _rank.train.position
metpavel Jun 22, 2023
4ecbbf9
Update test_engine.py
metpavel Jun 22, 2023
3bc9415
Update test_engine.py
metpavel Jun 22, 2023
7df4323
Update test_engine.py
metpavel Jun 22, 2023
844d101
Update test_engine.py
metpavel Jun 22, 2023
1ae780a
Update test_engine.py
metpavel Jun 22, 2023
dffdda1
Update the position of import statement
shiyu1994 Jun 26, 2023
055bd1c
Update rank_objective.hpp
metpavel Jun 27, 2023
2885928
Update config.h
metpavel Jun 27, 2023
a566872
Update config_auto.cpp
metpavel Jun 27, 2023
0707cc0
Update rank_objective.hpp
metpavel Jun 27, 2023
64ec098
Update rank_objective.hpp
metpavel Jun 27, 2023
dd58f69
update documentation
shiyu1994 Jun 29, 2023
58666e3
remove extra blank line
shiyu1994 Jun 29, 2023
04e66ed
Update src/io/metadata.cpp
shiyu1994 Jul 5, 2023
a8c77c0
Update src/io/metadata.cpp
shiyu1994 Jul 5, 2023
fb0d251
remove _rank.train.position
shiyu1994 Jul 10, 2023
923b41a
add position in python API
shiyu1994 Jul 12, 2023
baa7b05
merge master
shiyu1994 Jul 12, 2023
2b91c24
fix set_positions in basic.py
shiyu1994 Jul 12, 2023
465f9fa
Update Advanced-Topics.rst
metpavel Jul 13, 2023
4cdf730
Update Advanced-Topics.rst
metpavel Jul 13, 2023
343968f
Update Advanced-Topics.rst
metpavel Jul 13, 2023
0186906
Update Advanced-Topics.rst
metpavel Jul 13, 2023
0819f4a
Update Advanced-Topics.rst
metpavel Jul 13, 2023
4e54f70
Update Advanced-Topics.rst
metpavel Jul 13, 2023
36f44f1
Update Advanced-Topics.rst
metpavel Jul 13, 2023
acfde2c
Update Advanced-Topics.rst
metpavel Jul 13, 2023
a7d0ae5
Update Advanced-Topics.rst
metpavel Jul 13, 2023
e8e1854
Update Advanced-Topics.rst
metpavel Jul 13, 2023
108f630
Update Advanced-Topics.rst
metpavel Jul 13, 2023
482d394
Update docs/Advanced-Topics.rst
shiyu1994 Jul 13, 2023
03e41bd
Update docs/Advanced-Topics.rst
metpavel Jul 14, 2023
621c187
Update Advanced-Topics.rst
metpavel Jul 14, 2023
d8adb4c
Update Advanced-Topics.rst
metpavel Jul 14, 2023
b93ecca
Update Advanced-Topics.rst
metpavel Jul 14, 2023
68ffab8
Update Advanced-Topics.rst
metpavel Jul 14, 2023
81b5f09
remove List from _LGBM_PositionType
shiyu1994 Jul 20, 2023
c917f83
Merge branch 'metpavel-posbias_GAM' of https://github.com/metpavel/Li…
shiyu1994 Jul 20, 2023
0db4caf
move new position parameter to the last in Dataset constructor
shiyu1994 Jul 20, 2023
102fb97
add position_filename as a parameter
shiyu1994 Jul 28, 2023
de43fc8
Update docs/Advanced-Topics.rst
metpavel Aug 2, 2023
626fa16
Update docs/Advanced-Topics.rst
metpavel Aug 2, 2023
adcde0c
Update Advanced-Topics.rst
metpavel Aug 2, 2023
f3c3387
Update src/objective/rank_objective.hpp
metpavel Aug 2, 2023
e3c5e6f
Update src/io/metadata.cpp
metpavel Aug 2, 2023
d737680
Update metadata.cpp
metpavel Aug 2, 2023
5d836ed
Update python-package/lightgbm/basic.py
shiyu1994 Aug 4, 2023
1c69862
Update python-package/lightgbm/basic.py
shiyu1994 Aug 4, 2023
5232cb8
Update python-package/lightgbm/basic.py
shiyu1994 Aug 4, 2023
7d9f0bb
Update python-package/lightgbm/basic.py
shiyu1994 Aug 4, 2023
14c9c60
Update src/io/metadata.cpp
shiyu1994 Aug 4, 2023
3a2031a
more infomrative fatal message
shiyu1994 Aug 4, 2023
a92dd1a
Merge branch 'metpavel-posbias_GAM' of https://github.com/metpavel/Li…
shiyu1994 Aug 4, 2023
757f7cb
update documentation for more flexible position specification
shiyu1994 Aug 4, 2023
70fc191
fix SetPosition
shiyu1994 Aug 4, 2023
d92f6d0
remove position_filename
shiyu1994 Aug 8, 2023
b55e44b
remove useless changes
shiyu1994 Aug 8, 2023
fdda50f
Update python-package/lightgbm/basic.py
shiyu1994 Aug 8, 2023
56d77c2
remove useless files
shiyu1994 Aug 8, 2023
0830fac
Merge branch 'metpavel-posbias_GAM' of https://github.com/metpavel/Li…
shiyu1994 Aug 8, 2023
9a4ac88
move position file when position set in Dataset
shiyu1994 Aug 8, 2023
74b6934
warn when positions are overwritten
shiyu1994 Aug 8, 2023
59c275f
skip ranking with position test in cuda
shiyu1994 Aug 8, 2023
8bbbc43
split test case
shiyu1994 Aug 9, 2023
e843de6
remove useless import
shiyu1994 Aug 9, 2023
5efb9a9
Update test_engine.py
metpavel Aug 25, 2023
ee2fc4b
Update test_engine.py
metpavel Aug 25, 2023
98ac42d
Update test_engine.py
metpavel Aug 25, 2023
ca4bd04
Update docs/Advanced-Topics.rst
metpavel Aug 26, 2023
56a337f
Update Parameters.rst
metpavel Aug 26, 2023
2b88856
Update rank_objective.hpp
metpavel Aug 26, 2023
893405b
Update config.h
metpavel Aug 26, 2023
5dd6428
Merge branch 'master' into metpavel-posbias_GAM
shiyu1994 Sep 4, 2023
f758b26
update config_auto.cppp
shiyu1994 Sep 4, 2023
960c758
Update docs/Advanced-Topics.rst
shiyu1994 Sep 4, 2023
3d934e6
fix randomness in test case for gpu
shiyu1994 Sep 4, 2023
ad188f6
Merge branch 'metpavel-posbias_GAM' of https://github.com/metpavel/Li…
shiyu1994 Sep 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,38 @@ class Metadata {
}
}

/*!
* \brief Get positions, if does not exist then return nullptr
* \return Pointer of positions
*/
inline const size_t* positions() const {
if (!positions_.empty()) {
return positions_.data();
} else {
return nullptr;
}
}

/*!
* \brief Get position IDs, if does not exist then return nullptr
* \return Pointer of position IDs
*/
inline const std::string* position_ids() const {
if (!position_ids_.empty()) {
return position_ids_.data();
} else {
return nullptr;
}
}

/*!
* \brief Get Number of different position IDs
* \return number of different position IDs
*/
inline const size_t num_position_ids() const {
return position_ids_.size();
}

/*!
* \brief Get data boundaries on queries, if not exists, will return nullptr
* we assume data will order by query,
Expand Down Expand Up @@ -289,6 +321,8 @@ class Metadata {
private:
/*! \brief Load wights from file */
void LoadWeights();
/*! \brief Load positions from file */
void LoadPositions();
/*! \brief Load query boundaries from file */
void LoadQueryBoundaries();
/*! \brief Calculate query weights from queries */
Expand All @@ -309,10 +343,16 @@ class Metadata {
data_size_t num_data_;
/*! \brief Number of weights, used to check correct weight file */
data_size_t num_weights_;
/*! \brief Number of positions, used to check correct position file */
data_size_t num_positions_;
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief Label data */
std::vector<label_t> label_;
/*! \brief Weights data */
std::vector<label_t> weights_;
/*! \brief Positions data */
std::vector<size_t> positions_;
/*! \brief Position identifiers */
std::vector<std::string> position_ids_;
/*! \brief Query boundaries */
std::vector<data_size_t> query_boundaries_;
/*! \brief Query weights */
Expand All @@ -328,6 +368,7 @@ class Metadata {
/*! \brief mutex for threading safe call */
std::mutex mutex_;
bool weight_load_from_file_;
bool position_load_from_file_;
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
bool query_load_from_file_;
bool init_score_load_from_file_;
#ifdef USE_CUDA
Expand Down
56 changes: 56 additions & 0 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ Metadata::Metadata() {
num_init_score_ = 0;
num_data_ = 0;
num_queries_ = 0;
num_positions_ = 0;
weight_load_from_file_ = false;
position_load_from_file_ = false;
query_load_from_file_ = false;
init_score_load_from_file_ = false;
#ifdef USE_CUDA
Expand All @@ -28,6 +30,7 @@ void Metadata::Init(const char* data_filename) {
// for lambdarank, it needs query data for partition data in distributed learning
LoadQueryBoundaries();
LoadWeights();
LoadPositions();
CalculateQueryWeights();
LoadInitialScore(data_filename_);
}
Expand Down Expand Up @@ -214,6 +217,13 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
Log::Fatal("Weights size doesn't match data size");
}

// check positions
if (!positions_.empty() && num_positions_ != num_data_) {
positions_.clear();
num_positions_ = 0;
Log::Fatal("Positions size doesn't match data size");
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
}

// check query boundries
if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
query_boundaries_.clear();
Expand Down Expand Up @@ -251,6 +261,25 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
old_weights.clear();
}
}
// check positions
if (position_load_from_file_) {
if (positions_.size() > 0 && num_positions_ != num_all_data) {
positions_.clear();
num_positions_ = 0;
Log::Fatal("Positions size doesn't match data size");
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
}
// get local weights
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
if (!positions_.empty()) {
auto old_positions = positions_;
num_positions_ = num_data_;
positions_ = std::vector<size_t>(num_data_);
#pragma omp parallel for schedule(static, 512)
for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
positions_[i] = old_positions[used_data_indices[i]];
}
old_positions.clear();
}
}
if (query_load_from_file_) {
// check query boundries
if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
Expand Down Expand Up @@ -528,6 +557,33 @@ void Metadata::LoadWeights() {
weight_load_from_file_ = true;
}

void Metadata::LoadPositions() {
num_positions_ = 0;
std::string position_filename(data_filename_);
// default position file name
position_filename.append(".position");
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
TextReader<size_t> reader(position_filename.c_str(), false);
reader.ReadAllLines();
if (reader.Lines().empty()) {
return;
}
Log::Info("Loading positions...");
num_positions_ = static_cast<data_size_t>(reader.Lines().size());
positions_ = std::vector<size_t>(num_positions_);
position_ids_ = std::vector<std::string>();
std::unordered_map<std::string, size_t> map_id2pos;
for (data_size_t i = 0; i < num_positions_; ++i) {
std::string& line = reader.Lines()[i];
if (map_id2pos.count(line) == 0) {
map_id2pos[line] = position_ids_.size();
position_ids_.push_back(line);
}
positions_[i] = map_id2pos.at(line);
}
position_load_from_file_ = true;
}


void Metadata::LoadInitialScore(const std::string& data_filename) {
num_init_score_ = 0;
std::string init_score_filename(data_filename);
Expand Down
83 changes: 81 additions & 2 deletions src/objective/rank_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ class RankingObjective : public ObjectiveFunction {
label_ = metadata.label();
// get weights
weights_ = metadata.weights();
// get positions
positions_ = metadata.positions();
// get position ids
position_ids_ = metadata.position_ids();
// get number of different position ids
num_position_ids_ = static_cast<data_size_t>(metadata.num_position_ids());
// get boundries
query_boundaries_ = metadata.query_boundaries();
if (query_boundaries_ == nullptr) {
Expand All @@ -62,13 +68,18 @@ class RankingObjective : public ObjectiveFunction {
}
}
}
if (num_position_ids_ > 0) {
UpdatePositionBiasFactors(gradients, hessians);
}
}

virtual void GetGradientsForOneQuery(data_size_t query_id, data_size_t cnt,
const label_t* label,
const double* score, score_t* lambdas,
score_t* hessians) const = 0;

virtual void UpdatePositionBiasFactors(const score_t* lambdas, const score_t* hessians) const {}

const char* GetName() const override = 0;

std::string ToString() const override {
Expand All @@ -88,6 +99,12 @@ class RankingObjective : public ObjectiveFunction {
const label_t* label_;
/*! \brief Pointer of weights */
const label_t* weights_;
/*! \brief Pointer of positions */
const size_t* positions_;
/*! \brief Pointer of position IDs */
const std::string* position_ids_;
/*! \brief Pointer of label */
data_size_t num_position_ids_;
/*! \brief Query boundaries */
const data_size_t* query_boundaries_;
};
Expand All @@ -111,6 +128,7 @@ class LambdarankNDCG : public RankingObjective {
if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid param %f should be greater than zero", sigmoid_);
}
learning_rate_ = config.learning_rate;
}

explicit LambdarankNDCG(const std::vector<std::string>& strs)
Expand All @@ -135,6 +153,8 @@ class LambdarankNDCG : public RankingObjective {
}
// construct Sigmoid table to speed up Sigmoid transform
ConstructSigmoidTable();
// initialize position bias vectors
pos_biases_.resize(num_position_ids_, 0.0);
}

inline void GetGradientsForOneQuery(data_size_t query_id, data_size_t cnt,
Expand Down Expand Up @@ -181,14 +201,18 @@ class LambdarankNDCG : public RankingObjective {
}
const data_size_t high = sorted_idx[high_rank];
const int high_label = static_cast<int>(label[high]);
const double high_score = score[high];
double high_score = score[high];
const double high_label_gain = label_gain_[high_label];
const double high_discount = DCGCalculator::GetDiscount(high_rank);
const data_size_t low = sorted_idx[low_rank];
const int low_label = static_cast<int>(label[low]);
const double low_score = score[low];
double low_score = score[low];
const double low_label_gain = label_gain_[low_label];
const double low_discount = DCGCalculator::GetDiscount(low_rank);
if (num_position_ids_ > 0) {
high_score += pos_biases_[positions_[query_boundaries_[query_id] + high]];
low_score += pos_biases_[positions_[query_boundaries_[query_id] + low]];
}

const double delta_score = high_score - low_score;

Expand Down Expand Up @@ -253,9 +277,60 @@ class LambdarankNDCG : public RankingObjective {
}
}

void UpdatePositionBiasFactors(const score_t* lambdas, const score_t* hessians) const override {
/// get number of threads
int num_threads = 1;
#pragma omp parallel
#pragma omp master
{
num_threads = omp_get_num_threads();
}
// create per-thread buffers for first and second derivatives of utility w.r.t. position bias factors
std::vector<double> bias_first_derivatives(num_position_ids_ * num_threads, 0.0);
std::vector<double> bias_second_derivatives(num_position_ids_ * num_threads, 0.0);
#pragma omp parallel for schedule(guided)
for (data_size_t i = 0; i < num_data_; i++) {
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
// get thread ID
const int tid = omp_get_thread_num();
size_t offset = static_cast<size_t>(positions_[i] + tid * num_threads);
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
// accumulate first derivatives of utility w.r.t. position bias factors, for each position
bias_first_derivatives[offset] -= lambdas[i];
// accumulate second derivatives of utility w.r.t. position bias factors, for each position
bias_second_derivatives[offset] -= hessians[i];
}
#pragma omp parallel for schedule(guided)
for (data_size_t i = 0; i < num_position_ids_; i++) {
double bias_first_derivative = 0.0;
double bias_second_derivative = 0.0;
// aggregate derivatives from per-thread buffers
for (int tid = 0; tid < num_threads; tid++) {
size_t offset = static_cast<size_t>(i + tid * num_threads);
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
bias_first_derivative += bias_first_derivatives[offset];
bias_second_derivative += bias_second_derivatives[offset];
}
// do Newton-Rhapson step to update position bias factors
metpavel marked this conversation as resolved.
Show resolved Hide resolved
pos_biases_[i] += learning_rate_ * bias_first_derivative / (std::abs(bias_second_derivative) + 0.001);
}
LogDebugPositionBiasFactors();
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
}

const char* GetName() const override { return "lambdarank"; }

protected:
void LogDebugPositionBiasFactors() const {
std::stringstream message_stream;
message_stream << std::setw(15) << "position"
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
<< std::setw(15) << "bias_factor"
<< std::endl;
Log::Debug(message_stream.str().c_str());
message_stream.str("");
for (int i = 0; i < num_position_ids_; ++i) {
message_stream << std::setw(15) << position_ids_[i]
<< std::setw(15) << pos_biases_[i];
Log::Debug(message_stream.str().c_str());
message_stream.str("");
}
}
/*! \brief Sigmoid param */
double sigmoid_;
/*! \brief Normalize the lambdas or not */
Expand All @@ -276,6 +351,10 @@ class LambdarankNDCG : public RankingObjective {
double max_sigmoid_input_ = 50;
/*! \brief Factor that covert score to bin in Sigmoid table */
double sigmoid_table_idx_factor_;
/*! \brief Position bias factors */
mutable std::vector<label_t> pos_biases_;
/*! \brief Learning rate to update position bias factors */
double learning_rate_;
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
};

/*!
Expand Down