-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to plot a decision tree (using a graphics package) #147
Comments
I think I've solved most of the questions above: Using
using AbstractTrees
function AbstractTrees.children(node::DecisionTree.Node)
return(node.left, node.right)
end
function AbstractTrees.printnode(io::IO, node::DecisionTree.Node)
print(io, "ID: ", node.featid, " - ", node.featval)
end
function AbstractTrees.printnode(io::IO, leaf::DecisionTree.Leaf)
print(io, "maj: ", leaf.majority, " - vals: ", length(leaf.values))
end
using GraphicRecipes
plot(TreePlot(dtree), method = :tree, nodeshape = :ellipse) Unfortunately |
Hi @roland-KA, I have created a function to plot decision tree using CairoMakie (https://github.com/Rahulub3r/Julia-ML-Utils/blob/main/decisionTreeUtils.jl). Although this is a minimum working example, I think it can be changed to be production material for integrating into the package. I am interested in contributing. Let me know if this looks good. |
Hi @Rahulub3r , thank you for your efforts! |
Sure. An MWE is as follows. Suggestions are appreciated.
|
Thank's, that looks really good! 👍 |
Very cool! I wonder what the best way to integrate this contribution might be. Be great for MLJ users to be able to do this (without adding Makie as dependency to existing MLJ packages). @Rahulub3r Any interest in working helping out with visualisation in MLJ more generally? Probably missing something but why is it necessary to specify feature names? Are they not part of |
Yes, I am interested. However, I am not sure how we would implement them without Makie. |
Yes, of course you right 👍🏾 . Perhaps we could get the MLJ interface to expose the feature names (as we do in, eg, the MLJGLMInterface). |
As far as I can see, we have now three options to visualize a decision tree:
No. 2 is without doubt visually the most beautiful solution. But I think we should also consider some software engineering aspects:
I've been digging a bit deeper into options 1 & 3 recently, so some more comments on that: Option 1: BTW: I've started a three-part tutorial on 'Towards Data Science' on how Julia and its ecosystem can be used for ML (for an audience that doesn't know Julia). The first part has been published last week. The third part uses exactly no. 1 as an example to show how easy it is in Julia to 'glue' several packages together and create a new solution with only little coding. So there will be an extensive description on how to do it. On that way, I identified an issue with the Option 3: I can provide a code example on how to apply this option next week (I'm a bit busy this week). |
@roland-KA Very nice summary of the issues and best options. I agree that while @Rahulub3r code is a lovely bit of work, a more generic solution is preferred (1 or 3). It probably makes sense to have a separate "MLJPlotRecipes" package for plotting things that are very specific to MLJ (we have, for example, recipe in MLJTuning to plot outcomes of hyper-parameter optimization). But this wouldn't fall into this category. I've opened JuliaAI/MLJDecisionTreeInterface.jl#13. This should be pretty easy (PR welcome 😄 ). Makie.jl or Plots.jl? That's a very difficult question. I'm waiting for someone to convince me one way or the other 😉 . Of course one could do both, starting with Plots.jl which seems more ready to roll. See also the MLJ visualization project in this GSoC list. |
I think a separate "MLJPlotRecipes" package is a good idea, as there are surely more models that could be visualized (and it would be another advantage of MLJ over similar ML-packages in general). As a first step, an implementation for Plots.jl would be preferable in my opinion too, because:
... and if time permits, I will have a look at JuliaAI/MLJDecisionTreeInterface.jl#13 😊 |
And as promised, here the code to plot a decision tree using So I'm doing basically a depth-first traversal of the decision tree but I collect the nodes (as well as the information for the labels) in a breadth-first order. Sounds a bit strange (and is a bit strange 😀), but can be done with relatively little effort. The trick is the use of the The function using Graph, GraphMakie, CairoMakie
depth = DecisionTree.depth(tree)
g = SimpleDiGraph(2^(depth+1)-1) # graph (for building the tree) of size given as an argument
label = Dict{Int, String}() # for collecting label information
counter = Dict([i => 2^i for i in 0:depth])
function traverse_tree(tree::DecisionTree.Node, level::Int, cnt::Dict{Int, Int}, g::SimpleDiGraph, label::Dict{Int, String})
label[cnt[level]] = "ID = " * string(tree.featid) * " - " * string(tree.featval)
add_node!(g, tree.left, level, cnt, label) # left child
add_node!(g, tree.right, level, cnt, label) # right child
end
traverse_tree(tree::DecisionTree.Leaf, level::Int, cnt::Dict{Int, Int}, g::SimpleDiGraph, label::Dict{Int, String}) =
label[cnt[level]] = "maj = " * string(tree.majority)
function add_node!(g::SimpleDiGraph, node::Union{DecisionTree.Node, DecisionTree.Leaf}, level::Int, cnt::Dict{Int, Int}, label::Dict{Int, String})
add_edge!(g, cnt[level], cnt[level+1])
traverse_tree(node, level + 1, cnt, g, label)
cnt[level+1] += 1
end The function traverse_tree(tree, 0, counter, g, label)
tree_labels = collect(values(sort(label)))
f, ax, p = graphplot(g, layout = layout, nlabels = tree_labels, nlabels_attr=(;justification = :center, align = (:center, :top), color = :blue)) The resulting tree plot is the one depicted in MakieOrg/GraphMakie.jl#57. |
Before noticing this issue, I went ahead and adapted the Plots recipe used in EvoTrees for this package. It's rather hacky, because the node shapes don't play well with the text, especially for deep trees, but between custom markers and bypassing the labels with annotations, it mostly works out. I ended up doing basically the same node traversal to get a digraph. Let me know if you think its worth making a pull request with some modifications (either here or in some eventual MLJPlotRecipes).
as for an example resulting graph, using the example above with some name changes:
|
I agree this looks nice. However, as it is specific to DecisionTree.jl trees, I suggest a PR either to DecisionTree.jl and/or MLJDecisionTreeInterface.jl. The MLJDecisionTreeInterface version could include original feature names. Minor suggestion: replace the test However, as commented above, a more maintainable solution would be to implement the AbstractTrees.jl API for the DecisionTrees.jl objects and try to improve the generic tree-plotting capabilities of Plots.jl, if this is lacking. It's just a pity that this approach will not be able to include feature names, without some significant changes to DecisionTree.jl, as node objects in DecisionTree don't know the feature names. Perhaps in the MLJ wrapper we could include the encoding in a plot legend? |
I've just noticed that DecisionTree.jl is no longer maintained. So a PR to this package is no longer an option. Anyway @ablaom, as you suggest, a more maintainable solution would be less dependent on DecisionTree.jl. As I'm not so familiar with plot recipes (I've just used them, never implemented one), I'm trying to understand, what that means:
As I that example above shows, making an An obstacle for a simple implementation is indeed, that @dsweber2 what do you think about such an adaption of your plot recipe? Is this a way to go? |
I think the problem here is that "AbstractTree" is not a type, only an interface. As far as I understand, plot recipes cannot be created in the usual way with trait-dispatch, only type-dispatch. But maybe there is a workaround if you don't use the macro.
I just mean that your nodes are labelled like " Feature 1 < 0.4", say, but you add a legend to the figure that looks like
This might even be an advantage, as long feature names are not messing up the plot. @bensadeghi Would you be happy to entertain a PR that implements the AbstractTrees.jl interface for |
@ablaom , That sounds fine. |
@roland-KA Is this something you might take on, at your leisure? |
Well, at least I could give it a try. But I need some help, especially when it comes to implement a plot recipe. @dsweber2 would you give your support on that topic? So just to make clear what the objectives are, let me summarize. We want
Is that correct? @ablaom and you would provide a an extension to the MLJ wrapper that delivers the feature names (perhaps in that extended 'legend' form you described above) as well as class names? So that in the end we could call plot(aDecisionTree_that_acts_like_an_AbstractTree,
MLJ.feature_names,
MLJ.class_names) |
Spending a bit more time today on the issue, I think, I missed the point a bit with my last comment. @ablaom your idea is probably, that the PR to A plot recipe for plotting decision trees could then be added to MLJ itself, relying just on the the |
Yes, that's exactly my suggestion. A PR here to implement the AbstractTrees.jl API and that is all (@bensadeghi's other requirements notwithstanding). Very happy to provided feedback on such a PR. |
Ok fine, that should be no problem 😊👍. Then we have to decide, how we get the information about feature names and class labels into the game (especially when the nodes of the trees don't have that information, as is in case of The idea in Currently I see the following alternatives on how this could be done:
Approach 1 has to be implemented whithin the implementation of the I.e. approach 1 lays more burden on the implementation of the Approach 2 would allow parameters on the plot recipe to fine tune the appearance of the tree. E.g. one could choose between showing just feature ids in the nodes, full text nodes or ids in the node and a legend with full text next to the tree plot (well, this would also be doable with approach 1, but it would be sort of a waste of effort). But perhaps I'm missing a point at the moment and things are less complicated ... 🤔 |
@roland-KA I'd suggest we that for now we avoid writing plot recipes that are specific to ML decision trees and instead work on building tree representations that implement AbstractTrees.jl that can then be thrown at generic plotting algorithms - whether Makie or Plots.jl. So my vote is for 1. I think this is simpler and easier to maintain. I very much like your idea to build an enhanced AbstractTree object for DecisionTree.jl and had not realised how far you had already gone with this in the Towards Data Science post - that's a complete POC really. So there is a way to get the feature names into the game after all. This is only slightly more complicated than than implementing AbstractTree API for raw nodes and leaves, and I expect @bensadeghi would not object to it's addition here, as there is still no refactoring of existing code. |
The objective for using an AbstractTree is to hide all implementation details of specific decision trees from the plot recipe. I've investigated possible alternatives to the two approaches described above, but I didn't find any other ways to do it. So I completely agree with you and would implement the concept from my TDS article. BTW: In the meantime I've delved a bit into the documentation of plot recipes and I think I can do that too. 🤓 |
I'm currently testing my implementation of the AbstractTree-interface for DecisionTree.jl. Here I came across the fact, that Is it correct that it stores there
|
Off the top of my head I don't know the answer to point one. Is there even a distinction in native DecisionTree.jl between edited Features are left untouched and DecisionTree.jl is happy with any eltype which supports Does this help? |
I'd probably not export them. They are not going to be used by the average user of DecisionTree.jl, and |
Here is now a first version of a plot recipe for plotting a Plot Recipe"""
A plot recipe (based on `RecipeBase.jl`) to visualize a decision tree (from `DecisionTree.jl`) wrapped in an `AbstractTrees`-interface.
# Overview
The Buchheim-Algorithm (from `NetworkLayout.jl`) is used to generate a layout for the tree.
Therefore the decision tree has to be converted first to a `SimpleDiGraph` (from `Graphs.jl`).
For that purpose we need the number `n` of nodes within the tree and the nodes have to be numbered
from 1 to `n` in a breadth first order (as the edges of a `SimpleDiGraph` are specified by pairs of
such numbers).
The following assumptions are made for the decision tree:
- It is a binary tree
- A node has either no children (`Leaf`) or exactly two children (`Node`)
This implies that the tree may be imbalanced but each single node is always balanced.
The whole process consists of the following steps:
1. Flatten
The tree is converted in a breadth first order into a list (a `Vector`) of `NodeInfo`s.
These elements contain the relevant information for visualizing the tree. The indices
within this array correspond to the above required numbers from 1 to `n` (1 is the root node).
2. Generate layout
The (number) information from the flat structure is taken to create a corresponding `SimpleDiGraph`.
Using the Buchheim-algorithm on this graph, a layout (consisting of the coordinates of the nodes)
is generated.
3. Make a visual description (using a Plot Recipe)
The plot recipe `dt_visualization` creates a visual description of the tree using the information of
the two preceding steps.
"""
import AbstractTrees
using NetworkLayout
using Graphs
using GraphRecipes
# plotting information for each node/leaf of the tree
struct PlotInfo
parent_id :: Int16 # number of the parent node within the tree
is_leaf :: Bool # is the node a `Leaf`?
print_label :: String # text to be printed as a label in the visualization of the node
end
"""
add_level!(plot_infos::Vector{PlotInfo}, nodes, i_crnt)
Traverse the tree recursively beginning from the root, level by level, collecting on each level
all nodes from left to right and adding them to the `plot_infos`-list in form of `PlotInfo`s.
On each call a `Vector` of nodes from the last level processed is given in `nodes`. So on the first
call a one-element array containing the root node is passed. `i_crnt ` is the index of the `PlotInfo`
in `plot_infos` corresponding to the first node in `nodes`.
On the first call, `add_level!` expects the first entry in `plot_infos` for the root node already
to be present (so this has to be done manually).
"""
function add_level!(plot_infos::Vector{PlotInfo}, nodes, i_crnt)
i_next = i_crnt + length(nodes)
child_nodes = []
for n in nodes
cn = AbstractTrees.children(n)
if length(cn) > 0
plot_infos[i_next] = PlotInfo(i_crnt, is_leaf(cn[1]), label(cn[1]))
plot_infos[i_next+1] = PlotInfo(i_crnt, is_leaf(cn[2]), label(cn[2]))
i_next += 2
push!(child_nodes, cn[1])
push!(child_nodes, cn[2])
end
i_crnt += 1
end
if length(child_nodes) > 0
add_level!(plot_infos, child_nodes, i_crnt)
end
end
# extract and format label information from nodes/leaves
function label(i::Union{InfoNode, InfoLeaf})
io = IOBuffer()
AbstractTrees.printnode(io, i)
return(String(take!(io)))
end
# which type of node is it?
is_leaf(i::InfoNode) = false
is_leaf(i::InfoLeaf) = true
depth(tree::InfoNode) = DecisionTree.depth(tree.node) # --> move to `DecisionTree/abstract_trees.jl`
"""
flatten(tree::InfoNode)
Create a list of all nodes/leaves (converted to `PlotInfo`s) within the `tree` in a breadth first order.
Note: If `tree` is imbalanced, the resulting list contains empty elements at its end.
"""
function flatten(tree::InfoNode)
plot_infos = Vector{PlotInfo}(undef, 2^(depth(tree)+1)) # tree has max. 2^(depth+1) nodes
plot_infos[1] = PlotInfo(-1, false, label(tree)) # root node is first entry; -1 is a dummy
add_level!(plot_infos, [tree], 1) # add recursevly nodes of all further tree levels to the list
return(plot_infos)
end
"""
layout()
Create a tree layout in form of a list of points (`GeometryBasics.Point2`) based on the list of
`PlotInfo`s created by `flatten`. The order of the points in the list corresponds to the information in
the `plot_infos`-list passed.
"""
function layout(plot_infos::Vector{PlotInfo})
size = count([isassigned(plot_infos, i) for i in eachindex(plot_infos)]) # count number of actual entries
g = SimpleDiGraph(size)
for i in 2:size
add_edge!(g, plot_infos[i].parent_id, i)
end
return(buchheim(g))
end
"""
Rectangular shape with centerpoint `center` and dimensions `width` and `height`
"""
function make_rect(center, width, height)
left, right = center[1] - width/2.0, center[1] + width/2.0
upper, lower = center[2] + height/2.0, center[2] - height/2.0
Shape([(left, upper), (right, upper), (right, lower), (left, lower), (left, upper)])
end
"""
Linear curve starting at the `parent`s bottom center leading to the `child`s top center
"""
function make_line(parent, child, height)
parent_bottom = [parent[1], parent[2] - height/2]
child_top = [child[1], child[2] + height/2]
return([parent_bottom[1], child_top[1]], [parent_bottom[2], child_top[2]])
end
"""
dt_visualization(tree::InfoNode)
Graph recipe to draw a `DecsionTree` (wrapped in an `AbstractTree`)
"""
@recipe function dt_visualization(tree::InfoNode, width = 0.7, height = 0.7)
# prepare data
plot_infos = flatten(tree)
coords = layout(plot_infos)
# we paint on a blank paper
framestyle --> :none
legend --> false
# connecting lines are a series of curves
i = 2
while i <= length(plot_infos) && isassigned(plot_infos, i)
@series begin
seriestype := curves
linecolor --> :silver
line = make_line(coords[plot_infos[i].parent_id], coords[i], height)
i += 1
return line
end
end
# nodes are a series of rectangle shapes
anns = plotattributes[:annotations] = [] # for the labels within the rectangles
for i in eachindex(coords)
@series begin
seriestype := :shape
fillcolor --> :deepskyblue3
alpha --> (plot_infos[i].is_leaf ? 0.4 : 0.15)
annotationcolor --> :brown4
annotationfontsize --> 7
c = coords[i]
push!(anns, (c[1], c[2], (plot_infos[i].print_label,)))
return make_rect(c, width, height)
end
end
end Using the following "handmade" decision tree ... Demo codeusing DecisionTree
using Plots
function dtree()
l1 = Leaf(1, [1,1,2])
l2 = Leaf(2, [1,2,2])
l3 = Leaf(3, [3,3,1])
l4 = Leaf(1, [1,1,1])
l5 = Leaf(2, [2,2,2])
n4 = Node(4, 0.8, l4, l5)
n3 = Node(3, 0.3, n4, l3)
n2 = Node(2, 0.5, l1, l2)
n1 = Node(1, 0.7, n2, n3)
return(n1)
end
dt = dtree()
feature_names = ["feat1", "feat2", "feat3", "feat4"]
class_labels = ["a", "b", "c"]
infotree = DecisionTree.wrap(dt, (featurenames = feature_names, classlabels = class_labels))
plot(infotree) ... we get this nice picture: 😊 @ablaom what do you think about this approach? And is there a place where I could put this code (we've discussed the idea of a repository for plot recipes for MLJ models)? |
@roland-KA Thanks for this. This definitely looks like progress. Let me check if I understand what you are proposing to provide here. Correct me if I misunderstand:
Thoughts:
@sylvaticus Do your BetaML trees happen to implement AbstractTrees.jl, including a |
Hi @ablaom, thanks for your feedback. That's really helpful! So let's get through it:
Yes.
I don't quite understand the second half of your comment, but perhaps the following explanation clarifies the point: The whole code above in 'Plot Recipe' is a plot recipe where I've put some of the code into separate functions (like
What does that mean: I really need the As I think I could get rid of the technical dependency just by introducing an abstract type That should do the trick and make the code completely generic. So any decision tree that implements the 'decision tree with depth'-interface should then be able to use this plot recipe (after having been transformed by its Does this sound plausible or do you see any further obstacles? PS: I should add, that the |
And I've just discovered a copy&paste error: In the 'Plot Recipe' code above the line |
Just a quick reply for now. If we can, I think we should avoid introducing any new abstract type, apart from |
I do have "print node" / "print tree" methods, but they are textual and not linked to the AbstractTrees.jl interface. |
@ablaom, unfortunately (contrary to general expectation) there is no type
I don't think so (or perhaps I don't understand how that would work in this context). It's not about behaviour. As you already mentioned
That's the point. When calling The code of the recipe is already generic. We already have, what you desire in point 4 above. The only thing which isn't generic enough is the 'recipe selector'. So, if you have no other idea, I would proceed and adapt the implementation to use such a more generic type. BTW: Why do you object to introducing a new abstract type? I don't quite understand the rationale behind. Apart from that, I've noticed that AbstractTrees.jl got a few updates recently. It has now a depth-function (it's called |
Ah, sorry for my hasty response. I keep having to reload this issue into my brain with all the other stuff going on in between. I understand now why you want the abstract type - this is forced upon us by the way recipes work. It looks to me that you are close to an optimal solution and I look forward to your implementation. As suggested earlier, I suggest putting the generic code in a standalone tree-plotting package. JuliaAI could host this if you want. |
Hi @ablaom, no problem! Your questions help me to check, if things are really ok. And I understand that you strive to minimize dependencies. But in this case, I think there is no alternative. New abstract typesI have therefore implemented now a version using a common supertype for the abstract type AbstractInfoNode{S, T} end
abstract type AbstractInfoLeaf{T} end So the definitions in struct InfoNode{S, T} <: AbstractInfoNode{S, T}
node :: DecisionTree.Node{S, T}
info :: NamedTuple
end
struct InfoLeaf{T} <: AbstractInfoLeaf{T}
leaf :: DecisionTree.Leaf{T}
info :: NamedTuple
end Furthermore I've created a package called Just for convenience, I've placed the new abstract types also in this package. But I think from a software engineering perspective this is not the right place, because a model like Do you have an idea, where the definition of these two abstract types could be placed so that all models which want to implement the Updated version of plot recipeApart from that I have an updated version of the plot recipe where I have
Plot Recipe (updated version)"""
A plot recipe (based on `RecipeBase.jl`) to create a graphical representation of a decision tree.
The decision tree must be wrapped in an `AbstractTrees`-interface (see `DecisionTree.jl` as
an example implementation of the concept). This approach ensures that the recipe is independent of the
implementation details of the decision tree.
"""
module DecisionTreesRecipe
include("AbstractInfoTree.jl") # definition of new abstract types --> not a good idea, to put them here
export AbstractInfoNode, AbstractInfoLeaf
import AbstractTrees
using NetworkLayout
using Graphs
using RecipesBase
"""
# Overview
The recipe uses the Buchheim-Algorithm (from `NetworkLayout.jl`) to generate a layout for the decision tree.
As this algorithm requires a `SimpleDiGraph` (from `Graphs.jl`) as its input, the tree has to be converted
first into that structure. For that purpose the nodes of the tree have to be numbered from 1 to `n`
(`n` being the number of nodes in the tree) in a breadth first order (as the edges of a `SimpleDiGraph`
are specified by pairs of such numbers).
So the recipe applies the following steps to convert a decision tree into a graphical representation:
1. Flatten
The tree is converted in a breadth first order into a list (a `Vector`) of `NodeInfo`s.
These elements contain the relevant information for visualizing the tree lateron. The indices
within this array correspond to the above required numbers from 1 to `n` (1 is the root node).
The functions `flatten` and `add_level!` implement this step.
2. Generate layout
The (number) information from the flat structure is taken to create a corresponding `SimpleDiGraph`.
Using the Buchheim-algorithm on this graph, a layout (consisting of the coordinates of the nodes)
is generated. This step is implemented by `layout`.
3. Make a visual description (using a plot recipe)
The plot recipe `dt_visualization` creates a visual description of the tree using the information of
the two preceding steps.
"""
### Step 1: Flatten
# plotting information for each node/leaf of the tree
struct PlotInfo
parent_id :: Int16 # number of the parent node within the tree
is_leaf :: Bool # is the node a leaf?
print_label :: String # text to be printed as a label in the visualization of the node
end
"""
add_level!(plot_infos::Vector{PlotInfo}, nodes, i_crnt)
Traverse the tree recursively beginning from the root, level by level, collecting on each level
all nodes from left to right and adding them to the `plot_infos`-list in form of `PlotInfo`s.
On each call, a `Vector` of nodes from the last level processed is given in `nodes`. So on the first
call a one-element array containing the root node is passed. `i_crnt` is the index of the
`PlotInfo`-element in `plot_infos` corresponding to the first node in `nodes`.
On the first call, `add_level!` expects the first entry in `plot_infos` for the root node already
to be present (so this has to be done manually).
"""
function add_level!(plot_infos::Vector{PlotInfo}, nodes, i_crnt)
i_next = i_crnt + length(nodes)
child_nodes = []
for n in nodes
cn = AbstractTrees.children(n)
for c in cn
plot_infos[i_next] = PlotInfo(i_crnt, is_leaf(c), label(c))
push!(child_nodes, c)
i_next += 1
end
i_crnt += 1
end
if length(child_nodes) > 0
add_level!(plot_infos, child_nodes, i_crnt)
end
end
# extract and format label information from nodes/leaves
function label(i::Union{<:AbstractInfoNode, <:AbstractInfoLeaf})
io = IOBuffer()
AbstractTrees.printnode(io, i)
return(String(take!(io)))
end
# which type of node is it?
is_leaf(i::AbstractInfoNode) = false
is_leaf(i::AbstractInfoLeaf) = true
"""
flatten(tree::InfoNode)
Create a list of all nodes/leaves (converted to `PlotInfo`s) within the `tree` in a breadth first order.
"""
function flatten(tree::AbstractInfoNode)
plot_infos = Vector{PlotInfo}(undef, AbstractTrees.treesize(tree)) # tree has `treesize` nodes
plot_infos[1] = PlotInfo(-1, false, label(tree)) # root node is first entry; -1 is a dummy
add_level!(plot_infos, [tree], 1) # add recursevly nodes of all further tree levels to the list
return(plot_infos)
end
### Step 2: Generate layout
"""
layout()
Create a tree layout in form of a list of points (`GeometryBasics.Point2`) based on the list of
`PlotInfo`s created by `flatten`. The order of the points in the list corresponds to the information in
the `plot_infos`-list passed.
"""
function layout(plot_infos::Vector{PlotInfo})
g = SimpleDiGraph(length(plot_infos))
for i in 2:length(plot_infos)
add_edge!(g, plot_infos[i].parent_id, i)
end
return(buchheim(g))
end
### Step 3: Make a visual description (using a plot recipe)
"""
Rectangular shape with centerpoint `center` and dimensions `width` and `height`
"""
function make_rect(center, width, height)
left, right = center[1] - width/2.0, center[1] + width/2.0
upper, lower = center[2] + height/2.0, center[2] - height/2.0
return([(left, upper), (right, upper), (right, lower), (left, lower), (left, upper)])
end
"""
Linear curve starting at the `parent`s bottom center leading to the `child`s top center.
`parent` and `child` are the center coordinates of the respective nodes.
"""
function make_line(parent, child, height)
parent_xbottom, parent_ybottom = parent[1], parent[2] - height/2
child_xtop, child_ytop = child[1], child[2] + height/2
return([parent_xbottom, child_xtop], [parent_ybottom, child_ytop])
end
"""
dt_visualization(tree::InfoNode)
Plot recipe to convert a decsion tree (wrapped in an `AbstractTree`) into a graphical representation
"""
@recipe function dt_visualization(tree::AbstractInfoNode, width = 0.7, height = 0.7)
# prepare data
plot_infos = flatten(tree)
coords = layout(plot_infos)
# we paint on a blank paper
framestyle --> :none
legend --> false
# connecting lines are a series of curves
for i in 2:length(plot_infos)
@series begin
seriestype := :curves
linecolor --> :silver
line = make_line(coords[plot_infos[i].parent_id], coords[i], height)
return line
end
end
# nodes are a series of rectangular shapes
anns = plotattributes[:annotations] = [] # for the labels within the rectangles
for i in eachindex(coords)
@series begin
seriestype := :shape
fillcolor --> :deepskyblue3
alpha --> (plot_infos[i].is_leaf ? 0.4 : 0.15)
annotationcolor --> :brown4
annotationfontsize --> 7
c = coords[i]
push!(anns, (c[1], c[2], (plot_infos[i].print_label,)))
return make_rect(c, width, height)
end
end
end
end # module
|
Hi @sylvaticus, in order to get a quick PoC of the visualisation concept discussed above (which is intended to work for different decision trees), I've forked BetaML with the aim to add a simple implementation of the Unfortunately I'm not familiar with the So far I have added/changed the following things in BetaML:
... getting nothing more than a lot of warnings about "could not import X into Y". To my understanding I've put the |
Hello, could you push your code so I can look at it (perhaps in another branch, or maybe this is not needed??) ? |
@roland-KA thanks for the progress!
So we do abstract type AbstractNode end # for types implementing AbstractArrays.jl interface (`children`, `depth`, `print_node`, ?)
struct InfoNode{S, T} <: AbstractNode
node :: DecisionTree.Node{S, T}
info :: NamedTuple
end
struct InfoLeaf{T} <: AbstractNode
leaf :: DecisionTree.Leaf{T}
info :: NamedTuple
end
And, unless I misunderstand (yet again) your plot recipe now works for any subtype of
|
Definitely yes! When I had a look at the BetaML-trees these days, it came to my mind that the current implementation is too
Good idea! That's simpler and more general.
Basically yes. We do this
Definitely yes, that would be ideal! Would you ask the folks maintaining
That was my first thought ... but I felt it would be too weird 😀. But you are right. If we don't get it into I have made all the changes you suggested and it works. 🤓 In a next step I would try to make a MWE using BetaML-trees in order to have a PoC that the current implementation is really as general as we intend it to be. |
Okay, great. I suggest you initiate the AbstractTrees.jl request. I expect there may be a discussion about exactly what Link in the current discussion and copy me in so I can add my 👍🏾 Thanks again for your perseverance. I think the extra effort you are putting into the software engineering details will be well worth it in the long term. |
@roland-KA I'm not seeing much action at your AbstractTrees.jl issue, despite your exhaustive efforts (many thanks!). How about we revert to the mini-package You've put a lot of work into this project and I'm keen make it available. |
Hi @ablaom, I've just added a comment to the issue asking for a (definitve) response and giving it a (last) chance, since it is in my opinion the better solution. If that doesn't work then we should really proceed with the mini-package solution. I'm currently on vacation until next weekend. So perhaps we have then an answer. |
Hi @ablaom, it seems now, that we will get the abstract type in In order to advance faster with the plot recipe, I would suggest the following steps:
|
In my experience these temporary solutions suck up time best invested elsewhere. The idea of depending on an unregistered package does not excite me either. Despite the very long wait so far, my inclination would be to wait. If you still feel otherwise, I'd suggest a separate package, rather than in some MLJ package, which sounds more messy. As it's only temporary, I don't think it matters who owns it. It's very easy to transfer ownership in any case - trivial in fact, if the package is not registered. |
Thanks, btw, for your advocacy at AbstractTrees.jl 🙏🏾 |
Well, thanks to you for your support! I don't think it would have worked without that.
The second question above (about placing the plot recipe) was not about a temporary package but about the final solution. I.e. I wanted to know, where I should put the plot recipe as it would be used in the future (because I can start now working on that). It seems that you would suggest also for that solution a separate new package in my repo which I would register then? In order to test that package under "real world conditions", I will need the abstract type module AbstractNodeType
export AbstractNode
abstract type AbstractNode
end It won't be necessary at all, if the next release of |
Yes, you're right, I completely misunderstood. 😳 Yes, a separate package for the recipe sounds good to me. You could own that or move it to JuliaAI as you wish. |
@roland-KA |
Wow, that's great news 🤗. Unfortunately I've been quite busy during the last few weeks, so I couldn't start to create the package for the plot recipe as intended. But November looks promising 😊. |
Hi @ablaom, sorry for the delay. But now I have a package for the plot recipe ready in my repo: https://github.com/roland-KA/TreeRecipe.jl. Apart from that, I've created a PR for Could you perhaps have a look on both and tell me, if this is the right way to go? |
Hi @sylvaticus, I've just created a PR for There is an example on how to plot a BetaML decision tree using the recipe in Note: Sorry for leaving your package so long in limbo. But now it should work! |
@roland-KA Thanks for you patience. Just back from leave. Looking awesome. Merged your DecisionTree PR: JuliaRegistries/General#72785 Posted suggestions for new package at JuliaAI/TreeRecipe.jl#1 |
Is there a possibility to plot a decision tree using
Plots.jl
(or some other graphics package)?I'm using MLJ and the only means to visualize a decision tree seems to be
report(mach).print_tree(n)
wheremach
is the trained machine.If there is no such possibility: How can I access the tree (data structure) directly in MLJ?
The text was updated successfully, but these errors were encountered: