Skip to content

Commit

Permalink
[R-package] Interface for interaction constraints (#3136)
Browse files Browse the repository at this point in the history
* Add interaction constraints functionality.

* Minor fixes.

* Minor fixes.

* Change lambda to function.

* Fix gpu bug, remove extra blank lines.

* Fix gpu bug.

* Fix style issues.

* Try to fix segfault on MACOS.

* Fix bug.

* Fix bug.

* Fix bugs.

* Change parameter format for R.

* Fix R style issues.

* Change string formatting code.

* Change docs to say R package not supported.

* Refactor check_interaction_constraints into separate function, add validation.

* Fix error messages.

* Add tests.

* Update docs.

* Fix tests, minor refactoring.

* Fix style issues.

* Fix R style issue.

* Remove old code.

* Fix existing test and add new one.

* Fix R lint error.
  • Loading branch information
btrotta authored Jul 2, 2020
1 parent cfc5e4f commit 4f8c32d
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 4 deletions.
9 changes: 9 additions & 0 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ lgb.cv <- function(params = list()
end_iteration <- begin_iteration + nrounds - 1L
}

# Check interaction constraints
cnames <- NULL
if (!is.null(colnames)) {
cnames <- colnames
} else if (!is.null(data$get_colnames())) {
cnames <- data$get_colnames()
}
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params, cnames)

# Check for weights
if (!is.null(weight)) {
data$setinfo("weight", weight)
Expand Down
9 changes: 7 additions & 2 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,14 @@ lgb.train <- function(params = list(),
end_iteration <- begin_iteration + nrounds - 1L
}

if (!is.null(params[["interaction_constraints"]])) {
stop("lgb.train: interaction_constraints is not implemented")
# Check interaction constraints
cnames <- NULL
if (!is.null(colnames)) {
cnames <- colnames
} else if (!is.null(data$get_colnames())) {
cnames <- data$get_colnames()
}
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params, cnames)

# Update parameters with parsed parameters
data$update_params(params)
Expand Down
59 changes: 59 additions & 0 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,65 @@ lgb.params2str <- function(params, ...) {

}

lgb.check_interaction_constraints <- function(params, column_names) {

# Convert interaction constraints to feature numbers
string_constraints <- list()

if (!is.null(params[["interaction_constraints"]])) {

# validation
if (!methods::is(params[["interaction_constraints"]], "list")) {
stop("interaction_constraints must be a list")
}
if (!all(sapply(params[["interaction_constraints"]], function(x) {is.character(x) || is.numeric(x)}))) {
stop("every element in interaction_constraints must be a character vector or numeric vector")
}

for (constraint in params[["interaction_constraints"]]) {

# Check for character name
if (is.character(constraint)) {

constraint_indices <- as.integer(match(constraint, column_names) - 1L)

# Provided indices, but some indices are not existing?
if (sum(is.na(constraint_indices)) > 0L) {
stop(
"supplied an unknown feature in interaction_constraints "
, sQuote(constraint[is.na(constraint_indices)])
)
}

} else {

# Check that constraint indices are at most number of features
if (max(constraint) > length(column_names)) {
stop(
"supplied a too large value in interaction_constraints: "
, max(constraint)
, " but only "
, length(column_names)
, " features"
)
}

# Store indices as [0, n-1] indexed instead of [1, n] indexed
constraint_indices <- as.integer(constraint - 1L)

}

# Convert constraint to string
constraint_string <- paste0("[", paste0(constraint_indices, collapse = ","), "]")
string_constraints <- append(string_constraints, constraint_string)
}

}

return(string_constraints)

}

lgb.c_str <- function(x) {

# Perform character to raw conversion
Expand Down
100 changes: 100 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -1030,3 +1030,103 @@ test_that("using lightgbm() without early stopping, best_iter and best_score com
expect_identical(bst$best_iter, which.max(auc_scores))
expect_identical(bst$best_score, auc_scores[which.max(auc_scores)])
})

test_that("lgb.train() throws an informative error if interaction_constraints is not a list", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", interaction_constraints = "[1,2],[3]")
expect_error({
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = 2L
)
}, "interaction_constraints must be a list")
})

test_that(paste0("lgb.train() throws an informative error if the members of interaction_constraints ",
"are not character or numeric vectors"), {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", interaction_constraints = list(list(1L, 2L), list(3L)))
expect_error({
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = 2L
)
}, "every element in interaction_constraints must be a character vector or numeric vector")
})

test_that("lgb.train() throws an informative error if interaction_constraints contains a too large index", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression",
interaction_constraints = list(c(1L, length(colnames(train$data)) + 1L), 3L))
expect_error({
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = 2L
)
}, "supplied a too large value in interaction_constraints")
})

test_that(paste0("lgb.train() gives same result when interaction_constraints is specified as a list of ",
"character vectors, numeric vectors, or a combination"), {
set.seed(1L)
dtrain <- lgb.Dataset(train$data, label = train$label)

params <- list(objective = "regression", interaction_constraints = list(c(1L, 2L), 3L))
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = 2L
)
pred1 <- bst$predict(test$data)

cnames <- colnames(train$data)
params <- list(objective = "regression", interaction_constraints = list(c(cnames[[1L]], cnames[[2L]]), cnames[[3L]]))
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = 2L
)
pred2 <- bst$predict(test$data)

params <- list(objective = "regression", interaction_constraints = list(c(cnames[[1L]], cnames[[2L]]), 3L))
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = 2L
)
pred3 <- bst$predict(test$data)

expect_equal(pred1, pred2)
expect_equal(pred2, pred3)

})

test_that(paste0("lgb.train() gives same results when using interaction_constraints and specifying colnames"), {
set.seed(1L)
dtrain <- lgb.Dataset(train$data, label = train$label)

params <- list(objective = "regression", interaction_constraints = list(c(1L, 2L), 3L))
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = 2L
)
pred1 <- bst$predict(test$data)

new_colnames <- paste0(colnames(train$data), "_x")
params <- list(objective = "regression"
, interaction_constraints = list(c(new_colnames[1L], new_colnames[2L]), new_colnames[3L]))
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = 2L
, colnames = new_colnames
)
pred2 <- bst$predict(test$data)

expect_equal(pred1, pred2)

})
2 changes: 1 addition & 1 deletion docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ Learning Control Parameters

- for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``

- for R-package, **not yet supported**
- for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc

- any two features can only appear in the same branch only if there exists a constraint containing both features

Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ struct Config {
// desc = by default interaction constraints are disabled, to enable them you can specify
// descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
// descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
// descl2 = for R-package, **not yet supported**
// descl2 = for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc
// desc = any two features can only appear in the same branch only if there exists a constraint containing both features
std::string interaction_constraints = "";

Expand Down

0 comments on commit 4f8c32d

Please sign in to comment.