From 7e34d23c05599ce3a8a6f22cdba29e103f57d218 Mon Sep 17 00:00:00 2001 From: Pavel Metrikov <46672636+metpavel@users.noreply.github.com> Date: Mon, 4 Sep 2023 02:05:46 -0700 Subject: [PATCH 01/14] Treat position bias via GAM in LambdaMART (#5929) * Update dataset.h * Update metadata.cpp * Update rank_objective.hpp * Update metadata.cpp * Update rank_objective.hpp * Update metadata.cpp * Update dataset.h * Update rank_objective.hpp * Update metadata.cpp * Update test_engine.py * Update test_engine.py * Add files via upload * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update _rank.train.position * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update _rank.train.position * Update _rank.train.position * Update test_engine.py * Update _rank.train.position * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update the position of import statement * Update rank_objective.hpp * Update config.h * Update config_auto.cpp * Update rank_objective.hpp * Update rank_objective.hpp * update documentation * remove extra blank line * Update src/io/metadata.cpp Co-authored-by: James Lamb * Update src/io/metadata.cpp Co-authored-by: James Lamb * remove _rank.train.position * add position in python API * fix set_positions in basic.py * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * remove List from _LGBM_PositionType * move new position parameter to the last in Dataset constructor * add position_filename as a parameter * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update Advanced-Topics.rst * Update src/objective/rank_objective.hpp Co-authored-by: James Lamb * Update src/io/metadata.cpp Co-authored-by: James Lamb * Update metadata.cpp * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * Update src/io/metadata.cpp Co-authored-by: James Lamb * more infomrative fatal message address more comments * update documentation for more flexible position specification * fix SetPosition add tests for get_position and set_position * remove position_filename * remove useless changes * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * remove useless files * move position file when position set in Dataset * warn when positions are overwritten * skip ranking with position test in cuda * split test case * remove useless import * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update Parameters.rst * Update rank_objective.hpp * Update config.h * update config_auto.cppp * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * fix randomness in test case for gpu --------- Co-authored-by: shiyu1994 Co-authored-by: James Lamb --- docs/Advanced-Topics.rst | 41 ++++++ docs/Parameters.rst | 4 + include/LightGBM/config.h | 4 + include/LightGBM/dataset.h | 43 ++++++ python-package/lightgbm/basic.py | 68 +++++++-- src/io/config_auto.cpp | 7 + src/io/dataset.cpp | 5 + src/io/metadata.cpp | 97 +++++++++++++ src/objective/rank_objective.hpp | 96 ++++++++++++- tests/python_package_test/test_engine.py | 168 ++++++++++++++++++++++- 10 files changed, 522 insertions(+), 11 deletions(-) diff --git a/docs/Advanced-Topics.rst b/docs/Advanced-Topics.rst index d1787b998479..345a1361bfa9 100644 --- a/docs/Advanced-Topics.rst +++ b/docs/Advanced-Topics.rst @@ -77,3 +77,44 @@ Recommendations for gcc Users (MinGW, \*nix) -------------------------------------------- - Refer to `gcc Tips <./gcc-Tips.rst>`__. + +Support for Position Bias Treatment +------------------------------------ + +Often the relevance labels provided in Learning-to-Rank tasks might be derived from implicit user feedback (e.g., clicks) and therefore might be biased due to their position/location on the screen when having been presented to a user. +LightGBM can make use of positional data. + +For example, consider the case where you expect that the first 3 results from a search engine will be visible in users' browsers without scrolling, and all other results for a query would require scrolling. + +LightGBM could be told to account for the position bias from results being "above the fold" by providing a ``positions`` array encoded as follows: + +:: + + 0 + 0 + 0 + 1 + 1 + 0 + 0 + 0 + 1 + ... + +Where ``0 = "above the fold"`` and ``1 = "requires scrolling"``. +The specific values are not important, as long as they are consistent across all observations in the training data. +An encoding like ``100 = "above the fold"`` and ``17 = "requires scrolling"`` would result in exactly the same trained model. + +In that way, ``positions`` in LightGBM's API are similar to a categorical feature. +Just as with non-ordinal categorical features, an integer representation is just used for memory and computational efficiency... LightGBM does not care about the absolute or relative magnitude of the values. + +Unlike a categorical feature, however, ``positions`` are used to adjust the target to reduce the bias in predictions made by the trained model. + +The position file corresponds with training data file line by line, and has one position per line. And if the name of training data file is ``train.txt``, the position file should be named as ``train.txt.position`` and placed in the same folder as the data file. +In this case, LightGBM will load the position file automatically if it exists. The positions can also be specified through the ``Dataset`` constructor when using Python API. If the positions are specified in both approaches, the ``.position`` file will be ignored. + +Currently, implemented is an approach to model position bias by using an idea of Generalized Additive Models (`GAM `_) to linearly decompose the document score ``s`` into the sum of a relevance component ``f`` and a positional component ``g``: ``s(x, pos) = f(x) + g(pos)`` where the former component depends on the original query-document features and the latter depends on the position of an item. +During the training, the compound scoring function ``s(x, pos)`` is fit with a standard ranking algorithm (e.g., LambdaMART) which boils down to jointly learning the relevance component ``f(x)`` (it is later returned as an unbiased model) and the position factors ``g(pos)`` that help better explain the observed (biased) labels. +Similar score decomposition ideas have previously been applied for classification & pointwise ranking tasks with assumptions of binary labels and binary relevance (a.k.a. "two-tower" models, refer to the papers: `Towards Disentangling Relevance and Bias in Unbiased Learning to Rank `_, `PAL: a position-bias aware learning framework for CTR prediction in live recommender systems `_, `A General Framework for Debiasing in CTR Prediction `_). +In LightGBM, we adapt this idea to general pairwise Lerarning-to-Rank with arbitrary ordinal relevance labels. +Besides, GAMs have been used in the context of explainable ML (`Accurate Intelligible Models with Pairwise Interactions `_) to linearly decompose the contribution of each feature (and possibly their pairwise interactions) to the overall score, for subsequent analysis and interpretation of their effects in the trained models. diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 5eecc27889b6..7d825f9f135a 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -1137,6 +1137,10 @@ Objective Parameters - separate by ``,`` +- ``lambdarank_position_bias_regularization`` :raw-html:`🔗︎`, default = ``0.0``, type = double, constraints: ``lambdarank_position_bias_regularization >= 0.0`` + + - used only in ``lambdarank`` application when positional information is provided and position bias is modeled. Larger values reduce the inferred position bias factors. + Metric Parameters ----------------- diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index e01578396259..343abf51e17f 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -965,6 +965,10 @@ struct Config { // desc = separate by ``,`` std::vector label_gain; + // check = >=0.0 + // desc = used only in ``lambdarank`` application when positional information is provided and position bias is modeled. Larger values reduce the inferred position bias factors. + double lambdarank_position_bias_regularization = 0.0; + #ifndef __NVCC__ #pragma endregion diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 825c5c6ebcf8..e7baa42dc2e6 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -114,6 +114,8 @@ class Metadata { void SetQuery(const data_size_t* query, data_size_t len); + void SetPosition(const data_size_t* position, data_size_t len); + /*! * \brief Set initial scores * \param init_score Initial scores, this class will manage memory for init_score. @@ -213,6 +215,38 @@ class Metadata { } } + /*! + * \brief Get positions, if does not exist then return nullptr + * \return Pointer of positions + */ + inline const data_size_t* positions() const { + if (!positions_.empty()) { + return positions_.data(); + } else { + return nullptr; + } + } + + /*! + * \brief Get position IDs, if does not exist then return nullptr + * \return Pointer of position IDs + */ + inline const std::string* position_ids() const { + if (!position_ids_.empty()) { + return position_ids_.data(); + } else { + return nullptr; + } + } + + /*! + * \brief Get Number of different position IDs + * \return number of different position IDs + */ + inline size_t num_position_ids() const { + return position_ids_.size(); + } + /*! * \brief Get data boundaries on queries, if not exists, will return nullptr * we assume data will order by query, @@ -289,6 +323,8 @@ class Metadata { private: /*! \brief Load wights from file */ void LoadWeights(); + /*! \brief Load positions from file */ + void LoadPositions(); /*! \brief Load query boundaries from file */ void LoadQueryBoundaries(); /*! \brief Calculate query weights from queries */ @@ -309,10 +345,16 @@ class Metadata { data_size_t num_data_; /*! \brief Number of weights, used to check correct weight file */ data_size_t num_weights_; + /*! \brief Number of positions, used to check correct position file */ + data_size_t num_positions_; /*! \brief Label data */ std::vector label_; /*! \brief Weights data */ std::vector weights_; + /*! \brief Positions data */ + std::vector positions_; + /*! \brief Position identifiers */ + std::vector position_ids_; /*! \brief Query boundaries */ std::vector query_boundaries_; /*! \brief Query weights */ @@ -328,6 +370,7 @@ class Metadata { /*! \brief mutex for threading safe call */ std::mutex mutex_; bool weight_load_from_file_; + bool position_load_from_file_; bool query_load_from_file_; bool init_score_load_from_file_; #ifdef USE_CUDA diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 6ff448dfeb3d..2f061bdacf31 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -62,6 +62,10 @@ np.ndarray, pd_Series ] +_LGBM_PositionType = Union[ + np.ndarray, + pd_Series +] _LGBM_InitScoreType = Union[ List[float], List[List[float]], @@ -577,7 +581,8 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va "label": _C_API_DTYPE_FLOAT32, "weight": _C_API_DTYPE_FLOAT32, "init_score": _C_API_DTYPE_FLOAT64, - "group": _C_API_DTYPE_INT32 + "group": _C_API_DTYPE_INT32, + "position": _C_API_DTYPE_INT32 } """String name to int feature importance type mapper""" @@ -1525,7 +1530,8 @@ def __init__( feature_name: _LGBM_FeatureNameConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', params: Optional[Dict[str, Any]] = None, - free_raw_data: bool = True + free_raw_data: bool = True, + position: Optional[_LGBM_PositionType] = None, ): """Initialize Dataset. @@ -1565,6 +1571,8 @@ def __init__( Other parameters for Dataset. free_raw_data : bool, optional (default=True) If True, raw data is freed after constructing inner Dataset. + position : numpy 1-D array, pandas Series or None, optional (default=None) + Position of items used in unbiased learning-to-rank task. """ self._handle: Optional[_DatasetHandle] = None self.data = data @@ -1572,6 +1580,7 @@ def __init__( self.reference = reference self.weight = weight self.group = group + self.position = position self.init_score = init_score self.feature_name: _LGBM_FeatureNameConfiguration = feature_name self.categorical_feature: _LGBM_CategoricalFeatureConfiguration = categorical_feature @@ -1836,7 +1845,8 @@ def _lazy_init( predictor: Optional[_InnerPredictor], feature_name: _LGBM_FeatureNameConfiguration, categorical_feature: _LGBM_CategoricalFeatureConfiguration, - params: Optional[Dict[str, Any]] + params: Optional[Dict[str, Any]], + position: Optional[_LGBM_PositionType] ) -> "Dataset": if data is None: self._handle = None @@ -1925,6 +1935,8 @@ def _lazy_init( self.set_weight(weight) if group is not None: self.set_group(group) + if position is not None: + self.set_position(position) if isinstance(predictor, _InnerPredictor): if self._predictor is None and init_score is not None: _log_warning("The init_score will be overridden by the prediction of init_model.") @@ -2219,7 +2231,7 @@ def construct(self) -> "Dataset": if self.used_indices is None: # create valid self._lazy_init(data=self.data, label=self.label, reference=self.reference, - weight=self.weight, group=self.group, + weight=self.weight, group=self.group, position=self.position, init_score=self.init_score, predictor=self._predictor, feature_name=self.feature_name, categorical_feature='auto', params=self.params) else: @@ -2242,6 +2254,8 @@ def construct(self) -> "Dataset": self.get_data() if self.group is not None: self.set_group(self.group) + if self.position is not None: + self.set_position(self.position) if self.get_label() is None: raise ValueError("Label should not be None.") if isinstance(self._predictor, _InnerPredictor) and self._predictor is not self.reference._predictor: @@ -2256,7 +2270,8 @@ def construct(self) -> "Dataset": self._lazy_init(data=self.data, label=self.label, reference=None, weight=self.weight, group=self.group, init_score=self.init_score, predictor=self._predictor, - feature_name=self.feature_name, categorical_feature=self.categorical_feature, params=self.params) + feature_name=self.feature_name, categorical_feature=self.categorical_feature, + params=self.params, position=self.position) if self.free_raw_data: self.data = None self.feature_name = self.get_feature_name() @@ -2269,7 +2284,8 @@ def create_valid( weight: Optional[_LGBM_WeightType] = None, group: Optional[_LGBM_GroupType] = None, init_score: Optional[_LGBM_InitScoreType] = None, - params: Optional[Dict[str, Any]] = None + params: Optional[Dict[str, Any]] = None, + position: Optional[_LGBM_PositionType] = None ) -> "Dataset": """Create validation data align with current Dataset. @@ -2292,6 +2308,8 @@ def create_valid( Init score for Dataset. params : dict or None, optional (default=None) Other parameters for validation Dataset. + position : numpy 1-D array, pandas Series or None, optional (default=None) + Position of items used in unbiased learning-to-rank task. Returns ------- @@ -2299,7 +2317,7 @@ def create_valid( Validation Dataset with reference to self. """ ret = Dataset(data, label=label, reference=self, - weight=weight, group=group, init_score=init_score, + weight=weight, group=group, position=position, init_score=init_score, params=params, free_raw_data=self.free_raw_data) ret._predictor = self._predictor ret.pandas_categorical = self.pandas_categorical @@ -2434,7 +2452,7 @@ def set_field( 'In multiclass classification init_score can also be a list of lists, numpy 2-D array or pandas DataFrame.' ) else: - dtype = np.int32 if field_name == 'group' else np.float32 + dtype = np.int32 if (field_name == 'group' or field_name == 'position') else np.float32 data = _list_to_1d_numpy(data, dtype=dtype, name=field_name) ptr_data: Union[_ctypes_float_ptr, _ctypes_int_ptr] @@ -2727,6 +2745,28 @@ def set_group( self.set_field('group', group) return self + def set_position( + self, + position: Optional[_LGBM_PositionType] + ) -> "Dataset": + """Set position of Dataset (used for ranking). + + Parameters + ---------- + position : numpy 1-D array, pandas Series or None, optional (default=None) + Position of items used in unbiased learning-to-rank task. + + Returns + ------- + self : Dataset + Dataset with set position. + """ + self.position = position + if self._handle is not None and position is not None: + position = _list_to_1d_numpy(position, dtype=np.int32, name='position') + self.set_field('position', position) + return self + def get_feature_name(self) -> List[str]: """Get the names of columns (features) in the Dataset. @@ -2853,6 +2893,18 @@ def get_group(self) -> Optional[np.ndarray]: self.group = np.diff(self.group) return self.group + def get_position(self) -> Optional[np.ndarray]: + """Get the position of the Dataset. + + Returns + ------- + position : numpy 1-D array or None + Position of items used in unbiased learning-to-rank task. + """ + if self.position is None: + self.position = self.get_field('position') + return self.position + def num_data(self) -> int: """Get the number of rows in the Dataset. diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 0906ba4b6439..8182c9b52b93 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -304,6 +304,7 @@ const std::unordered_set& Config::parameter_set() { "lambdarank_truncation_level", "lambdarank_norm", "label_gain", + "lambdarank_position_bias_regularization", "metric", "metric_freq", "is_provide_training_metric", @@ -619,6 +620,9 @@ void Config::GetMembersFromString(const std::unordered_map(tmp_str, ','); } + GetDouble(params, "lambdarank_position_bias_regularization", &lambdarank_position_bias_regularization); + CHECK_GE(lambdarank_position_bias_regularization, 0.0); + GetInt(params, "metric_freq", &metric_freq); CHECK_GT(metric_freq, 0); @@ -754,6 +758,7 @@ std::string Config::SaveMembersToString() const { str_buf << "[lambdarank_truncation_level: " << lambdarank_truncation_level << "]\n"; str_buf << "[lambdarank_norm: " << lambdarank_norm << "]\n"; str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n"; + str_buf << "[lambdarank_position_bias_regularization: " << lambdarank_position_bias_regularization << "]\n"; str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n"; str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n"; str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n"; @@ -893,6 +898,7 @@ const std::unordered_map>& Config::paramet {"lambdarank_truncation_level", {}}, {"lambdarank_norm", {}}, {"label_gain", {}}, + {"lambdarank_position_bias_regularization", {}}, {"metric", {"metrics", "metric_types"}}, {"metric_freq", {"output_freq"}}, {"is_provide_training_metric", {"training_metric", "is_training_metric", "train_metric"}}, @@ -1035,6 +1041,7 @@ const std::unordered_map& Config::ParameterTypes() { {"lambdarank_truncation_level", "int"}, {"lambdarank_norm", "bool"}, {"label_gain", "vector"}, + {"lambdarank_position_bias_regularization", "double"}, {"metric", "vector"}, {"metric_freq", "int"}, {"is_provide_training_metric", "bool"}, diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 9e590b79821c..d5aa707adcc0 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -937,6 +937,8 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, name = Common::Trim(name); if (name == std::string("query") || name == std::string("group")) { metadata_.SetQuery(field_data, num_element); + } else if (name == std::string("position")) { + metadata_.SetPosition(field_data, num_element); } else { return false; } @@ -987,6 +989,9 @@ bool Dataset::GetIntField(const char* field_name, data_size_t* out_len, if (name == std::string("query") || name == std::string("group")) { *out_ptr = metadata_.query_boundaries(); *out_len = metadata_.num_queries() + 1; + } else if (name == std::string("position")) { + *out_ptr = metadata_.positions(); + *out_len = num_data_; } else { return false; } diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index 2a589fa24ef8..1fc47c46787f 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -15,7 +16,9 @@ Metadata::Metadata() { num_init_score_ = 0; num_data_ = 0; num_queries_ = 0; + num_positions_ = 0; weight_load_from_file_ = false; + position_load_from_file_ = false; query_load_from_file_ = false; init_score_load_from_file_ = false; #ifdef USE_CUDA @@ -28,6 +31,7 @@ void Metadata::Init(const char* data_filename) { // for lambdarank, it needs query data for partition data in distributed learning LoadQueryBoundaries(); LoadWeights(); + LoadPositions(); CalculateQueryWeights(); LoadInitialScore(data_filename_); } @@ -214,6 +218,13 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector 0 && num_positions_ != num_all_data) { + positions_.clear(); + num_positions_ = 0; + Log::Fatal("Positions size (%i) doesn't match data size (%i)", num_positions_, num_data_); + } + // get local positions + if (!positions_.empty()) { + auto old_positions = positions_; + num_positions_ = num_data_; + positions_ = std::vector(num_data_); + #pragma omp parallel for schedule(static, 512) + for (int i = 0; i < static_cast(used_data_indices.size()); ++i) { + positions_[i] = old_positions[used_data_indices[i]]; + } + old_positions.clear(); + } + } if (query_load_from_file_) { // check query boundries if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) { @@ -489,6 +519,47 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) { #endif // USE_CUDA } +void Metadata::SetPosition(const data_size_t* positions, data_size_t len) { + std::lock_guard lock(mutex_); + // save to nullptr + if (positions == nullptr || len == 0) { + positions_.clear(); + num_positions_ = 0; + return; + } + #ifdef USE_CUDA + Log::Fatal("Positions in learning to rank is not supported in CUDA version yet."); + #endif // USE_CUDA + if (num_data_ != len) { + Log::Fatal("Positions size (%i) doesn't match data size (%i)", len, num_data_); + } + if (positions_.empty()) { + positions_.resize(num_data_); + } else { + Log::Warning("Overwritting positions in dataset."); + } + num_positions_ = num_data_; + + position_load_from_file_ = false; + + position_ids_.clear(); + std::unordered_map map_id2pos; + for (data_size_t i = 0; i < num_positions_; ++i) { + if (map_id2pos.count(positions[i]) == 0) { + int pos = static_cast(map_id2pos.size()); + map_id2pos[positions[i]] = pos; + position_ids_.push_back(std::to_string(positions[i])); + } + } + + Log::Debug("number of unique positions found = %ld", position_ids_.size()); + + #pragma omp parallel for schedule(static, 512) if (num_positions_ >= 1024) + for (data_size_t i = 0; i < num_positions_; ++i) { + positions_[i] = map_id2pos.at(positions[i]); + } +} + void Metadata::InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len) { if (!queries) { Log::Fatal("Passed null queries"); @@ -528,6 +599,32 @@ void Metadata::LoadWeights() { weight_load_from_file_ = true; } +void Metadata::LoadPositions() { + num_positions_ = 0; + std::string position_filename(data_filename_); + // default position file name + position_filename.append(".position"); + TextReader reader(position_filename.c_str(), false); + reader.ReadAllLines(); + if (reader.Lines().empty()) { + return; + } + Log::Info("Loading positions from %s ...", position_filename.c_str()); + num_positions_ = static_cast(reader.Lines().size()); + positions_ = std::vector(num_positions_); + position_ids_ = std::vector(); + std::unordered_map map_id2pos; + for (data_size_t i = 0; i < num_positions_; ++i) { + std::string& line = reader.Lines()[i]; + if (map_id2pos.count(line) == 0) { + map_id2pos[line] = static_cast(position_ids_.size()); + position_ids_.push_back(line); + } + positions_[i] = map_id2pos.at(line); + } + position_load_from_file_ = true; +} + void Metadata::LoadInitialScore(const std::string& data_filename) { num_init_score_ = 0; std::string init_score_filename(data_filename); diff --git a/src/objective/rank_objective.hpp b/src/objective/rank_objective.hpp index 653fc6e8609a..6bd5324812f8 100644 --- a/src/objective/rank_objective.hpp +++ b/src/objective/rank_objective.hpp @@ -25,7 +25,10 @@ namespace LightGBM { class RankingObjective : public ObjectiveFunction { public: explicit RankingObjective(const Config& config) - : seed_(config.objective_seed) {} + : seed_(config.objective_seed) { + learning_rate_ = config.learning_rate; + position_bias_regularization_ = config.lambdarank_position_bias_regularization; + } explicit RankingObjective(const std::vector&) : seed_(0) {} @@ -37,12 +40,20 @@ class RankingObjective : public ObjectiveFunction { label_ = metadata.label(); // get weights weights_ = metadata.weights(); + // get positions + positions_ = metadata.positions(); + // get position ids + position_ids_ = metadata.position_ids(); + // get number of different position ids + num_position_ids_ = static_cast(metadata.num_position_ids()); // get boundries query_boundaries_ = metadata.query_boundaries(); if (query_boundaries_ == nullptr) { Log::Fatal("Ranking tasks require query information"); } num_queries_ = metadata.num_queries(); + // initialize position bias vectors + pos_biases_.resize(num_position_ids_, 0.0); } void GetGradients(const double* score, score_t* gradients, @@ -51,7 +62,13 @@ class RankingObjective : public ObjectiveFunction { for (data_size_t i = 0; i < num_queries_; ++i) { const data_size_t start = query_boundaries_[i]; const data_size_t cnt = query_boundaries_[i + 1] - query_boundaries_[i]; - GetGradientsForOneQuery(i, cnt, label_ + start, score + start, + std::vector score_adjusted; + if (num_position_ids_ > 0) { + for (data_size_t j = 0; j < cnt; ++j) { + score_adjusted.push_back(score[start + j] + pos_biases_[positions_[start + j]]); + } + } + GetGradientsForOneQuery(i, cnt, label_ + start, num_position_ids_ > 0 ? score_adjusted.data() : score + start, gradients + start, hessians + start); if (weights_ != nullptr) { for (data_size_t j = 0; j < cnt; ++j) { @@ -62,6 +79,9 @@ class RankingObjective : public ObjectiveFunction { } } } + if (num_position_ids_ > 0) { + UpdatePositionBiasFactors(gradients, hessians); + } } virtual void GetGradientsForOneQuery(data_size_t query_id, data_size_t cnt, @@ -69,6 +89,8 @@ class RankingObjective : public ObjectiveFunction { const double* score, score_t* lambdas, score_t* hessians) const = 0; + virtual void UpdatePositionBiasFactors(const score_t* /*lambdas*/, const score_t* /*hessians*/) const {} + const char* GetName() const override = 0; std::string ToString() const override { @@ -88,8 +110,20 @@ class RankingObjective : public ObjectiveFunction { const label_t* label_; /*! \brief Pointer of weights */ const label_t* weights_; + /*! \brief Pointer of positions */ + const data_size_t* positions_; + /*! \brief Pointer of position IDs */ + const std::string* position_ids_; + /*! \brief Pointer of label */ + data_size_t num_position_ids_; /*! \brief Query boundaries */ const data_size_t* query_boundaries_; + /*! \brief Position bias factors */ + mutable std::vector pos_biases_; + /*! \brief Learning rate to update position bias factors */ + double learning_rate_; + /*! \brief Position bias regularization */ + double position_bias_regularization_; }; /*! @@ -253,9 +287,67 @@ class LambdarankNDCG : public RankingObjective { } } + void UpdatePositionBiasFactors(const score_t* lambdas, const score_t* hessians) const override { + /// get number of threads + int num_threads = 1; + #pragma omp parallel + #pragma omp master + { + num_threads = omp_get_num_threads(); + } + // create per-thread buffers for first and second derivatives of utility w.r.t. position bias factors + std::vector bias_first_derivatives(num_position_ids_ * num_threads, 0.0); + std::vector bias_second_derivatives(num_position_ids_ * num_threads, 0.0); + std::vector instance_counts(num_position_ids_ * num_threads, 0); + #pragma omp parallel for schedule(guided) + for (data_size_t i = 0; i < num_data_; i++) { + // get thread ID + const int tid = omp_get_thread_num(); + size_t offset = static_cast(positions_[i] + tid * num_position_ids_); + // accumulate first derivatives of utility w.r.t. position bias factors, for each position + bias_first_derivatives[offset] -= lambdas[i]; + // accumulate second derivatives of utility w.r.t. position bias factors, for each position + bias_second_derivatives[offset] -= hessians[i]; + instance_counts[offset]++; + } + #pragma omp parallel for schedule(guided) + for (data_size_t i = 0; i < num_position_ids_; i++) { + double bias_first_derivative = 0.0; + double bias_second_derivative = 0.0; + int instance_count = 0; + // aggregate derivatives from per-thread buffers + for (int tid = 0; tid < num_threads; tid++) { + size_t offset = static_cast(i + tid * num_position_ids_); + bias_first_derivative += bias_first_derivatives[offset]; + bias_second_derivative += bias_second_derivatives[offset]; + instance_count += instance_counts[offset]; + } + // L2 regularization on position bias factors + bias_first_derivative -= pos_biases_[i] * position_bias_regularization_ * instance_count; + bias_second_derivative -= position_bias_regularization_ * instance_count; + // do Newton-Raphson step to update position bias factors + pos_biases_[i] += learning_rate_ * bias_first_derivative / (std::abs(bias_second_derivative) + 0.001); + } + LogDebugPositionBiasFactors(); + } + const char* GetName() const override { return "lambdarank"; } protected: + void LogDebugPositionBiasFactors() const { + std::stringstream message_stream; + message_stream << std::setw(15) << "position" + << std::setw(15) << "bias_factor" + << std::endl; + Log::Debug(message_stream.str().c_str()); + message_stream.str(""); + for (int i = 0; i < num_position_ids_; ++i) { + message_stream << std::setw(15) << position_ids_[i] + << std::setw(15) << pos_biases_[i]; + Log::Debug(message_stream.str().c_str()); + message_stream.str(""); + } + } /*! \brief Sigmoid param */ double sigmoid_; /*! \brief Normalize the lambdas or not */ diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index e9e7179a9b66..25413d7ea072 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -9,6 +9,7 @@ import re from os import getenv from pathlib import Path +from shutil import copyfile import numpy as np import psutil @@ -19,7 +20,7 @@ from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split import lightgbm as lgb -from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame +from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series from .utils import (SERIALIZERS, dummy_obj, load_breast_cancer, load_digits, load_iris, logistic_sigmoid, make_synthetic_regression, mse_obj, pickle_and_unpickle_object, sklearn_multiclass_custom_objective, @@ -747,6 +748,171 @@ def test_ranking_prediction_early_stopping(): np.testing.assert_allclose(ret_early, ret_early_more_strict) +# Simulates position bias for a given ranking dataset. +# The ouput dataset is identical to the input one with the exception for the relevance labels. +# The new labels are generated according to an instance of a cascade user model: +# for each query, the user is simulated to be traversing the list of documents ranked by a baseline ranker +# (in our example it is simply the ordering by some feature correlated with relevance, e.g., 34) +# and clicks on that document (new_label=1) with some probability 'pclick' depending on its true relevance; +# at each position the user may stop the traversal with some probability pstop. For the non-clicked documents, +# new_label=0. Thus the generated new labels are biased towards the baseline ranker. +# The positions of the documents in the ranked lists produced by the baseline, are returned. +def simulate_position_bias(file_dataset_in, file_query_in, file_dataset_out, baseline_feature): + # a mapping of a document's true relevance (defined on a 5-grade scale) into the probability of clicking it + def get_pclick(label): + if label == 0: + return 0.4 + elif label == 1: + return 0.6 + elif label == 2: + return 0.7 + elif label == 3: + return 0.8 + else: + return 0.9 + # an instantiation of a cascade model where the user stops with probability 0.2 after observing each document + pstop = 0.2 + + f_dataset_in = open(file_dataset_in, 'r') + f_dataset_out = open(file_dataset_out, 'w') + random.seed(10) + positions_all = [] + for line in open(file_query_in): + docs_num = int (line) + lines = [] + index_values = [] + positions = [0] * docs_num + for index in range(docs_num): + features = f_dataset_in.readline().split() + lines.append(features) + val = 0.0 + for feature_val in features: + feature_val_split = feature_val.split(":") + if int(feature_val_split[0]) == baseline_feature: + val = float(feature_val_split[1]) + index_values.append([index, val]) + index_values.sort(key=lambda x: -x[1]) + stop = False + for pos in range(docs_num): + index = index_values[pos][0] + new_label = 0 + if not stop: + label = int(lines[index][0]) + pclick = get_pclick(label) + if random.random() < pclick: + new_label = 1 + stop = random.random() < pstop + lines[index][0] = str(new_label) + positions[index] = pos + for features in lines: + f_dataset_out.write(' '.join(features) + '\n') + positions_all.extend(positions) + f_dataset_out.close() + return positions_all + + +@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Positions in learning to rank is not supported in CUDA version yet') +def test_ranking_with_position_information_with_file(tmp_path): + rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank' + params = { + 'objective': 'lambdarank', + 'verbose': -1, + 'eval_at': [3], + 'metric': 'ndcg', + 'bagging_freq': 1, + 'bagging_fraction': 0.9, + 'min_data_in_leaf': 50, + 'min_sum_hessian_in_leaf': 5.0 + } + + # simulate position bias for the train dataset and put the train dataset with biased labels to temp directory + positions = simulate_position_bias(str(rank_example_dir / 'rank.train'), str(rank_example_dir / 'rank.train.query'), str(tmp_path / 'rank.train'), baseline_feature=34) + copyfile(str(rank_example_dir / 'rank.train.query'), str(tmp_path / 'rank.train.query')) + copyfile(str(rank_example_dir / 'rank.test'), str(tmp_path / 'rank.test')) + copyfile(str(rank_example_dir / 'rank.test.query'), str(tmp_path / 'rank.test.query')) + + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_baseline = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + f_positions_out = open(str(tmp_path / 'rank.train.position'), 'w') + for pos in positions: + f_positions_out.write(str(pos) + '\n') + f_positions_out.close() + + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_unbiased_with_file = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + # the performance of the unbiased LambdaMART should outperform the plain LambdaMART on the dataset with position bias + assert gbm_baseline.best_score['valid_0']['ndcg@3'] + 0.03 <= gbm_unbiased_with_file.best_score['valid_0']['ndcg@3'] + + # add extra row to position file + with open(str(tmp_path / 'rank.train.position'), 'a') as file: + file.write('pos_1000\n') + file.close() + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + with pytest.raises(lgb.basic.LightGBMError, match="Positions size \(3006\) doesn't match data size"): + lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + +@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Positions in learning to rank is not supported in CUDA version yet') +def test_ranking_with_position_information_with_dataset_constructor(tmp_path): + rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank' + params = { + 'objective': 'lambdarank', + 'verbose': -1, + 'eval_at': [3], + 'metric': 'ndcg', + 'bagging_freq': 1, + 'bagging_fraction': 0.9, + 'min_data_in_leaf': 50, + 'min_sum_hessian_in_leaf': 5.0, + 'num_threads': 1, + 'deterministic': True, + 'seed': 0 + } + + # simulate position bias for the train dataset and put the train dataset with biased labels to temp directory + positions = simulate_position_bias(str(rank_example_dir / 'rank.train'), str(rank_example_dir / 'rank.train.query'), str(tmp_path / 'rank.train'), baseline_feature=34) + copyfile(str(rank_example_dir / 'rank.train.query'), str(tmp_path / 'rank.train.query')) + copyfile(str(rank_example_dir / 'rank.test'), str(tmp_path / 'rank.test')) + copyfile(str(rank_example_dir / 'rank.test.query'), str(tmp_path / 'rank.test.query')) + + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_baseline = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + positions = np.array(positions) + + # test setting positions through Dataset constructor with numpy array + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params, position=positions) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_unbiased = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + # the performance of the unbiased LambdaMART should outperform the plain LambdaMART on the dataset with position bias + assert gbm_baseline.best_score['valid_0']['ndcg@3'] + 0.03 <= gbm_unbiased.best_score['valid_0']['ndcg@3'] + + if PANDAS_INSTALLED: + # test setting positions through Dataset constructor with pandas Series + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params, position=pd_Series(positions)) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_unbiased_pandas_series = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + assert gbm_unbiased.best_score['valid_0']['ndcg@3'] == gbm_unbiased_pandas_series.best_score['valid_0']['ndcg@3'] + + # test setting positions through set_position + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + lgb_train.set_position(positions) + gbm_unbiased_set_position = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + assert gbm_unbiased.best_score['valid_0']['ndcg@3'] == gbm_unbiased_set_position.best_score['valid_0']['ndcg@3'] + + # test get_position works + positions_from_get = lgb_train.get_position() + np.testing.assert_array_equal(positions_from_get, positions) + + def test_early_stopping(): X, y = load_breast_cancer(return_X_y=True) params = { From bca716cc3cc42c60aa7af50f6b0357b580bf22d9 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Mon, 4 Sep 2023 16:44:58 +0200 Subject: [PATCH 02/14] [R-package] Fix misdetected objective when passing `lgb.Dataset` instance to `lightgbm()` (#6005) --- R-package/R/lightgbm.R | 5 ++++- R-package/man/lightgbm.Rd | 2 +- R-package/tests/testthat/test_basic.R | 15 +++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/R-package/R/lightgbm.R b/R-package/R/lightgbm.R index cb3ef31e8afa..711b3ef0dc38 100644 --- a/R-package/R/lightgbm.R +++ b/R-package/R/lightgbm.R @@ -116,7 +116,7 @@ NULL #' \item If passing a factor with more than two variables, will use objective \code{"multiclass"} #' (note that parameter \code{num_class} in this case will also be determined automatically from #' \code{label}). -#' \item Otherwise, will use objective \code{"regression"}. +#' \item Otherwise (or if passing \code{lgb.Dataset} as input), will use objective \code{"regression"}. #' } #' #' \emph{New in version 4.0.0} @@ -211,6 +211,9 @@ lightgbm <- function(data, rm(temp) } else { data_processor <- NULL + if (objective == "auto") { + objective <- "regression" + } } # Set data to a temporary variable diff --git a/R-package/man/lightgbm.Rd b/R-package/man/lightgbm.Rd index 88f3e3188fec..09d7704605c1 100644 --- a/R-package/man/lightgbm.Rd +++ b/R-package/man/lightgbm.Rd @@ -68,7 +68,7 @@ set to the iteration number of the best iteration.} \item If passing a factor with more than two variables, will use objective \code{"multiclass"} (note that parameter \code{num_class} in this case will also be determined automatically from \code{label}). - \item Otherwise, will use objective \code{"regression"}. + \item Otherwise (or if passing \code{lgb.Dataset} as input), will use objective \code{"regression"}. } \emph{New in version 4.0.0}} diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 4a687cdd0950..57c33c35dfee 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -3790,3 +3790,18 @@ test_that("lightgbm() accepts named categorical_features", { ) expect_true(length(model$params$categorical_feature) > 0L) }) + +test_that("lightgbm() correctly sets objective when passing lgb.Dataset as input", { + data(mtcars) + y <- mtcars$mpg + x <- as.matrix(mtcars[, -1L]) + ds <- lgb.Dataset(x, label = y) + model <- lightgbm( + ds + , objective = "auto" + , verbose = .LGB_VERBOSITY + , nrounds = 5L + , num_threads = .LGB_MAX_THREADS + ) + expect_equal(model$params$objective, "regression") +}) From 5ea005790e3f0eedeb75574d63abd911c158e2ef Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 4 Sep 2023 11:09:30 -0500 Subject: [PATCH 03/14] [ci] [R-package] test against R 4.3 on Linux and macOS (#6075) --- .ci/test_r_package.sh | 7 ++++--- .github/workflows/r_package.yml | 14 +++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.ci/test_r_package.sh b/.ci/test_r_package.sh index 40a438e0899b..e4d70261aa36 100755 --- a/.ci/test_r_package.sh +++ b/.ci/test_r_package.sh @@ -21,9 +21,9 @@ if [[ "${R_MAJOR_VERSION}" == "3" ]]; then export R_LINUX_VERSION="3.6.3-1bionic" export R_APT_REPO="bionic-cran35/" elif [[ "${R_MAJOR_VERSION}" == "4" ]]; then - export R_MAC_VERSION=4.2.2 - export R_MAC_PKG_URL=${CRAN_MIRROR}/bin/macosx/base/R-${R_MAC_VERSION}.pkg - export R_LINUX_VERSION="4.2.2-1.2204.0" + export R_MAC_VERSION=4.3.1 + export R_MAC_PKG_URL=${CRAN_MIRROR}/bin/macosx/big-sur-x86_64/base/R-${R_MAC_VERSION}-x86_64.pkg + export R_LINUX_VERSION="4.3.1-1.2204.0" export R_APT_REPO="jammy-cran40/" else echo "Unrecognized R version: ${R_VERSION}" @@ -56,6 +56,7 @@ if [[ $OS_NAME == "linux" ]]; then texlive-latex-recommended \ texlive-fonts-recommended \ texlive-fonts-extra \ + tidy \ qpdf \ || exit -1 diff --git a/.github/workflows/r_package.yml b/.github/workflows/r_package.yml index eb2cb90a424e..838528617143 100644 --- a/.github/workflows/r_package.yml +++ b/.github/workflows/r_package.yml @@ -48,7 +48,7 @@ jobs: - os: ubuntu-latest task: r-package compiler: gcc - r_version: 4.2 + r_version: 4.3 build_type: cmake container: 'ubuntu:22.04' - os: ubuntu-latest @@ -60,19 +60,19 @@ jobs: - os: ubuntu-latest task: r-package compiler: clang - r_version: 4.2 + r_version: 4.3 build_type: cmake container: 'ubuntu:22.04' - os: macOS-latest task: r-package compiler: gcc - r_version: 4.2 + r_version: 4.3 build_type: cmake container: null - os: macOS-latest task: r-package compiler: clang - r_version: 4.2 + r_version: 4.3 build_type: cmake container: null - os: windows-latest @@ -125,13 +125,13 @@ jobs: - os: ubuntu-latest task: r-package compiler: gcc - r_version: 4.2 + r_version: 4.3 build_type: cran container: 'ubuntu:22.04' - os: macOS-latest task: r-package compiler: clang - r_version: 4.2 + r_version: 4.3 build_type: cran container: null ################ @@ -140,7 +140,7 @@ jobs: - os: ubuntu-latest task: r-rchk compiler: gcc - r_version: 4.2 + r_version: 4.3 build_type: cran container: 'ubuntu:22.04' steps: From 82033064005d07ae98cd3003190d675992061a61 Mon Sep 17 00:00:00 2001 From: mjmckp Date: Tue, 5 Sep 2023 11:34:35 +1000 Subject: [PATCH 04/14] Fix updates in random forest model using GOSS data sample strategy (#6017) --- src/boosting/rf.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/boosting/rf.hpp b/src/boosting/rf.hpp index 9a87e982483e..88ece154e432 100644 --- a/src/boosting/rf.hpp +++ b/src/boosting/rf.hpp @@ -115,6 +115,12 @@ class RF : public GBDT { const data_size_t bag_data_cnt = data_sample_strategy_->bag_data_cnt(); const std::vector>& bag_data_indices = data_sample_strategy_->bag_data_indices(); + // GOSSStrategy->Bagging may modify value of bag_data_cnt_ + if (is_use_subset && bag_data_cnt < num_data_) { + tmp_grad_.resize(num_data_); + tmp_hess_.resize(num_data_); + } + CHECK_EQ(gradients, nullptr); CHECK_EQ(hessians, nullptr); From ee51120118b1e4a04c13df32da923d5805f4f9f9 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 6 Sep 2023 08:14:20 -0500 Subject: [PATCH 05/14] [python-package] simplify processing of pandas data (#6066) --- python-package/lightgbm/basic.py | 118 +++++++++++++----------- python-package/lightgbm/plotting.py | 4 +- tests/python_package_test/test_basic.py | 21 ++++- 3 files changed, 83 insertions(+), 60 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 2f061bdacf31..182ec200d207 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -668,57 +668,52 @@ def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None: def _data_from_pandas( - data, - feature_name: Optional[_LGBM_FeatureNameConfiguration], - categorical_feature: Optional[_LGBM_CategoricalFeatureConfiguration], + data: pd_DataFrame, + feature_name: _LGBM_FeatureNameConfiguration, + categorical_feature: _LGBM_CategoricalFeatureConfiguration, pandas_categorical: Optional[List[List]] -): - if isinstance(data, pd_DataFrame): - if len(data.shape) != 2 or data.shape[0] < 1: - raise ValueError('Input data must be 2 dimensional and non empty.') - if feature_name == 'auto' or feature_name is None: - data = data.rename(columns=str, copy=False) - cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] - cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] - if pandas_categorical is None: # train dataset - pandas_categorical = [list(data[col].cat.categories) for col in cat_cols] - else: - if len(cat_cols) != len(pandas_categorical): - raise ValueError('train and valid dataset categorical_feature do not match.') - for col, category in zip(cat_cols, pandas_categorical): - if list(data[col].cat.categories) != list(category): - data[col] = data[col].cat.set_categories(category) - if len(cat_cols): # cat_cols is list - data = data.copy(deep=False) # not alter origin DataFrame - data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan}) - if categorical_feature is not None: - if feature_name is None: - feature_name = list(data.columns) - if categorical_feature == 'auto': # use cat cols from DataFrame - categorical_feature = cat_cols_not_ordered - else: # use cat cols specified by user - categorical_feature = list(categorical_feature) # type: ignore[assignment] - if feature_name == 'auto': - feature_name = list(data.columns) - _check_for_bad_pandas_dtypes(data.dtypes) - df_dtypes = [dtype.type for dtype in data.dtypes] - df_dtypes.append(np.float32) # so that the target dtype considers floats - target_dtype = np.result_type(*df_dtypes) - try: - # most common case (no nullable dtypes) - data = data.to_numpy(dtype=target_dtype, copy=False) - except TypeError: - # 1.0 <= pd version < 1.1 and nullable dtypes, least common case - # raises error because array is casted to type(pd.NA) and there's no na_value argument - data = data.astype(target_dtype, copy=False).values - except ValueError: - # data has nullable dtypes, but we can specify na_value argument and copy will be made - data = data.to_numpy(dtype=target_dtype, na_value=np.nan) +) -> Tuple[np.ndarray, List[str], List[str], List[List]]: + if len(data.shape) != 2 or data.shape[0] < 1: + raise ValueError('Input data must be 2 dimensional and non empty.') + + # determine feature names + if feature_name == 'auto': + feature_name = [str(col) for col in data.columns] + + # determine categorical features + cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] + cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] + if pandas_categorical is None: # train dataset + pandas_categorical = [list(data[col].cat.categories) for col in cat_cols] else: - if feature_name == 'auto': - feature_name = None - if categorical_feature == 'auto': - categorical_feature = None + if len(cat_cols) != len(pandas_categorical): + raise ValueError('train and valid dataset categorical_feature do not match.') + for col, category in zip(cat_cols, pandas_categorical): + if list(data[col].cat.categories) != list(category): + data[col] = data[col].cat.set_categories(category) + if len(cat_cols): # cat_cols is list + data = data.copy(deep=False) # not alter origin DataFrame + data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan}) + if categorical_feature == 'auto': # use cat cols from DataFrame + categorical_feature = cat_cols_not_ordered + else: # use cat cols specified by user + categorical_feature = list(categorical_feature) # type: ignore[assignment] + + # get numpy representation of the data + _check_for_bad_pandas_dtypes(data.dtypes) + df_dtypes = [dtype.type for dtype in data.dtypes] + df_dtypes.append(np.float32) # so that the target dtype considers floats + target_dtype = np.result_type(*df_dtypes) + try: + # most common case (no nullable dtypes) + data = data.to_numpy(dtype=target_dtype, copy=False) + except TypeError: + # 1.0 <= pd version < 1.1 and nullable dtypes, least common case + # raises error because array is casted to type(pd.NA) and there's no na_value argument + data = data.astype(target_dtype, copy=False).values + except ValueError: + # data has nullable dtypes, but we can specify na_value argument and copy will be made + data = data.to_numpy(dtype=target_dtype, na_value=np.nan) return data, feature_name, categorical_feature, pandas_categorical @@ -1004,7 +999,15 @@ def predict( ctypes.c_int(len(data_names)), ) ) - data = _data_from_pandas(data, None, None, self.pandas_categorical)[0] + + if isinstance(data, pd_DataFrame): + data = _data_from_pandas( + data=data, + feature_name="auto", + categorical_feature="auto", + pandas_categorical=self.pandas_categorical + )[0] + predict_type = _C_API_PREDICT_NORMAL if raw_score: predict_type = _C_API_PREDICT_RAW_SCORE @@ -1854,10 +1857,13 @@ def _lazy_init( if reference is not None: self.pandas_categorical = reference.pandas_categorical categorical_feature = reference.categorical_feature - data, feature_name, categorical_feature, self.pandas_categorical = _data_from_pandas(data=data, - feature_name=feature_name, - categorical_feature=categorical_feature, - pandas_categorical=self.pandas_categorical) + if isinstance(data, pd_DataFrame): + data, feature_name, categorical_feature, self.pandas_categorical = _data_from_pandas( + data=data, + feature_name=feature_name, + categorical_feature=categorical_feature, + pandas_categorical=self.pandas_categorical + ) # process for args params = {} if params is None else params @@ -1867,10 +1873,10 @@ def _lazy_init( _log_warning(f'{key} keyword has been found in `params` and will be ignored.\n' f'Please use {key} argument of the Dataset constructor to pass this parameter.') # get categorical features - if categorical_feature is not None: + if isinstance(categorical_feature, list): categorical_indices = set() feature_dict = {} - if feature_name is not None: + if isinstance(feature_name, list): feature_dict = {name: i for i, name in enumerate(feature_name)} for name in categorical_feature: if isinstance(name, str) and name in feature_dict: diff --git a/python-package/lightgbm/plotting.py b/python-package/lightgbm/plotting.py index f16a4f274313..85b245c187ef 100644 --- a/python-package/lightgbm/plotting.py +++ b/python-package/lightgbm/plotting.py @@ -712,8 +712,8 @@ def create_tree_digraph( if isinstance(example_case, pd_DataFrame): example_case = _data_from_pandas( data=example_case, - feature_name=None, - categorical_feature=None, + feature_name="auto", + categorical_feature="auto", pandas_categorical=booster.pandas_categorical )[0] example_case = example_case[0] diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 267041eae2e4..7f8980c271f7 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -723,7 +723,12 @@ def test_no_copy_when_single_float_dtype_dataframe(dtype, feature_name): pd = pytest.importorskip('pandas') X = np.random.rand(10, 2).astype(dtype) df = pd.DataFrame(X) - built_data = lgb.basic._data_from_pandas(df, feature_name, None, None)[0] + built_data = lgb.basic._data_from_pandas( + data=df, + feature_name=feature_name, + categorical_feature="auto", + pandas_categorical=None + )[0] assert built_data.dtype == dtype assert np.shares_memory(X, built_data) @@ -734,7 +739,12 @@ def test_categorical_code_conversion_doesnt_modify_original_data(feature_name): X = np.random.choice(['a', 'b'], 100).reshape(-1, 1) column_name = 'a' if feature_name == 'auto' else feature_name[0] df = pd.DataFrame(X.copy(), columns=[column_name], dtype='category') - data = lgb.basic._data_from_pandas(df, feature_name, None, None)[0] + data = lgb.basic._data_from_pandas( + data=df, + feature_name=feature_name, + categorical_feature="auto", + pandas_categorical=None + )[0] # check that the original data wasn't modified np.testing.assert_equal(df[column_name], X[:, 0]) # check that the built data has the codes @@ -806,3 +816,10 @@ def test_set_leaf_output(): leaf_output = bst.get_leaf_output(tree_id=0, leaf_id=leaf_id) bst.set_leaf_output(tree_id=0, leaf_id=leaf_id, value=leaf_output + 1) np.testing.assert_allclose(bst.predict(X), y_pred + 1) + + +def test_feature_names_are_set_correctly_when_no_feature_names_passed_into_Dataset(): + ds = lgb.Dataset( + data=np.random.randn(100, 3), + ) + assert ds.construct().feature_name == ["Column_0", "Column_1", "Column_2"] From 1d7ee63686272bceffd522284127573b511df6be Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Wed, 6 Sep 2023 20:00:18 +0200 Subject: [PATCH 06/14] Remove superfluous todo from gitignore (#6081) --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index d4045d9a4798..bcf6f48b4cea 100644 --- a/.gitignore +++ b/.gitignore @@ -139,8 +139,6 @@ publish/ # Publish Web Output *.[Pp]ublish.xml *.azurePubxml -# TODO: Comment the next line if you want to checkin your web deploy settings -# but database connection strings (with potential passwords) will be unencrypted *.pubxml *.publishproj From e9fface2f7013014a4100270b01f44efd2b920d3 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 7 Sep 2023 15:44:40 -0500 Subject: [PATCH 07/14] [ci] [docs] fix broken ACM links (#6083) --- docs/.linkcheckerrc | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/.linkcheckerrc b/docs/.linkcheckerrc index e6ab4ea1a5df..96fdcbd08157 100644 --- a/docs/.linkcheckerrc +++ b/docs/.linkcheckerrc @@ -9,6 +9,7 @@ threads=1 ignore= pythonapi/lightgbm\..*\.html.* http.*amd.com/.* + https.*dl.acm.org/doi/.* https.*tandfonline.com/.* ignorewarnings=http-robots-denied,https-certificate-error checkextern=1 From 04b66e066228a947e5d713626e5b14439ada0909 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 8 Sep 2023 11:29:56 -0500 Subject: [PATCH 08/14] [docs] add vaex-ml to list of external repositories (#6085) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index a44d557f058b..f6f4e8c570e0 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,8 @@ lightgbm-transform (feature transformation binding): https://github.com/microsof `postgresml` (LightGBM training and prediction in SQL, via a Postgres extension): https://github.com/postgresml/postgresml +`vaex-ml` (Python DataFrame library with its own interface to LightGBM): https://github.com/vaexio/vaex + Support ------- From 68621628cec76c5580338b24210b3587f8704924 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sat, 9 Sep 2023 05:34:05 +0900 Subject: [PATCH 09/14] [python-package] [docs] Update key format of eval_hist in docstring example (#5980) --- python-package/lightgbm/engine.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 512aa4016345..daa6e16b6a9a 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -651,13 +651,18 @@ def cv( Returns ------- - eval_hist : dict - Evaluation history. + eval_results : dict + History of evaluation results of each metric. The dictionary has the following format: - {'metric1-mean': [values], 'metric1-stdv': [values], - 'metric2-mean': [values], 'metric2-stdv': [values], + {'valid metric1-mean': [values], 'valid metric1-stdv': [values], + 'valid metric2-mean': [values], 'valid metric2-stdv': [values], ...}. If ``return_cvbooster=True``, also returns trained boosters wrapped in a ``CVBooster`` object via ``cvbooster`` key. + If ``eval_train_metric=True``, also returns the train metric history. + In this case, the dictionary has the following format: + {'train metric1-mean': [values], 'valid metric1-mean': [values], + 'train metric2-mean': [values], 'valid metric2-mean': [values], + ...}. """ if not isinstance(train_set, Dataset): raise TypeError(f"cv() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.") From 501ce1cb63e39c67ceb93a063662f3d9867e044c Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 11 Sep 2023 17:12:29 -0500 Subject: [PATCH 10/14] Release v4.1.0 (#6076) --- .appveyor.yml | 2 +- R-package/configure | 18 +++++----- R-package/cran-comments.md | 61 ++++++++++++++++++++++++++++++++++ R-package/pkgdown/_pkgdown.yml | 2 +- VERSION.txt | 2 +- docs/Parameters.rst | 2 ++ include/LightGBM/config.h | 1 + python-package/pyproject.toml | 2 +- 8 files changed, 77 insertions(+), 13 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index dc431ded5018..8733301fbfe9 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,4 +1,4 @@ -version: 4.0.0.99.{build} +version: 4.1.0.{build} image: Visual Studio 2015 platform: x64 diff --git a/R-package/configure b/R-package/configure index 867ef2d395a6..5f441f942e63 100755 --- a/R-package/configure +++ b/R-package/configure @@ -1,6 +1,6 @@ #! /bin/sh # Guess values for system-dependent variables and create Makefiles. -# Generated by GNU Autoconf 2.71 for lightgbm 4.0.0.99. +# Generated by GNU Autoconf 2.71 for lightgbm 4.1.0. # # # Copyright (C) 1992-1996, 1998-2017, 2020-2021 Free Software Foundation, @@ -607,8 +607,8 @@ MAKEFLAGS= # Identity of this package. PACKAGE_NAME='lightgbm' PACKAGE_TARNAME='lightgbm' -PACKAGE_VERSION='4.0.0.99' -PACKAGE_STRING='lightgbm 4.0.0.99' +PACKAGE_VERSION='4.1.0' +PACKAGE_STRING='lightgbm 4.1.0' PACKAGE_BUGREPORT='' PACKAGE_URL='' @@ -1211,7 +1211,7 @@ if test "$ac_init_help" = "long"; then # Omit some internal or obsolete options to make the list less imposing. # This message is too long to be a string in the A/UX 3.1 sh. cat <<_ACEOF -\`configure' configures lightgbm 4.0.0.99 to adapt to many kinds of systems. +\`configure' configures lightgbm 4.1.0 to adapt to many kinds of systems. Usage: $0 [OPTION]... [VAR=VALUE]... @@ -1273,7 +1273,7 @@ fi if test -n "$ac_init_help"; then case $ac_init_help in - short | recursive ) echo "Configuration of lightgbm 4.0.0.99:";; + short | recursive ) echo "Configuration of lightgbm 4.1.0:";; esac cat <<\_ACEOF @@ -1341,7 +1341,7 @@ fi test -n "$ac_init_help" && exit $ac_status if $ac_init_version; then cat <<\_ACEOF -lightgbm configure 4.0.0.99 +lightgbm configure 4.1.0 generated by GNU Autoconf 2.71 Copyright (C) 2021 Free Software Foundation, Inc. @@ -1378,7 +1378,7 @@ cat >config.log <<_ACEOF This file contains any messages produced by compilers while running configure, to aid debugging if configure makes a mistake. -It was created by lightgbm $as_me 4.0.0.99, which was +It was created by lightgbm $as_me 4.1.0, which was generated by GNU Autoconf 2.71. Invocation command line was $ $0$ac_configure_args_raw @@ -2454,7 +2454,7 @@ cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 # report actual input values of CONFIG_FILES etc. instead of their # values after options handling. ac_log=" -This file was extended by lightgbm $as_me 4.0.0.99, which was +This file was extended by lightgbm $as_me 4.1.0, which was generated by GNU Autoconf 2.71. Invocation command line was CONFIG_FILES = $CONFIG_FILES @@ -2509,7 +2509,7 @@ ac_cs_config_escaped=`printf "%s\n" "$ac_cs_config" | sed "s/^ //; s/'/'\\\\\\\\ cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 ac_cs_config='$ac_cs_config_escaped' ac_cs_version="\\ -lightgbm config.status 4.0.0.99 +lightgbm config.status 4.1.0 configured by $0, generated by GNU Autoconf 2.71, with options \\"\$ac_cs_config\\" diff --git a/R-package/cran-comments.md b/R-package/cran-comments.md index 6fa74cdac4cb..44b8ed391bfc 100644 --- a/R-package/cran-comments.md +++ b/R-package/cran-comments.md @@ -1,5 +1,66 @@ # CRAN Submission History +## v4.1.0 - not submitted + +v4.1.0 was not submitted to CRAN, because https://github.com/microsoft/LightGBM/issues/5987 had not been resolved. + +## v4.0.0 - Submission 2 - (July 19, 2023) + +### CRAN response + +> Dear maintainer, +> package lightgbm_4.0.0.tar.gz does not pass the incoming checks automatically. + +The logs linked from those messagges showed one issue remaining on Debian (0 on Windows). + +```text +* checking examples ... [7s/4s] NOTE +Examples with CPU time > 2.5 times elapsed time + user system elapsed ratio +lgb.restore_handle 1.206 0.085 0.128 10.08 +``` + +### Maintainer Notes + +Chose to document the issue and need for a fix in https://github.com/microsoft/LightGBM/issues/5987, but not resubmit, +to avoid annoying CRAN maintainers. + +## v4.0.0 - Submission 1 - (July 16, 2023) + +### CRAN response + +> Dear maintainer, +> package lightgbm_4.0.0.tar.gz does not pass the incoming checks automatically. + +The logs linked from those messages showed the following issues from `R CMD check`. + +```text +* checking S3 generic/method consistency ... NOTE +Mismatches for apparent methods not registered: +merge: + function(x, y, ...) +merge.eval.string: + function(env) + +format: + function(x, ...) +format.eval.string: + function(eval_res, eval_err) +See section 'Registering S3 methods' in the 'Writing R Extensions' +manual. +``` + +```text +* checking examples ... [8s/4s] NOTE +Examples with CPU time > 2.5 times elapsed time + user system elapsed ratio +lgb.restore_handle 1.819 0.128 0.165 11.8 +``` + +### Maintainer Notes + +Attempted to fix these with https://github.com/microsoft/LightGBM/pull/5988 and resubmitted. + ## v3.3.5 - Submission 2 - (January 16, 2023) ### CRAN response diff --git a/R-package/pkgdown/_pkgdown.yml b/R-package/pkgdown/_pkgdown.yml index 233a31f0ead9..ca4a84a5d045 100644 --- a/R-package/pkgdown/_pkgdown.yml +++ b/R-package/pkgdown/_pkgdown.yml @@ -14,7 +14,7 @@ repo: user: https://github.com/ development: - mode: unreleased + mode: release authors: Yu Shi: diff --git a/VERSION.txt b/VERSION.txt index 200681852af8..ee74734aa225 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -4.0.0.99 +4.1.0 diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 7d825f9f135a..86104ba5be55 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -1141,6 +1141,8 @@ Objective Parameters - used only in ``lambdarank`` application when positional information is provided and position bias is modeled. Larger values reduce the inferred position bias factors. + - *New in version 4.1.0* + Metric Parameters ----------------- diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 343abf51e17f..187043cc2053 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -967,6 +967,7 @@ struct Config { // check = >=0.0 // desc = used only in ``lambdarank`` application when positional information is provided and position bias is modeled. Larger values reduce the inferred position bias factors. + // desc = *New in version 4.1.0* double lambdarank_position_bias_regularization = 0.0; #ifndef __NVCC__ diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 40d57e1af634..79006c92b963 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -30,7 +30,7 @@ maintainers = [ name = "lightgbm" readme = "README.rst" requires-python = ">=3.6" -version = "4.0.0.99" +version = "4.1.0" [project.optional-dependencies] dask = [ From 5e592fe6ff2b6eed83dd77942aab8e464768235c Mon Sep 17 00:00:00 2001 From: david-cortes Date: Tue, 12 Sep 2023 04:53:36 +0200 Subject: [PATCH 11/14] [python-package] Fix misdetected objective after multiple calls to `LGBMClassifier.fit` (#6002) --- python-package/lightgbm/sklearn.py | 2 ++ tests/python_package_test/test_sklearn.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 7e909342c01f..c71c233df908 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -1103,6 +1103,8 @@ def fit( # type: ignore[override] self._classes = self._le.classes_ self._n_classes = len(self._classes) # type: ignore[arg-type] + if self.objective is None: + self._objective = None # adjust eval metrics to match whether binary or multiclass # classification is being performed diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index e41719845c0a..2247c9a512d2 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1561,3 +1561,20 @@ def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type ) preds = model.predict(X) assert spearmanr(preds, y).correlation >= 0.99 + + +def test_classifier_fit_detects_classes_every_time(): + rng = np.random.default_rng(seed=123) + nrows = 1000 + ncols = 20 + + X = rng.standard_normal(size=(nrows, ncols)) + y_bin = (rng.random(size=nrows) <= .3).astype(np.float64) + y_multi = rng.integers(4, size=nrows) + + model = lgb.LGBMClassifier(verbose=-1) + for _ in range(2): + model.fit(X, y_multi) + assert model.objective_ == "multiclass" + model.fit(X, y_bin) + assert model.objective_ == "binary" From cd39520c5e00992572186e9998740011122c6150 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 12 Sep 2023 09:04:11 -0500 Subject: [PATCH 12/14] bump development version to 4.1.0.99 (#6090) --- .appveyor.yml | 2 +- R-package/configure | 18 +++++++++--------- R-package/pkgdown/_pkgdown.yml | 2 +- VERSION.txt | 2 +- python-package/pyproject.toml | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index 8733301fbfe9..4cff03d571a1 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,4 +1,4 @@ -version: 4.1.0.{build} +version: 4.1.0.99.{build} image: Visual Studio 2015 platform: x64 diff --git a/R-package/configure b/R-package/configure index 5f441f942e63..39a18d669833 100755 --- a/R-package/configure +++ b/R-package/configure @@ -1,6 +1,6 @@ #! /bin/sh # Guess values for system-dependent variables and create Makefiles. -# Generated by GNU Autoconf 2.71 for lightgbm 4.1.0. +# Generated by GNU Autoconf 2.71 for lightgbm 4.1.0.99. # # # Copyright (C) 1992-1996, 1998-2017, 2020-2021 Free Software Foundation, @@ -607,8 +607,8 @@ MAKEFLAGS= # Identity of this package. PACKAGE_NAME='lightgbm' PACKAGE_TARNAME='lightgbm' -PACKAGE_VERSION='4.1.0' -PACKAGE_STRING='lightgbm 4.1.0' +PACKAGE_VERSION='4.1.0.99' +PACKAGE_STRING='lightgbm 4.1.0.99' PACKAGE_BUGREPORT='' PACKAGE_URL='' @@ -1211,7 +1211,7 @@ if test "$ac_init_help" = "long"; then # Omit some internal or obsolete options to make the list less imposing. # This message is too long to be a string in the A/UX 3.1 sh. cat <<_ACEOF -\`configure' configures lightgbm 4.1.0 to adapt to many kinds of systems. +\`configure' configures lightgbm 4.1.0.99 to adapt to many kinds of systems. Usage: $0 [OPTION]... [VAR=VALUE]... @@ -1273,7 +1273,7 @@ fi if test -n "$ac_init_help"; then case $ac_init_help in - short | recursive ) echo "Configuration of lightgbm 4.1.0:";; + short | recursive ) echo "Configuration of lightgbm 4.1.0.99:";; esac cat <<\_ACEOF @@ -1341,7 +1341,7 @@ fi test -n "$ac_init_help" && exit $ac_status if $ac_init_version; then cat <<\_ACEOF -lightgbm configure 4.1.0 +lightgbm configure 4.1.0.99 generated by GNU Autoconf 2.71 Copyright (C) 2021 Free Software Foundation, Inc. @@ -1378,7 +1378,7 @@ cat >config.log <<_ACEOF This file contains any messages produced by compilers while running configure, to aid debugging if configure makes a mistake. -It was created by lightgbm $as_me 4.1.0, which was +It was created by lightgbm $as_me 4.1.0.99, which was generated by GNU Autoconf 2.71. Invocation command line was $ $0$ac_configure_args_raw @@ -2454,7 +2454,7 @@ cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 # report actual input values of CONFIG_FILES etc. instead of their # values after options handling. ac_log=" -This file was extended by lightgbm $as_me 4.1.0, which was +This file was extended by lightgbm $as_me 4.1.0.99, which was generated by GNU Autoconf 2.71. Invocation command line was CONFIG_FILES = $CONFIG_FILES @@ -2509,7 +2509,7 @@ ac_cs_config_escaped=`printf "%s\n" "$ac_cs_config" | sed "s/^ //; s/'/'\\\\\\\\ cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 ac_cs_config='$ac_cs_config_escaped' ac_cs_version="\\ -lightgbm config.status 4.1.0 +lightgbm config.status 4.1.0.99 configured by $0, generated by GNU Autoconf 2.71, with options \\"\$ac_cs_config\\" diff --git a/R-package/pkgdown/_pkgdown.yml b/R-package/pkgdown/_pkgdown.yml index ca4a84a5d045..233a31f0ead9 100644 --- a/R-package/pkgdown/_pkgdown.yml +++ b/R-package/pkgdown/_pkgdown.yml @@ -14,7 +14,7 @@ repo: user: https://github.com/ development: - mode: release + mode: unreleased authors: Yu Shi: diff --git a/VERSION.txt b/VERSION.txt index ee74734aa225..1f06da0058c9 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -4.1.0 +4.1.0.99 diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 79006c92b963..6e43dc242d1b 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -30,7 +30,7 @@ maintainers = [ name = "lightgbm" readme = "README.rst" requires-python = ">=3.6" -version = "4.1.0" +version = "4.1.0.99" [project.optional-dependencies] dask = [ From a92bf3742be78b96edc25bffac95027dc78fc400 Mon Sep 17 00:00:00 2001 From: shiyu1994 Date: Wed, 13 Sep 2023 01:06:20 +0800 Subject: [PATCH 13/14] [fix] fix quantized training (fixes #5982) (fixes #5994) (#6092) * fix leaf splits update after split in quantized training * fix preparation ordered gradients for quantized training * remove force_row_wise in distributed test for quantized training * Update src/treelearner/leaf_splits.hpp --------- Co-authored-by: James Lamb --- src/io/dataset.cpp | 37 +++++--- src/treelearner/leaf_splits.hpp | 19 ++++ src/treelearner/serial_tree_learner.cpp | 115 ++++++++++++++++++++---- src/treelearner/serial_tree_learner.h | 2 + tests/python_package_test/test_dask.py | 1 - 5 files changed, 142 insertions(+), 32 deletions(-) diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index d5aa707adcc0..cd692afb031a 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -1278,21 +1278,34 @@ void Dataset::ConstructHistogramsInner( auto ptr_ordered_grad = gradients; auto ptr_ordered_hess = hessians; if (num_used_dense_group > 0) { - if (USE_INDICES) { - if (USE_HESSIAN) { -#pragma omp parallel for schedule(static, 512) if (num_data >= 1024) + if (USE_QUANT_GRAD) { + int16_t* ordered_gradients_and_hessians = reinterpret_cast(ordered_gradients); + const int16_t* gradients_and_hessians = reinterpret_cast(gradients); + if (USE_INDICES) { + #pragma omp parallel for schedule(static, 512) if (num_data >= 1024) for (data_size_t i = 0; i < num_data; ++i) { - ordered_gradients[i] = gradients[data_indices[i]]; - ordered_hessians[i] = hessians[data_indices[i]]; + ordered_gradients_and_hessians[i] = gradients_and_hessians[data_indices[i]]; } - ptr_ordered_grad = ordered_gradients; - ptr_ordered_hess = ordered_hessians; - } else { -#pragma omp parallel for schedule(static, 512) if (num_data >= 1024) - for (data_size_t i = 0; i < num_data; ++i) { - ordered_gradients[i] = gradients[data_indices[i]]; + ptr_ordered_grad = reinterpret_cast(ordered_gradients); + ptr_ordered_hess = nullptr; + } + } else { + if (USE_INDICES) { + if (USE_HESSIAN) { + #pragma omp parallel for schedule(static, 512) if (num_data >= 1024) + for (data_size_t i = 0; i < num_data; ++i) { + ordered_gradients[i] = gradients[data_indices[i]]; + ordered_hessians[i] = hessians[data_indices[i]]; + } + ptr_ordered_grad = ordered_gradients; + ptr_ordered_hess = ordered_hessians; + } else { + #pragma omp parallel for schedule(static, 512) if (num_data >= 1024) + for (data_size_t i = 0; i < num_data; ++i) { + ordered_gradients[i] = gradients[data_indices[i]]; + } + ptr_ordered_grad = ordered_gradients; } - ptr_ordered_grad = ordered_gradients; } } OMP_INIT_EX(); diff --git a/src/treelearner/leaf_splits.hpp b/src/treelearner/leaf_splits.hpp index 163bfc4df9ca..fdf55693a0e9 100644 --- a/src/treelearner/leaf_splits.hpp +++ b/src/treelearner/leaf_splits.hpp @@ -53,6 +53,25 @@ class LeafSplits { weight_ = weight; } + /*! + * \brief Init split on current leaf on partial data. + * \param leaf Index of current leaf + * \param data_partition current data partition + * \param sum_gradients + * \param sum_hessians + * \param sum_gradients_and_hessians + * \param weight + */ + void Init(int leaf, const DataPartition* data_partition, double sum_gradients, + double sum_hessians, int64_t sum_gradients_and_hessians, double weight) { + leaf_index_ = leaf; + data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_); + sum_gradients_ = sum_gradients; + sum_hessians_ = sum_hessians; + int_sum_gradients_and_hessians_ = sum_gradients_and_hessians; + weight_ = weight; + } + /*! * \brief Init split on current leaf on partial data. * \param leaf Index of current leaf diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index c322c1a796c2..37d9a2a50713 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -841,32 +841,65 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, #endif // init the leaves that used on next iteration - if (best_split_info.left_count < best_split_info.right_count) { - CHECK_GT(best_split_info.left_count, 0); - smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), - best_split_info.left_sum_gradient, - best_split_info.left_sum_hessian, - best_split_info.left_output); - larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), - best_split_info.right_sum_gradient, - best_split_info.right_sum_hessian, - best_split_info.right_output); + if (!config_->use_quantized_grad) { + if (best_split_info.left_count < best_split_info.right_count) { + CHECK_GT(best_split_info.left_count, 0); + smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), + best_split_info.left_sum_gradient, + best_split_info.left_sum_hessian, + best_split_info.left_output); + larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), + best_split_info.right_sum_gradient, + best_split_info.right_sum_hessian, + best_split_info.right_output); + } else { + CHECK_GT(best_split_info.right_count, 0); + smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), + best_split_info.right_sum_gradient, + best_split_info.right_sum_hessian, + best_split_info.right_output); + larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), + best_split_info.left_sum_gradient, + best_split_info.left_sum_hessian, + best_split_info.left_output); + } } else { - CHECK_GT(best_split_info.right_count, 0); - smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), - best_split_info.right_sum_gradient, - best_split_info.right_sum_hessian, - best_split_info.right_output); - larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), - best_split_info.left_sum_gradient, - best_split_info.left_sum_hessian, - best_split_info.left_output); + if (best_split_info.left_count < best_split_info.right_count) { + CHECK_GT(best_split_info.left_count, 0); + smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), + best_split_info.left_sum_gradient, + best_split_info.left_sum_hessian, + best_split_info.left_sum_gradient_and_hessian, + best_split_info.left_output); + larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), + best_split_info.right_sum_gradient, + best_split_info.right_sum_hessian, + best_split_info.right_sum_gradient_and_hessian, + best_split_info.right_output); + } else { + CHECK_GT(best_split_info.right_count, 0); + smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), + best_split_info.right_sum_gradient, + best_split_info.right_sum_hessian, + best_split_info.right_sum_gradient_and_hessian, + best_split_info.right_output); + larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), + best_split_info.left_sum_gradient, + best_split_info.left_sum_hessian, + best_split_info.left_sum_gradient_and_hessian, + best_split_info.left_output); + } } if (config_->use_quantized_grad && config_->tree_learner != std::string("data")) { gradient_discretizer_->SetNumBitsInHistogramBin(*left_leaf, *right_leaf, data_partition_->leaf_count(*left_leaf), data_partition_->leaf_count(*right_leaf)); } + + #ifdef DEBUG + CheckSplit(best_split_info, *left_leaf, *right_leaf); + #endif + auto leaves_need_update = constraints_->Update( is_numerical_split, *left_leaf, *right_leaf, best_split_info.monotone_type, best_split_info.right_output, @@ -1024,4 +1057,48 @@ std::vector node_used_features = col_sampler_.GetByNode(tree, leaf); *split = bests[best_idx]; } +#ifdef DEBUG +void SerialTreeLearner::CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index) { + data_size_t num_data_in_left = 0; + data_size_t num_data_in_right = 0; + const data_size_t* data_indices_in_left = data_partition_->GetIndexOnLeaf(left_leaf_index, &num_data_in_left); + const data_size_t* data_indices_in_right = data_partition_->GetIndexOnLeaf(right_leaf_index, &num_data_in_right); + if (config_->use_quantized_grad) { + int32_t sum_left_gradient = 0; + int32_t sum_left_hessian = 0; + int32_t sum_right_gradient = 0; + int32_t sum_right_hessian = 0; + const int8_t* discretized_grad_and_hess = gradient_discretizer_->discretized_gradients_and_hessians(); + for (data_size_t i = 0; i < num_data_in_left; ++i) { + const data_size_t index = data_indices_in_left[i]; + sum_left_gradient += discretized_grad_and_hess[2 * index + 1]; + sum_left_hessian += discretized_grad_and_hess[2 * index]; + } + for (data_size_t i = 0; i < num_data_in_right; ++i) { + const data_size_t index = data_indices_in_right[i]; + sum_right_gradient += discretized_grad_and_hess[2 * index + 1]; + sum_right_hessian += discretized_grad_and_hess[2 * index]; + } + Log::Warning("============================ start leaf split info ============================"); + Log::Warning("left_leaf_index = %d, right_leaf_index = %d", left_leaf_index, right_leaf_index); + Log::Warning("num_data_in_left = %d, num_data_in_right = %d", num_data_in_left, num_data_in_right); + Log::Warning("sum_left_gradient = %d, best_split_info->left_sum_gradient_and_hessian.gradient = %d", sum_left_gradient, + static_cast(best_split_info.left_sum_gradient_and_hessian >> 32)); + Log::Warning("sum_left_hessian = %d, best_split_info->left_sum_gradient_and_hessian.hessian = %d", sum_left_hessian, + static_cast(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff)); + Log::Warning("sum_right_gradient = %d, best_split_info->right_sum_gradient_and_hessian.gradient = %d", sum_right_gradient, + static_cast(best_split_info.right_sum_gradient_and_hessian >> 32)); + Log::Warning("sum_right_hessian = %d, best_split_info->right_sum_gradient_and_hessian.hessian = %d", sum_right_hessian, + static_cast(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff)); + CHECK_EQ(num_data_in_left, best_split_info.left_count); + CHECK_EQ(num_data_in_right, best_split_info.right_count); + CHECK_EQ(sum_left_gradient, static_cast(best_split_info.left_sum_gradient_and_hessian >> 32)) + CHECK_EQ(sum_left_hessian, static_cast(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff)); + CHECK_EQ(sum_right_gradient, static_cast(best_split_info.right_sum_gradient_and_hessian >> 32)); + CHECK_EQ(sum_right_hessian, static_cast(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff)); + Log::Warning("============================ end leaf split info ============================"); + } +} +#endif + } // namespace LightGBM diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h index d815d265c0d2..93e0787a90cf 100644 --- a/src/treelearner/serial_tree_learner.h +++ b/src/treelearner/serial_tree_learner.h @@ -171,7 +171,9 @@ class SerialTreeLearner: public TreeLearner { std::set FindAllForceFeatures(Json force_split_leaf_setting); + #ifdef DEBUG void CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index); + #endif /*! * \brief Get the number of data in a leaf diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index cb69440b3cde..9da50945385c 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1838,7 +1838,6 @@ def test_distributed_quantized_training(cluster): 'num_grad_quant_bins': 30, 'quant_train_renew_leaf': True, 'verbose': -1, - 'force_row_wise': True, } quant_dask_classifier = lgb.DaskLGBMRegressor( From 921479b99fb5b691801e0e794f2196a94ea17d79 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 12 Sep 2023 13:40:41 -0500 Subject: [PATCH 14/14] update to fmt 10.1.1, fast_double_parser 0.7.0 (#6074) --- CMakeLists.txt | 7 +++++++ external_libs/fast_double_parser | 2 +- external_libs/fmt | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5087d6a8fddb..6705ef130052 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -326,6 +326,13 @@ if(UNIX OR MINGW OR CYGWIN) CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -Wextra -Wall -Wno-ignored-attributes -Wno-unknown-pragmas -Wno-return-type" ) + if(MINGW) + # ignore this warning: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=95353 + set( + CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS} -Wno-stringop-overflow" + ) + endif() if(USE_DEBUG) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0") else() diff --git a/external_libs/fast_double_parser b/external_libs/fast_double_parser index ace60646c02d..efec03532ef6 160000 --- a/external_libs/fast_double_parser +++ b/external_libs/fast_double_parser @@ -1 +1 @@ -Subproject commit ace60646c02dc54c57f19d644e49a61e7e7758ec +Subproject commit efec03532ef65984786e5e32dbc81f6e6a55a115 diff --git a/external_libs/fmt b/external_libs/fmt index b6f4ceaed0a0..f5e54359df4c 160000 --- a/external_libs/fmt +++ b/external_libs/fmt @@ -1 +1 @@ -Subproject commit b6f4ceaed0a0a24ccf575fab6c56dd50ccf6f1a9 +Subproject commit f5e54359df4c26b6230fc61d38aa294581393084