Skip to content

Commit

Permalink
[R-package] Added unit tests (#2498)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Oct 24, 2019
1 parent bdc310a commit b4bb38d
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 2 deletions.
2 changes: 1 addition & 1 deletion R-package/man/lgb.interprete.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R-package/man/slice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ test_that("use of multiple eval metrics works", {


test_that("training continuation works", {
testthat::skip("This test is currently broken. See issue #2468 for details.")
dtrain <- lgb.Dataset(train$data, label = train$label, free_raw_data=FALSE)
watchlist = list(train=dtrain)
param <- list(objective = "binary", metric="binary_logloss", num_leaves = 5, learning_rate = 1)
Expand Down
39 changes: 39 additions & 0 deletions R-package/tests/testthat/test_lgb.importance.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
context("lgb.importance")

test_that("lgb.importance() should reject bad inputs", {
bad_inputs <- list(
.Machine$integer.max
, Inf
, -Inf
, NA
, NA_real_
, -10L:10L
, list(c("a", "b", "c"))
, data.frame(
x = rnorm(20)
, y = sample(
x = c(1, 2)
, size = 20
, replace = TRUE
)
)
, data.table::data.table(
x = rnorm(20)
, y = sample(
x = c(1, 2)
, size = 20
, replace = TRUE
)
)
, lgb.Dataset(
data = matrix(rnorm(100), ncol = 2)
, label = matrix(sample(c(0, 1), 50, replace = TRUE))
)
, "lightgbm.model"
)
for (input in bad_inputs){
expect_error({
lgb.importance(input)
}, regexp = "'model' has to be an object of class lgb\\.Booster")
}
})
113 changes: 113 additions & 0 deletions R-package/tests/testthat/test_lgb.interprete.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
context("lgb.interpete")

.sigmoid <- function(x){
1 / (1 + exp(-x))
}
.logit <- function(x){
log(x / (1 - x))
}

test_that("lgb.intereprete works as expected for binary classification", {
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
setinfo(
dataset = dtrain
, "init_score"
, rep(
.logit(mean(train$label))
, length(train$label)
)
)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
params <- list(
objective = "binary"
, learning_rate = 0.01
, num_leaves = 63
, max_depth = -1
, min_data_in_leaf = 1
, min_sum_hessian_in_leaf = 1
)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 10
)
num_trees <- 5
tree_interpretation <- lgb.interprete(
model = model
, data = test$data
, idxset = 1:num_trees
)
expect_true(methods::is(tree_interpretation, "list"))
expect_true(length(tree_interpretation) == num_trees)
expect_null(names(tree_interpretation))
expect_true(all(
sapply(
X = tree_interpretation
, FUN = function(treeDT){
checks <- c(
data.table::is.data.table(treeDT)
, identical(names(treeDT), c("Feature", "Contribution"))
, is.character(treeDT[, Feature])
, is.numeric(treeDT[, Contribution])
)
return(all(checks))
}
)
))
})

test_that("lgb.intereprete works as expected for multiclass classification", {
data(iris)

# We must convert factors to numeric
# They must be starting from number 0 to use multiclass
# For instance: 0, 1, 2, 3, 4, 5...
iris$Species <- as.numeric(as.factor(iris$Species)) - 1

# Create imbalanced training data (20, 30, 40 examples for classes 0, 1, 2)
train <- as.matrix(iris[c(1:20, 51:80, 101:140), ])
# The 10 last samples of each class are for validation
test <- as.matrix(iris[c(41:50, 91:100, 141:150), ])
dtrain <- lgb.Dataset(data = train[, 1:4], label = train[, 5])
dtest <- lgb.Dataset.create.valid(dtrain, data = test[, 1:4], label = test[, 5])
params <- list(
objective = "multiclass"
, metric = "multi_logloss"
, num_class = 3
, learning_rate = 0.00001
)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 10
, min_data = 1
)
num_trees <- 5
tree_interpretation <- lgb.interprete(
model = model
, data = test[, 1:4]
, idxset = 1:num_trees
)
expect_true(methods::is(tree_interpretation, "list"))
expect_true(length(tree_interpretation) == num_trees)
expect_null(names(tree_interpretation))
expect_true(all(
sapply(
X = tree_interpretation
, FUN = function(treeDT){
checks <- c(
data.table::is.data.table(treeDT)
, identical(names(treeDT), c("Feature", "Class 0", "Class 1", "Class 2"))
, is.character(treeDT[, Feature])
, is.numeric(treeDT[, `Class 0`])
, is.numeric(treeDT[, `Class 1`])
, is.numeric(treeDT[, `Class 2`])
)
return(all(checks))
}
)
))
})
97 changes: 97 additions & 0 deletions R-package/tests/testthat/test_lgb.plot.interpretation.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
context("lgb.plot.interpretation")

.sigmoid <- function(x){
1 / (1 + exp(-x))
}
.logit <- function(x){
log(x / (1 - x))
}

test_that("lgb.plot.interepretation works as expected for binary classification", {
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
setinfo(
dataset = dtrain
, "init_score"
, rep(
.logit(mean(train$label))
, length(train$label)
)
)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
params <- list(
objective = "binary"
, learning_rate = 0.01
, num_leaves = 63
, max_depth = -1
, min_data_in_leaf = 1
, min_sum_hessian_in_leaf = 1
)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 10
)
num_trees <- 5
tree_interpretation <- lgb.interprete(
model = model
, data = test$data
, idxset = 1:num_trees
)
expect_true({
lgb.plot.interpretation(
tree_interpretation_dt = tree_interpretation[[1]]
, top_n = 5
)
TRUE
})

# should also work when you explicitly pass cex
plot_res <- lgb.plot.interpretation(
tree_interpretation_dt = tree_interpretation[[1]]
, top_n = 5
, cex = 0.95
)
expect_null(plot_res)
})

test_that("lgb.plot.interepretation works as expected for multiclass classification", {
data(iris)

# We must convert factors to numeric
# They must be starting from number 0 to use multiclass
# For instance: 0, 1, 2, 3, 4, 5...
iris$Species <- as.numeric(as.factor(iris$Species)) - 1

# Create imbalanced training data (20, 30, 40 examples for classes 0, 1, 2)
train <- as.matrix(iris[c(1:20, 51:80, 101:140), ])
# The 10 last samples of each class are for validation
test <- as.matrix(iris[c(41:50, 91:100, 141:150), ])
dtrain <- lgb.Dataset(data = train[, 1:4], label = train[, 5])
dtest <- lgb.Dataset.create.valid(dtrain, data = test[, 1:4], label = test[, 5])
params <- list(
objective = "multiclass"
, metric = "multi_logloss"
, num_class = 3
, learning_rate = 0.00001
)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 10
, min_data = 1
)
num_trees <- 5
tree_interpretation <- lgb.interprete(
model = model
, data = test[, 1:4]
, idxset = 1:num_trees
)
plot_res <- lgb.plot.interpretation(
tree_interpretation_dt = tree_interpretation[[1]]
, top_n = 5
)
expect_null(plot_res)
})

0 comments on commit b4bb38d

Please sign in to comment.