-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TREE] add interaction constraints (#3466)
* add interaction constraints * enable both interaction and monotonic constraints at the same time * fix lint * add R test, fix lint, update demo * Use dmlc::JSONReader to express interaction constraints as nested lists; Use sparse arrays for bookkeeping * Add Python test for interaction constraints * make R interaction constraints parameter based on feature index instead of column names, fix R coding style * Fix lint * Add BlueTea88 to CONTRIBUTORS.md * Short circuit when no constraint is specified; address review comments * Add tutorial for feature interaction constraints * allow interaction constraints to be passed as string, remove redundant column_names argument * Fix typo * Address review comments * Add comments to Python test
- Loading branch information
Showing
12 changed files
with
581 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
library(xgboost) | ||
library(data.table) | ||
|
||
set.seed(1024) | ||
|
||
# Function to obtain a list of interactions fitted in trees, requires input of maximum depth | ||
treeInteractions <- function(input_tree, input_max_depth){ | ||
trees <- copy(input_tree) # copy tree input to prevent overwriting | ||
if (input_max_depth < 2) return(list()) # no interactions if max depth < 2 | ||
if (nrow(input_tree) == 1) return(list()) | ||
|
||
# Attach parent nodes | ||
for (i in 2:input_max_depth){ | ||
if (i == 2) trees[, ID_merge:=ID] else trees[, ID_merge:=get(paste0('parent_',i-2))] | ||
parents_left <- trees[!is.na(Split), list(i.id=ID, i.feature=Feature, ID_merge=Yes)] | ||
parents_right <- trees[!is.na(Split), list(i.id=ID, i.feature=Feature, ID_merge=No)] | ||
|
||
setorderv(trees, 'ID_merge') | ||
setorderv(parents_left, 'ID_merge') | ||
setorderv(parents_right, 'ID_merge') | ||
|
||
trees <- merge(trees, parents_left, by='ID_merge', all.x=T) | ||
trees[!is.na(i.id), c(paste0('parent_', i-1), paste0('parent_feat_', i-1)):=list(i.id, i.feature)] | ||
trees[, c('i.id','i.feature'):=NULL] | ||
|
||
trees <- merge(trees, parents_right, by='ID_merge', all.x=T) | ||
trees[!is.na(i.id), c(paste0('parent_', i-1), paste0('parent_feat_', i-1)):=list(i.id, i.feature)] | ||
trees[, c('i.id','i.feature'):=NULL] | ||
} | ||
|
||
# Extract nodes with interactions | ||
interaction_trees <- trees[!is.na(Split) & !is.na(parent_1), | ||
c('Feature',paste0('parent_feat_',1:(input_max_depth-1))), with=F] | ||
interaction_trees_split <- split(interaction_trees, 1:nrow(interaction_trees)) | ||
interaction_list <- lapply(interaction_trees_split, as.character) | ||
|
||
# Remove NAs (no parent interaction) | ||
interaction_list <- lapply(interaction_list, function(x) x[!is.na(x)]) | ||
|
||
# Remove non-interactions (same variable) | ||
interaction_list <- lapply(interaction_list, unique) # remove same variables | ||
interaction_length <- sapply(interaction_list, length) | ||
interaction_list <- interaction_list[interaction_length > 1] | ||
interaction_list <- unique(lapply(interaction_list, sort)) | ||
return(interaction_list) | ||
} | ||
|
||
# Generate sample data | ||
x <- list() | ||
for (i in 1:10){ | ||
x[[i]] = i*rnorm(1000, 10) | ||
} | ||
x <- as.data.table(x) | ||
|
||
y = -1*x[, rowSums(.SD)] + x[['V1']]*x[['V2']] + x[['V3']]*x[['V4']]*x[['V5']] + rnorm(1000, 0.001) + 3*sin(x[['V7']]) | ||
|
||
train = as.matrix(x) | ||
|
||
# Interaction constraint list (column names form) | ||
interaction_list <- list(c('V1','V2'),c('V3','V4','V5')) | ||
|
||
# Convert interaction constraint list into feature index form | ||
cols2ids <- function(object, col_names) { | ||
LUT <- seq_along(col_names) - 1 | ||
names(LUT) <- col_names | ||
rapply(object, function(x) LUT[x], classes="character", how="replace") | ||
} | ||
interaction_list_fid = cols2ids(interaction_list, colnames(train)) | ||
|
||
# Fit model with interaction constraints | ||
bst = xgboost(data = train, label = y, max_depth = 4, | ||
eta = 0.1, nthread = 2, nrounds = 1000, | ||
interaction_constraints = interaction_list_fid) | ||
|
||
bst_tree <- xgb.model.dt.tree(colnames(train), bst) | ||
bst_interactions <- treeInteractions(bst_tree, 4) # interactions constrained to combinations of V1*V2 and V3*V4*V5 | ||
|
||
# Fit model without interaction constraints | ||
bst2 = xgboost(data = train, label = y, max_depth = 4, | ||
eta = 0.1, nthread = 2, nrounds = 1000) | ||
|
||
bst2_tree <- xgb.model.dt.tree(colnames(train), bst2) | ||
bst2_interactions <- treeInteractions(bst2_tree, 4) # much more interactions | ||
|
||
# Fit model with both interaction and monotonicity constraints | ||
bst3 = xgboost(data = train, label = y, max_depth = 4, | ||
eta = 0.1, nthread = 2, nrounds = 1000, | ||
interaction_constraints = interaction_list_fid, | ||
monotone_constraints = c(-1,0,0,0,0,0,0,0,0,0)) | ||
|
||
bst3_tree <- xgb.model.dt.tree(colnames(train), bst3) | ||
bst3_interactions <- treeInteractions(bst3_tree, 4) # interactions still constrained to combinations of V1*V2 and V3*V4*V5 | ||
|
||
# Show monotonic constraints still apply by checking scores after incrementing V1 | ||
x1 <- sort(unique(x[['V1']])) | ||
for (i in 1:length(x1)){ | ||
testdata <- copy(x[, -c('V1')]) | ||
testdata[['V1']] <- x1[i] | ||
testdata <- testdata[, paste0('V',1:10), with=F] | ||
pred <- predict(bst3, as.matrix(testdata)) | ||
|
||
# Should not print out anything due to monotonic constraints | ||
if (i > 1) if (any(pred > prev_pred)) print(i) | ||
prev_pred <- pred | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
require(xgboost) | ||
|
||
context("interaction constraints") | ||
|
||
set.seed(1024) | ||
x1 <- rnorm(1000, 1) | ||
x2 <- rnorm(1000, 1) | ||
x3 <- sample(c(1,2,3), size=1000, replace=TRUE) | ||
y <- x1 + x2 + x3 + x1*x2*x3 + rnorm(1000, 0.001) + 3*sin(x1) | ||
train <- matrix(c(x1,x2,x3), ncol = 3) | ||
|
||
test_that("interaction constraints for regression", { | ||
# Fit a model that only allows interaction between x1 and x2 | ||
bst <- xgboost(data = train, label = y, max_depth = 3, | ||
eta = 0.1, nthread = 2, nrounds = 100, verbose = 0, | ||
interaction_constraints = list(c(0,1))) | ||
|
||
# Set all observations to have the same x3 values then increment | ||
# by the same amount | ||
preds <- lapply(c(1,2,3), function(x){ | ||
tmat <- matrix(c(x1,x2,rep(x,1000)), ncol=3) | ||
return(predict(bst, tmat)) | ||
}) | ||
|
||
# Check incrementing x3 has the same effect on all observations | ||
# since x3 is constrained to be independent of x1 and x2 | ||
# and all observations start off from the same x3 value | ||
diff1 <- preds[[2]] - preds[[1]] | ||
test1 <- all(abs(diff1 - diff1[1]) < 1e-4) | ||
|
||
diff2 <- preds[[3]] - preds[[2]] | ||
test2 <- all(abs(diff2 - diff2[1]) < 1e-4) | ||
|
||
expect_true({ | ||
test1 & test2 | ||
}, "Interaction Contraint Satisfied") | ||
|
||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
############################### | ||
Feature Interaction Constraints | ||
############################### | ||
|
||
The decision tree is a powerful tool to discover interaction among independent | ||
variables (features). Variables that appear together in a traversal path | ||
are interacting with one another, since the condition of a child node is | ||
predicated on the condition of the parent node. For example, the highlighted | ||
red path in the diagram below contains three variables: :math:`x_1`, :math:`x_7`, | ||
and :math:`x_{10}`, so the highlighted prediction (at the highlighted leaf node) | ||
is the product of interaction between :math:`x_1`, :math:`x_7`, and | ||
:math:`x_{10}`. | ||
|
||
.. plot:: | ||
:nofigs: | ||
|
||
from graphviz import Source | ||
source = r""" | ||
digraph feature_interaction_illustration1 { | ||
graph [fontname = "helvetica"]; | ||
node [fontname = "helvetica"]; | ||
edge [fontname = "helvetica"]; | ||
0 [label=<x<SUB><FONT POINT-SIZE="11">10</FONT></SUB> < -1.5 ?>, shape=box, color=red, fontcolor=red]; | ||
1 [label=<x<SUB><FONT POINT-SIZE="11">2</FONT></SUB> < 2 ?>, shape=box]; | ||
2 [label=<x<SUB><FONT POINT-SIZE="11">7</FONT></SUB> < 0.3 ?>, shape=box, color=red, fontcolor=red]; | ||
3 [label="...", shape=none]; | ||
4 [label="...", shape=none]; | ||
5 [label=<x<SUB><FONT POINT-SIZE="11">1</FONT></SUB> < 0.5 ?>, shape=box, color=red, fontcolor=red]; | ||
6 [label="...", shape=none]; | ||
7 [label="...", shape=none]; | ||
8 [label="Predict +1.3", color=red, fontcolor=red]; | ||
0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "]; | ||
0 -> 2 [labeldistance=2.0, labelangle=-45, | ||
headlabel="No", color=red, fontcolor=red]; | ||
1 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes"]; | ||
1 -> 4 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"]; | ||
2 -> 5 [labeldistance=2.0, labelangle=-45, headlabel="Yes", | ||
color=red, fontcolor=red]; | ||
2 -> 6 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"]; | ||
5 -> 7; | ||
5 -> 8 [color=red]; | ||
} | ||
""" | ||
Source(source, format='png').render('../_static/feature_interaction_illustration1', view=False) | ||
Source(source, format='svg').render('../_static/feature_interaction_illustration1', view=False) | ||
|
||
.. raw:: html | ||
|
||
<p> | ||
<img src="../_static/feature_interaction_illustration1.svg" | ||
onerror="this.src='../_static/feature_interaction_illustration1.png'; this.onerror=null;"> | ||
</p> | ||
|
||
When the tree depth is larger than one, many variables interact on | ||
the sole basis of minimizing training loss, and the resulting decision tree may | ||
capture a spurious relationship (noise) rather than a legitimate relationship | ||
that generalizes across different datasets. **Feature interaction constraints** | ||
allow users to decide which variables are allowed to interact and which are not. | ||
|
||
Potential benefits include: | ||
|
||
* Better predictive performance from focusing on interactions that work -- | ||
whether through domain specific knowledge or algorithms that rank interactions | ||
* Less noise in predictions; better generalization | ||
* More control to the user on what the model can fit. For example, the user may | ||
want to exclude some interactions even if they perform well due to regulatory | ||
constraints | ||
|
||
**************** | ||
A Simple Example | ||
**************** | ||
|
||
Feature interaction constraints are expressed in terms of groups of variables | ||
that are allowed to interact. For example, the constraint | ||
``[0, 1]`` indicates that variables :math:`x_0` and :math:`x_1` are allowed to | ||
interact with each other but with no other variable. Similarly, ``[2, 3, 4]`` | ||
indicates that :math:`x_2`, :math:`x_3`, and :math:`x_4` are allowed to | ||
interact with one another but with no other variable. A set of feature | ||
interaction constraints is expressed as a nested list, e.g. | ||
``[[0, 1], [2, 3, 4]]``, where each inner list is a group of indices of features | ||
that are allowed to interact with each other. | ||
|
||
In the following diagram, the left decision tree is in violation of the first | ||
constraint (``[0, 1]``), whereas the right decision tree complies with both the | ||
first and second constraints (``[0, 1]``, ``[2, 3, 4]``). | ||
|
||
.. plot:: | ||
:nofigs: | ||
|
||
from graphviz import Source | ||
source = r""" | ||
digraph feature_interaction_illustration2 { | ||
graph [fontname = "helvetica"]; | ||
node [fontname = "helvetica"]; | ||
edge [fontname = "helvetica"]; | ||
0 [label=<x<SUB><FONT POINT-SIZE="11">0</FONT></SUB> < 5.0 ?>, shape=box]; | ||
1 [label=<x<SUB><FONT POINT-SIZE="11">2</FONT></SUB> < -3.0 ?>, shape=box]; | ||
2 [label="+0.6"]; | ||
3 [label="-0.4"]; | ||
4 [label="+1.2"]; | ||
0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "]; | ||
0 -> 2 [labeldistance=2.0, labelangle=-45, headlabel="No"]; | ||
1 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes"]; | ||
1 -> 4 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"]; | ||
} | ||
""" | ||
Source(source, format='png').render('../_static/feature_interaction_illustration2', view=False) | ||
Source(source, format='svg').render('../_static/feature_interaction_illustration2', view=False) | ||
|
||
.. plot:: | ||
:nofigs: | ||
|
||
from graphviz import Source | ||
source = r""" | ||
digraph feature_interaction_illustration3 { | ||
graph [fontname = "helvetica"]; | ||
node [fontname = "helvetica"]; | ||
edge [fontname = "helvetica"]; | ||
0 [label=<x<SUB><FONT POINT-SIZE="11">3</FONT></SUB> < 2.5 ?>, shape=box]; | ||
1 [label="+1.6"]; | ||
2 [label=<x<SUB><FONT POINT-SIZE="11">2</FONT></SUB> < -1.2 ?>, shape=box]; | ||
3 [label="+0.1"]; | ||
4 [label="-0.3"]; | ||
0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes"]; | ||
0 -> 2 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"]; | ||
2 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "]; | ||
2 -> 4 [labeldistance=2.0, labelangle=-45, headlabel="No"]; | ||
} | ||
""" | ||
Source(source, format='png').render('../_static/feature_interaction_illustration3', view=False) | ||
Source(source, format='svg').render('../_static/feature_interaction_illustration3', view=False) | ||
|
||
.. raw:: html | ||
|
||
<p> | ||
<img src="../_static/feature_interaction_illustration2.svg" | ||
onerror="this.src='../_static/feature_interaction_illustration2.png'; this.onerror=null;"> | ||
<img src="../_static/feature_interaction_illustration3.svg" | ||
onerror="this.src='../_static/feature_interaction_illustration3.png'; this.onerror=null;"> | ||
</p> | ||
|
||
**************************************************** | ||
Enforcing Feature Interaction Constraints in XGBoost | ||
**************************************************** | ||
|
||
It is very simple to enforce monotonicity constraints in XGBoost. Here we will | ||
give an example using Python, but the same general idea generalizes to other | ||
platforms. | ||
|
||
Suppose the following code fits your model without monotonicity constraints: | ||
|
||
.. code-block:: python | ||
model_no_constraints = xgb.train(params, dtrain, | ||
num_boost_round = 1000, evals = evallist, | ||
early_stopping_rounds = 10) | ||
Then fitting with monotonicity constraints only requires adding a single | ||
parameter: | ||
|
||
.. code-block:: python | ||
params_constrained = params.copy() | ||
# Use nested list to define feature interaction constraints | ||
params_constrained['interaction_constraints'] = '[[0, 2], [1, 3, 4], [5, 6]]' | ||
# Features 0 and 2 are allowed to interact with each other but with no other feature | ||
# Features 1, 3, 4 are allowed to interact with one another but with no other feature | ||
# Features 5 and 6 are allowed to interact with each other but with no other feature | ||
model_with_constraints = xgb.train(params_constrained, dtrain, | ||
num_boost_round = 1000, evals = evallist, | ||
early_stopping_rounds = 10) | ||
**Choice of tree construction algorithm**. To use feature interaction | ||
constraints, be sure to set the ``tree_method`` parameter to either ``exact`` | ||
or ``hist``. Currently, GPU algorithms (``gpu_hist``, ``gpu_exact``) do not | ||
support feature interaction constraints. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.