Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add FastTreeSHAPv1 implementation for EvoTrees #270

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/EvoTrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ end

include("callback.jl")
include("importance.jl")
include("treeshapv1.jl")
include("plot.jl")
include("MLJ.jl")

Expand Down
144 changes: 144 additions & 0 deletions src/treeshapv1.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
FastTreeSHAPv1 SHAP values for a EvoTrees.EvoTree model

## References
* Jilei Yang, [*Fast TreeSHAP: Accelerating SHAP Value Computation for Trees*](https://arxiv.org/abs/2109.09847), 2022, pp. 16-17
"""
function treeshapv1(model::EvoTrees.EvoTree, x::AbstractVector; fnames=model.info[:fnames])
chan = Channel{typeof(x)}(length(model.trees))
@inbounds Threads.@threads for tree in model.trees
put!(chan, treeshapv1!(tree, zero(x), x))
end
ϕ = take!(chan)
while !isempty(chan)
ϕ += take!(chan)
end

pairs = string.(fnames) .=> ϕ
sort!(pairs, by=x->abs(x[2]), rev=true)
pairs
end

"""
FastTreeSHAPv1 SHAP values for a single EvoTrees.Tree

## References
* Jilei Yang, [*Fast TreeSHAP: Accelerating SHAP Value Computation for Trees*](https://arxiv.org/abs/2109.09847), 2022, pp. 16-17
"""
function treeshapv1!(tree::EvoTrees.Tree, φ::AbstractVector, x::AbstractVector)
recurse!(tree, firstindex(tree.feat), PathValue[], Float64[], 1., PathValue(0, 1., true), φ, x)
φ
end

"""
Element type used to make *m*, the path vector of unique features we have split on so far.

## Fields
* `d`: feature index of the jth internal node
* `z`: covering ratio for the jth internal node
* `o`: the threshold condition for the jth internal node
"""
struct PathValue
d::Int
z::Float64
o::Bool
end

function nodevalue(tree::EvoTrees.Tree{L, K}, j::Integer) where {L<:EvoTrees.GradientRegression, K}
tree.pred[1, j]
end

function nodevalue(tree::EvoTrees.Tree{L, K}, j::Integer) where {L<:Union{EvoTrees.LogLoss,EvoTrees.MLogLoss}, K}
sum(tree.pred[:, j]) # logit sum
end

function recurse!(tree::EvoTrees.Tree, j::Integer, m::AbstractVector{PathValue}, w::AbstractVector, q::Real, p::PathValue, φ::AbstractVector, x::AbstractVector)
m, w, q = extend(m, w, q, p)

if !tree.split[j] # leaf node
if length(m) > length(w)
s₀ = -sum(unwind(m, w, q)[2])
end

for i=2:lastindex(m)
vⱼ = nodevalue(tree, j)
if !m[i].o
update = s₀ * q * vⱼ
else
s = sum(unwind(m, w, q, i)[2])
update = s * q * (1 - m[i].z) * vⱼ
end
φ[m[i].d] += update
end
else
dⱼ = tree.feat[j]
aⱼ, bⱼ = (2j, 2j+1) # j's (left, right) child
h, c = x[dⱼ] ≤ tree.cond_float[j] ? (aⱼ, bⱼ) : (bⱼ, aⱼ) # (hot, cold) child
iz, io = 1, true
k = findfirst(x -> x.d == dⱼ, m)

if !isnothing(k) # Undo previous updates if feature was already used
iz, io = m[k].z, m[k].o
m, w, q = unwind(m, w, q, k)
end
recurse!(tree, h, m, w, q, PathValue(dⱼ, iz * (tree.gain[h] / tree.gain[j]), io), φ, x)
recurse!(tree, c, m, w, q, PathValue(dⱼ, iz * (tree.gain[c] / tree.gain[j]), false), φ, x)
end
end

function extend(m::AbstractVector{PathValue}, w::AbstractVector{T}, q::Real, p::PathValue) where {T<:Real}
l, lw = length(m), length(w)
m, w = copy(m), copy(w)
push!(m, p)

if !p.o
q *= p.z
for i=lw:-1:1
w[i] *= (l - (i - 1)) / (l + 1)
end
else
push!(w, T(lw == 0))
for i=lw:-1:1
w[i + 1] += w[i] * (i / (l + 1))
w[i] *= p.z * ((l - (i - 1)) / (l + 1))
end
end
m, w, q
end

function unwind(m::AbstractVector{PathValue}, w::AbstractVector, q::Real)
l, lw = length(m)-1, length(w)-1
m, w = copy(view(m, 1:l)), copy(w)

for j=lw+1:-1:1
w[j] *= (l + 1) / (l - (j - 1))
end
m, w, q
end

function unwind(m::AbstractVector{PathValue}, w::AbstractVector, q::Real, i::Integer)
l, lw = length(m)-1, length(w)-1
mᵢ = m[i]
m = copy(view(m, 1:l))

if !mᵢ.o
w = copy(w)
for j=lw+1:-1:1
w[j] *= (l + 1) / (l - (j - 1))
end
q /= mᵢ.z
else
n = w[lw+1]
w = copy(view(w, 1:lw))
for j=lw:-1:1
t = w[j]
w[j] = n * ((l + 1) / (j + 1))
n = t - w[j] * mᵢ.z * ((l - (j - 1)) / (l + 1))
end
end
for j=i:l-1
m[j] = m[j + 1]
end
m, w, q
end