Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] [R-package] refine the parameters for Dataset #2594

Merged
merged 57 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
edcf1e5
reset
guolinke Nov 26, 2019
f5e3583
fix a bug
guolinke Nov 26, 2019
da0ded8
fix test
guolinke Nov 26, 2019
63c87a1
Update c_api.h
guolinke Nov 26, 2019
a861360
support to no filter features by min_data
guolinke Nov 27, 2019
9b2eeff
add warning in reset config
guolinke Nov 27, 2019
49cf341
refine warnings for override dataset's parameter
guolinke Nov 27, 2019
5af7986
some cleans
guolinke Nov 27, 2019
e78bc21
clean code
guolinke Nov 28, 2019
1813614
clean code
guolinke Nov 28, 2019
94aca30
fixed conflict
StrikerRUS Nov 30, 2019
3fe37d5
refine C API function doxygen comments
StrikerRUS Nov 30, 2019
bb6b420
refined new param description
StrikerRUS Nov 30, 2019
472e4b9
refined doxygen comments for R API function
StrikerRUS Nov 30, 2019
85eddc5
removed stuff related to int8
StrikerRUS Nov 30, 2019
9c89a04
break long line in warning message
StrikerRUS Nov 30, 2019
ec6aab4
removed tests which results cannot be validated anymore
StrikerRUS Dec 1, 2019
d53c7e6
added test for warnings about unchangeable params
StrikerRUS Dec 1, 2019
5604f30
Merge branch 'master' into parameter-refine
StrikerRUS Dec 1, 2019
4f63b4b
write parameter from dataset to booster
guolinke Dec 3, 2019
d9778ec
consider free_raw_data.
guolinke Dec 3, 2019
d530089
fix params
guolinke Dec 4, 2019
080bd41
fix bug
guolinke Dec 4, 2019
7521efd
implementing R
guolinke Dec 4, 2019
c0cae4b
fix typo
guolinke Dec 4, 2019
b3fbb66
filter params in R
guolinke Dec 4, 2019
3d2f1d1
fix R
guolinke Dec 4, 2019
028c8fd
not min_data
guolinke Dec 4, 2019
d73e2d7
Merge remote-tracking branch 'origin/master' into parameter-refine
StrikerRUS Dec 5, 2019
6e44d98
refined tests
StrikerRUS Dec 6, 2019
f026ba8
fixed linting
StrikerRUS Dec 6, 2019
a403e3c
refine
guolinke Dec 6, 2019
fc73881
pilint
guolinke Dec 6, 2019
0d56dfe
add docstring
guolinke Dec 6, 2019
7514039
fix docstring
guolinke Dec 6, 2019
b787d3e
R lint
guolinke Dec 6, 2019
555d24f
updated description for C API function
StrikerRUS Dec 8, 2019
4b9b068
use param aliases in Python
StrikerRUS Dec 8, 2019
7428244
fixed typo
StrikerRUS Dec 8, 2019
7649d77
fixed typo
StrikerRUS Dec 8, 2019
94fac53
added more params to test
StrikerRUS Dec 8, 2019
f1a4035
removed debug print
StrikerRUS Dec 8, 2019
51eb0ff
fix dataset construct place
guolinke Dec 9, 2019
b27dcae
Merge remote-tracking branch 'origin/master' into parameter-refine
StrikerRUS Dec 17, 2019
2dbc6e2
Merge branch 'master' into parameter-refine
StrikerRUS Dec 19, 2019
d920fe5
Merge branch 'master' into parameter-refine
guolinke Jan 14, 2020
17b2f7d
Merge branch 'master' into parameter-refine
guolinke Feb 2, 2020
b74579b
fix merge bug
guolinke Feb 2, 2020
a29b17e
Update feature_histogram.hpp
guolinke Feb 3, 2020
9d9b870
Merge remote-tracking branch 'origin/master' into parameter-refine
guolinke Feb 3, 2020
82824e0
add is_sparse back
guolinke Feb 3, 2020
fb8364d
remove unused parameters
guolinke Feb 3, 2020
7680f05
fix lint
guolinke Feb 3, 2020
448781c
add data random seed
guolinke Feb 3, 2020
5df6b05
fixed conflicts
StrikerRUS Feb 8, 2020
330e3f1
update
StrikerRUS Feb 9, 2020
5b5f4e3
[R-package] centrallized Dataset parameter aliases and added tests on…
jameslamb Feb 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,86 @@
# Central location for parameter aliases.
# See https://lightgbm.readthedocs.io/en/latest/Parameters.html#core-parameters

# [description] List of respected parameter aliases specific to lgb.Dataset. Wrapped in a function to
# take advantage of lazy evaluation (so it doesn't matter what order
# R sources files during installation).
# [return] A named list, where each key is a parameter relevant to lgb.DataSet and each value is a character
# vector of corresponding aliases.
.DATASET_PARAMETERS <- function() {
return(list(
"bin_construct_sample_cnt" = c(
"bin_construct_sample_cnt"
, "subsample_for_bin"
)
, "categorical_feature" = c(
"categorical_feature"
, "cat_feature"
, "categorical_column"
, "cat_column"
)
, "seed" = c(
"seed"
, "data_random_seed"
, "feature_fraction_seed"
)
, "enable_bundle" = c(
"enable_bundle"
, "is_endable_bundle"
, "bundle"
)
, "enable_sparse" = c(
"enable_sparse"
, "is_sparse"
, "sparse"
)
, "feature_pre_filter" = "feature_pre_filter"
, "forcedbins_filename" = "forcedbins_filename"
, "group_column" = c(
"group_column"
, "group_id"
, "query_column"
, "query"
, "query_id"
)
, "header" = c(
"header"
, "has_header"
)
, "ignore_column" = c(
"ignore_column"
, "ignore_feature"
, "blacklist"
)
, "label_column" = c(
"label_column"
, "label"
)
, "max_bin" = "max_bin"
, "max_bin_by_feature" = "max_bin_by_feature"
, "pre_partition" = c(
"pre_parition"
, "is_pre_partition"
)
, "two_round" = c(
"two_round"
, "two_round_loading"
, "use_two_round_loading"
)
, "use_missing" = "use_missing"
, "weight_column" = c(
"weight_column"
, "weight"
)
, "zero_as_missing" = "zero_as_missing"
))
}

# [description] List of respected parameter aliases. Wrapped in a function to take advantage of
# lazy evaluation (so it doesn't matter what order R sources files during installation).
# [return] A named list, where each key is a main LightGBM parameter and each value is a character
# vector of corresponding aliases.
.PARAMETER_ALIASES <- function() {
return(list(
learning_params <- list(
"boosting" = c(
"boosting"
, "boost"
Expand All @@ -29,5 +103,6 @@
, "num_boost_round"
, "n_estimators"
)
))
)
return(c(learning_params, .DATASET_PARAMETERS()))
}
8 changes: 4 additions & 4 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,25 @@ Booster <- R6::R6Class(

# Create parameters and handle
params <- append(params, list(...))
params_str <- lgb.params2str(params)
handle <- 0.0

# Attempts to create a handle for the dataset
try({

# Check if training dataset is not null
if (!is.null(train_set)) {

# Check if training dataset is lgb.Dataset or not
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
stop("lgb.Booster: Can only use lgb.Dataset as training data")
}

train_set_handle <- train_set$.__enclos_env__$private$get_handle()
params <- modifyList(params, train_set$get_params())
params_str <- lgb.params2str(params)
# Store booster handle
handle <- lgb.call(
"LGBM_BoosterCreate_R"
, ret = handle
, train_set$.__enclos_env__$private$get_handle()
, train_set_handle
, params_str
)

Expand Down
44 changes: 35 additions & 9 deletions R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -530,22 +530,48 @@ Dataset <- R6::R6Class(

# Update parameters
update_params = function(params) {

# Parameter updating
if (!lgb.is.null.handle(private$handle)) {
lgb.call(
"LGBM_DatasetUpdateParam_R"
, ret = NULL
, private$handle
if (length(params) == 0L) {
return(invisible(self))
}
if (lgb.is.null.handle(private$handle)) {
private$params <- modifyList(private$params, params)
} else {
call_state <- 0L
call_state <- .Call(
"LGBM_DatasetUpdateParamChecking_R"
, lgb.params2str(private$params)
, lgb.params2str(params)
, call_state
, PACKAGE = "lib_lightgbm"
)
return(invisible(self))
call_state <- as.integer(call_state)
if (call_state != 0L) {

# raise error if raw data is freed
if (is.null(private$raw_data)) {
lgb.last_error()
}

# Overwrite paramms
private$params <- modifyList(private$params, params)
self$finalize()
}
}
private$params <- modifyList(private$params, params)
return(invisible(self))

},

get_params = function() {
dataset_params <- unname(unlist(.DATASET_PARAMETERS()))
ret <- list()
for (param_key in names(private$params)) {
if (param_key %in% dataset_params) {
ret[[param_key]] <- private$params[[param_key]]
}
}
return(ret)
},

# Set categorical feature parameter
set_categorical_feature = function(categorical_feature) {

Expand Down
60 changes: 31 additions & 29 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,36 @@ lgb.encode.char <- function(arr, len) {

}

lgb.last_error <- function() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this change! I'm surprised we didn't have a function like this before, actually.

# Perform text error buffering
buf_len <- 200L
act_len <- 0L
err_msg <- raw(buf_len)
err_msg <- .Call(
"LGBM_GetLastError_R"
, buf_len
, act_len
, err_msg
, PACKAGE = "lib_lightgbm"
)

# Check error buffer
if (act_len > buf_len) {
buf_len <- act_len
err_msg <- raw(buf_len)
err_msg <- .Call(
"LGBM_GetLastError_R"
, buf_len
, act_len
, err_msg
, PACKAGE = "lib_lightgbm"
)
}

# Return error
stop("api error: ", lgb.encode.char(err_msg, act_len))
}

lgb.call <- function(fun_name, ret, ...) {
# Set call state to a zero value
call_state <- 0L
Expand All @@ -43,35 +73,7 @@ lgb.call <- function(fun_name, ret, ...) {
call_state <- as.integer(call_state)
# Check for call state value post call
if (call_state != 0L) {

# Perform text error buffering
buf_len <- 200L
act_len <- 0L
err_msg <- raw(buf_len)
err_msg <- .Call(
"LGBM_GetLastError_R"
, buf_len
, act_len
, err_msg
, PACKAGE = "lib_lightgbm"
)

# Check error buffer
if (act_len > buf_len) {
buf_len <- act_len
err_msg <- raw(buf_len)
err_msg <- .Call(
"LGBM_GetLastError_R"
, buf_len
, act_len
, err_msg
, PACKAGE = "lib_lightgbm"
)
}

# Return error
stop("api error: ", lgb.encode.char(err_msg, act_len))

lgb.last_error()
}

return(ret)
Expand Down
77 changes: 77 additions & 0 deletions R-package/tests/testthat/test_dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,80 @@ test_that("Dataset$new() should throw an error if 'predictor' is provided but of
)
}, regexp = "predictor must be a", fixed = TRUE)
})

test_that("Dataset$get_params() successfully returns parameters if you passed them", {
# note that this list uses one "main" parameter (feature_pre_filter) and one that
# is an alias (is_sparse), to check that aliases are handled correctly
params <- list(
"feature_pre_filter" = TRUE
, "is_sparse" = FALSE
)
ds <- lgb.Dataset(
test_data
, label = test_label
, params = params
)
returned_params <- ds$get_params()
expect_true(methods::is(returned_params, "list"))
expect_identical(length(params), length(returned_params))
expect_identical(sort(names(params)), sort(names(returned_params)))
for (param_name in names(params)) {
expect_identical(params[[param_name]], returned_params[[param_name]])
}
})

test_that("Dataset$get_params() ignores irrelevant parameters", {
params <- list(
"feature_pre_filter" = TRUE
, "is_sparse" = FALSE
, "nonsense_parameter" = c(1.0, 2.0, 5.0)
)
ds <- lgb.Dataset(
test_data
, label = test_label
, params = params
)
returned_params <- ds$get_params()
expect_false("nonsense_parameter" %in% names(returned_params))
})

test_that("Dataset$update_parameters() does nothing for empty inputs", {
ds <- lgb.Dataset(
test_data
, label = test_label
)
initial_params <- ds$get_params()
expect_identical(initial_params, list())

# update_params() should return "self" so it can be chained
res <- ds$update_params(
params = list()
)
expect_true(lgb.is.Dataset(res))

new_params <- ds$get_params()
expect_identical(new_params, initial_params)
})

test_that("Dataset$update_params() works correctly for recognized Dataset parameters", {
ds <- lgb.Dataset(
test_data
, label = test_label
)
initial_params <- ds$get_params()
expect_identical(initial_params, list())

new_params <- list(
"data_random_seed" = 708L
, "enable_bundle" = FALSE
)
res <- ds$update_params(
params = new_params
)
expect_true(lgb.is.Dataset(res))

updated_params <- ds$get_params()
for (param_name in names(new_params)) {
expect_identical(new_params[[param_name]], updated_params[[param_name]])
}
})
7 changes: 6 additions & 1 deletion R-package/tests/testthat/test_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,18 @@ test_that("Feature penalties work properly", {
expect_length(var_gain[[length(var_gain)]], 0L)
})

test_that(".PARAMETER_ALIASES() returns a named list", {
context("parameter aliases")

test_that(".PARAMETER_ALIASES() returns a named list of character vectors, where names are unique", {
param_aliases <- .PARAMETER_ALIASES()
expect_true(is.list(param_aliases))
expect_true(is.character(names(param_aliases)))
expect_true(is.character(param_aliases[["boosting"]]))
expect_true(is.character(param_aliases[["early_stopping_round"]]))
expect_true(is.character(param_aliases[["num_iterations"]]))
expect_true(length(names(param_aliases)) == length(param_aliases))
expect_true(all(sapply(param_aliases, is.character)))
expect_true(length(unique(names(param_aliases))) == length(param_aliases))
})

test_that("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", {
Expand Down
8 changes: 8 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,14 @@ IO Parameters

- use this to avoid one-data-one-bin (potential over-fitting)

- ``feature_pre_filter`` :raw-html:`<a id="feature_pre_filter" title="Permalink to this parameter" href="#feature_pre_filter">&#x1F517;&#xFE0E;</a>`, default = ``true``, type = bool

- set this to ``true`` to pre-filter the unsplittable features by ``min_data_in_leaf``

- as dataset object is initialized only once and cannot be changed after that, you may need to set this to ``false`` when searching parameters with ``min_data_in_leaf``, otherwise features are filtered by ``min_data_in_leaf`` firstly if you don't reconstruct dataset object

- **Note**: setting this to ``false`` may slow down the training

- ``bin_construct_sample_cnt`` :raw-html:`<a id="bin_construct_sample_cnt" title="Permalink to this parameter" href="#bin_construct_sample_cnt">&#x1F517;&#xFE0E;</a>`, default = ``200000``, type = int, aliases: ``subsample_for_bin``, constraints: ``bin_construct_sample_cnt > 0``

- number of data that sampled to construct histogram bins
Expand Down
3 changes: 2 additions & 1 deletion include/LightGBM/bin.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,13 @@ class BinMapper {
* \param max_bin The maximal number of bin
* \param min_data_in_bin min number of data in one bin
* \param min_split_data
* \param pre_filter
* \param bin_type Type of this bin
* \param use_missing True to enable missing value handle
* \param zero_as_missing True to use zero as missing value
* \param forced_upper_bounds Vector of split points that must be used (if this has size less than max_bin, remaining splits are found by the algorithm)
*/
void FindBin(double* values, int num_values, size_t total_sample_cnt, int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type,
void FindBin(double* values, int num_values, size_t total_sample_cnt, int max_bin, int min_data_in_bin, int min_split_data, bool pre_filter, BinType bin_type,
bool use_missing, bool zero_as_missing, const std::vector<double>& forced_upper_bounds);

/*!
Expand Down
Loading