Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] add a tree plotting function #6729

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ Suggests:
markdown,
processx,
RhpcBLASctl,
testthat
testthat,
DiagrammeR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please keep this list in alphabetical order (move DiagrammeR to the top of this list).

You will also need to add DiagrammeR to every place in continuous integration scripts that installs optional dependencies for the project. You can find those like this:

git grep processx

Depends:
R (>= 3.5)
Imports:
Expand Down
1 change: 1 addition & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export(lgb.make_serializable)
export(lgb.model.dt.tree)
export(lgb.plot.importance)
export(lgb.plot.interpretation)
export(lgb.plot.tree)
export(lgb.restore_handle)
export(lgb.save)
export(lgb.slice.Dataset)
Expand Down
184 changes: 184 additions & 0 deletions R-package/R/lgb.plot.tree.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#' @name lgb.plot.tree
#' @title Plot a single LightGBM tree using DiagrammeR.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#' @title Plot a single LightGBM tree using DiagrammeR.
#' @title Plot a single LightGBM tree.

Let's simplify this, please.

#' @description The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree.
#' @param model a \code{lgb.Booster} object.
#' @param tree an integer specifying the tree to plot.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#' @param tree an integer specifying the tree to plot.
#' @param tree an integer specifying the tree to plot. This is 1-based, so e.g. a value of '7' means 'the 7th tree' (tree_index=6 in LightGBM's underlying representation).

Let's please make this 1-based, as that's a direction we eventually want to move in the package: #4970 (review)

#' @param rules a list of rules to replace the split values with feature levels.
#'
#' @return
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot.
#'
#' @details
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. The tree is extracted from the model and displayed as a directed graph. The nodes are labelled with the feature, split value, gain, cover and value. The edges are labelled with the decision type and split value. The nodes are styled with a rectangle shape and filled with a beige colour. Leaf nodes are styled with an oval shape and filled with a khaki colour. The graph is rendered using the dot layout with a left-to-right rank direction. The nodes are coloured dim gray with a filled style and a Helvetica font. The edges are coloured dim gray with a solid style, a 1.5 arrow size, a vee arrowhead and a Helvetica font.
#'
#' @examples
#' \donttest{
#' # EXAMPLE: use the LightGBM example dataset to build a model with a single tree
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' # define model parameters and build a single tree
#' params <- list(
#' objective = "regression",
#' metric = "l2",
#' min_data = 1L,
#' learning_rate = 1.0
#' )
Comment on lines +24 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#' params <- list(
#' objective = "regression",
#' metric = "l2",
#' min_data = 1L,
#' learning_rate = 1.0
#' )
#' params <- list(
#' objective = "regression",
#' min_data = 1L
#' )

I understand how min_data = 1L is related (growing a deeper tree makes the resulting plot more interesting). But I think we can safely remove metric = "l2" (that will be the default for the regression objective) and any customization of the learning rate (since here we're only interested in showing the structure of one tree).

Let's simplify this, please.

#' valids <- list(test = dtest)
#' model <- lgb.train(
#' params = params,
#' data = dtrain,
#' nrounds = 1L,
#' valids = valids,
#' early_stopping_rounds = 1L
#' )
#' # plot the tree and compare to the tree table
#' # trees start from 0 in lgb.model.dt.tree
#' tree_table <- lgb.model.dt.tree(model)
#' lgb.plot.tree(model, 0)
#' }
#'
#' @export

# function to plot a single LightGBM tree using DiagrammeR
Comment on lines +45 to +46
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# function to plot a single LightGBM tree using DiagrammeR

We do not need to repeat in a comment here the same information that's already in the roxygen comments.

lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
# check model is lgb.Booster
if (!inherits(model, "lgb.Booster")) {
stop("model: Has to be an object of class lgb.Booster")
}
Comment on lines +49 to +51
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!inherits(model, "lgb.Booster")) {
stop("model: Has to be an object of class lgb.Booster")
}
if (!.is_Booster(x = model)) {
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster"))
}

Please follow the patterns used elsewhere in the library for this:

if (!.is_Booster(x = model)) {
stop("lgb.restore_handle: model should be an ", sQuote("lgb.Booster"))
}

# check DiagrammeR is available
if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
stop("DiagrammeR package is required for lgb.plot.tree",
call. = FALSE
)
}
# tree must be numeric
if (!inherits(tree, "numeric")) {
stop("tree: Has to be an integer numeric")
}
# tree must be integer
if (tree %% 1 != 0) {
stop("tree: Has to be an integer numeric")
}
# extract data.table model structure
dt <- lgb.model.dt.tree(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dt <- lgb.model.dt.tree(model)
modelDT <- lgb.model.dt.tree(model)

Please don't use the name dt. That is a function in the {stats} package (for finding the density of a t-distribution)... try ?dt to see that.

Shadowing names from the standard library can lead to confusing errors. Please use modelDT as the name for this data.table instead.

# check that tree is less than or equal to the maximum tree index in the model
if (tree > max(dt$tree_index)) {
stop("tree: has to be less than the number of trees in the model")
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify this error message so that it has enough information for someone to quickly debug the issue, like the provided value of tree and the number of trees in the model. And please combine it with the other check that the value is `>=01.

Something like this:

lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (125). Got: 181.

# filter dt to just the rows for the selected tree
dt <- dt[tree_index == tree, ]
# change the column names to shorter more diagram friendly versions
data.table::setnames(dt, old = c("tree_index", "split_feature", "threshold", "split_gain"), new = c("Tree", "Feature", "Split", "Gain"))
dt[, Value := 0.0]
dt[, Value := leaf_value]
Comment on lines +76 to +77
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dt[, Value := 0.0]
dt[, Value := leaf_value]
dt[, Value := leaf_value]

I don't understand this... what's the purpose of setting all rows to 0.0 and then immediately overwriting them? It seems to me that the 0.0 could probably be removed.

dt[is.na(Value), Value := internal_value]
dt[is.na(Gain), Gain := leaf_value]
dt[is.na(Feature), Feature := "Leaf"]
dt[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count]
dt[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL]
dt[, Node := split_index]
max_node <- max(dt[["Node"]], na.rm = TRUE)
dt[is.na(Node), Node := max_node + leaf_index + 1]
dt[, ID := paste(Tree, Node, sep = "-")]
dt[, c("depth", "leaf_index") := NULL]
dt[, parent := node_parent][is.na(parent), parent := leaf_parent]
dt[, c("node_parent", "leaf_parent", "split_index") := NULL]
dt[, Yes := dt$ID[match(dt$Node, dt$parent)]]
dt <- dt[nrow(dt):1, ]
dt[, No := dt$ID[match(dt$Node, dt$parent)]]
# which way do the NA's go (this path will get a thicker arrow)
# for categorical features, NA gets put into the zero group
dt[default_left == TRUE, Missing := Yes]
dt[default_left == FALSE, Missing := No]
Comment on lines +78 to +96
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add some comments to make it a bit easier to understand what's happening in this wall of code? It's very difficult to read (at least for me) as currently written).

zero_present <- function(x) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's please avoid re-defining internal helper functions every time lgb.plot.tree() is called. This is a little bit expensive, and makes the code harder to read and develop.

Please move this up near the top of the file, and give it a name beginning with a . to clarify that it's internaly-only, like .zero_present.

sapply(strsplit(as.character(x), "||", fixed = TRUE), function(el) {
any(el == "0")
})
}
dt[zero_present(Split), Missing := Yes]
# dt[, c('parent', 'default_left') := NULL]
# data.table::setcolorder(dt, c('Tree','Node','ID','Feature','decision_type','Split','Yes','No','Missing','Gain','Cover','Value'))
# create the label text
dt[, label := paste0(
Feature,
"\nCover: ", Cover,
ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf", "", round(Gain, 4)),
"\nValue: ", round(Value, 4)
)]
# style the nodes - same format as xgboost
dt[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
dt[, shape := "rectangle"][Feature == "Leaf", shape := "oval"]
dt[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"]
# in order to draw the first tree on top:
dt <- dt[order(-Tree)]
nodes <- DiagrammeR::create_node_df(
n = nrow(dt),
ID = dt$ID,
label = dt$label,
fillcolor = dt$filledcolor,
shape = dt$shape,
data = dt$Feature,
fontcolor = "black"
)
# round the edge labels to 4 s.f. if they are numeric
# as otherwise get too many decimal places and the diagram looks bad
# would rather not use suppressWarnings
numeric_idx <- suppressWarnings(!is.na(as.numeric(dt[["Split"]])))
dt[numeric_idx, Split := round(as.numeric(Split), 4)]
# replace indices with feature levels if rules supplied
levels.to.names <- function(x, feature_name, rules) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to my previous comment, please move this up out of the definition of lgb.plot.tree() and give it a name beginning with a ., and without any other inner ., like .levels_to_names.

Avoiding the inner dots is useful to reduce the risk of that function accidentally being interpreted as an S3 method in the future.

lvls <- sort(rules[[feature_name]])
result <- strsplit(x, "||", fixed = TRUE)
result <- lapply(result, as.numeric)
levels_to_names <- function(x) {
names(lvls)[as.numeric(x)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
names(lvls)[as.numeric(x)]
names(lvls)[as.numeric(x)]
return(invisible(NULL))

In this project, we prefer having an explicit return() statement in every function... to make the intention clearer and to avoid accidentally returning data unintentionally. See #3352 for some background.

Please add an explicit return statement to every function you're defining here.

}
result <- lapply(result, levels_to_names)
result <- lapply(result, paste, collapse = "\n")
result <- as.character(result)
}
if (!is.null(rules)) {
for (f in names(rules)) {
dt[Feature == f & decision_type == "==", Split := levels.to.names(Split, f, rules)]
}
}
# replace long split names with a message
dt[nchar(Split) > 500, Split := "Split too long to render"]
# create the edge labels
edges <- DiagrammeR::create_edge_df(
from = match(dt[Feature != "Leaf", c(ID)] %>% rep(2), dt$ID),
to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
label = dt[Feature != "Leaf", paste(decision_type, Split)] %>%
c(rep("", nrow(dt[Feature != "Leaf"]))),
style = dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
c(dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
rel = "leading_to"
)
# create the graph
graph <- DiagrammeR::create_graph(
nodes_df = nodes,
edges_df = edges,
attr_theme = NULL
) %>%
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this project, by convention we:

  • do not use the %>% operator
  • use comma-first style everywhere

Please update this code and all the other code you're adding to follow that. Keeping all of the code looking the same across the codebase helps us to develop and review changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xgboost's implementation of similar functionality might be useful as a reference. See https://github.com/dmlc/xgboost/blob/e988b7cf1515b08ad0f949c26beb043ce0b33fe8/R-package/R/xgb.plot.tree.R#L159-L181

DiagrammeR::add_global_graph_attrs(
attr_type = "graph",
attr = c("layout", "rankdir"),
value = c("dot", "LR")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "node",
attr = c("color", "style", "fontname"),
value = c("DimGray", "filled", "Helvetica")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "edge",
attr = c("color", "arrowsize", "arrowhead", "fontname"),
value = c("DimGray", "1.5", "vee", "Helvetica")
)
# render the graph
DiagrammeR::render_graph(graph)
}
55 changes: 55 additions & 0 deletions R-package/man/lgb.plot.tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions R-package/pkgdown/_pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ reference:
- '`lgb.interprete`'
- '`lgb.plot.importance`'
- '`lgb.plot.interpretation`'
- '`lgb.plot.tree`'
- '`print.lgb.Booster`'
- '`summary.lgb.Booster`'
- title: Multithreading Control
Expand Down
59 changes: 59 additions & 0 deletions R-package/tests/testthat/test_lgb.plot.tree.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
test_that("lgb.plot.tree works as expected"){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add tests for the other types of machine learning tasks LightGBM can be used for:

  • binary classification
  • multiclass classification (where, please note, there are num_classes trees produced per iteration)
  • learning-to-rank

And for the following model situations:

  • uses categorical features

These are all cases that could affect the code as written... for example, categorical features have different splitting rules.

data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
# define model parameters and build a single tree
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
valids <- list(test = dtest)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 1L
, valids = valids
, early_stopping_rounds = 1L
)
Comment on lines +9 to +22
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
valids <- list(test = dtest)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 1L
, valids = valids
, early_stopping_rounds = 1L
)
model <- lgb.train(
params = list(
objective = "regression"
, num_threads = .LGB_MAX_THREADS
)
, data = dtrain
, nrounds = 1L
, verbose = .LGB_VERBOSITY
)

This is part of some suggested changes, please apply other changes following from it and to other examples and tests.

  1. every call to lgb.train() should set num_threads = .LGB_MAX_THREADS in params, to avoid using too many CPUs on the CRAN check machines (see [R-package] limit number of threads used in tests and examples (fixes #5987) #5988 for background)
  2. every call to lightgbm functions should set verbosity to .LGB_VERBOSITY to allow globally controlling the amount of log messages produced across all tests (see https://github.com/microsoft/LightGBM/blob/master/R-package/README.md#running-the-tests)
  3. since params is small and only being used once in this test code, just define it inline
  4. only specify things in params which are necessary for the test to be effective (e.g., no need to set learning_rate to a non-default value)

# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_true({
lgb.plot.tree(model, 0)TRUE
})
}

test_that("lgb.plot.tree fails when a non existing tree is selected"){
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
Comment on lines +36 to +37
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please remove all these uses of a validation set? This feature is about plotting the trained model, and you are not using early stopping, so all of this work to create validation sets is unnecessary.

Keeping the tests and examples as small and simple as possible makes the code easier to read / develop, and makes it clearer how test cases differ from each other.

# define model parameters and build a single tree
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
Comment on lines +39 to +44
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
, learning_rate = 1.0
)
params <- list(
objective = "regression"
)

Similar to my comments on the docs... I strongly suspect we could just use default parameters here.

valids <- list(test = dtest)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 1L
, valids = valids
, early_stopping_rounds = 1L
)
# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_error({
lgb.plot.tree(model, 999)TRUE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lgb.plot.tree(model, 999)TRUE
lgb.plot.tree(model, 999)

This looks like it was included accidentally?

})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For every use of expect_error() here, please check for the specific error you are expecting, like this:

https://github.com/microsoft/LightGBM/blob/83c0ff3de1925b0e2d4831a9ccb6ffc196aa795b/R-package/tests/testthat/test_lgb.importance.R#L33-35

That way, the test will be able to catch the case where some other unexpected issue causes this code path to fail.

}
Loading