From 8a87e8cb194d244defd31cbacc71ce2612966b29 Mon Sep 17 00:00:00 2001 From: Athos Damiani Date: Thu, 20 Aug 2020 13:03:10 -0300 Subject: [PATCH] set train_lightgbm defaults to match with lgb.train()'s #25 --- R/lightgbm.R | 13 ++++++++----- man/train_lightgbm.Rd | 6 +++--- tests/testthat/test-lightgbm.R | 24 +++++++++++++++++++----- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/R/lightgbm.R b/R/lightgbm.R index 608e942..263521e 100644 --- a/R/lightgbm.R +++ b/R/lightgbm.R @@ -184,8 +184,8 @@ prepare_df_lgbm <- function(x, y = NULL) { #' @return A fitted `lightgbm.Model` object. #' @keywords internal #' @export -train_lightgbm <- function(x, y, max_depth = 6, num_iterations = 100, learning_rate = 0.1, - feature_fraction = 1, min_data_in_leaf = 1, min_gain_to_split = 0, bagging_fraction = 1, ...) { +train_lightgbm <- function(x, y, max_depth = 17, num_iterations = 10, learning_rate = 0.1, + feature_fraction = 1, min_data_in_leaf = 20, min_gain_to_split = 0, bagging_fraction = 1, ...) { force(x) force(y) @@ -239,10 +239,13 @@ train_lightgbm <- function(x, y, max_depth = 6, num_iterations = 100, learning_r # parallelism should be explicitly specified by the user if(all(sapply(others[c("num_threads", "num_thread", "nthread", "nthreads", "n_jobs")], is.null))) others$num_threads <- 1L + if(max_depth > 17) { + warning("max_depth > 17, num_leaves truncated to 2^17 - 1") + max_depth <- 17 + } + if(is.null(others$num_leaves)) { - others$num_leaves = max(2^min(max_depth, 17) - 1, 2) - if(max_depth > 17) - warning("max_depth > 17, num_leaves truncated to 2^17 - 1") + others$num_leaves = max(2^max_depth - 1, 2) } arg_list <- purrr::compact(c(arg_list, others)) diff --git a/man/train_lightgbm.Rd b/man/train_lightgbm.Rd index 8e92878..5813b17 100644 --- a/man/train_lightgbm.Rd +++ b/man/train_lightgbm.Rd @@ -7,11 +7,11 @@ train_lightgbm( x, y, - max_depth = 6, - num_iterations = 100, + max_depth = 17, + num_iterations = 10, learning_rate = 0.1, feature_fraction = 1, - min_data_in_leaf = 1, + min_data_in_leaf = 20, min_gain_to_split = 0, bagging_fraction = 1, ... diff --git a/tests/testthat/test-lightgbm.R b/tests/testthat/test-lightgbm.R index 94538bf..96b8f0d 100644 --- a/tests/testthat/test-lightgbm.R +++ b/tests/testthat/test-lightgbm.R @@ -1,6 +1,6 @@ test_that("lightgbm", { - model <- parsnip::boost_tree(mtry = 1, trees = 50, tree_depth = 15) + model <- parsnip::boost_tree(mtry = 1, trees = 50, tree_depth = 15, min_n = 1) expect_all_modes_works(model, 'lightgbm') }) @@ -28,7 +28,7 @@ test_that("lightgbm mtry", { hyperparameters <- data.frame(mtry = c(1, 2, 6)) for(i in 1:nrow(hyperparameters)) { - model <- parsnip::boost_tree(mtry = hyperparameters$mtry[i]) + model <- parsnip::boost_tree(mtry = hyperparameters$mtry[i], min_n = 1) expect_all_modes_works(model, 'lightgbm') } @@ -38,7 +38,7 @@ test_that("lightgbm trees", { hyperparameters <- data.frame(trees = c(1, 20, 300)) for(i in 1:nrow(hyperparameters)) { - model <- parsnip::boost_tree(trees = hyperparameters$trees[i]) + model <- parsnip::boost_tree(trees = hyperparameters$trees[i], min_n = 1) expect_all_modes_works(model, 'lightgbm') } @@ -58,15 +58,29 @@ test_that("lightgbm min_n hyperparameter", { test_that("lightgbm tree_depth", { hyperparameters <- data.frame(tree_depth = c(1, 16)) for(i in 1:nrow(hyperparameters)) { - model <- parsnip::boost_tree(tree_depth = hyperparameters$tree_depth[i]) + model <- parsnip::boost_tree(tree_depth = hyperparameters$tree_depth[i], min_n = 1) expect_all_modes_works(model, 'lightgbm') } }) +test_that("lightgbm loss_reduction", { + hyperparameters <- data.frame(loss_reduction = c(0, 0.2, 2)) + for(i in 1:nrow(hyperparameters)) { + model <- parsnip::boost_tree(loss_reduction = hyperparameters$loss_reduction[i], min_n = 1) + expect_all_modes_works(model, 'lightgbm') + } +}) +test_that("lightgbm tree_depth", { + hyperparameters <- data.frame(loss_reduction = c(0, 0.2, 2)) + for(i in 1:nrow(hyperparameters)) { + model <- parsnip::boost_tree(loss_reduction = hyperparameters$loss_reduction[i], min_n = 1) + expect_all_modes_works(model, 'lightgbm') + } +}) test_that("lightgbm multi_predict", { - model <- parsnip::boost_tree(mtry = 5, trees = 5, mode = "regression") + model <- parsnip::boost_tree(mtry = 5, trees = 5, mode = "regression", min_n = 1) model <- parsnip::set_engine(model, "lightgbm") expect_multi_predict_works(model)