diff --git a/NEWS.md b/NEWS.md index 7b4b54ec3..1c58094a0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,6 +10,8 @@ * Fixed bug where prediction on rank dificient `lm()` models produced `.pred_res` instead of `.pred`. (#985) +* Fixed bug where `boost_tree()` models couldn't be fit with 1 predictor if `validation` argument was used. (#994) + # parsnip 1.1.0 This release of parsnip contains a number of new features and bug fixes, accompanied by several optimizations that substantially decrease the time to `fit()` and `predict()` with the package. diff --git a/R/boost_tree.R b/R/boost_tree.R index 329173b7a..6e65ea8e0 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -435,15 +435,22 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir # Split data m <- floor(n * (1 - validation)) + 1 trn_index <- sample(seq_len(n), size = max(m, 2)) - val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) + val_data <- xgboost::xgb.DMatrix( + data = x[-trn_index, , drop = FALSE], + label = y[-trn_index], + missing = NA + ) watch_list <- list(validation = val_data) info_list <- list(label = y[trn_index]) if (!is.null(weights)) { info_list$weight <- weights[trn_index] } - dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) - + dat <- xgboost::xgb.DMatrix( + data = x[trn_index, , drop = FALSE], + missing = NA, + info = info_list + ) } else { info_list <- list(label = y) diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index dcf1577a2..7abfcf9a4 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -38,3 +38,12 @@ test_that('argument checks for data dimensions', { expect_equal(args$min_instances_per_node, expr(min_rows(1000, x))) }) +test_that('boost_tree can be fit with 1 predictor if validation is used', { + spec <- boost_tree(trees = 1) %>% + set_engine("xgboost", validation = 0.5) %>% + set_mode("regression") + + expect_no_error( + fit(spec, mpg ~ disp, data = mtcars) + ) +})