Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to plot a decision tree (using a graphics package) #147

Open
roland-KA opened this issue Jan 19, 2022 · 57 comments
Open

How to plot a decision tree (using a graphics package) #147

roland-KA opened this issue Jan 19, 2022 · 57 comments

Comments

@roland-KA
Copy link
Collaborator

Is there a possibility to plot a decision tree using Plots.jl (or some other graphics package)?

I'm using MLJ and the only means to visualize a decision tree seems to be report(mach).print_tree(n) where mach is the trained machine.

If there is no such possibility: How can I access the tree (data structure) directly in MLJ?

@roland-KA
Copy link
Collaborator Author

I think I've solved most of the questions above: Using AbstractTrees.jl and GraphRecipes.jl it's relatively easy to implement.

  • The decision tree can be accessed via fitted_params(mach).tree.
  • Then the AbstractTrees functions children and print node have to be implemented:
using AbstractTrees

function AbstractTrees.children(node::DecisionTree.Node)
	return(node.left, node.right)
end

function AbstractTrees.printnode(io::IO, node::DecisionTree.Node)
	print(io, "ID: ", node.featid, " - ", node.featval)
end

function AbstractTrees.printnode(io::IO, leaf::DecisionTree.Leaf)
	print(io, "maj: ", leaf.majority, " - vals: ", length(leaf.values))
end
  • Finally GraphRecipes can be used together with Plots (dtree is the decision tree to be plotted):
using GraphicRecipes

plot(TreePlot(dtree), method = :tree, nodeshape = :ellipse)

Unfortunately DecisionTree.Node stores only the id of the feature which is used for a split (featid). It would be nice, if the feature name could also be shown. Within the DecisionTree package exists an array feature_names with these names. But I didn't find a way to access it. How can this be done (in printnode)?

@Rahulub3r
Copy link

Rahulub3r commented Feb 7, 2022

Hi @roland-KA, I have created a function to plot decision tree using CairoMakie (https://github.com/Rahulub3r/Julia-ML-Utils/blob/main/decisionTreeUtils.jl). Although this is a minimum working example, I think it can be changed to be production material for integrating into the package. I am interested in contributing. Let me know if this looks good.

@roland-KA
Copy link
Collaborator Author

Hi @Rahulub3r , thank you for your efforts!
Could you perhaps give a working example on how to call drawTree? I.e. which initial values should be used to draw a tree from its root down to the leafs?

@Rahulub3r
Copy link

Sure. An MWE is as follows. Suggestions are appreciated.

using MLJ
using DecisionTree
using CairoMakie
using Random

X, y = make_blobs(300;rng=MersenneTwister(1234))

dtc = @load DecisionTreeClassifier pkg=DecisionTree verbosity=0
dtc_model = dtc(min_purity_increase=0.005, min_samples_leaf=1, min_samples_split=2, max_depth=3)
dtc_mach = machine(dtc_model, X, y)
MLJ.fit!(dtc_mach)
x = fitted_params(dtc_mach)
#print_tree(x.tree)

f = Figure(;resolution=(1000, 800))
ax1 = Axis(f[1,1])
drawTree(x.tree, x.encoding, ax1; feature_names=["X1", "X2"], 
        nodetextsize=20, nodetextcolor=:black, nodewth=12,
        linetextsize=13, leaftextsize=13, leafwth=4)
hidespines!(ax1)
hidedecorations!(ax1)
f

image

@roland-KA
Copy link
Collaborator Author

Thank's, that looks really good! 👍

@ablaom
Copy link
Member

ablaom commented Feb 13, 2022

Very cool! I wonder what the best way to integrate this contribution might be. Be great for MLJ users to be able to do this (without adding Makie as dependency to existing MLJ packages).

@Rahulub3r Any interest in working helping out with visualisation in MLJ more generally?

Probably missing something but why is it necessary to specify feature names? Are they not part of x.encoding?

@Rahulub3r
Copy link

Yes, I am interested. However, I am not sure how we would implement them without Makie.
RE feature names, if you do not specify them, they will be shown as Feature 1, Feature 2, .. in the plot, but if you specify the names, you will see them as shown in the plot above. x.encoding contains the class label information and not the features.

@ablaom
Copy link
Member

ablaom commented Feb 14, 2022

x.encoding contains the class label information and not the features.

Yes, of course you right 👍🏾 . Perhaps we could get the MLJ interface to expose the feature names (as we do in, eg, the MLJGLMInterface).

@roland-KA
Copy link
Collaborator Author

As far as I can see, we have now three options to visualize a decision tree:

  1. Using GraphRecipes (with Plots in the background and the decision tree 'extended' to an AbstractTree) as described in my post above.
  2. The code developed by @Rahulub3r, which implements its own tree layout algorithm using Makie.jl.
  3. Using Graphs (with GraphMakie the background).

No. 2 is without doubt visually the most beautiful solution. But I think we should also consider some software engineering aspects:

  • Maintenance: I think it is better to build on existing efforts to keep the own efforts for development and maintenance low (our resources are scarce). No. 1 & 3 only need some 'translation' of the decision tree to another data structure and use then ready made layout and plotting algorithms.
  • Loose coupling: As @ablaom mentioned, there shouldn't be a direct dependency between MLJ and a graphics package. With option 1 every package capable of plotting an AbstractTree will do the job and with option 3 the same holds for every package capable of plotting a Graph. So we have a decoupling using these abstract data types.
    And of course it would be helpful, if MLJ could provide the class labels as well as the feature names through a clear interface. That's a good idea @ablaom!

I've been digging a bit deeper into options 1 & 3 recently, so some more comments on that:

Option 1:
Using class labels and attribute names (from 'somewhere', ideally from an MLJ interface 😊), trees with all that information can be plotted using the means stated above.

BTW: I've started a three-part tutorial on 'Towards Data Science' on how Julia and its ecosystem can be used for ML (for an audience that doesn't know Julia). The first part has been published last week. The third part uses exactly no. 1 as an example to show how easy it is in Julia to 'glue' several packages together and create a new solution with only little coding. So there will be an extensive description on how to do it.

On that way, I identified an issue with the TreePlot recipe (JuliaPlots/GraphRecipes.jl#172), which has to be resolved, before we can use it to produce correct tree plots.

Option 3:
In order to get a 'beautiful' tree plot, some extra work is needed as the standard output from GraphMakie is a bit basic and doesn't offer more out-of-the-box as I a learned here: MakieOrg/GraphMakie.jl#57

I can provide a code example on how to apply this option next week (I'm a bit busy this week).

@ablaom
Copy link
Member

ablaom commented Feb 18, 2022

@roland-KA Very nice summary of the issues and best options. I agree that while @Rahulub3r code is a lovely bit of work, a more generic solution is preferred (1 or 3).

It probably makes sense to have a separate "MLJPlotRecipes" package for plotting things that are very specific to MLJ (we have, for example, recipe in MLJTuning to plot outcomes of hyper-parameter optimization). But this wouldn't fall into this category.

I've opened JuliaAI/MLJDecisionTreeInterface.jl#13. This should be pretty easy (PR welcome 😄 ).

Makie.jl or Plots.jl? That's a very difficult question. I'm waiting for someone to convince me one way or the other 😉 . Of course one could do both, starting with Plots.jl which seems more ready to roll.

See also the MLJ visualization project in this GSoC list.

@roland-KA
Copy link
Collaborator Author

I think a separate "MLJPlotRecipes" package is a good idea, as there are surely more models that could be visualized (and it would be another advantage of MLJ over similar ML-packages in general).

As a first step, an implementation for Plots.jl would be preferable in my opinion too, because:

  • Plots.jl is more mature in comparison. In Makie.jl there are still quite a few loose ends and a lot of changes going on.
  • The maintainer of GraphMakie.jl informed me, that he is planning to extend that package with the functionality we need (see: Text boxes instead of text labels MakieOrg/GraphMakie.jl#57). So it hasn't to implemented for decision trees in MLJPlotRecipes.

... and if time permits, I will have a look at JuliaAI/MLJDecisionTreeInterface.jl#13 😊

@roland-KA
Copy link
Collaborator Author

And as promised, here the code to plot a decision tree using Graphs and GraphMakie. This variant needs a bit more coding than GraphRecipies & Plots since Graphs expects the tree in a breadth-first order whereas the structures of the decision tree are better suited for a depth-first traversal.

So I'm doing basically a depth-first traversal of the decision tree but I collect the nodes (as well as the information for the labels) in a breadth-first order. Sounds a bit strange (and is a bit strange 😀), but can be done with relatively little effort.

The trick is the use of the counter dictionary below, which has a counter for each level of the (binary) tree and delivers a numbering of the nodes in breadth-first order (1st level starts with 1, 2nd level with 2, 3rd level with 4 etc.).

The function traverse_tree visits each node and adds on this visit label information to label and adds the edge to its left and right child to the graph g (in the call to add_node). We assume that the decision tree is in tree and has been created beforehand.

using Graph, GraphMakie, CairoMakie

depth = DecisionTree.depth(tree)
g = SimpleDiGraph(2^(depth+1)-1)  # graph (for building the tree) of size given as an argument
label = Dict{Int, String}()  # for collecting label information
counter = Dict([i => 2^i for i in 0:depth])

function traverse_tree(tree::DecisionTree.Node, level::Int, cnt::Dict{Int, Int}, g::SimpleDiGraph, label::Dict{Int, String})
	label[cnt[level]] = "ID = " * string(tree.featid) * " - " * string(tree.featval)
	add_node!(g, tree.left, level, cnt, label) 	# left child
	add_node!(g, tree.right, level, cnt, label) # right child
end

traverse_tree(tree::DecisionTree.Leaf, level::Int, cnt::Dict{Int, Int}, g::SimpleDiGraph, label::Dict{Int, String}) = 
	label[cnt[level]] = "maj = " * string(tree.majority)

function add_node!(g::SimpleDiGraph, node::Union{DecisionTree.Node, DecisionTree.Leaf}, level::Int, cnt::Dict{Int, Int}, label::Dict{Int, String})
	add_edge!(g, cnt[level], cnt[level+1])
	traverse_tree(node, level + 1, cnt, g, label)
	cnt[level+1] += 1
end

The function traverse_tree has to be applied to the initial data structures defined above and a level of 0. Afterwards we have a tree structure in g similar to the decision tree and all label information in label. The keys of label give us the (breadth-first) order we need for plotting. Therefore we have to extract it in that order (call to sort) and then everything can be plotted using graphplot.

traverse_tree(tree, 0, counter, g, label)
tree_labels = collect(values(sort(label)))
f, ax, p = graphplot(g, layout = layout, nlabels = tree_labels, nlabels_attr=(;justification = :center, align = (:center, :top), color = :blue))

The resulting tree plot is the one depicted in MakieOrg/GraphMakie.jl#57.

@dsweber2
Copy link

dsweber2 commented Apr 8, 2022

Before noticing this issue, I went ahead and adapted the Plots recipe used in EvoTrees for this package. It's rather hacky, because the node shapes don't play well with the text, especially for deep trees, but between custom markers and bypassing the labels with annotations, it mostly works out. I ended up doing basically the same node traversal to get a digraph. Let me know if you think its worth making a pull request with some modifications (either here or in some eventual MLJPlotRecipes).

using RecipeBase
@recipe function plot(tree::DecisionTree.Node, var_names::Vector{Symbol}; widthAdjust=0.8, nodeWidth=0.45, falseColor="#FFC7AC", trueColor="#DAFBAA", decisionColor="#D6CCFC")
    g, reorderedNames, reorderedValues, reorderedLabels, n = buildGraph(tree, var_names)
    adjList = g.fadjlist
    size_base = floor(log2(length(adjList)))
    sz = (128 * 2^(size_base * widthAdjust), 96 * (1 + size_base))
    xBuf, yBuf = (nodeWidth, 0.45)
    buch = buchheim(adjList)
    xBuch = [x[1] for x in buch]
    yBuch = [x[2] for x in buch]
    shapes = [[(xBuch[ii] - xBuf, yBuch[ii] + yBuf), (xBuch[ii] + xBuf, yBuch[ii] + yBuf), (xBuch[ii] + xBuf, yBuch[ii] - yBuf - 0.15), (xBuch[ii] - xBuf, yBuch[ii] - yBuf - 0.15)] for ii = 1:length(xBuch)]
    curves = Vector{Tuple{Vector{Float64},Vector{Float64}}}(undef, sum(length.(adjList)))
    iCurve = 1
    for ii = 1:length(reorderedLabels)
        for jj = 1:length(adjList[ii])
            curves[iCurve] = ([xBuch[ii], xBuch[adjList[ii][jj]]], [shapes[ii][3][2], shapes[adjList[ii][jj]][1][2]])
            iCurve += 1
        end
    end
    annotate = [(xBuch[ii], yBuch[ii], reorderedLabels[ii], 9) for ii = 1:length(reorderedLabels)]
    background_color --> :white
    linecolor --> :black
    legend --> nothing
    axis --> nothing
    framestyle --> :none
    size --> sz
    annotations --> annotate
    for ii = 1:length(shapes)
        @series begin
            if length(adjList[ii]) == 0
                if reorderedValues[ii] > 0.5
                    fillColor = trueColor
                else
                    fillColor = falseColor
                end
            else
                fillColor = decisionColor
            end
            fillcolor --> fillColor
            seriestype --> :shape
            return shapes[ii]
        end
    end
    for ii = 1:length(curves)
        @series begin
            seriestype --> :curves
            return curves[ii]
        end
    end
end
function buildGraph(tree::DecisionTree.Node, givenNames)
    g = SimpleDiGraph()
    reorderedNames = Vector{String}()
    reorderedValues = tuple()
    reorderedLabels = Vector{String}()
    g, reorderedNames, reorderedValues, reorderedLabels, n = addNode(tree, g, givenNames, reorderedNames, reorderedValues, reorderedLabels, 1)
    return g, reorderedNames, reorderedValues, reorderedLabels, n
end
function addNode(node::DecisionTree.Node, g, givenNames, reorderedNames, reorderedValues, reorderedLabels, n)
    add_vertex!(g)
    append!(reorderedNames, [String(givenNames[node.featid])])
    reorderedValues = (reorderedValues..., node.featval)
    if node.featval isa Float64
        showVal = round(node.featval, sigdigits=3)
    else
        showVal = node.featval
    end
    append!(reorderedLabels, ["$(String(givenNames[node.featid]))\n ≥ $(showVal)"])
    g, reorderedNames, reorderedValues, reorderedLabels, nLeft = addNode(node.left, g, givenNames, reorderedNames, reorderedValues, reorderedLabels, n + 1)
    add_edge!(g, n, n + 1)
    g, reorderedNames, reorderedValues, reorderedLabels, nRight = addNode(node.right, g, givenNames, reorderedNames, reorderedValues, reorderedLabels, nLeft + 1)
    add_edge!(g, n, nLeft + 1)
    return g, reorderedNames, reorderedValues, reorderedLabels, nRight
end
function addNode(node::DecisionTree.Leaf, g, givenNames, reorderedNames, reorderedValues, reorderedLabels, n)
    add_vertex!(g)
    leafVal = sum(node.values .== 2) // length(node.values)
    append!(reorderedNames, [String(givenNames[end])])
    append!(reorderedLabels, ["$(String(givenNames[end])):\n $(leafVal.num)/$(leafVal.den)"])
    reorderedValues = (reorderedValues..., Float64(leafVal))
    return g, reorderedNames, reorderedValues, reorderedLabels, n
end

as for an example resulting graph, using the example above with some name changes:

using MLJ, DecisionTree
using Random, DataFrames, Tables
using Plots
X, y = make_blobs(300; rng=MersenneTwister(1234))
using DataFrames, Tables
df = DataFrame(theFirstThing=Tables.matrix(X)[:, 1], theSecondThing=Tables.matrix(X)[:, 2])
dtc = @load DecisionTreeClassifier pkg = DecisionTree verbosity = 0
dtc_model = dtc(min_purity_increase=0.005, min_samples_leaf=1, min_samples_split=2, max_depth=3)
dtc_mach = machine(dtc_model, df, y)
MLJ.fit!(dtc_mach)
x = fitted_params(dtc_mach)
Plots.plot(x[:tree], [x[:features]..., :y])

tmp

@roland-KA
Copy link
Collaborator Author

Hi @dsweber2, this looks quite good to me! 👍 ... and as you've realized the implementation using Plots recipes, we get the independence between MLJ and the graphics package (as discussed above).

@ablaom, wouldn't this be a good start for a MLJPlotRecipes package?

@ablaom
Copy link
Member

ablaom commented Apr 19, 2022

I agree this looks nice. However, as it is specific to DecisionTree.jl trees, I suggest a PR either to DecisionTree.jl and/or MLJDecisionTreeInterface.jl. The MLJDecisionTreeInterface version could include original feature names.

Minor suggestion: replace the test if node.featval isa Float64 with if node.featval isa AbstractFloat.

However, as commented above, a more maintainable solution would be to implement the AbstractTrees.jl API for the DecisionTrees.jl objects and try to improve the generic tree-plotting capabilities of Plots.jl, if this is lacking. It's just a pity that this approach will not be able to include feature names, without some significant changes to DecisionTree.jl, as node objects in DecisionTree don't know the feature names. Perhaps in the MLJ wrapper we could include the encoding in a plot legend?

@roland-KA
Copy link
Collaborator Author

I've just noticed that DecisionTree.jl is no longer maintained. So a PR to this package is no longer an option.

Anyway @ablaom, as you suggest, a more maintainable solution would be less dependent on DecisionTree.jl. As I'm not so familiar with plot recipes (I've just used them, never implemented one), I'm trying to understand, what that means:

  • @dsweber2 s recipe would then have an AbstractTree as its first argument (instead of a DecisionTree.Node), right?
  • And it should be a GraphRecipe replacing the TreePlot recipe in my example above?

As I that example above shows, making an AbstractTree from a DecisionTree can be relatively easy achieved: Only the AbstractTree-functions children and printnode have to be implemented.

An obstacle for a simple implementation is indeed, that DecisionTree doesn't know the feature names. In this tutorial, I've explained how that information could be added. But it is just sort of a work-around, not a desirable solution. So if MLJ could deliver that information directly, a simple implementation would be possible. @ablaom could you describe more precisely how that 'encoding in a plot legend' would look like?

@dsweber2 what do you think about such an adaption of your plot recipe? Is this a way to go?

@ablaom
Copy link
Member

ablaom commented Apr 19, 2022

@dsweber2 s recipe would then have an AbstractTree as its first argument (instead of a DecisionTree.Node), right?

I think the problem here is that "AbstractTree" is not a type, only an interface. As far as I understand, plot recipes cannot be created in the usual way with trait-dispatch, only type-dispatch. But maybe there is a workaround if you don't use the macro.

. @ablaom could you describe more precisely how that 'encoding in a plot legend' would look like?

I just mean that your nodes are labelled like " Feature 1 < 0.4", say, but you add a legend to the figure that looks like

Feature 1 - number_of_bedrooms
Feature 2 - floor_area
Feature 3 - median_price_neighborhood
...

This might even be an advantage, as long feature names are not messing up the plot.

@bensadeghi Would you be happy to entertain a PR that implements the AbstractTrees.jl interface for Node and Leaf objects? This would be pretty minimal - see above comment. AbstractTrees.jl is a popular package with no dependencies. The existing print_tree functionality could be left untouched, or replaced by the AbtractTrees.jl (text-based) version, which would eliminate some code.

@bensadeghi
Copy link
Member

@dsweber2 s recipe would then have an AbstractTree as its first argument (instead of a DecisionTree.Node), right?

I think the problem here is that "AbstractTree" is not a type, only an interface. As far as I understand, plot recipes cannot be created in the usual way with trait-dispatch, only type-dispatch. But maybe there is a workaround if you don't use the macro.

. @ablaom could you describe more precisely how that 'encoding in a plot legend' would look like?

I just mean that your nodes are labelled like " Feature 1 < 0.4", say, but you add a legend to the figure that looks like

Feature 1 - number_of_bedrooms
Feature 2 - floor_area
Feature 3 - median_price_neighborhood
...

This might even be an advantage, as long feature names are not messing up the plot.

@bensadeghi Would you be happy to entertain a PR that implements the AbstractTrees.jl interface for Node and Leaf objects? This would be pretty minimal - see above comment. AbstractTrees.jl is a popular package with no dependencies. The existing print_tree functionality could be left untouched, or replaced by the AbtractTrees.jl (text-based) version, which would eliminate some code.

@ablaom , That sounds fine.
But please make sure that the PR includes appropriate unit tests and documentation in README.

@ablaom
Copy link
Member

ablaom commented Apr 21, 2022

@roland-KA Is this something you might take on, at your leisure?

@roland-KA
Copy link
Collaborator Author

Well, at least I could give it a try. But I need some help, especially when it comes to implement a plot recipe. @dsweber2 would you give your support on that topic?

So just to make clear what the objectives are, let me summarize. We want

  • a PR to DecisionTree.jl that implements
  • an AbstractTree wrapper (by implementing children and printnode)
  • a plot recipe that is able to plot a DecisionTree (in form of an AbstractTree)

Is that correct?

@ablaom and you would provide a an extension to the MLJ wrapper that delivers the feature names (perhaps in that extended 'legend' form you described above) as well as class names?

So that in the end we could call plot (more or less) in the following way:

plot(aDecisionTree_that_acts_like_an_AbstractTree, 
         MLJ.feature_names, 
         MLJ.class_names)

@roland-KA
Copy link
Collaborator Author

Spending a bit more time today on the issue, I think, I missed the point a bit with my last comment.

@ablaom your idea is probably, that the PR to DecisionTree.jl only encompasses the AbstractTree wrapper?

A plot recipe for plotting decision trees could then be added to MLJ itself, relying just on the the AbstractTree interface. That would open the possibility for every other decision tree within in the MLJ universe to use the same generic plot recipe (it would only have to implement the AbstractTree interface). Right?

@ablaom
Copy link
Member

ablaom commented Apr 25, 2022

Yes, that's exactly my suggestion. A PR here to implement the AbstractTrees.jl API and that is all (@bensadeghi's other requirements notwithstanding). Very happy to provided feedback on such a PR.

@roland-KA
Copy link
Collaborator Author

Ok fine, that should be no problem 😊👍.

Then we have to decide, how we get the information about feature names and class labels into the game (especially when the nodes of the trees don't have that information, as is in case of DecisionTree.jl).

The idea in AbstractTree is, that the function printnode produces the text that should be displayed inside the nodes (and leaves). So this function needs access to feature names and class labels.

Currently I see the following alternatives on how this could be done:

  1. The 'pure' nodes and leaves that come from a decision tree can be wrapped in an enriched structure which has that additional knowledge. That's the variant, I described in my article in Towards Data Science. In this case printnode can deliver a ready-to-use text.
  2. The information is added on a later step, as arguments to the plot recipe (which would result in a usage of the recipe like: plot(decisiontree, feature_names, class_labels) ). Here the result of printnode will have to be combined with the additional information (in the best case) or (more probable) it has to be ignored and replaced by something else. I.e. the result of printnode would only be used for a simple default representation.

Approach 1 has to be implemented whithin the implementation of the AbstractTree traits, whereas no. 2 will be implemented in the context of the recipe (the implementation of the TreePlot recipe may serve as an example).

I.e. approach 1 lays more burden on the implementation of the AbstractTree traits when the tree structure doesn't have that information. But it would be relatively straightforward on a decision tree implementation like BetaML.jl (as far as I understand this code they have labels and names included).

Approach 2 would allow parameters on the plot recipe to fine tune the appearance of the tree. E.g. one could choose between showing just feature ids in the nodes, full text nodes or ids in the node and a legend with full text next to the tree plot (well, this would also be doable with approach 1, but it would be sort of a waste of effort).

But perhaps I'm missing a point at the moment and things are less complicated ... 🤔

@ablaom
Copy link
Member

ablaom commented Apr 28, 2022

@roland-KA I'd suggest we that for now we avoid writing plot recipes that are specific to ML decision trees and instead work on building tree representations that implement AbstractTrees.jl that can then be thrown at generic plotting algorithms - whether Makie or Plots.jl. So my vote is for 1. I think this is simpler and easier to maintain.

I very much like your idea to build an enhanced AbstractTree object for DecisionTree.jl and had not realised how far you had already gone with this in the Towards Data Science post - that's a complete POC really. So there is a way to get the feature names into the game after all. This is only slightly more complicated than than implementing AbstractTree API for raw nodes and leaves, and I expect @bensadeghi would not object to it's addition here, as there is still no refactoring of existing code.

@roland-KA
Copy link
Collaborator Author

The objective for using an AbstractTree is to hide all implementation details of specific decision trees from the plot recipe. I've investigated possible alternatives to the two approaches described above, but I didn't find any other ways to do it.

So I completely agree with you and would implement the concept from my TDS article.

BTW: In the meantime I've delved a bit into the documentation of plot recipes and I think I can do that too. 🤓

@roland-KA
Copy link
Collaborator Author

I'm currently testing my implementation of the AbstractTree-interface for DecisionTree.jl.

Here I came across the fact, that DecisionTree stores in its Leafs in field majority sometimes the class labels and in some cases ids of class labels (i.e. an index value).

Is it correct that it stores there

  • class labels when used with its native interface (outside of MLJ)
  • ids to class labels when used within MLJ?

@ablaom
Copy link
Member

ablaom commented May 5, 2022

Off the top of my head I don't know the answer to point one. Is there even a distinction in native DecisionTree.jl between id and class labels?

edited
Currently, the MLJ interface, always converts target class labels (categorical values) to integers, for passing to DecisionTree.jl.

Features are left untouched and DecisionTree.jl is happy with any eltype which supports <, I'm pretty sure.

Does this help?

@ablaom
Copy link
Member

ablaom commented May 8, 2022

In src/DecsionTree.jl I've exported (among others) the functions children and printnode (coming from AbstractTrees). Any opinion about this being good or bad?

I'd probably not export them. They are not going to be used by the average user of DecisionTree.jl, and children might conflict with some other usage.

@roland-KA
Copy link
Collaborator Author

roland-KA commented Jul 15, 2022

Here is now a first version of a plot recipe for plotting a DecisionTree wrapped in an AbstractTrees-interface:

Plot Recipe
"""
A plot recipe (based on `RecipeBase.jl`) to visualize a decision tree (from `DecisionTree.jl`) wrapped in an `AbstractTrees`-interface.

# Overview

The Buchheim-Algorithm (from `NetworkLayout.jl`) is used to generate a layout for the tree. 
Therefore the decision tree has to be converted first to a `SimpleDiGraph` (from `Graphs.jl`).
For that purpose we need the number `n` of nodes within the tree and the nodes have to be numbered
from 1 to `n` in a breadth first order (as the edges of a `SimpleDiGraph` are specified by pairs of
such numbers).

The following assumptions are made for the decision tree:
- It is a binary tree
- A node has either no children (`Leaf`) or exactly two children (`Node`)
This implies that the tree may be imbalanced but each single node is always balanced.

The whole process consists of the following steps:

1. Flatten
   The tree is converted in a breadth first order into a list (a `Vector`) of `NodeInfo`s. 
   These elements contain the relevant information for visualizing the tree. The indices
   within this array correspond to the above required numbers from 1 to `n` (1 is the root node).

2. Generate layout
   The (number) information from the flat structure is taken to create a corresponding `SimpleDiGraph`.
   Using the Buchheim-algorithm on this graph, a layout (consisting of the coordinates of the nodes)
   is generated.

3. Make a visual description (using a Plot Recipe)
   The plot recipe `dt_visualization` creates a visual description of the tree using the information of 
   the two preceding steps.
"""

import AbstractTrees
using NetworkLayout
using Graphs
using GraphRecipes


# plotting information for each node/leaf of the tree
struct PlotInfo
    parent_id   :: Int16        # number of the parent node within the tree
    is_leaf     :: Bool         # is the node a `Leaf`?
    print_label :: String       # text to be printed as a label in the visualization of the node 
end

"""
    add_level!(plot_infos::Vector{PlotInfo}, nodes, i_crnt)

Traverse the tree recursively beginning from the root, level by level, collecting on each level
all nodes from left to right and adding them to the `plot_infos`-list in form of `PlotInfo`s.

On each call a `Vector` of nodes from the last level processed is given in `nodes`. So on the first 
call a one-element array containing the root node is passed. `i_crnt ` is the index of the `PlotInfo`
in `plot_infos` corresponding to the first node in `nodes`. 

On the first call, `add_level!` expects the first entry in `plot_infos` for the root node already 
to be  present (so this has to be done manually).
"""
function add_level!(plot_infos::Vector{PlotInfo}, nodes, i_crnt)
    i_next = i_crnt + length(nodes) 
    child_nodes = []    
    for n in nodes
        cn = AbstractTrees.children(n)
        if length(cn) > 0 
            plot_infos[i_next]   = PlotInfo(i_crnt, is_leaf(cn[1]), label(cn[1]))
            plot_infos[i_next+1] = PlotInfo(i_crnt, is_leaf(cn[2]), label(cn[2]))
            i_next += 2
            push!(child_nodes, cn[1])
            push!(child_nodes, cn[2])
        end
        i_crnt += 1
    end
    if length(child_nodes) > 0
        add_level!(plot_infos, child_nodes, i_crnt)
    end
end

# extract and format label information from nodes/leaves
function label(i::Union{InfoNode, InfoLeaf})
    io = IOBuffer()
    AbstractTrees.printnode(io, i)
    return(String(take!(io)))
end

# which type of node is it?
is_leaf(i::InfoNode) = false
is_leaf(i::InfoLeaf) = true

depth(tree::InfoNode) = DecisionTree.depth(tree.node)    # --> move to `DecisionTree/abstract_trees.jl`

"""
    flatten(tree::InfoNode)

Create a list of all nodes/leaves (converted to `PlotInfo`s) within the `tree` in a breadth first order.
Note: If `tree` is imbalanced, the resulting list contains empty elements at its end.
"""
function flatten(tree::InfoNode)
    plot_infos = Vector{PlotInfo}(undef, 2^(depth(tree)+1))     # tree has max. 2^(depth+1) nodes 
    plot_infos[1] = PlotInfo(-1, false, label(tree))            # root node is first entry; -1 is a dummy 
    add_level!(plot_infos, [tree], 1)                           # add recursevly nodes of all further tree levels to the list 
    return(plot_infos)
end

"""
    layout()

Create a tree layout in form of a list of points (`GeometryBasics.Point2`) based on the list of
`PlotInfo`s created by `flatten`. The order of the points in the list corresponds to the information in 
the `plot_infos`-list passed. 
"""
function layout(plot_infos::Vector{PlotInfo})
    size = count([isassigned(plot_infos, i) for i in eachindex(plot_infos)])    # count number of actual entries
    g = SimpleDiGraph(size)
    for i in 2:size
        add_edge!(g, plot_infos[i].parent_id, i)
    end
    return(buchheim(g))
end

"""
Rectangular shape with centerpoint `center` and dimensions `width` and `height`
"""
function make_rect(center, width, height)
    left, right  = center[1] - width/2.0,  center[1] + width/2.0
    upper, lower = center[2] + height/2.0, center[2] - height/2.0
    Shape([(left, upper), (right, upper), (right, lower), (left, lower), (left, upper)])
end

"""
Linear curve starting at the `parent`s bottom center leading to the `child`s top center
"""
function make_line(parent, child, height)
    parent_bottom = [parent[1], parent[2] - height/2]
    child_top     = [child[1],  child[2]  + height/2]
    return([parent_bottom[1], child_top[1]], [parent_bottom[2], child_top[2]])
end

"""
    dt_visualization(tree::InfoNode)

Graph recipe to draw a `DecsionTree` (wrapped in an `AbstractTree`)
"""
@recipe function dt_visualization(tree::InfoNode, width = 0.7, height = 0.7)
    # prepare data 
    plot_infos = flatten(tree)
    coords = layout(plot_infos)

    # we paint on a blank paper
    framestyle --> :none
    legend --> false
    
    # connecting lines are a series of curves 
    i = 2
    while i <= length(plot_infos) && isassigned(plot_infos, i)
        @series begin
            seriestype := curves
            linecolor --> :silver 
            line = make_line(coords[plot_infos[i].parent_id], coords[i], height)
            i += 1
            return line
        end
    end

    # nodes are a series of rectangle shapes
    anns = plotattributes[:annotations] = []        # for the labels within the rectangles
    for i in eachindex(coords)
        @series begin
            seriestype := :shape
            fillcolor --> :deepskyblue3
            alpha --> (plot_infos[i].is_leaf ? 0.4 : 0.15)
            annotationcolor --> :brown4
            annotationfontsize --> 7
            c = coords[i]
            push!(anns, (c[1], c[2], (plot_infos[i].print_label,)))
            return make_rect(c, width, height)
        end
    end
end

Using the following "handmade" decision tree ...

Demo code
using DecisionTree    
using Plots             

function dtree()
    l1 = Leaf(1, [1,1,2])
    l2 = Leaf(2, [1,2,2])
    l3 = Leaf(3, [3,3,1])
    l4 = Leaf(1, [1,1,1])
    l5 = Leaf(2, [2,2,2])
    n4 = Node(4, 0.8, l4, l5)
    n3 = Node(3, 0.3, n4, l3)
    n2 = Node(2, 0.5, l1, l2)
    n1 = Node(1, 0.7, n2, n3)
    return(n1)
end

dt = dtree()
feature_names = ["feat1", "feat2", "feat3", "feat4"]
class_labels = ["a", "b", "c"]
infotree = DecisionTree.wrap(dt, (featurenames = feature_names, classlabels = class_labels))

plot(infotree) 

... we get this nice picture: 😊

image

@ablaom what do you think about this approach? And is there a place where I could put this code (we've discussed the idea of a repository for plot recipes for MLJ models)?

@ablaom
Copy link
Member

ablaom commented Jul 18, 2022

@roland-KA Thanks for this. This definitely looks like progress.

Let me check if I understand what you are proposing to provide here. Correct me if I misunderstand:

  1. The first part is pretty generic. Your apparatus provides a new method depth and any tree implementing AbstractTrees.jl interface, together with your depth method, is a tree to which your generic methods apply. Let's call these abstract trees with depth.

  2. Given any abstract tree with depth, you provide utilities for computing a collection of objects sufficient to visualise (in a rather sophisticated way) the tree, using generic functions (and no recipes) provided by Plot.jl and/or some other libraries.

  3. You provide a Plots.jl recipe that will automate the visualisation afforded by Step 2 with a single call plot(tree::DecisionTrees.InfoNode, ...). However, this is specific to the wrapped trees provided by DecisionTree.jl. And we can't really make this recipe more generic - we need a type for dispatch, right?

Thoughts:

  1. Is it possible to close the gap between 2 and 3. That is, can we have a generic function for plotting any abstract tree with depth, and refactor the recipe to be a simple call to that function? In that way, I can readily implement a plot recipe for the trees from other models (eg, from BetaML) without duplicating a lot of code.

  2. As far as where the code should live. I think you could make a nice standalone package to provide the generic plotting tools (minus recipe). The package docs could provide the DecisionTree.jl-specific recipe as an illustration. And then each MLJ decision tree model interface could provide a model-specific recipe. At least, this makes sense, if my suggestion No. 4 is workable.

@sylvaticus Do your BetaML trees happen to implement AbstractTrees.jl, including a printnode method like we have here?

@roland-KA
Copy link
Collaborator Author

Hi @ablaom, thanks for your feedback. That's really helpful!

So let's get through it:

  1. The first part is pretty generic. Your apparatus provides a new method depth and any tree implementing AbstractTrees.jl interface, together with your depth method, is a tree to which your generic methods apply. Let's call these abstract trees with depth.

Yes. depth should be defined together with the AbstractTrees-interface (children, printnode, wrap), so that we get from there an 'abstract tree with depth'.

  1. Given any abstract tree with depth, you provide utilities for computing a collection of objects sufficient to visualise (in a rather sophisticated way) the tree, using generic functions (and no recipes) provided by Plot.jl and/or some other libraries.

I don't quite understand the second half of your comment, but perhaps the following explanation clarifies the point: The whole code above in 'Plot Recipe' is a plot recipe where I've put some of the code into separate functions (like flatten and layout) to make it more readable and maintainable.

  • The objective of the code is to translate a 'abstract tree with depth' (i.e. any decision tree, not only a DecisionTree.jl) into a series of coordinates and strings that can be processed by Plots.jl (or another graphics package) to visualize the tree.
  • So, the code relies only on the 'abstract tree with depth`-interface
  • It is independent of Plots.jl
  • It is independent of DecisionTree.jl
    ... well not quite 😀, it is conceptually independent from DecisionTree.jl but not technically.

What does that mean: I really need the InfoNode (and InfoLeaf) type only for dispatch: Foremost in the parameter list of the dt_visualization recipe (I think that's your point no 3). And then also in the parameter list of is_leaf and label.

As InfoNode is defined within DecisionTree.jl I have here (technically) a dependency. But I never access the attributes of InfoNode or InfoLeaf (that's what I mean by 'the code is conceptually independent of DecisionTree.jl').

I think I could get rid of the technical dependency just by introducing an abstract type AbstractInfoNode & AbstractInfoLeaf (which must be defined outside of DecisionTree.jl and can be used by any other decision tree implementation). InfoNode in DecisionTree.jl must then be a subtype of AbstractInfoNode and the parameter list of dt_visualization should define the first parameter as tree::AbstractInfoNode (the same holds in an analogous way for label and is_leaf).

That should do the trick and make the code completely generic.

So any decision tree that implements the 'decision tree with depth'-interface should then be able to use this plot recipe (after having been transformed by its wrap-implementation to some subtype of AbstractInfoNode).

Does this sound plausible or do you see any further obstacles?

PS: I should add, that the add_level!-function isn't fully generic at the moment, because it expects a binary tree and some decision trees are n-ary trees. But that can be solved with little effort.

@roland-KA
Copy link
Collaborator Author

And I've just discovered a copy&paste error:

In the 'Plot Recipe' code above the line using GraphRecipes has to be replaced by using RecipesBase.

@ablaom
Copy link
Member

ablaom commented Jul 19, 2022

Just a quick reply for now. If we can, I think we should avoid introducing any new abstract type, apart from AbstractTrees.AbstractTree. Can we get the genericity we want some other way, eg, trait-dispatch?

@sylvaticus
Copy link

@sylvaticus Do your BetaML trees happen to implement AbstractTrees.jl, including a printnode method like we have here?

I do have "print node" / "print tree" methods, but they are textual and not linked to the AbstractTrees.jl interface.
However, I am open to looking at it. Actually, I am rewriting the "internal" BetaML API, and if this isn't too long I can try to implement it.
But I am leaving tomorrow for the holidays.. with 3 small kids I'll have no time to touch the laptop :-) :-) Late August ??

@roland-KA
Copy link
Collaborator Author

roland-KA commented Jul 19, 2022

Just a quick reply for now. If we can, I think we should avoid introducing any new abstract type, apart from AbstractTrees.AbstractTree.

@ablaom, unfortunately (contrary to general expectation) there is no type AbstractTree defined in AbstractTrees.jl. That would of course solve the problem.

Can we get the genericity we want some other way, eg, trait-dispatch?

I don't think so (or perhaps I don't understand how that would work in this context). It's not about behaviour.

As you already mentioned

we need a type for dispatch, right?

That's the point. When calling plot the correct recipe is selected depending on the type of the argument we give to plot. In the example above the argument is an InfoNode, which is not generic enough (as it is defined on the level of DecisionTree). If the recipe should work on all decision trees, we need a more generic type.

The code of the recipe is already generic. We already have, what you desire in point 4 above. The only thing which isn't generic enough is the 'recipe selector'.

So, if you have no other idea, I would proceed and adapt the implementation to use such a more generic type.

BTW: Why do you object to introducing a new abstract type? I don't quite understand the rationale behind.

Apart from that, I've noticed that AbstractTrees.jl got a few updates recently. It has now a depth-function (it's called treeheight). So we can use that and have already 'abstract trees with depth` 😊.

@ablaom
Copy link
Member

ablaom commented Jul 24, 2022

Ah, sorry for my hasty response. I keep having to reload this issue into my brain with all the other stuff going on in between.

I understand now why you want the abstract type - this is forced upon us by the way recipes work. It looks to me that you are close to an optimal solution and I look forward to your implementation. As suggested earlier, I suggest putting the generic code in a standalone tree-plotting package. JuliaAI could host this if you want.

@roland-KA
Copy link
Collaborator Author

Hi @ablaom, no problem! Your questions help me to check, if things are really ok. And I understand that you strive to minimize dependencies. But in this case, I think there is no alternative.

New abstract types

I have therefore implemented now a version using a common supertype for the InfoNode and InfoLeaf types. These new abstract types are as follows:

abstract type AbstractInfoNode{S, T} end
abstract type AbstractInfoLeaf{T} end

So the definitions in DecisionTree.jl changed to:

struct InfoNode{S, T} <: AbstractInfoNode{S, T}
    node    :: DecisionTree.Node{S, T}
    info    :: NamedTuple
end

struct InfoLeaf{T} <: AbstractInfoLeaf{T}
    leaf    :: DecisionTree.Leaf{T}
    info    :: NamedTuple
end

Furthermore I've created a package called DecisionTreesRecipe.jl where I've put the plot recipe (I think in a MLJ context this could rather become something like ModelRecipes.jl in order to have a place for all plot recipes for the visualisation of different models).

Just for convenience, I've placed the new abstract types also in this package. But I think from a software engineering perspective this is not the right place, because a model like DecisionTrees.jl needs only to know these abstract types, but not the plot recipes; that's an unnecessary dependency.

Do you have an idea, where the definition of these two abstract types could be placed so that all models which want to implement the AbstractTrees-interface would have easy access to them? Is there a place in the MLJ-universe where they would fit in?

Updated version of plot recipe

Apart from that I have an updated version of the plot recipe where I have

  • used the new functionality from AbstractTrees which makes a few things easier
  • adapted the documentation to make it (hopefully) more understandable
  • repaired a few minor glitches
Plot Recipe (updated version)
"""
A plot recipe (based on `RecipeBase.jl`) to create a graphical representation of a decision tree.

The decision tree must be wrapped in an `AbstractTrees`-interface (see `DecisionTree.jl` as 
an example implementation of the concept). This approach ensures that the recipe is independent of the
implementation details of the decision tree.
"""
module DecisionTreesRecipe

include("AbstractInfoTree.jl")   # definition of new abstract types --> not a good idea, to put them here
export AbstractInfoNode, AbstractInfoLeaf    

import AbstractTrees
using NetworkLayout
using Graphs
using RecipesBase

"""
# Overview

The recipe uses the Buchheim-Algorithm (from `NetworkLayout.jl`) to generate a layout for the decision tree. 
As this algorithm requires a `SimpleDiGraph` (from `Graphs.jl`) as its input, the tree has to be converted
first into that structure. For that purpose the nodes of the tree have to be numbered from 1 to `n` 
(`n` being the number of nodes in the tree) in a breadth first order (as the edges of a `SimpleDiGraph` 
are specified by pairs of such numbers).

So the recipe applies the following steps to convert a decision tree into a graphical representation:

1. Flatten
   The tree is converted in a breadth first order into a list (a `Vector`) of `NodeInfo`s. 
   These elements contain the relevant information for visualizing the tree lateron. The indices
   within this array correspond to the above required numbers from 1 to `n` (1 is the root node).
   The functions `flatten` and `add_level!` implement this step. 

2. Generate layout
   The (number) information from the flat structure is taken to create a corresponding `SimpleDiGraph`.
   Using the Buchheim-algorithm on this graph, a layout (consisting of the coordinates of the nodes)
   is generated. This step is implemented by `layout`.

3. Make a visual description (using a plot recipe)
   The plot recipe `dt_visualization` creates a visual description of the tree using the information of 
   the two preceding steps.
"""

### Step 1: Flatten

# plotting information for each node/leaf of the tree
struct PlotInfo
    parent_id   :: Int16        # number of the parent node within the tree
    is_leaf     :: Bool         # is the node a leaf?
    print_label :: String       # text to be printed as a label in the visualization of the node 
end

"""
    add_level!(plot_infos::Vector{PlotInfo}, nodes, i_crnt)

Traverse the tree recursively beginning from the root, level by level, collecting on each level
all nodes from left to right and adding them to the `plot_infos`-list in form of `PlotInfo`s.

On each call, a `Vector` of nodes from the last level processed is given in `nodes`. So on the first 
call a one-element array containing the root node is passed. `i_crnt` is the index of the 
`PlotInfo`-element in `plot_infos` corresponding to the first node in `nodes`. 

On the first call, `add_level!` expects the first entry in `plot_infos` for the root node already 
to be present (so this has to be done manually).
"""
function add_level!(plot_infos::Vector{PlotInfo}, nodes, i_crnt)
    i_next = i_crnt + length(nodes) 
    child_nodes = []    
    for n in nodes
        cn = AbstractTrees.children(n)
        for c in cn
            plot_infos[i_next]  = PlotInfo(i_crnt, is_leaf(c), label(c))
            push!(child_nodes, c)
            i_next += 1
        end
        i_crnt += 1
    end
    if length(child_nodes) > 0
        add_level!(plot_infos, child_nodes, i_crnt)
    end
end

# extract and format label information from nodes/leaves
function label(i::Union{<:AbstractInfoNode, <:AbstractInfoLeaf})
    io = IOBuffer()
    AbstractTrees.printnode(io, i)
    return(String(take!(io)))
end

# which type of node is it?
is_leaf(i::AbstractInfoNode) = false
is_leaf(i::AbstractInfoLeaf) = true

"""
    flatten(tree::InfoNode)

Create a list of all nodes/leaves (converted to `PlotInfo`s) within the `tree` in a breadth first order.
"""
function flatten(tree::AbstractInfoNode)
    plot_infos = Vector{PlotInfo}(undef, AbstractTrees.treesize(tree))          # tree has `treesize` nodes 
    plot_infos[1] = PlotInfo(-1, false, label(tree))                            # root node is first entry; -1 is a dummy 
    add_level!(plot_infos, [tree], 1)                                           # add recursevly nodes of all further tree levels to the list 
    return(plot_infos)
end


### Step 2: Generate layout

"""
    layout()

Create a tree layout in form of a list of points (`GeometryBasics.Point2`) based on the list of
`PlotInfo`s created by `flatten`. The order of the points in the list corresponds to the information in 
the `plot_infos`-list passed. 
"""
function layout(plot_infos::Vector{PlotInfo})
    g = SimpleDiGraph(length(plot_infos))
    for i in 2:length(plot_infos)
        add_edge!(g, plot_infos[i].parent_id, i)
    end
    return(buchheim(g))
end


### Step 3: Make a visual description (using a plot recipe)

"""
Rectangular shape with centerpoint `center` and dimensions `width` and `height`
"""
function make_rect(center, width, height)
    left, right  = center[1] - width/2.0,  center[1] + width/2.0
    upper, lower = center[2] + height/2.0, center[2] - height/2.0
    return([(left, upper), (right, upper), (right, lower), (left, lower), (left, upper)])
end

"""
Linear curve starting at the `parent`s bottom center leading to the `child`s top center.
`parent` and `child` are the center coordinates of the respective nodes.
"""
function make_line(parent, child, height)
    parent_xbottom, parent_ybottom = parent[1], parent[2] - height/2
    child_xtop,     child_ytop     = child[1],  child[2]  + height/2
    return([parent_xbottom, child_xtop], [parent_ybottom, child_ytop])
end

"""
    dt_visualization(tree::InfoNode)

Plot recipe to convert a decsion tree (wrapped in an `AbstractTree`) into a graphical representation
"""
@recipe function dt_visualization(tree::AbstractInfoNode, width = 0.7, height = 0.7)
    # prepare data 
    plot_infos = flatten(tree)
    coords = layout(plot_infos)

    # we paint on a blank paper
    framestyle --> :none
    legend --> false
    
    # connecting lines are a series of curves 
    for i in 2:length(plot_infos) 
        @series begin
            seriestype := :curves
            linecolor --> :silver 
            line = make_line(coords[plot_infos[i].parent_id], coords[i], height)
            return line
        end
    end

    # nodes are a series of rectangular shapes
    anns = plotattributes[:annotations] = []        # for the labels within the rectangles
    for i in eachindex(coords)
        @series begin
            seriestype := :shape
            fillcolor --> :deepskyblue3
            alpha --> (plot_infos[i].is_leaf ? 0.4 : 0.15)
            annotationcolor --> :brown4
            annotationfontsize --> 7
            c = coords[i]
            push!(anns, (c[1], c[2], (plot_infos[i].print_label,)))
            return make_rect(c, width, height)
        end
    end
end


end # module

@roland-KA
Copy link
Collaborator Author

roland-KA commented Jul 26, 2022

Hi @sylvaticus, in order to get a quick PoC of the visualisation concept discussed above (which is intended to work for different decision trees), I've forked BetaML with the aim to add a simple implementation of the AbstractTrees-interface to the BetaML-trees in my local environment.

Unfortunately I'm not familiar with the ForceImport-package you are using and got therefore stuck. Perhaps you can give me some hints on how to proceed?

So far I have added/changed the following things in BetaML:

  • Added the file abstract_trees.jl containing the following simple (and so far untested 😇) implementation of the AbstractTrees-interface to src/Trees.jl:

      struct InfoNode{S, T} <: AbstractInfoNode{S, T}
          node    :: DecisionNode{T}
          info    :: NamedTuple
      end
      
      struct InfoLeaf{T} <: AbstractInfoLeaf{T}
          leaf    :: Leaf{T}
          info    :: NamedTuple
      end
      
      wrap(node::DecisionNode, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
      wrap(leaf::Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
      
      AbstractTrees.children(node::InfoNode) = (
          wrap(node.node.trueBranch, node.info),
          wrap(node.node.falseBranch, node.info)
      )
      AbstractTrees.children(node::InfoLeaf) = ()
      
      function AbstractTrees.printnode(io::IO, node::InfoNode)
          print(io, node.node)
      end
      
      function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
          print(io, leaf.leaf)
      end
  • Added the dependencies:

    using AbstractTrees
    using DecisionTreesRecipe
  • Added the exports

    export InfoNode, InfoLeaf, wrap

... getting nothing more than a lot of warnings about "could not import X into Y". To my understanding I've put the using- and export-clauses in the wrong places. Where do I have to put them and do I have to apply more changes to make this work?

@sylvaticus
Copy link

Hello, could you push your code so I can look at it (perhaps in another branch, or maybe this is not needed??) ?
I assume you did the include of your file within the Trees submodule.
When you add a function you want to export you need to restart Julia, as even Revise doesn't "see" the newly added functions.

@ablaom
Copy link
Member

ablaom commented Jul 27, 2022

@roland-KA thanks for the progress!

  1. I have a question about the choice to include type parameters in AbstractInfoNode{S,T} and so forth. As I understand it, the appearance of these parameters in the concrete type Node{S,T} are specific to DecisionTree.jl. A different implementation might have fewer or less type parameters. Also, I don't think you need to know these parameters for dispatch in the plotting. So shouldn't we just drop them?

  2. I'd suggest renaming AbstractInfoNode to AbstractNode. To declare MyNodeType <: AbstractNode is a contract that MyNodeType implements (some part of the) AbstractTrees.jl interface, and that's all, right? And then it's actually enough to have a single abstract type AbstractNode (for internal and leaf nodes) and detect a leaf by checking isempty(children(node)).

So we do

abstract type AbstractNode end    # for types implementing AbstractArrays.jl interface (`children`, `depth`, `print_node`, ?)

struct InfoNode{S, T} <: AbstractNode
    node    :: DecisionTree.Node{S, T}
    info    :: NamedTuple
end

struct InfoLeaf{T} <: AbstractNode
    leaf    :: DecisionTree.Leaf{T}
    info    :: NamedTuple
end

And, unless I misunderstand (yet again) your plot recipe now works for any subtype of AbstractTree, not just decision trees, right?

  1. As to where the type(s) should live? Most natural: AbstractTrees.jl. Maybe worth asking if they would consider it. I can't think of a natural place in MLJ ecosystem. Otherwise, they could go in a standalone package (yes! a package defining a single type - why not?)

@roland-KA
Copy link
Collaborator Author

1 ... Also, I don't think you need to know these parameters for dispatch in the plotting. So shouldn't we just drop them?

Definitely yes! When I had a look at the BetaML-trees these days, it came to my mind that the current implementation is too DecisionTree-specific.

2 ... And then it's actually enough to have a single abstract type AbstractNode (for internal and leaf nodes) and detect a leaf by checking isempty(children(node)).

Good idea! That's simpler and more general.

And, unless I misunderstand (yet again) your plot recipe now works for any subtype of AbstractTree, not just decision trees, right?

Basically yes. We do this warp-thing in order to get more information for printing labels (which is rather decision tree-specific), but the recipe doesn't directly access that. It's encapsulated in printnode.

As to where the type(s) should live? Most natural: AbstractTrees.jl.

Definitely yes, that would be ideal! Would you ask the folks maintaining AbstractTrees.jl or should I do that?

Otherwise, they could go in a standalone package (yes! a package defining a single type - why not?)

That was my first thought ... but I felt it would be too weird 😀. But you are right. If we don't get it into AbstractTrees.jl, that would be a solution.

I have made all the changes you suggested and it works. 🤓

In a next step I would try to make a MWE using BetaML-trees in order to have a PoC that the current implementation is really as general as we intend it to be.

@ablaom
Copy link
Member

ablaom commented Jul 28, 2022

Okay, great. I suggest you initiate the AbstractTrees.jl request. I expect there may be a discussion about exactly what A <: AbstractTree promises and you need to be happy with the outcome. It may be A <: AbstractTree will only promise children. Is that enough? There is a fallback for printnode, right?

Link in the current discussion and copy me in so I can add my 👍🏾

Thanks again for your perseverance. I think the extra effort you are putting into the software engineering details will be well worth it in the long term.

@ablaom
Copy link
Member

ablaom commented Sep 4, 2022

@roland-KA I'm not seeing much action at your AbstractTrees.jl issue, despite your exhaustive efforts (many thanks!). How about we revert to the mini-package AbstractTreeTypes.jl idea?

You've put a lot of work into this project and I'm keen make it available.

@roland-KA
Copy link
Collaborator Author

Hi @ablaom, I've just added a comment to the issue asking for a (definitve) response and giving it a (last) chance, since it is in my opinion the better solution. If that doesn't work then we should really proceed with the mini-package solution.

I'm currently on vacation until next weekend. So perhaps we have then an answer.

@roland-KA
Copy link
Collaborator Author

Hi @ablaom, it seems now, that we will get the abstract type in AbstractTrees.jl 🙏😊. But my guess is, that it could take some weeks until the next release of that package is ready.

In order to advance faster with the plot recipe, I would suggest the following steps:

  • I could create a temporary package containing an abstract type AbstractNode. I wouldn't register that package, so that it should be referenced in the using clause by its full path. And as soon as the next release of AbstractTrees.jl is ready, we can just delete the using clauses in the packages referencing the temporary package.
  • This would open the path to publish the plot recipe. Here I have the question, if I should place it in a new package, which could be registered or if you have already some ideas where it could be placed within the MLJ universe?

@ablaom
Copy link
Member

ablaom commented Sep 24, 2022

In my experience these temporary solutions suck up time best invested elsewhere. The idea of depending on an unregistered package does not excite me either. Despite the very long wait so far, my inclination would be to wait.

If you still feel otherwise, I'd suggest a separate package, rather than in some MLJ package, which sounds more messy. As it's only temporary, I don't think it matters who owns it. It's very easy to transfer ownership in any case - trivial in fact, if the package is not registered.

@ablaom
Copy link
Member

ablaom commented Sep 24, 2022

Thanks, btw, for your advocacy at AbstractTrees.jl 🙏🏾

@roland-KA
Copy link
Collaborator Author

Thanks, btw, for your advocacy at AbstractTrees.jl 🙏🏾

Well, thanks to you for your support! I don't think it would have worked without that.

If you still feel otherwise, I'd suggest a separate package, rather than in some MLJ package

The second question above (about placing the plot recipe) was not about a temporary package but about the final solution. I.e. I wanted to know, where I should put the plot recipe as it would be used in the future (because I can start now working on that).

It seems that you would suggest also for that solution a separate new package in my repo which I would register then?

In order to test that package under "real world conditions", I will need the abstract type AbstractNode. To do these tests before the type will be available from AbstractTrees.jl, I thought about creating a temporary package containing the following few lines of code (just for testing):

module AbstractNodeType

export AbstractNode
abstract type AbstractNode

end

It won't be necessary at all, if the next release of AbstractTrees.jl appears soon (enough).

@ablaom
Copy link
Member

ablaom commented Sep 25, 2022

Yes, you're right, I completely misunderstood. 😳

Yes, a separate package for the recipe sounds good to me. You could own that or move it to JuliaAI as you wish.

@ablaom
Copy link
Member

ablaom commented Oct 20, 2022

@roland-KA AbstractNode is now at AbstractTrees.jl 0.4.3 just released 🥳

@roland-KA
Copy link
Collaborator Author

Wow, that's great news 🤗.

Unfortunately I've been quite busy during the last few weeks, so I couldn't start to create the package for the plot recipe as intended. But November looks promising 😊.

@roland-KA
Copy link
Collaborator Author

Hi @ablaom, sorry for the delay. But now I have a package for the plot recipe ready in my repo: https://github.com/roland-KA/TreeRecipe.jl.

Apart from that, I've created a PR for DecisionTree.jl in order to use the new AbstractTrees.AbstractNode.

Could you perhaps have a look on both and tell me, if this is the right way to go?

@roland-KA
Copy link
Collaborator Author

Hi @sylvaticus, I've just created a PR for BetaML.jl in order to make it work with the TreeRecipe.jl I've created in the meantime. It uses the new abstract type AbstractNode from the AbstractTrees.jl-package.

There is an example on how to plot a BetaML decision tree using the recipe in roland-KA/TreeRecipe.jl/test/testBetaMLtrees.jl.

Note: TreeRecipe.jl isn't yet a registered package. So it has to be loaded currently from roland-KA/TreeRecipe.jl.

Sorry for leaving your package so long in limbo. But now it should work!

@ablaom
Copy link
Member

ablaom commented Nov 23, 2022

@roland-KA Thanks for you patience. Just back from leave.

Looking awesome. Merged your DecisionTree PR: JuliaRegistries/General#72785

Posted suggestions for new package at JuliaAI/TreeRecipe.jl#1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants