Skip to content

Commit

Permalink
Merge branch 'main' into RC-1-1-1
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jul 28, 2023
2 parents bf30d4b + 40ec24f commit b96ad02
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

* A few censored regression helper functions were exported: `.extract_surv_status()` and `.extract_surv_time()` (#973, #980).

* Fixed bug where `boost_tree()` models couldn't be fit with 1 predictor if `validation` argument was used. (#994)

# parsnip 1.1.0

This release of parsnip contains a number of new features and bug fixes, accompanied by several optimizations that substantially decrease the time to `fit()` and `predict()` with the package.
Expand Down
13 changes: 10 additions & 3 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -435,15 +435,22 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir
# Split data
m <- floor(n * (1 - validation)) + 1
trn_index <- sample(seq_len(n), size = max(m, 2))
val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA)
val_data <- xgboost::xgb.DMatrix(
data = x[-trn_index, , drop = FALSE],
label = y[-trn_index],
missing = NA
)
watch_list <- list(validation = val_data)

info_list <- list(label = y[trn_index])
if (!is.null(weights)) {
info_list$weight <- weights[trn_index]
}
dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list)

dat <- xgboost::xgb.DMatrix(
data = x[trn_index, , drop = FALSE],
missing = NA,
info = info_list
)

} else {
info_list <- list(label = y)
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test_boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,12 @@ test_that('argument checks for data dimensions', {
expect_equal(args$min_instances_per_node, expr(min_rows(1000, x)))
})

test_that('boost_tree can be fit with 1 predictor if validation is used', {
spec <- boost_tree(trees = 1) %>%
set_engine("xgboost", validation = 0.5) %>%
set_mode("regression")

expect_no_error(
fit(spec, mpg ~ disp, data = mtcars)
)
})

0 comments on commit b96ad02

Please sign in to comment.