Skip to content

Commit

Permalink
set train_lightgbm defaults to match with lgb.train()'s #25
Browse files Browse the repository at this point in the history
  • Loading branch information
Athospd committed Aug 20, 2020
1 parent 35959e9 commit 8a87e8c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
13 changes: 8 additions & 5 deletions R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ prepare_df_lgbm <- function(x, y = NULL) {
#' @return A fitted `lightgbm.Model` object.
#' @keywords internal
#' @export
train_lightgbm <- function(x, y, max_depth = 6, num_iterations = 100, learning_rate = 0.1,
feature_fraction = 1, min_data_in_leaf = 1, min_gain_to_split = 0, bagging_fraction = 1, ...) {
train_lightgbm <- function(x, y, max_depth = 17, num_iterations = 10, learning_rate = 0.1,
feature_fraction = 1, min_data_in_leaf = 20, min_gain_to_split = 0, bagging_fraction = 1, ...) {

force(x)
force(y)
Expand Down Expand Up @@ -239,10 +239,13 @@ train_lightgbm <- function(x, y, max_depth = 6, num_iterations = 100, learning_r
# parallelism should be explicitly specified by the user
if(all(sapply(others[c("num_threads", "num_thread", "nthread", "nthreads", "n_jobs")], is.null))) others$num_threads <- 1L

if(max_depth > 17) {
warning("max_depth > 17, num_leaves truncated to 2^17 - 1")
max_depth <- 17
}

if(is.null(others$num_leaves)) {
others$num_leaves = max(2^min(max_depth, 17) - 1, 2)
if(max_depth > 17)
warning("max_depth > 17, num_leaves truncated to 2^17 - 1")
others$num_leaves = max(2^max_depth - 1, 2)
}

arg_list <- purrr::compact(c(arg_list, others))
Expand Down
6 changes: 3 additions & 3 deletions man/train_lightgbm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 19 additions & 5 deletions tests/testthat/test-lightgbm.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_that("lightgbm", {

model <- parsnip::boost_tree(mtry = 1, trees = 50, tree_depth = 15)
model <- parsnip::boost_tree(mtry = 1, trees = 50, tree_depth = 15, min_n = 1)

expect_all_modes_works(model, 'lightgbm')
})
Expand Down Expand Up @@ -28,7 +28,7 @@ test_that("lightgbm mtry", {

hyperparameters <- data.frame(mtry = c(1, 2, 6))
for(i in 1:nrow(hyperparameters)) {
model <- parsnip::boost_tree(mtry = hyperparameters$mtry[i])
model <- parsnip::boost_tree(mtry = hyperparameters$mtry[i], min_n = 1)
expect_all_modes_works(model, 'lightgbm')
}

Expand All @@ -38,7 +38,7 @@ test_that("lightgbm trees", {

hyperparameters <- data.frame(trees = c(1, 20, 300))
for(i in 1:nrow(hyperparameters)) {
model <- parsnip::boost_tree(trees = hyperparameters$trees[i])
model <- parsnip::boost_tree(trees = hyperparameters$trees[i], min_n = 1)
expect_all_modes_works(model, 'lightgbm')
}

Expand All @@ -58,15 +58,29 @@ test_that("lightgbm min_n hyperparameter", {
test_that("lightgbm tree_depth", {
hyperparameters <- data.frame(tree_depth = c(1, 16))
for(i in 1:nrow(hyperparameters)) {
model <- parsnip::boost_tree(tree_depth = hyperparameters$tree_depth[i])
model <- parsnip::boost_tree(tree_depth = hyperparameters$tree_depth[i], min_n = 1)
expect_all_modes_works(model, 'lightgbm')
}
})

test_that("lightgbm loss_reduction", {
hyperparameters <- data.frame(loss_reduction = c(0, 0.2, 2))
for(i in 1:nrow(hyperparameters)) {
model <- parsnip::boost_tree(loss_reduction = hyperparameters$loss_reduction[i], min_n = 1)
expect_all_modes_works(model, 'lightgbm')
}
})

test_that("lightgbm tree_depth", {
hyperparameters <- data.frame(loss_reduction = c(0, 0.2, 2))
for(i in 1:nrow(hyperparameters)) {
model <- parsnip::boost_tree(loss_reduction = hyperparameters$loss_reduction[i], min_n = 1)
expect_all_modes_works(model, 'lightgbm')
}
})

test_that("lightgbm multi_predict", {
model <- parsnip::boost_tree(mtry = 5, trees = 5, mode = "regression")
model <- parsnip::boost_tree(mtry = 5, trees = 5, mode = "regression", min_n = 1)
model <- parsnip::set_engine(model, "lightgbm")

expect_multi_predict_works(model)
Expand Down

0 comments on commit 8a87e8c

Please sign in to comment.