diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index 6fa0a30c606f..597ed4c0f25a 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -148,6 +148,15 @@ lgb.cv <- function(params = list() end_iteration <- begin_iteration + nrounds - 1L } + # Check interaction constraints + cnames <- NULL + if (!is.null(colnames)) { + cnames <- colnames + } else if (!is.null(data$get_colnames())) { + cnames <- data$get_colnames() + } + params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params, cnames) + # Check for weights if (!is.null(weight)) { data$setinfo("weight", weight) diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index 1ba5bf086647..54f83bd759b0 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -124,9 +124,14 @@ lgb.train <- function(params = list(), end_iteration <- begin_iteration + nrounds - 1L } - if (!is.null(params[["interaction_constraints"]])) { - stop("lgb.train: interaction_constraints is not implemented") + # Check interaction constraints + cnames <- NULL + if (!is.null(colnames)) { + cnames <- colnames + } else if (!is.null(data$get_colnames())) { + cnames <- data$get_colnames() } + params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params, cnames) # Update parameters with parsed parameters data$update_params(params) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 1e0e759d653b..f21d85001256 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -167,6 +167,65 @@ lgb.params2str <- function(params, ...) { } +lgb.check_interaction_constraints <- function(params, column_names) { + + # Convert interaction constraints to feature numbers + string_constraints <- list() + + if (!is.null(params[["interaction_constraints"]])) { + + # validation + if (!methods::is(params[["interaction_constraints"]], "list")) { + stop("interaction_constraints must be a list") + } + if (!all(sapply(params[["interaction_constraints"]], function(x) {is.character(x) || is.numeric(x)}))) { + stop("every element in interaction_constraints must be a character vector or numeric vector") + } + + for (constraint in params[["interaction_constraints"]]) { + + # Check for character name + if (is.character(constraint)) { + + constraint_indices <- as.integer(match(constraint, column_names) - 1L) + + # Provided indices, but some indices are not existing? + if (sum(is.na(constraint_indices)) > 0L) { + stop( + "supplied an unknown feature in interaction_constraints " + , sQuote(constraint[is.na(constraint_indices)]) + ) + } + + } else { + + # Check that constraint indices are at most number of features + if (max(constraint) > length(column_names)) { + stop( + "supplied a too large value in interaction_constraints: " + , max(constraint) + , " but only " + , length(column_names) + , " features" + ) + } + + # Store indices as [0, n-1] indexed instead of [1, n] indexed + constraint_indices <- as.integer(constraint - 1L) + + } + + # Convert constraint to string + constraint_string <- paste0("[", paste0(constraint_indices, collapse = ","), "]") + string_constraints <- append(string_constraints, constraint_string) + } + + } + + return(string_constraints) + +} + lgb.c_str <- function(x) { # Perform character to raw conversion diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index c7583ac77154..82ba123f65a9 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -1030,3 +1030,103 @@ test_that("using lightgbm() without early stopping, best_iter and best_score com expect_identical(bst$best_iter, which.max(auc_scores)) expect_identical(bst$best_score, auc_scores[which.max(auc_scores)]) }) + +test_that("lgb.train() throws an informative error if interaction_constraints is not a list", { + dtrain <- lgb.Dataset(train$data, label = train$label) + params <- list(objective = "regression", interaction_constraints = "[1,2],[3]") + expect_error({ + bst <- lightgbm( + data = dtrain + , params = params + , nrounds = 2L + ) + }, "interaction_constraints must be a list") +}) + +test_that(paste0("lgb.train() throws an informative error if the members of interaction_constraints ", + "are not character or numeric vectors"), { + dtrain <- lgb.Dataset(train$data, label = train$label) + params <- list(objective = "regression", interaction_constraints = list(list(1L, 2L), list(3L))) + expect_error({ + bst <- lightgbm( + data = dtrain + , params = params + , nrounds = 2L + ) + }, "every element in interaction_constraints must be a character vector or numeric vector") +}) + +test_that("lgb.train() throws an informative error if interaction_constraints contains a too large index", { + dtrain <- lgb.Dataset(train$data, label = train$label) + params <- list(objective = "regression", + interaction_constraints = list(c(1L, length(colnames(train$data)) + 1L), 3L)) + expect_error({ + bst <- lightgbm( + data = dtrain + , params = params + , nrounds = 2L + ) + }, "supplied a too large value in interaction_constraints") +}) + +test_that(paste0("lgb.train() gives same result when interaction_constraints is specified as a list of ", + "character vectors, numeric vectors, or a combination"), { + set.seed(1L) + dtrain <- lgb.Dataset(train$data, label = train$label) + + params <- list(objective = "regression", interaction_constraints = list(c(1L, 2L), 3L)) + bst <- lightgbm( + data = dtrain + , params = params + , nrounds = 2L + ) + pred1 <- bst$predict(test$data) + + cnames <- colnames(train$data) + params <- list(objective = "regression", interaction_constraints = list(c(cnames[[1L]], cnames[[2L]]), cnames[[3L]])) + bst <- lightgbm( + data = dtrain + , params = params + , nrounds = 2L + ) + pred2 <- bst$predict(test$data) + + params <- list(objective = "regression", interaction_constraints = list(c(cnames[[1L]], cnames[[2L]]), 3L)) + bst <- lightgbm( + data = dtrain + , params = params + , nrounds = 2L + ) + pred3 <- bst$predict(test$data) + + expect_equal(pred1, pred2) + expect_equal(pred2, pred3) + +}) + +test_that(paste0("lgb.train() gives same results when using interaction_constraints and specifying colnames"), { + set.seed(1L) + dtrain <- lgb.Dataset(train$data, label = train$label) + + params <- list(objective = "regression", interaction_constraints = list(c(1L, 2L), 3L)) + bst <- lightgbm( + data = dtrain + , params = params + , nrounds = 2L + ) + pred1 <- bst$predict(test$data) + + new_colnames <- paste0(colnames(train$data), "_x") + params <- list(objective = "regression" + , interaction_constraints = list(c(new_colnames[1L], new_colnames[2L]), new_colnames[3L])) + bst <- lightgbm( + data = dtrain + , params = params + , nrounds = 2L + , colnames = new_colnames + ) + pred2 <- bst$predict(test$data) + + expect_equal(pred1, pred2) + +}) diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 01362fb9af34..3d8b23eee22e 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -548,7 +548,7 @@ Learning Control Parameters - for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]`` - - for R-package, **not yet supported** + - for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc - any two features can only appear in the same branch only if there exists a constraint containing both features diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 2a3335c1c0ad..5825ec5f6d43 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -509,7 +509,7 @@ struct Config { // desc = by default interaction constraints are disabled, to enable them you can specify // descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]`` // descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]`` - // descl2 = for R-package, **not yet supported** + // descl2 = for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc // desc = any two features can only appear in the same branch only if there exists a constraint containing both features std::string interaction_constraints = "";