From 4e744033221f5366007130da33ef74f7a465f587 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Thu, 13 Jun 2024 14:49:14 +0200 Subject: [PATCH] [R-package] ensure use of interaction_constraints does not lead to features being ignored (#6377) --- .ci/lint-python.sh | 8 +- R-package/R/utils.R | 100 ++++++++++---------- R-package/tests/testthat/test_basic.R | 45 +++++++-- R-package/tests/testthat/test_lgb.Booster.R | 2 +- R-package/tests/testthat/test_utils.R | 18 ++++ 5 files changed, 110 insertions(+), 63 deletions(-) diff --git a/.ci/lint-python.sh b/.ci/lint-python.sh index e1e9e306c883..edab8993a799 100755 --- a/.ci/lint-python.sh +++ b/.ci/lint-python.sh @@ -2,9 +2,11 @@ set -e -E -u -o pipefail -echo "running pre-commit checks" -pre-commit run --all-files || exit 1 -echo "done running pre-commit checks" +# this can be re-enabled when this is fixed: +# https://github.com/tox-dev/filelock/issues/337 +# echo "running pre-commit checks" +# pre-commit run --all-files || exit 1 +# echo "done running pre-commit checks" echo "running mypy" mypy \ diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 646a306c97f6..9fbdba778cc4 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -59,68 +59,66 @@ } +# [description] +# +# Besides applying checks, this function +# +# 1. turns feature *names* into 1-based integer positions, then +# 2. adds an extra list element with skipped features, then +# 3. turns 1-based integer positions into 0-based positions, and finally +# 4. collapses the values of each list element into a string like "[0, 1]". +# .check_interaction_constraints <- function(interaction_constraints, column_names) { + if (is.null(interaction_constraints)) { + return(list()) + } + if (!identical(class(interaction_constraints), "list")) { + stop("interaction_constraints must be a list") + } - # Convert interaction constraints to feature numbers - string_constraints <- list() + column_indices <- seq_along(column_names) - if (!is.null(interaction_constraints)) { + # Convert feature names to 1-based integer positions and apply checks + for (j in seq_along(interaction_constraints)) { + constraint <- interaction_constraints[[j]] - if (!methods::is(interaction_constraints, "list")) { - stop("interaction_constraints must be a list") - } - constraint_is_character_or_numeric <- sapply( - X = interaction_constraints - , FUN = function(x) { - return(is.character(x) || is.numeric(x)) - } - ) - if (!all(constraint_is_character_or_numeric)) { - stop("every element in interaction_constraints must be a character vector or numeric vector") + if (is.character(constraint)) { + constraint_indices <- match(constraint, column_names) + } else if (is.numeric(constraint)) { + constraint_indices <- as.integer(constraint) + } else { + stop("every element in interaction_constraints must be a character vector or numeric vector") } - for (constraint in interaction_constraints) { - - # Check for character name - if (is.character(constraint)) { - - constraint_indices <- as.integer(match(constraint, column_names) - 1L) - - # Provided indices, but some indices are not existing? - if (sum(is.na(constraint_indices)) > 0L) { - stop( - "supplied an unknown feature in interaction_constraints " - , sQuote(constraint[is.na(constraint_indices)]) - ) - } - - } else { - - # Check that constraint indices are at most number of features - if (max(constraint) > length(column_names)) { - stop( - "supplied a too large value in interaction_constraints: " - , max(constraint) - , " but only " - , length(column_names) - , " features" - ) - } - - # Store indices as [0, n-1] indexed instead of [1, n] indexed - constraint_indices <- as.integer(constraint - 1L) - - } - - # Convert constraint to string - constraint_string <- paste0("[", paste0(constraint_indices, collapse = ","), "]") - string_constraints <- append(string_constraints, constraint_string) + # Features outside range? + bad <- !(constraint_indices %in% column_indices) + if (any(bad)) { + stop( + "unknown feature(s) in interaction_constraints: " + , toString(sQuote(constraint[bad], q = "'")) + ) } + interaction_constraints[[j]] <- constraint_indices } - return(string_constraints) + # Add missing features as new interaction set + remaining_indices <- setdiff( + column_indices, sort(unique(unlist(interaction_constraints))) + ) + if (length(remaining_indices) > 0L) { + interaction_constraints <- c( + interaction_constraints, list(remaining_indices) + ) + } + # Turn indices 0-based and convert to string + for (j in seq_along(interaction_constraints)) { + interaction_constraints[[j]] <- paste0( + "[", paste0(interaction_constraints[[j]] - 1L, collapse = ","), "]" + ) + } + return(interaction_constraints) } diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 74c46dcef141..ed477a42c00b 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -2776,14 +2776,12 @@ test_that(paste0("lgb.train() throws an informative error if the members of inte test_that("lgb.train() throws an informative error if interaction_constraints contains a too large index", { dtrain <- lgb.Dataset(train$data, label = train$label) params <- list(objective = "regression", - interaction_constraints = list(c(1L, length(colnames(train$data)) + 1L), 3L)) - expect_error({ - bst <- lightgbm( - data = dtrain - , params = params - , nrounds = 2L - ) - }, "supplied a too large value in interaction_constraints") + interaction_constraints = list(c(1L, ncol(train$data) + 1L:2L), 3L)) + expect_error( + lightgbm(data = dtrain, params = params, nrounds = 2L) + , "unknown feature(s) in interaction_constraints: '127', '128'" + , fixed = TRUE + ) }) test_that(paste0("lgb.train() gives same result when interaction_constraints is specified as a list of ", @@ -2876,6 +2874,37 @@ test_that(paste0("lgb.train() gives same results when using interaction_constrai }) +test_that("Interaction constraints add missing features correctly as new group", { + dtrain <- lgb.Dataset( + train$data[, 1L:6L] # Pick only some columns + , label = train$label + , params = list(num_threads = .LGB_MAX_THREADS) + ) + + list_of_constraints <- list( + list(3L, 1L:2L) + , list("cap-shape=convex", c("cap-shape=bell", "cap-shape=conical")) + ) + + for (constraints in list_of_constraints) { + params <- list( + objective = "regression" + , interaction_constraints = constraints + , verbose = .LGB_VERBOSITY + , num_threads = .LGB_MAX_THREADS + ) + bst <- lightgbm(data = dtrain, params = params, nrounds = 10L) + + expected_list <- list("[2]", "[0,1]", "[3,4,5]") + expect_equal(bst$params$interaction_constraints, expected_list) + + expected_string <- "[interaction_constraints: [2],[0,1],[3,4,5]]" + expect_true( + grepl(expected_string, bst$save_model_to_string(), fixed = TRUE) + ) + } +}) + .generate_trainset_for_monotone_constraints_tests <- function(x3_to_categorical) { n_samples <- 3000L x1_positively_correlated_with_y <- runif(n = n_samples, min = 0.0, max = 1.0) diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index 8e49c7b7069b..e81dc89673e0 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -174,7 +174,7 @@ test_that("Loading a Booster from a text file works", { , bagging_freq = 1L , boost_from_average = FALSE , categorical_feature = c(1L, 2L) - , interaction_constraints = list(c(1L, 2L), 1L) + , interaction_constraints = list(1L:2L, 3L, 4L:ncol(train$data)) , feature_contri = rep(0.5, ncol(train$data)) , metric = c("mape", "average_precision") , learning_rate = 1.0 diff --git a/R-package/tests/testthat/test_utils.R b/R-package/tests/testthat/test_utils.R index 898aed9b0915..2534cb24cb13 100644 --- a/R-package/tests/testthat/test_utils.R +++ b/R-package/tests/testthat/test_utils.R @@ -147,3 +147,21 @@ test_that(".equal_or_both_null produces expected results", { expect_false(.equal_or_both_null(10.0, 1L)) expect_true(.equal_or_both_null(0L, 0L)) }) + +test_that(".check_interaction_constraints() adds skipped features", { + ref <- letters[1L:5L] + ic_num <- list(1L, c(2L, 3L)) + ic_char <- list("a", c("b", "c")) + expected <- list("[0]", "[1,2]", "[3,4]") + + ic_checked_num <- .check_interaction_constraints( + interaction_constraints = ic_num, column_names = ref + ) + + ic_checked_char <- .check_interaction_constraints( + interaction_constraints = ic_char, column_names = ref + ) + + expect_equal(ic_checked_num, expected) + expect_equal(ic_checked_char, expected) +})