-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R-package] add a tree plotting function #6729
base: master
Are you sure you want to change the base?
Conversation
Added DiagrammeR as suggested in DESCRIPTION Added lgb.plot.tree in _pkgdown.yml Roxygenized.
@microsoft-github-policy-service agree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your interest in LightGBM.
As I mentioned in the discussion on #1222, I'm supportive of trying to add something like this (especially since xgboost has it as well).
But I hope you'll see from the first round of suggestions I left here... significant work remains before I'd support merging this change into the package. If you are willing to work with us on this and go through multiple rounds of reviews and suggestions, we'd be grateful for the help! But if you don't have the time/interest to get this ready for inclusion in the package, please let me know and we'll close this PR and leave #1222 open for someone else to pick up.
@@ -49,7 +49,8 @@ Suggests: | |||
markdown, | |||
processx, | |||
RhpcBLASctl, | |||
testthat | |||
testthat, | |||
DiagrammeR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please keep this list in alphabetical order (move DiagrammeR
to the top of this list).
You will also need to add DiagrammeR
to every place in continuous integration scripts that installs optional dependencies for the project. You can find those like this:
git grep processx
@@ -0,0 +1,184 @@ | |||
#' @name lgb.plot.tree | |||
#' @title Plot a single LightGBM tree using DiagrammeR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#' @title Plot a single LightGBM tree using DiagrammeR. | |
#' @title Plot a single LightGBM tree. |
Let's simplify this, please.
|
||
# function to plot a single LightGBM tree using DiagrammeR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# function to plot a single LightGBM tree using DiagrammeR |
We do not need to repeat in a comment here the same information that's already in the roxygen comments.
if (!inherits(model, "lgb.Booster")) { | ||
stop("model: Has to be an object of class lgb.Booster") | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!inherits(model, "lgb.Booster")) { | |
stop("model: Has to be an object of class lgb.Booster") | |
} | |
if (!.is_Booster(x = model)) { | |
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster")) | |
} |
Please follow the patterns used elsewhere in the library for this:
LightGBM/R-package/R/lgb.restore_handle.R
Lines 42 to 44 in 83c0ff3
if (!.is_Booster(x = model)) { | |
stop("lgb.restore_handle: model should be an ", sQuote("lgb.Booster")) | |
} |
stop("tree: Has to be an integer numeric") | ||
} | ||
# extract data.table model structure | ||
dt <- lgb.model.dt.tree(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dt <- lgb.model.dt.tree(model) | |
modelDT <- lgb.model.dt.tree(model) |
Please don't use the name dt
. That is a function in the {stats}
package (for finding the density of a t-distribution)... try ?dt
to see that.
Shadowing names from the standard library can lead to confusing errors. Please use modelDT
as the name for this data.table
instead.
nodes_df = nodes, | ||
edges_df = edges, | ||
attr_theme = NULL | ||
) %>% |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this project, by convention we:
- do not use the
%>%
operator - use comma-first style everywhere
Please update this code and all the other code you're adding to follow that. Keeping all of the code looking the same across the codebase helps us to develop and review changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xgboost
's implementation of similar functionality might be useful as a reference. See https://github.com/dmlc/xgboost/blob/e988b7cf1515b08ad0f949c26beb043ce0b33fe8/R-package/R/xgb.plot.tree.R#L159-L181
@@ -0,0 +1,59 @@ | |||
test_that("lgb.plot.tree works as expected"){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add tests for the other types of machine learning tasks LightGBM can be used for:
- binary classification
- multiclass classification (where, please note, there are
num_classes
trees produced per iteration) - learning-to-rank
And for the following model situations:
- uses categorical features
These are all cases that could affect the code as written... for example, categorical features have different splitting rules.
dt[, Value := 0.0] | ||
dt[, Value := leaf_value] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dt[, Value := 0.0] | |
dt[, Value := leaf_value] | |
dt[, Value := leaf_value] |
I don't understand this... what's the purpose of setting all rows to 0.0
and then immediately overwriting them? It seems to me that the 0.0
could probably be removed.
dt[is.na(Value), Value := internal_value] | ||
dt[is.na(Gain), Gain := leaf_value] | ||
dt[is.na(Feature), Feature := "Leaf"] | ||
dt[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count] | ||
dt[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL] | ||
dt[, Node := split_index] | ||
max_node <- max(dt[["Node"]], na.rm = TRUE) | ||
dt[is.na(Node), Node := max_node + leaf_index + 1] | ||
dt[, ID := paste(Tree, Node, sep = "-")] | ||
dt[, c("depth", "leaf_index") := NULL] | ||
dt[, parent := node_parent][is.na(parent), parent := leaf_parent] | ||
dt[, c("node_parent", "leaf_parent", "split_index") := NULL] | ||
dt[, Yes := dt$ID[match(dt$Node, dt$parent)]] | ||
dt <- dt[nrow(dt):1, ] | ||
dt[, No := dt$ID[match(dt$Node, dt$parent)]] | ||
# which way do the NA's go (this path will get a thicker arrow) | ||
# for categorical features, NA gets put into the zero group | ||
dt[default_left == TRUE, Missing := Yes] | ||
dt[default_left == FALSE, Missing := No] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add some comments to make it a bit easier to understand what's happening in this wall of code? It's very difficult to read (at least for me) as currently written).
# trees start from 0 in lgb.model.dt.tree | ||
tree_table <- lgb.model.dt.tree(model) | ||
expect_error({ | ||
lgb.plot.tree(model, 999)TRUE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgb.plot.tree(model, 999)TRUE | |
lgb.plot.tree(model, 999) |
This looks like it was included accidentally?
Feature requested in #1222
Added a R function to plot trees.
Basically used the code posted in #1222 by @SpeckledJim2 and followed the given instruction.