From bd9177b1b3f166859c5f7421286a2bdef600853b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 11 Aug 2022 16:07:13 -0500 Subject: [PATCH 1/3] include parameters from reference dataset on copy --- src/io/dataset.cpp | 5 +++++ tests/python_package_test/test_basic.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 6158f9a8e8d2..2842551cf2ee 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -740,6 +740,11 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) { group_feature_cnt_ = dataset->group_feature_cnt_; forced_bin_bounds_ = dataset->forced_bin_bounds_; feature_need_push_zeros_ = dataset->feature_need_push_zeros_; + max_bin_ = dataset->max_bin_; + min_data_in_bin_ = dataset->min_data_in_bin_; + bin_construct_sample_cnt_ = dataset->bin_construct_sample_cnt_; + use_missing_ = dataset->use_missing_; + zero_as_missing_ = dataset->zero_as_missing_; } void Dataset::CreateValid(const Dataset* dataset) { diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 57d32c21f4c5..a8b76af7362d 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -243,6 +243,13 @@ def test_chunked_dataset_linear(): valid_data.construct() +def test_save_dataset_subset_and_load_from_file(tmp_path): + data = np.random.rand(100, 2) + ds = lgb.Dataset(data) + ds.subset([1,2,3,5,8]).save_binary(tmp_path / 'subset.bin') + lgb.Dataset(tmp_path / 'subset.bin').construct() + + def test_subset_group(): rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank' X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train')) From 7f91b290c30ff96eccfa563ad460a1aabccc8a2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 11 Aug 2022 16:17:38 -0500 Subject: [PATCH 2/3] lint --- tests/python_package_test/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index a8b76af7362d..1ba78fe4c216 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -246,7 +246,7 @@ def test_chunked_dataset_linear(): def test_save_dataset_subset_and_load_from_file(tmp_path): data = np.random.rand(100, 2) ds = lgb.Dataset(data) - ds.subset([1,2,3,5,8]).save_binary(tmp_path / 'subset.bin') + ds.subset([1, 2, 3, 5, 8]).save_binary(tmp_path / 'subset.bin') lgb.Dataset(tmp_path / 'subset.bin').construct() From 6a2fd1f76ab6e76886756385877e86e635a94a91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Tue, 16 Aug 2022 08:56:25 -0500 Subject: [PATCH 3/3] set non-default parameters --- tests/python_package_test/test_basic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 1ba78fe4c216..dc4fb29a79a1 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -245,9 +245,10 @@ def test_chunked_dataset_linear(): def test_save_dataset_subset_and_load_from_file(tmp_path): data = np.random.rand(100, 2) - ds = lgb.Dataset(data) + params = {'max_bin': 50, 'min_data_in_bin': 10} + ds = lgb.Dataset(data, params=params) ds.subset([1, 2, 3, 5, 8]).save_binary(tmp_path / 'subset.bin') - lgb.Dataset(tmp_path / 'subset.bin').construct() + lgb.Dataset(tmp_path / 'subset.bin', params=params).construct() def test_subset_group():