From 5079de4a0a7f936ff6df7c8c10268e653e02592e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Sun, 28 Aug 2022 10:22:28 -0500 Subject: [PATCH] include parameters from reference dataset on subset (fixes #5402) (#5416) * include parameters from reference dataset on copy * lint * set non-default parameters --- src/io/dataset.cpp | 5 +++++ tests/python_package_test/test_basic.py | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 6158f9a8e8d2..2842551cf2ee 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -740,6 +740,11 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) { group_feature_cnt_ = dataset->group_feature_cnt_; forced_bin_bounds_ = dataset->forced_bin_bounds_; feature_need_push_zeros_ = dataset->feature_need_push_zeros_; + max_bin_ = dataset->max_bin_; + min_data_in_bin_ = dataset->min_data_in_bin_; + bin_construct_sample_cnt_ = dataset->bin_construct_sample_cnt_; + use_missing_ = dataset->use_missing_; + zero_as_missing_ = dataset->zero_as_missing_; } void Dataset::CreateValid(const Dataset* dataset) { diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 57d32c21f4c5..dc4fb29a79a1 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -243,6 +243,14 @@ def test_chunked_dataset_linear(): valid_data.construct() +def test_save_dataset_subset_and_load_from_file(tmp_path): + data = np.random.rand(100, 2) + params = {'max_bin': 50, 'min_data_in_bin': 10} + ds = lgb.Dataset(data, params=params) + ds.subset([1, 2, 3, 5, 8]).save_binary(tmp_path / 'subset.bin') + lgb.Dataset(tmp_path / 'subset.bin', params=params).construct() + + def test_subset_group(): rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank' X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train'))