diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 4149233b63b0..7c128c7547f4 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -77,6 +77,7 @@ Booster <- R6::R6Class( LGBM_BoosterCreateFromModelfile_R , modelfile ) + params <- private$get_loaded_param(handle) } else if (!is.null(model_str)) { @@ -727,6 +728,20 @@ Booster <- R6::R6Class( }, + get_loaded_param = function(handle) { + params_str <- .Call( + LGBM_BoosterGetLoadedParam_R + , handle + ) + params <- jsonlite::fromJSON(params_str) + if ("interaction_constraints" %in% names(params)) { + params[["interaction_constraints"]] <- lapply(params[["interaction_constraints"]], function(x) x + 1L) + } + + return(params) + + }, + inner_eval = function(data_name, data_idx, feval = NULL) { # Check for unknown dataset (over the maximum provided range) diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index 1d503ab7b465..82956daef4b9 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -1183,6 +1183,27 @@ SEXP LGBM_DumpParamAliases_R() { R_API_END(); } +SEXP LGBM_BoosterGetLoadedParam_R(SEXP handle) { + SEXP cont_token = PROTECT(R_MakeUnwindCont()); + R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); + SEXP params_str; + int64_t out_len = 0; + int64_t buf_len = 1024 * 1024; + std::vector inner_char_buf(buf_len); + CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), buf_len, &out_len, inner_char_buf.data())); + // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again + if (out_len > buf_len) { + inner_char_buf.resize(out_len); + CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), out_len, &out_len, inner_char_buf.data())); + } + params_str = PROTECT(safe_R_string(static_cast(1), &cont_token)); + SET_STRING_ELT(params_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token)); + UNPROTECT(2); + return params_str; + R_API_END(); +} + // .Call() calls static const R_CallMethodDef CallEntries[] = { {"LGBM_HandleIsNull_R" , (DL_FUNC) &LGBM_HandleIsNull_R , 1}, @@ -1211,6 +1232,7 @@ static const R_CallMethodDef CallEntries[] = { {"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2}, {"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2}, {"LGBM_BoosterGetNumFeature_R" , (DL_FUNC) &LGBM_BoosterGetNumFeature_R , 1}, + {"LGBM_BoosterGetLoadedParam_R" , (DL_FUNC) &LGBM_BoosterGetLoadedParam_R , 1}, {"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1}, {"LGBM_BoosterUpdateOneIterCustom_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R , 4}, {"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1}, diff --git a/R-package/src/lightgbm_R.h b/R-package/src/lightgbm_R.h index 9818a233aee3..7141a06a207c 100644 --- a/R-package/src/lightgbm_R.h +++ b/R-package/src/lightgbm_R.h @@ -266,6 +266,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R( SEXP model_str ); +/*! +* \brief Get parameters as JSON string. +* \param handle Booster handle +* \return R character vector (length=1) with parameters in JSON format +*/ +LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLoadedParam_R( + SEXP handle +); + /*! * \brief Merge model in two Boosters to first handle * \param handle handle primary Booster handle, will merge other handle to this diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index 8208ef416a65..1bd565a07345 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -172,15 +172,24 @@ test_that("Loading a Booster from a text file works", { data(agaricus.test, package = "lightgbm") train <- agaricus.train test <- agaricus.test + params <- list( + num_leaves = 4L + , boosting = "rf" + , bagging_fraction = 0.8 + , bagging_freq = 1L + , boost_from_average = FALSE + , categorical_feature = c(1L, 2L) + , interaction_constraints = list(c(1L, 2L), 1L) + , feature_contri = rep(0.5, ncol(train$data)) + , metric = c("mape", "average_precision") + , learning_rate = 1.0 + , objective = "binary" + , verbosity = VERBOSITY + ) bst <- lightgbm( data = as.matrix(train$data) , label = train$label - , params = list( - num_leaves = 4L - , learning_rate = 1.0 - , objective = "binary" - , verbose = VERBOSITY - ) + , params = params , nrounds = 2L ) expect_true(lgb.is.Booster(bst)) @@ -199,6 +208,9 @@ test_that("Loading a Booster from a text file works", { ) pred2 <- predict(bst2, test$data) expect_identical(pred, pred2) + + # check that the parameters are loaded correctly + expect_equal(bst2$params[names(params)], params) }) test_that("boosters with linear models at leaves can be written to text file and re-loaded successfully", { diff --git a/helpers/parameter_generator.py b/helpers/parameter_generator.py index 9bc62b093a26..407f2c73e1e3 100644 --- a/helpers/parameter_generator.py +++ b/helpers/parameter_generator.py @@ -6,6 +6,7 @@ along with parameters description in LightGBM/docs/Parameters.rst file from the information in LightGBM/include/LightGBM/config.h file. """ +import re from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple @@ -373,6 +374,32 @@ def gen_parameter_code( } """ + str_to_write += """const std::unordered_map& Config::ParameterTypes() { + static std::unordered_map map({""" + int_t_pat = re.compile(r'int\d+_t') + # the following are stored as comma separated strings but are arrays in the wrappers + overrides = { + 'categorical_feature': 'vector', + 'ignore_column': 'vector', + 'interaction_constraints': 'vector>', + } + for x in infos: + for y in x: + name = y["name"][0] + if name == 'task': + continue + if name in overrides: + param_type = overrides[name] + else: + param_type = int_t_pat.sub('int', y["inner_type"][0]).replace('std::', '') + str_to_write += '\n {"' + name + '", "' + param_type + '"},' + str_to_write += """ + }); + return map; +} + +""" + str_to_write += "} // namespace LightGBM\n" with open(config_out_cpp, "w") as config_out_cpp_file: config_out_cpp_file.write(str_to_write) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 7530495c0e17..1bfc18b4470b 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -313,6 +313,8 @@ class LIGHTGBM_EXPORT Boosting { */ static Boosting* CreateBoosting(const std::string& type, const char* filename); + virtual std::string GetLoadedParam() const = 0; + virtual bool IsLinear() const { return false; } virtual std::string ParserConfigStr() const = 0; diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 4f89cc784da1..287826ea182c 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -595,6 +595,20 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(const char* model_str, int* out_num_iterations, BoosterHandle* out); +/*! + * \brief Get parameters as JSON string. + * \param handle Handle of booster. + * \param buffer_len Allocated space for string. + * \param[out] out_len Actual size of string. + * \param[out] out_str JSON string containing parameters. + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterGetLoadedParam(BoosterHandle handle, + int64_t buffer_len, + int64_t* out_len, + char* out_str); + + /*! * \brief Free space for booster. * \param handle Handle of booster to be freed diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index f4791394c43c..2de8c984f70b 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -1077,6 +1077,7 @@ struct Config { static const std::unordered_set& parameter_set(); std::vector> auc_mu_weights_matrix; std::vector> interaction_constraints_vector; + static const std::unordered_map& ParameterTypes(); static const std::string DumpAliases(); private: diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index bda805808f75..974b6544ae7e 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2816,6 +2816,9 @@ def __init__( ctypes.byref(out_num_class))) self.__num_class = out_num_class.value self.pandas_categorical = _load_pandas_categorical(file_name=model_file) + if params: + _log_warning('Ignoring params argument, using parameters from model file.') + params = self._get_loaded_param() elif model_str is not None: self.model_from_string(model_str) else: @@ -2864,6 +2867,28 @@ def __setstate__(self, state: Dict[str, Any]) -> None: state['handle'] = handle self.__dict__.update(state) + def _get_loaded_param(self) -> Dict[str, Any]: + buffer_len = 1 << 20 + tmp_out_len = ctypes.c_int64(0) + string_buffer = ctypes.create_string_buffer(buffer_len) + ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) + _safe_call(_LIB.LGBM_BoosterGetLoadedParam( + self.handle, + ctypes.c_int64(buffer_len), + ctypes.byref(tmp_out_len), + ptr_string_buffer)) + actual_len = tmp_out_len.value + # if buffer length is not long enough, re-allocate a buffer + if actual_len > buffer_len: + string_buffer = ctypes.create_string_buffer(actual_len) + ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) + _safe_call(_LIB.LGBM_BoosterGetLoadedParam( + self.handle, + ctypes.c_int64(actual_len), + ctypes.byref(tmp_out_len), + ptr_string_buffer)) + return json.loads(string_buffer.value.decode('utf-8')) + def free_dataset(self) -> "Booster": """Free Booster's Datasets. diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 3d9ef619f1dd..5cc3cc7541b0 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -157,6 +157,60 @@ class GBDT : public GBDTBase { */ int GetCurrentIteration() const override { return static_cast(models_.size()) / num_tree_per_iteration_; } + /*! + * \brief Get parameters as a JSON string + */ + std::string GetLoadedParam() const override { + if (loaded_parameter_.empty()) { + return std::string("{}"); + } + const auto param_types = Config::ParameterTypes(); + const auto lines = Common::Split(loaded_parameter_.c_str(), "\n"); + bool first = true; + std::stringstream str_buf; + str_buf << "{"; + for (const auto& line : lines) { + const auto pair = Common::Split(line.c_str(), ":"); + if (pair[1] == " ]") + continue; + if (first) { + first = false; + str_buf << "\""; + } else { + str_buf << ",\""; + } + const auto param = pair[0].substr(1); + const auto value_str = pair[1].substr(1, pair[1].size() - 2); + const auto param_type = param_types.at(param); + str_buf << param << "\": "; + if (param_type == "string") { + str_buf << "\"" << value_str << "\""; + } else if (param_type == "int") { + int value; + Common::Atoi(value_str.c_str(), &value); + str_buf << value; + } else if (param_type == "double") { + double value; + Common::Atof(value_str.c_str(), &value); + str_buf << value; + } else if (param_type == "bool") { + bool value = value_str == "1"; + str_buf << std::boolalpha << value; + } else if (param_type.substr(0, 6) == "vector") { + str_buf << "["; + if (param_type.substr(7, 6) == "string") { + const auto parts = Common::Split(value_str.c_str(), ","); + str_buf << "\"" << Common::Join(parts, "\",\"") << "\""; + } else { + str_buf << value_str; + } + str_buf << "]"; + } + } + str_buf << "}"; + return str_buf.str(); + } + /*! * \brief Can use early stopping for prediction or not * \return True if cannot use early stopping for prediction diff --git a/src/c_api.cpp b/src/c_api.cpp index 5a5dc81fd5e9..20633273134e 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -1748,6 +1748,21 @@ int LGBM_BoosterLoadModelFromString( API_END(); } +int LGBM_BoosterGetLoadedParam( + BoosterHandle handle, + int64_t buffer_len, + int64_t* out_len, + char* out_str) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + std::string params = ref_booster->GetBoosting()->GetLoadedParam(); + *out_len = static_cast(params.size()) + 1; + if (*out_len <= buffer_len) { + std::memcpy(out_str, params.c_str(), *out_len); + } + API_END(); +} + #ifdef _MSC_VER #pragma warning(disable : 4702) #endif diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 6c2e3cabad00..a86abd3a2c1d 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -894,4 +894,141 @@ const std::unordered_map>& Config::paramet return map; } +const std::unordered_map& Config::ParameterTypes() { + static std::unordered_map map({ + {"config", "string"}, + {"objective", "string"}, + {"boosting", "string"}, + {"data", "string"}, + {"valid", "vector"}, + {"num_iterations", "int"}, + {"learning_rate", "double"}, + {"num_leaves", "int"}, + {"tree_learner", "string"}, + {"num_threads", "int"}, + {"device_type", "string"}, + {"seed", "int"}, + {"deterministic", "bool"}, + {"force_col_wise", "bool"}, + {"force_row_wise", "bool"}, + {"histogram_pool_size", "double"}, + {"max_depth", "int"}, + {"min_data_in_leaf", "int"}, + {"min_sum_hessian_in_leaf", "double"}, + {"bagging_fraction", "double"}, + {"pos_bagging_fraction", "double"}, + {"neg_bagging_fraction", "double"}, + {"bagging_freq", "int"}, + {"bagging_seed", "int"}, + {"feature_fraction", "double"}, + {"feature_fraction_bynode", "double"}, + {"feature_fraction_seed", "int"}, + {"extra_trees", "bool"}, + {"extra_seed", "int"}, + {"early_stopping_round", "int"}, + {"first_metric_only", "bool"}, + {"max_delta_step", "double"}, + {"lambda_l1", "double"}, + {"lambda_l2", "double"}, + {"linear_lambda", "double"}, + {"min_gain_to_split", "double"}, + {"drop_rate", "double"}, + {"max_drop", "int"}, + {"skip_drop", "double"}, + {"xgboost_dart_mode", "bool"}, + {"uniform_drop", "bool"}, + {"drop_seed", "int"}, + {"top_rate", "double"}, + {"other_rate", "double"}, + {"min_data_per_group", "int"}, + {"max_cat_threshold", "int"}, + {"cat_l2", "double"}, + {"cat_smooth", "double"}, + {"max_cat_to_onehot", "int"}, + {"top_k", "int"}, + {"monotone_constraints", "vector"}, + {"monotone_constraints_method", "string"}, + {"monotone_penalty", "double"}, + {"feature_contri", "vector"}, + {"forcedsplits_filename", "string"}, + {"refit_decay_rate", "double"}, + {"cegb_tradeoff", "double"}, + {"cegb_penalty_split", "double"}, + {"cegb_penalty_feature_lazy", "vector"}, + {"cegb_penalty_feature_coupled", "vector"}, + {"path_smooth", "double"}, + {"interaction_constraints", "vector>"}, + {"verbosity", "int"}, + {"input_model", "string"}, + {"output_model", "string"}, + {"saved_feature_importance_type", "int"}, + {"snapshot_freq", "int"}, + {"linear_tree", "bool"}, + {"max_bin", "int"}, + {"max_bin_by_feature", "vector"}, + {"min_data_in_bin", "int"}, + {"bin_construct_sample_cnt", "int"}, + {"data_random_seed", "int"}, + {"is_enable_sparse", "bool"}, + {"enable_bundle", "bool"}, + {"use_missing", "bool"}, + {"zero_as_missing", "bool"}, + {"feature_pre_filter", "bool"}, + {"pre_partition", "bool"}, + {"two_round", "bool"}, + {"header", "bool"}, + {"label_column", "string"}, + {"weight_column", "string"}, + {"group_column", "string"}, + {"ignore_column", "vector"}, + {"categorical_feature", "vector"}, + {"forcedbins_filename", "string"}, + {"save_binary", "bool"}, + {"precise_float_parser", "bool"}, + {"parser_config_file", "string"}, + {"start_iteration_predict", "int"}, + {"num_iteration_predict", "int"}, + {"predict_raw_score", "bool"}, + {"predict_leaf_index", "bool"}, + {"predict_contrib", "bool"}, + {"predict_disable_shape_check", "bool"}, + {"pred_early_stop", "bool"}, + {"pred_early_stop_freq", "int"}, + {"pred_early_stop_margin", "double"}, + {"output_result", "string"}, + {"convert_model_language", "string"}, + {"convert_model", "string"}, + {"objective_seed", "int"}, + {"num_class", "int"}, + {"is_unbalance", "bool"}, + {"scale_pos_weight", "double"}, + {"sigmoid", "double"}, + {"boost_from_average", "bool"}, + {"reg_sqrt", "bool"}, + {"alpha", "double"}, + {"fair_c", "double"}, + {"poisson_max_delta_step", "double"}, + {"tweedie_variance_power", "double"}, + {"lambdarank_truncation_level", "int"}, + {"lambdarank_norm", "bool"}, + {"label_gain", "vector"}, + {"metric", "vector"}, + {"metric_freq", "int"}, + {"is_provide_training_metric", "bool"}, + {"eval_at", "vector"}, + {"multi_error_top_k", "int"}, + {"auc_mu_weights", "vector"}, + {"num_machines", "int"}, + {"local_listen_port", "int"}, + {"time_out", "int"}, + {"machine_list_filename", "string"}, + {"machines", "string"}, + {"gpu_platform_id", "int"}, + {"gpu_device_id", "int"}, + {"gpu_use_dp", "bool"}, + {"num_gpu", "int"}, + }); + return map; +} + } // namespace LightGBM diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 35b5113be8a5..d68cc9ecbec7 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1211,6 +1211,35 @@ def test_feature_name_with_non_ascii(): assert feature_names == gbm2.feature_name() +def test_parameters_are_loaded_from_model_file(tmp_path): + X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))]) + y = np.random.rand(100) + ds = lgb.Dataset(X, y) + params = { + 'bagging_fraction': 0.8, + 'bagging_freq': 2, + 'boosting': 'rf', + 'feature_contri': [0.5, 0.5, 0.5], + 'feature_fraction': 0.7, + 'boost_from_average': False, + 'interaction_constraints': [[0, 1], [0]], + 'metric': ['l2', 'rmse'], + 'num_leaves': 5, + 'num_threads': 1, + } + model_file = tmp_path / 'model.txt' + lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2]).save_model(model_file) + bst = lgb.Booster(model_file=model_file) + set_params = {k: bst.params[k] for k in params.keys()} + assert set_params == params + assert bst.params['categorical_feature'] == [1, 2] + + # check that passing parameters to the constructor raises warning and ignores them + with pytest.warns(UserWarning, match='Ignoring params argument'): + bst2 = lgb.Booster(params={'num_leaves': 7}, model_file=model_file) + assert bst.params == bst2.params + + def test_save_load_copy_pickle(): def train_and_predict(init_model=None, return_model=False): X, y = make_synthetic_regression()