Skip to content

Commit

Permalink
include parameters from reference dataset on subset (fixes #5402) (#5416
Browse files Browse the repository at this point in the history
)

* include parameters from reference dataset on copy

* lint

* set non-default parameters
  • Loading branch information
jmoralez authored Aug 28, 2022
1 parent b6e2793 commit 5079de4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,11 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
group_feature_cnt_ = dataset->group_feature_cnt_;
forced_bin_bounds_ = dataset->forced_bin_bounds_;
feature_need_push_zeros_ = dataset->feature_need_push_zeros_;
max_bin_ = dataset->max_bin_;
min_data_in_bin_ = dataset->min_data_in_bin_;
bin_construct_sample_cnt_ = dataset->bin_construct_sample_cnt_;
use_missing_ = dataset->use_missing_;
zero_as_missing_ = dataset->zero_as_missing_;
}

void Dataset::CreateValid(const Dataset* dataset) {
Expand Down
8 changes: 8 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ def test_chunked_dataset_linear():
valid_data.construct()


def test_save_dataset_subset_and_load_from_file(tmp_path):
data = np.random.rand(100, 2)
params = {'max_bin': 50, 'min_data_in_bin': 10}
ds = lgb.Dataset(data, params=params)
ds.subset([1, 2, 3, 5, 8]).save_binary(tmp_path / 'subset.bin')
lgb.Dataset(tmp_path / 'subset.bin', params=params).construct()


def test_subset_group():
rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank'
X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train'))
Expand Down

0 comments on commit 5079de4

Please sign in to comment.