diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index a492eac..1c63826 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -11,18 +11,30 @@ jobs: runs-on: windows-latest steps: - uses: actions/checkout@v2 + + - uses: r-lib/actions/setup-pandoc@v1 + - uses: r-lib/actions/setup-r@master + - name: Install remotes run: install.packages(c("remotes", "rcmdcheck")) shell: Rscript {0} - - name: Install catboost + + - name: Install additional dependencies run: | - remotes::install_url("https://github.com/catboost/catboost/releases/download/v0.23/catboost-R-Windows-0.23.tgz", INSTALL_opts = c("--no-multiarch")) + remotes::install_deps(dependencies = TRUE, INSTALL_opts = c("--no-multiarch")) shell: Rscript {0} - - name: Install dependencies + + - name: Install catboost run: | - remotes::install_deps(dependencies = TRUE, INSTALL_opts = c("--no-multiarch")) + remotes::install_url("https://github.com/catboost/catboost/releases/download/v0.26/catboost-R-Windows-0.26.tgz", INSTALL_opts = c("--no-multiarch")) shell: Rscript {0} + - name: Check - run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--no-multiarch"), error_on = "error") + run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--no-multiarch"), error_on = "error", check_dir = "check") shell: Rscript {0} + + - name: Show testthat output + if: always() + run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true + shell: bash diff --git a/DESCRIPTION b/DESCRIPTION index cc6465d..a17a373 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -23,7 +23,9 @@ Suggests: lightgbm, knitr, rmarkdown, - readr + readr, + glue, + scales RoxygenNote: 7.1.1 Imports: rlang, diff --git a/R/catboost.R b/R/catboost.R index 53e7bde..3d91063 100644 --- a/R/catboost.R +++ b/R/catboost.R @@ -191,7 +191,7 @@ add_boost_tree_catboost <- function() { eng = "catboost", parsnip = "sample_prop", original = "subsample", - func = list(pkg = "dials", fun = "sample_size"), + func = list(pkg = "dials", fun = "sample_prop"), has_submodel = FALSE ) } diff --git a/R/lightgbm.R b/R/lightgbm.R index cb84245..dafccd3 100644 --- a/R/lightgbm.R +++ b/R/lightgbm.R @@ -168,7 +168,7 @@ add_boost_tree_lightgbm <- function() { eng = "lightgbm", parsnip = "sample_prop", original = "bagging_fraction", - func = list(pkg = "dials", fun = "sample_size"), + func = list(pkg = "dials", fun = "sample_prop"), has_submodel = FALSE ) } @@ -272,7 +272,7 @@ train_lightgbm <- function(x, y, max_depth = 17, num_iterations = 10, learning_r data = prepare_df_lgbm(x), label = y, categorical_feature = categorical_columns(x), - feature_pre_filter = FALSE + params = list(feature_pre_filter = FALSE) ) main_args <- list( @@ -345,7 +345,8 @@ predict_lightgbm_classification_raw <- function(object, new_data, ...) { #' @export predict_lightgbm_regression_numeric <- function(object, new_data, ...) { # train_colnames <- object$fit$.__enclos_env__$private$train_set$get_colnames() - p <- stats::predict(object$fit, prepare_df_lgbm(new_data), reshape = TRUE, predict_disable_shape_check=TRUE, ...) + p <- stats::predict(object$fit, prepare_df_lgbm(new_data), reshape = TRUE, + params = list(predict_disable_shape_check=TRUE), ...) p } diff --git a/R/train.R b/R/train.R index 7cc6f51..e66fe08 100644 --- a/R/train.R +++ b/R/train.R @@ -1,7 +1,18 @@ + .onLoad <- function(libname, pkgname){ - add_decision_tree_tree() - add_boost_tree_catboost() - add_boost_tree_lightgbm() + + if (!"lightgbm" %in% parsnip::get_model_env()$boost_tree$engine) { + add_boost_tree_lightgbm() + } + + if (!"catboost" %in% parsnip::get_model_env()$boost_tree$engine) { + add_boost_tree_catboost() + } + + if (!"tree" %in% parsnip::get_model_env()$decision_tree$engine) { + add_decision_tree_tree() + } + } diff --git a/R/tree.R b/R/tree.R index 79d2eec..32d1b08 100644 --- a/R/tree.R +++ b/R/tree.R @@ -12,6 +12,18 @@ add_decision_tree_tree <- function() { parsnip::set_dependency("decision_tree", eng = "tree", pkg = "tree") + parsnip::set_encoding( + model = "decision_tree", + eng = "tree", + mode = "regression", + options = list( + predictor_indicators = "none", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) + ) + parsnip::set_fit( model = "decision_tree", eng = "tree", diff --git a/tests/testthat/helper-model.R b/tests/testthat/helper-model.R index cfa8f36..0403c2e 100644 --- a/tests/testthat/helper-model.R +++ b/tests/testthat/helper-model.R @@ -5,7 +5,7 @@ mtcars_class_binary$vs <- as.factor(mtcars$vs) expect_all_modes_works <- function(model, engine) { if(engine == "lightgbm") { - model <- parsnip::set_engine(model, engine) + model <- parsnip::set_engine(model, engine, verbose = -1L) } else { model <- parsnip::set_engine(model, engine) } @@ -92,6 +92,9 @@ expect_categorical_vars_works <- function(model) { } expect_can_tune_boost_tree <- function(model) { + + mtcars <- dplyr::sample_n(mtcars, size = 500, replace = TRUE) + mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) diff --git a/tests/testthat/test-catboost.R b/tests/testthat/test-catboost.R index 0e85a95..d361e0a 100644 --- a/tests/testthat/test-catboost.R +++ b/tests/testthat/test-catboost.R @@ -16,8 +16,8 @@ test_that('catboost alternate objective', { info <- catboost::catboost.get_model_params(cat_fit$fit) expect_equal(info$loss_function$type, "Huber") - expect_equal(info$loss_function$params[1], "delta") - expect_equal(info$loss_function$params[2], "1") + expect_true(grepl("delta", info$loss_function$params[1])) + expect_equal(info$loss_function$params[[2]], "1") }) test_that("catboost with tune", { diff --git a/tests/testthat/test-lightgbm.R b/tests/testthat/test-lightgbm.R index 1309ffa..b3f3c7d 100644 --- a/tests/testthat/test-lightgbm.R +++ b/tests/testthat/test-lightgbm.R @@ -50,7 +50,7 @@ test_that("lightgbm mtry", { test_that("lightgbm trees", { - hyperparameters <- data.frame(trees = c(1, 20, 300)) + hyperparameters <- data.frame(trees = c(1, 20, 50)) for(i in 1:nrow(hyperparameters)) { model <- parsnip::boost_tree(trees = hyperparameters$trees[i], min_n = 1) expect_all_modes_works(model, 'lightgbm') diff --git a/vignettes/.gitignore b/vignettes/.gitignore index 097b241..22cde80 100644 --- a/vignettes/.gitignore +++ b/vignettes/.gitignore @@ -1,2 +1,6 @@ *.html *.R +working-with-lightgbm-catboost_cache +threading-forking-benchmark_cache +parallel-processing_cache +catboost_info diff --git a/vignettes/parallel-processing_cache/html/__packages b/vignettes/parallel-processing_cache/html/__packages deleted file mode 100644 index 6568700..0000000 --- a/vignettes/parallel-processing_cache/html/__packages +++ /dev/null @@ -1,19 +0,0 @@ -base -tidymodels -broom -scales -dials -dplyr -ggplot2 -infer -modeldata -parsnip -purrr -recipes -rsample -tibble -tidyr -tune -workflows -yardstick -treesnip diff --git a/vignettes/parallel-processing_cache/html/code_fed05125bc1448aed90ee1abe24570d5.RData b/vignettes/parallel-processing_cache/html/code_fed05125bc1448aed90ee1abe24570d5.RData deleted file mode 100644 index 7b24175..0000000 Binary files a/vignettes/parallel-processing_cache/html/code_fed05125bc1448aed90ee1abe24570d5.RData and /dev/null differ diff --git a/vignettes/parallel-processing_cache/html/code_fed05125bc1448aed90ee1abe24570d5.rdb b/vignettes/parallel-processing_cache/html/code_fed05125bc1448aed90ee1abe24570d5.rdb deleted file mode 100644 index e69de29..0000000 diff --git a/vignettes/parallel-processing_cache/html/code_fed05125bc1448aed90ee1abe24570d5.rdx b/vignettes/parallel-processing_cache/html/code_fed05125bc1448aed90ee1abe24570d5.rdx deleted file mode 100644 index 34b6730..0000000 Binary files a/vignettes/parallel-processing_cache/html/code_fed05125bc1448aed90ee1abe24570d5.rdx and /dev/null differ diff --git a/vignettes/parallel-processing_cache/html/forking_0880993d5c4023853d50944ad1d59fbe.RData b/vignettes/parallel-processing_cache/html/forking_0880993d5c4023853d50944ad1d59fbe.RData deleted file mode 100644 index 4873aad..0000000 Binary files a/vignettes/parallel-processing_cache/html/forking_0880993d5c4023853d50944ad1d59fbe.RData and /dev/null differ diff --git a/vignettes/parallel-processing_cache/html/forking_0880993d5c4023853d50944ad1d59fbe.rdb b/vignettes/parallel-processing_cache/html/forking_0880993d5c4023853d50944ad1d59fbe.rdb deleted file mode 100644 index e69de29..0000000 diff --git a/vignettes/parallel-processing_cache/html/forking_0880993d5c4023853d50944ad1d59fbe.rdx b/vignettes/parallel-processing_cache/html/forking_0880993d5c4023853d50944ad1d59fbe.rdx deleted file mode 100644 index 34b6730..0000000 Binary files a/vignettes/parallel-processing_cache/html/forking_0880993d5c4023853d50944ad1d59fbe.rdx and /dev/null differ diff --git a/vignettes/threading-forking-benchmark_cache/html/__packages b/vignettes/threading-forking-benchmark_cache/html/__packages deleted file mode 100644 index 6568700..0000000 --- a/vignettes/threading-forking-benchmark_cache/html/__packages +++ /dev/null @@ -1,19 +0,0 @@ -base -tidymodels -broom -scales -dials -dplyr -ggplot2 -infer -modeldata -parsnip -purrr -recipes -rsample -tibble -tidyr -tune -workflows -yardstick -treesnip diff --git a/vignettes/threading-forking-benchmark_cache/html/code_0f350d3724cc7b8a16bedcde1d89a6be.RData b/vignettes/threading-forking-benchmark_cache/html/code_0f350d3724cc7b8a16bedcde1d89a6be.RData deleted file mode 100644 index 55a145f..0000000 Binary files a/vignettes/threading-forking-benchmark_cache/html/code_0f350d3724cc7b8a16bedcde1d89a6be.RData and /dev/null differ diff --git a/vignettes/threading-forking-benchmark_cache/html/code_0f350d3724cc7b8a16bedcde1d89a6be.rdb b/vignettes/threading-forking-benchmark_cache/html/code_0f350d3724cc7b8a16bedcde1d89a6be.rdb deleted file mode 100644 index e69de29..0000000 diff --git a/vignettes/threading-forking-benchmark_cache/html/code_0f350d3724cc7b8a16bedcde1d89a6be.rdx b/vignettes/threading-forking-benchmark_cache/html/code_0f350d3724cc7b8a16bedcde1d89a6be.rdx deleted file mode 100644 index 34b6730..0000000 Binary files a/vignettes/threading-forking-benchmark_cache/html/code_0f350d3724cc7b8a16bedcde1d89a6be.rdx and /dev/null differ diff --git a/vignettes/working-with-lightgbm-catboost.Rmd b/vignettes/working-with-lightgbm-catboost.Rmd index 4f848a7..e037d71 100644 --- a/vignettes/working-with-lightgbm-catboost.Rmd +++ b/vignettes/working-with-lightgbm-catboost.Rmd @@ -13,7 +13,8 @@ knitr::opts_chunk$set( comment = "#>", eval = FALSE, warning = FALSE, - message = FALSE + message = FALSE, + eval = FALSE ) ``` @@ -21,7 +22,7 @@ knitr::opts_chunk$set( library(tidymodels) library(treesnip) data("diamonds", package = "ggplot2") - +diamonds <- diamonds %>% sample_n(1000) # vfold resamples diamonds_splits <- vfold_cv(diamonds, v = 5) ```