diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index a406a4358478..f23e80426f3d 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -78,3 +78,5 @@ List of Contributors * [Pierre de Sahb](https://github.com/pdesahb) * [liuliang01](https://github.com/liuliang01) - liuliang01 added support for the qid column for LibSVM input format. This makes ranking task easier in distributed setting. +* [Andrew Thia](https://github.com/BlueTea88) + - Andrew Thia implemented feature interaction constraints diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 3a4b59f7ff49..6d618709a368 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -74,6 +74,19 @@ check.booster.params <- function(params, ...) { params[['monotone_constraints']] = vec2str } + # interaction constraints parser (convert from list of column indices to string) + if (!is.null(params[['interaction_constraints']]) && + typeof(params[['interaction_constraints']]) != "character"){ + # check input class + if (class(params[['interaction_constraints']]) != 'list') stop('interaction_constraints should be class list') + if (!all(unique(sapply(params[['interaction_constraints']], class)) %in% c('numeric','integer'))) { + stop('interaction_constraints should be a list of numeric/integer vectors') + } + + # recast parameter as string + interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse=','), ']')) + params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse=','), ']') + } return(params) } diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 80ade2b43a37..7061114ca1bc 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -26,6 +26,7 @@ #' \item \code{colsample_bytree} subsample ratio of columns when constructing each tree. Default: 1 #' \item \code{num_parallel_tree} Experimental parameter. number of trees to grow per round. Useful to test Random Forest through Xgboost (set \code{colsample_bytree < 1}, \code{subsample < 1} and \code{round = 1}) accordingly. Default: 1 #' \item \code{monotone_constraints} A numerical vector consists of \code{1}, \code{0} and \code{-1} with its length equals to the number of features in the training data. \code{1} is increasing, \code{-1} is decreasing and \code{0} is no constraint. +#' \item \code{interaction_constraints} A list of vectors specifying feature indices of permitted interactions. Each item of the list represents one permitted interaction where specified features are allowed to interact with each other. Feature index values should start from \code{0} (\code{0} references the first column). Leave argument unspecified for no interaction constraints. #' } #' #' 2.2. Parameter for Linear Booster diff --git a/R-package/demo/interaction_constraints.R b/R-package/demo/interaction_constraints.R new file mode 100644 index 000000000000..2f2edb15548c --- /dev/null +++ b/R-package/demo/interaction_constraints.R @@ -0,0 +1,105 @@ +library(xgboost) +library(data.table) + +set.seed(1024) + +# Function to obtain a list of interactions fitted in trees, requires input of maximum depth +treeInteractions <- function(input_tree, input_max_depth){ + trees <- copy(input_tree) # copy tree input to prevent overwriting + if (input_max_depth < 2) return(list()) # no interactions if max depth < 2 + if (nrow(input_tree) == 1) return(list()) + + # Attach parent nodes + for (i in 2:input_max_depth){ + if (i == 2) trees[, ID_merge:=ID] else trees[, ID_merge:=get(paste0('parent_',i-2))] + parents_left <- trees[!is.na(Split), list(i.id=ID, i.feature=Feature, ID_merge=Yes)] + parents_right <- trees[!is.na(Split), list(i.id=ID, i.feature=Feature, ID_merge=No)] + + setorderv(trees, 'ID_merge') + setorderv(parents_left, 'ID_merge') + setorderv(parents_right, 'ID_merge') + + trees <- merge(trees, parents_left, by='ID_merge', all.x=T) + trees[!is.na(i.id), c(paste0('parent_', i-1), paste0('parent_feat_', i-1)):=list(i.id, i.feature)] + trees[, c('i.id','i.feature'):=NULL] + + trees <- merge(trees, parents_right, by='ID_merge', all.x=T) + trees[!is.na(i.id), c(paste0('parent_', i-1), paste0('parent_feat_', i-1)):=list(i.id, i.feature)] + trees[, c('i.id','i.feature'):=NULL] + } + + # Extract nodes with interactions + interaction_trees <- trees[!is.na(Split) & !is.na(parent_1), + c('Feature',paste0('parent_feat_',1:(input_max_depth-1))), with=F] + interaction_trees_split <- split(interaction_trees, 1:nrow(interaction_trees)) + interaction_list <- lapply(interaction_trees_split, as.character) + + # Remove NAs (no parent interaction) + interaction_list <- lapply(interaction_list, function(x) x[!is.na(x)]) + + # Remove non-interactions (same variable) + interaction_list <- lapply(interaction_list, unique) # remove same variables + interaction_length <- sapply(interaction_list, length) + interaction_list <- interaction_list[interaction_length > 1] + interaction_list <- unique(lapply(interaction_list, sort)) + return(interaction_list) +} + +# Generate sample data +x <- list() +for (i in 1:10){ + x[[i]] = i*rnorm(1000, 10) +} +x <- as.data.table(x) + +y = -1*x[, rowSums(.SD)] + x[['V1']]*x[['V2']] + x[['V3']]*x[['V4']]*x[['V5']] + rnorm(1000, 0.001) + 3*sin(x[['V7']]) + +train = as.matrix(x) + +# Interaction constraint list (column names form) +interaction_list <- list(c('V1','V2'),c('V3','V4','V5')) + +# Convert interaction constraint list into feature index form +cols2ids <- function(object, col_names) { + LUT <- seq_along(col_names) - 1 + names(LUT) <- col_names + rapply(object, function(x) LUT[x], classes="character", how="replace") +} +interaction_list_fid = cols2ids(interaction_list, colnames(train)) + +# Fit model with interaction constraints +bst = xgboost(data = train, label = y, max_depth = 4, + eta = 0.1, nthread = 2, nrounds = 1000, + interaction_constraints = interaction_list_fid) + +bst_tree <- xgb.model.dt.tree(colnames(train), bst) +bst_interactions <- treeInteractions(bst_tree, 4) # interactions constrained to combinations of V1*V2 and V3*V4*V5 + +# Fit model without interaction constraints +bst2 = xgboost(data = train, label = y, max_depth = 4, + eta = 0.1, nthread = 2, nrounds = 1000) + +bst2_tree <- xgb.model.dt.tree(colnames(train), bst2) +bst2_interactions <- treeInteractions(bst2_tree, 4) # much more interactions + +# Fit model with both interaction and monotonicity constraints +bst3 = xgboost(data = train, label = y, max_depth = 4, + eta = 0.1, nthread = 2, nrounds = 1000, + interaction_constraints = interaction_list_fid, + monotone_constraints = c(-1,0,0,0,0,0,0,0,0,0)) + +bst3_tree <- xgb.model.dt.tree(colnames(train), bst3) +bst3_interactions <- treeInteractions(bst3_tree, 4) # interactions still constrained to combinations of V1*V2 and V3*V4*V5 + +# Show monotonic constraints still apply by checking scores after incrementing V1 +x1 <- sort(unique(x[['V1']])) +for (i in 1:length(x1)){ + testdata <- copy(x[, -c('V1')]) + testdata[['V1']] <- x1[i] + testdata <- testdata[, paste0('V',1:10), with=F] + pred <- predict(bst3, as.matrix(testdata)) + + # Should not print out anything due to monotonic constraints + if (i > 1) if (any(pred > prev_pred)) print(i) + prev_pred <- pred +} diff --git a/R-package/tests/testthat/test_interaction_constraints.R b/R-package/tests/testthat/test_interaction_constraints.R new file mode 100644 index 000000000000..1b4902576e61 --- /dev/null +++ b/R-package/tests/testthat/test_interaction_constraints.R @@ -0,0 +1,38 @@ +require(xgboost) + +context("interaction constraints") + +set.seed(1024) +x1 <- rnorm(1000, 1) +x2 <- rnorm(1000, 1) +x3 <- sample(c(1,2,3), size=1000, replace=TRUE) +y <- x1 + x2 + x3 + x1*x2*x3 + rnorm(1000, 0.001) + 3*sin(x1) +train <- matrix(c(x1,x2,x3), ncol = 3) + +test_that("interaction constraints for regression", { + # Fit a model that only allows interaction between x1 and x2 + bst <- xgboost(data = train, label = y, max_depth = 3, + eta = 0.1, nthread = 2, nrounds = 100, verbose = 0, + interaction_constraints = list(c(0,1))) + + # Set all observations to have the same x3 values then increment + # by the same amount + preds <- lapply(c(1,2,3), function(x){ + tmat <- matrix(c(x1,x2,rep(x,1000)), ncol=3) + return(predict(bst, tmat)) + }) + + # Check incrementing x3 has the same effect on all observations + # since x3 is constrained to be independent of x1 and x2 + # and all observations start off from the same x3 value + diff1 <- preds[[2]] - preds[[1]] + test1 <- all(abs(diff1 - diff1[1]) < 1e-4) + + diff2 <- preds[[3]] - preds[[2]] + test2 <- all(abs(diff2 - diff2[1]) < 1e-4) + + expect_true({ + test1 & test2 + }, "Interaction Contraint Satisfied") + +}) diff --git a/doc/conf.py b/doc/conf.py index 759a785a4472..268f553a73d4 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -41,7 +41,7 @@ # -- mock out modules import mock -MOCK_MODULES = ['numpy', 'scipy', 'scipy.sparse', 'sklearn', 'matplotlib', 'pandas', 'graphviz'] +MOCK_MODULES = ['scipy', 'scipy.sparse', 'sklearn', 'pandas'] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() @@ -62,6 +62,8 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones extensions = [ + 'matplotlib.sphinxext.only_directives', + 'matplotlib.sphinxext.plot_directive', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.mathjax', @@ -69,6 +71,11 @@ 'breathe' ] +graphviz_output_format = 'png' +plot_formats = [('svg', 300), ('png', 100), ('hires.png', 300)] +plot_html_show_source_link = False +plot_html_show_formats = False + # Breathe extension variables breathe_projects = {"xgboost": "doxyxml/"} breathe_default_project = "xgboost" diff --git a/doc/tutorials/feature_interaction_constraint.rst b/doc/tutorials/feature_interaction_constraint.rst new file mode 100644 index 000000000000..dcd653d70097 --- /dev/null +++ b/doc/tutorials/feature_interaction_constraint.rst @@ -0,0 +1,177 @@ +############################### +Feature Interaction Constraints +############################### + +The decision tree is a powerful tool to discover interaction among independent +variables (features). Variables that appear together in a traversal path +are interacting with one another, since the condition of a child node is +predicated on the condition of the parent node. For example, the highlighted +red path in the diagram below contains three variables: :math:`x_1`, :math:`x_7`, +and :math:`x_{10}`, so the highlighted prediction (at the highlighted leaf node) +is the product of interaction between :math:`x_1`, :math:`x_7`, and +:math:`x_{10}`. + +.. plot:: + :nofigs: + + from graphviz import Source + source = r""" + digraph feature_interaction_illustration1 { + graph [fontname = "helvetica"]; + node [fontname = "helvetica"]; + edge [fontname = "helvetica"]; + 0 [label=10 < -1.5 ?>, shape=box, color=red, fontcolor=red]; + 1 [label=2 < 2 ?>, shape=box]; + 2 [label=7 < 0.3 ?>, shape=box, color=red, fontcolor=red]; + 3 [label="...", shape=none]; + 4 [label="...", shape=none]; + 5 [label=1 < 0.5 ?>, shape=box, color=red, fontcolor=red]; + 6 [label="...", shape=none]; + 7 [label="...", shape=none]; + 8 [label="Predict +1.3", color=red, fontcolor=red]; + 0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "]; + 0 -> 2 [labeldistance=2.0, labelangle=-45, + headlabel="No", color=red, fontcolor=red]; + 1 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes"]; + 1 -> 4 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"]; + 2 -> 5 [labeldistance=2.0, labelangle=-45, headlabel="Yes", + color=red, fontcolor=red]; + 2 -> 6 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"]; + 5 -> 7; + 5 -> 8 [color=red]; + } + """ + Source(source, format='png').render('../_static/feature_interaction_illustration1', view=False) + Source(source, format='svg').render('../_static/feature_interaction_illustration1', view=False) + +.. raw:: html + +

+ +

+ +When the tree depth is larger than one, many variables interact on +the sole basis of minimizing training loss, and the resulting decision tree may +capture a spurious relationship (noise) rather than a legitimate relationship +that generalizes across different datasets. **Feature interaction constraints** +allow users to decide which variables are allowed to interact and which are not. + +Potential benefits include: + +* Better predictive performance from focusing on interactions that work -- + whether through domain specific knowledge or algorithms that rank interactions +* Less noise in predictions; better generalization +* More control to the user on what the model can fit. For example, the user may + want to exclude some interactions even if they perform well due to regulatory + constraints + +**************** +A Simple Example +**************** + +Feature interaction constraints are expressed in terms of groups of variables +that are allowed to interact. For example, the constraint +``[0, 1]`` indicates that variables :math:`x_0` and :math:`x_1` are allowed to +interact with each other but with no other variable. Similarly, ``[2, 3, 4]`` +indicates that :math:`x_2`, :math:`x_3`, and :math:`x_4` are allowed to +interact with one another but with no other variable. A set of feature +interaction constraints is expressed as a nested list, e.g. +``[[0, 1], [2, 3, 4]]``, where each inner list is a group of indices of features +that are allowed to interact with each other. + +In the following diagram, the left decision tree is in violation of the first +constraint (``[0, 1]``), whereas the right decision tree complies with both the +first and second constraints (``[0, 1]``, ``[2, 3, 4]``). + +.. plot:: + :nofigs: + + from graphviz import Source + source = r""" + digraph feature_interaction_illustration2 { + graph [fontname = "helvetica"]; + node [fontname = "helvetica"]; + edge [fontname = "helvetica"]; + 0 [label=0 < 5.0 ?>, shape=box]; + 1 [label=2 < -3.0 ?>, shape=box]; + 2 [label="+0.6"]; + 3 [label="-0.4"]; + 4 [label="+1.2"]; + 0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "]; + 0 -> 2 [labeldistance=2.0, labelangle=-45, headlabel="No"]; + 1 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes"]; + 1 -> 4 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"]; + } + """ + Source(source, format='png').render('../_static/feature_interaction_illustration2', view=False) + Source(source, format='svg').render('../_static/feature_interaction_illustration2', view=False) + +.. plot:: + :nofigs: + + from graphviz import Source + source = r""" + digraph feature_interaction_illustration3 { + graph [fontname = "helvetica"]; + node [fontname = "helvetica"]; + edge [fontname = "helvetica"]; + 0 [label=3 < 2.5 ?>, shape=box]; + 1 [label="+1.6"]; + 2 [label=2 < -1.2 ?>, shape=box]; + 3 [label="+0.1"]; + 4 [label="-0.3"]; + 0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes"]; + 0 -> 2 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"]; + 2 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "]; + 2 -> 4 [labeldistance=2.0, labelangle=-45, headlabel="No"]; + } + """ + Source(source, format='png').render('../_static/feature_interaction_illustration3', view=False) + Source(source, format='svg').render('../_static/feature_interaction_illustration3', view=False) + +.. raw:: html + +

+ + +

+ +**************************************************** +Enforcing Feature Interaction Constraints in XGBoost +**************************************************** + +It is very simple to enforce monotonicity constraints in XGBoost. Here we will +give an example using Python, but the same general idea generalizes to other +platforms. + +Suppose the following code fits your model without monotonicity constraints: + +.. code-block:: python + + model_no_constraints = xgb.train(params, dtrain, + num_boost_round = 1000, evals = evallist, + early_stopping_rounds = 10) + +Then fitting with monotonicity constraints only requires adding a single +parameter: + +.. code-block:: python + + params_constrained = params.copy() + # Use nested list to define feature interaction constraints + params_constrained['interaction_constraints'] = '[[0, 2], [1, 3, 4], [5, 6]]' + # Features 0 and 2 are allowed to interact with each other but with no other feature + # Features 1, 3, 4 are allowed to interact with one another but with no other feature + # Features 5 and 6 are allowed to interact with each other but with no other feature + + model_with_constraints = xgb.train(params_constrained, dtrain, + num_boost_round = 1000, evals = evallist, + early_stopping_rounds = 10) + +**Choice of tree construction algorithm**. To use feature interaction +constraints, be sure to set the ``tree_method`` parameter to either ``exact`` +or ``hist``. Currently, GPU algorithms (``gpu_hist``, ``gpu_exact``) do not +support feature interaction constraints. diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index db48d9fbf84a..ef47f63f7ca4 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -14,6 +14,7 @@ See `Awesome XGBoost `_ for mo Distributed XGBoost with XGBoost4J-Spark dart monotonic + feature_interaction_constraint input_format param_tuning external_memory diff --git a/doc/tutorials/monotonic.rst b/doc/tutorials/monotonic.rst index f3dcc7c058d4..1cecac5b1817 100644 --- a/doc/tutorials/monotonic.rst +++ b/doc/tutorials/monotonic.rst @@ -82,7 +82,7 @@ Some other examples: - ``(1,0)``: An increasing constraint on the first predictor and no constraint on the second. - ``(0,-1)``: No constraint on the first predictor and a decreasing constraint on the second. -**Choise of tree construction algorithm**. To use monotonic constraints, be +**Choice of tree construction algorithm**. To use monotonic constraints, be sure to set the ``tree_method`` parameter to one of ``exact``, ``hist``, and ``gpu_hist``. diff --git a/src/tree/param.h b/src/tree/param.h index 43d653e9b99d..ce3183b711b8 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -194,7 +194,7 @@ struct TrainParam : public dmlc::Parameter { .describe("Number of rows in a GPU batch, used for finding quantiles on GPU; " "-1 to use all rows assignted to a GPU, and 0 to auto-deduce"); DMLC_DECLARE_FIELD(split_evaluator) - .set_default("elastic_net,monotonic") + .set_default("elastic_net,monotonic,interaction") .describe("The criteria to use for ranking splits"); // add alias of parameters DMLC_DECLARE_ALIAS(reg_lambda, lambda); diff --git a/src/tree/split_evaluator.cc b/src/tree/split_evaluator.cc index dc9da278dca6..1c161d7ea5c5 100644 --- a/src/tree/split_evaluator.cc +++ b/src/tree/split_evaluator.cc @@ -4,8 +4,11 @@ * \brief Contains implementations of different split evaluators. */ #include "split_evaluator.h" +#include #include #include +#include +#include #include #include #include @@ -303,5 +306,196 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic") return new MonotonicConstraint(std::move(inner)); }); +/*! \brief Encapsulates the parameters required by the InteractionConstraint + split evaluator +*/ +struct InteractionConstraintParams + : public dmlc::Parameter { + std::string interaction_constraints; + bst_uint num_feature; + + DMLC_DECLARE_PARAMETER(InteractionConstraintParams) { + DMLC_DECLARE_FIELD(interaction_constraints) + .set_default("") + .describe("Constraints for interaction representing permitted interactions." + "The constraints must be specified in the form of a nest list," + "e.g. [[0, 1], [2, 3, 4]], where each inner list is a group of" + "indices of features that are allowed to interact with each other." + "See tutorial for more information"); + DMLC_DECLARE_FIELD(num_feature) + .describe("Number of total features used"); + } +}; + +DMLC_REGISTER_PARAMETER(InteractionConstraintParams); + +/*! \brief Enforces that the tree is monotonically increasing/decreasing with respect to a user specified set of + features. +*/ +class InteractionConstraint final : public SplitEvaluator { + public: + explicit InteractionConstraint(std::unique_ptr inner) { + if (!inner) { + LOG(FATAL) << "InteractionConstraint must be given an inner evaluator"; + } + inner_ = std::move(inner); + } + + void Init(const std::vector >& args) + override { + inner_->Init(args); + params_.InitAllowUnknown(args); + Reset(); + } + + void Reset() override { + if (params_.interaction_constraints.empty()) { + return; // short-circuit if no constraint is specified + } + + // Parse interaction constraints + std::istringstream iss(params_.interaction_constraints); + dmlc::JSONReader reader(&iss); + // Read std::vector> first and then + // convert to std::vector> + std::vector> tmp; + reader.Read(&tmp); + for (const auto& e : tmp) { + interaction_constraints_.emplace_back(e.begin(), e.end()); + } + + // Initialise interaction constraints record with all variables permitted for the first node + int_cont_.clear(); + int_cont_.resize(1, std::unordered_set()); + int_cont_[0].reserve(params_.num_feature); + for (bst_uint i = 0; i < params_.num_feature; ++i) { + int_cont_[0].insert(i); + } + + // Initialise splits record + splits_.clear(); + splits_.resize(1, std::unordered_set()); + } + + SplitEvaluator* GetHostClone() const override { + if (params_.interaction_constraints.empty()) { + // No interaction constraints specified, just return a clone of inner + return inner_->GetHostClone(); + } else { + auto c = new InteractionConstraint( + std::unique_ptr(inner_->GetHostClone())); + c->params_ = this->params_; + c->Reset(); + return c; + } + } + + bst_float ComputeSplitScore(bst_uint nodeid, + bst_uint featureid, + const GradStats& left_stats, + const GradStats& right_stats, + bst_float left_weight, + bst_float right_weight) const override { + // Return negative infinity score if feature is not permitted by interaction constraints + if (!CheckInteractionConstraint(featureid, nodeid)) { + return -std::numeric_limits::infinity(); + } + + // Otherwise, get score from inner evaluator + bst_float score = inner_->ComputeSplitScore( + nodeid, featureid, left_stats, right_stats, left_weight, right_weight); + return score; + } + + bst_float ComputeScore(bst_uint parentID, const GradStats& stats, bst_float weight) + const override { + return inner_->ComputeScore(parentID, stats, weight); + } + + bst_float ComputeWeight(bst_uint parentID, const GradStats& stats) + const override { + return inner_->ComputeWeight(parentID, stats); + } + + void AddSplit(bst_uint nodeid, + bst_uint leftid, + bst_uint rightid, + bst_uint featureid, + bst_float leftweight, + bst_float rightweight) override { + inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight); + + if (params_.interaction_constraints.empty()) { + return; // short-circuit if no constraint is specified + } + bst_uint newsize = std::max(leftid, rightid) + 1; + + // Record previous splits for child nodes + std::unordered_set feature_splits = splits_[nodeid]; // fid history of current node + feature_splits.insert(featureid); // add feature of current node + splits_.resize(newsize); + splits_[leftid] = feature_splits; + splits_[rightid] = feature_splits; + + // Resize constraints record, initialise all features to be not permitted for new nodes + int_cont_.resize(newsize, std::unordered_set()); + + // Permit features used in previous splits + for (bst_uint fid : feature_splits) { + int_cont_[leftid].insert(fid); + int_cont_[rightid].insert(fid); + } + + // Loop across specified interactions in constraints + for (const auto& constraint : interaction_constraints_) { + bst_uint flag = 1; // flags whether the specified interaction is still relevant + + // Test relevance of specified interaction by checking all previous features are included + for (bst_uint checkvar : feature_splits) { + if (constraint.count(checkvar) == 0) { + flag = 0; + break; // interaction is not relevant due to unmet constraint + } + } + + // If interaction is still relevant, permit all other features in the interaction + if (flag == 1) { + for (bst_uint k : constraint) { + int_cont_[leftid].insert(k); + int_cont_[rightid].insert(k); + } + } + } + } + + private: + InteractionConstraintParams params_; + std::unique_ptr inner_; + // interaction_constraints_[constraint_id] contains a single interaction + // constraint, which specifies a group of feature IDs that can interact + // with each other + std::vector< std::unordered_set > interaction_constraints_; + // int_cont_[nid] contains the set of all feature IDs that are allowed to + // be used for a split at node nid + std::vector< std::unordered_set > int_cont_; + // splits_[nid] contains the set of all feature IDs that have been used for + // splits in node nid and its parents + std::vector< std::unordered_set > splits_; + + // Check interaction constraints. Returns true if a given feature ID is + // permissible in a given node; returns false otherwise + inline bool CheckInteractionConstraint(bst_uint featureid, bst_uint nodeid) const { + // short-circuit if no constraint is specified + return (params_.interaction_constraints.empty() + || int_cont_[nodeid].count(featureid) > 0); + } +}; + +XGBOOST_REGISTER_SPLIT_EVALUATOR(InteractionConstraint, "interaction") +.describe("Enforces interaction constraints on tree features") +.set_body([](std::unique_ptr inner) { + return new InteractionConstraint(std::move(inner)); + }); + } // namespace tree } // namespace xgboost diff --git a/tests/python/test_interaction_constraints.py b/tests/python/test_interaction_constraints.py new file mode 100644 index 000000000000..c8c06bbe4453 --- /dev/null +++ b/tests/python/test_interaction_constraints.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +import numpy as np +import xgboost +import unittest + +dpath = 'demo/data/' +rng = np.random.RandomState(1994) + + +class TestInteractionConstraints(unittest.TestCase): + + def test_interaction_constraints(self): + x1 = np.random.normal(loc=1.0, scale=1.0, size=1000) + x2 = np.random.normal(loc=1.0, scale=1.0, size=1000) + x3 = np.random.choice([1, 2, 3], size=1000, replace=True) + y = x1 + x2 + x3 + x1 * x2 * x3 \ + + np.random.normal(loc=0.001, scale=1.0, size=1000) + 3 * np.sin(x1) + X = np.column_stack((x1, x2, x3)) + dtrain = xgboost.DMatrix(X, label=y) + + params = {'max_depth': 3, 'eta': 0.1, 'nthread': 2, 'silent': 1, + 'interaction_constraints': '[[0, 1]]'} + num_boost_round = 100 + # Fit a model that only allows interaction between x1 and x2 + bst = xgboost.train(params, dtrain, num_boost_round, evals=[(dtrain, 'train')]) + + # Set all observations to have the same x3 values then increment + # by the same amount + def f(x): + tmat = xgboost.DMatrix(np.column_stack((x1, x2, np.repeat(x, 1000)))) + return bst.predict(tmat) + preds = [f(x) for x in [1, 2, 3]] + + # Check incrementing x3 has the same effect on all observations + # since x3 is constrained to be independent of x1 and x2 + # and all observations start off from the same x3 value + diff1 = preds[1] - preds[0] + assert np.all(np.abs(diff1 - diff1[0]) < 1e-4) + diff2 = preds[2] - preds[1] + assert np.all(np.abs(diff2 - diff2[0]) < 1e-4)