Skip to content

Commit

Permalink
allow 'objective' for lightgbm to be passed as an engine argument #24
Browse files Browse the repository at this point in the history
  • Loading branch information
Athospd committed Jan 15, 2021
1 parent c15533a commit 62fe946
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
30 changes: 15 additions & 15 deletions R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ train_lightgbm <- function(x, y, max_depth = 17, num_iterations = 10, learning_r

force(x)
force(y)
others <- list(...)

# feature_fraction ------------------------------
if(!is.null(feature_fraction)) {
Expand All @@ -220,25 +221,25 @@ train_lightgbm <- function(x, y, max_depth = 17, num_iterations = 10, learning_r
}

# loss and num_class -------------------------
if (is.numeric(y)) {
num_class <- 1
objective <- "regression"
} else {
lvl <- levels(y)
lvls <- length(lvl)
y <- as.numeric(y) - 1
if (lvls == 2) {
num_class <- 1
objective <- "binary"
if (!any(names(others) %in% c("objective"))) {
if (is.numeric(y)) {
others$num_class <- 1
others$objective <- "regression"
} else {
num_class <- lvls
objective <- "multiclass"
lvl <- levels(y)
lvls <- length(lvl)
y <- as.numeric(y) - 1
if (lvls == 2) {
others$num_class <- 1
others$objective <- "binary"
} else {
others$num_class <- lvls
others$objective <- "multiclass"
}
}
}

arg_list <- list(
num_class = num_class,
objective = objective,
num_iterations = num_iterations,
learning_rate = learning_rate,
max_depth = max_depth,
Expand All @@ -249,7 +250,6 @@ train_lightgbm <- function(x, y, max_depth = 17, num_iterations = 10, learning_r
)

# override or add some other args
others <- list(...)
others <- others[!(names(others) %in% c("data", names(arg_list)))]

# parallelism should be explicitly specified by the user
Expand Down
15 changes: 15 additions & 0 deletions tests/testthat/test-lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@ test_that("lightgbm", {
})


test_that('lightgbm alternate objective', {
skip_if_not_installed("lightgbm")

spec <- boost_tree(mtry = 1, trees = 50, tree_depth = 15, min_n = 1) %>%
set_engine("lightgbm", objective = "huber") %>%
set_mode("regression")

lgb_fit <- spec %>% fit(mpg ~ ., data = mtcars)

info <- jsonlite::fromJSON(lightgbm::lgb.dump(lgb_fit$fit))

expect_equal(info$objective, "huber")
expect_all_modes_works(spec, 'lightgbm')
})

test_that("lightgbm with tune", {

model <- parsnip::boost_tree(
Expand Down

0 comments on commit 62fe946

Please sign in to comment.