-
Notifications
You must be signed in to change notification settings - Fork 24
/
README.Rmd
173 lines (122 loc) · 7.79 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
set.seed(21)
```
# treeshap
<!-- badges: start -->
[![R-CMD-check](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml/badge.svg)](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml)
[![CRAN status](https://www.r-pkg.org/badges/version/treeshap)](https://CRAN.R-project.org/package=treeshap)
<!-- badges: end -->
In the era of complicated classifiers conquering their market, sometimes even the authors of algorithms do not know the exact manner of building a tree ensemble model. The difficulties in models' structures are one of the reasons why most users use them simply like black-boxes. But, how can they know whether the prediction made by the model is reasonable? `treeshap` is an efficient answer for this question. Due to implementing an optimized algorithm for tree ensemble models (called TreeSHAP), it calculates the SHAP values in polynomial (instead of exponential) time. Currently, `treeshap` supports models produced with `xgboost`, `lightgbm`, `gbm`, `ranger`, and `randomForest` packages. Support for `catboost` is available only in [`catboost` branch](https://github.com/ModelOriented/treeshap/tree/catboost) (see why [here](#catboost)).
## Installation
The package is available on CRAN:
``` r
install.packages('treeshap')
```
You can install the latest development version from GitHub using `devtools` with:
``` r
devtools::install_github('ModelOriented/treeshap')
```
## Example
First of all, let's focus on an example how to represent a `xgboost` model as a unified model object:
```{r unifier-example, warning=FALSE, message=FALSE}
library(treeshap)
library(xgboost)
data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
target <- fifa20$target
param <- list(objective = "reg:squarederror", max_depth = 6)
xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, nrounds = 200, verbose = 0)
unified <- unify(xgb_model, data)
head(unified$model)
```
Having the object of unified structure, it is a piece of cake to produce SHAP values for a specific observation. The `treeshap()` function requires passing two data arguments: one representing an ensemble model unified representation and one with the observations about which we want to get the explanations. Obviously, the latter one should contain the same columns as data used during building the model.
```{r treeshap-example}
treeshap1 <- treeshap(unified, data[700:800, ], verbose = 0)
treeshap1$shaps[1:3, 1:6]
```
We can also compute SHAP values for interactions. As an example we will calculate them for a model built with simpler (only 5 columns) data and first 100 observations.
```{r interactions-example}
data2 <- fifa20$data[, 1:5]
xgb_model2 <- xgboost::xgboost(as.matrix(data2), params = param, label = target, nrounds = 200, verbose = 0)
unified2 <- unify(xgb_model2, data2)
treeshap_interactions <- treeshap(unified2, data2[1:100, ], interactions = TRUE, verbose = 0)
treeshap_interactions$interactions[, , 1:2]
```
## Plotting results
The explanation results can be visualized using [`shapviz`](https://github.com/ModelOriented/shapviz/) package, see [here](https://modeloriented.github.io/shapviz/articles/basic_use.html#treeshap).
However, `treeshap` also provides 4 plotting functions:
### Feature Contribution (Break-Down)
On this plot we can see how features contribute into the prediction for a single observation. It is similar to the Break Down plot from [iBreakDown](https://github.com/ModelOriented/iBreakDown) package, which uses different method to approximate SHAP values.
```{r plot_contribution_example}
plot_contribution(treeshap1, obs = 1, min_max = c(0, 16000000))
```
### Feature Importance
This plot shows us average absolute impact of features on the prediction of the model.
```{r plot_importance_example}
plot_feature_importance(treeshap1, max_vars = 6)
```
### Feature Dependence
Using this plot we can see, how a single feature contributes into the prediction depending on its value.
```{r plot_dependence_example}
plot_feature_dependence(treeshap1, "height_cm")
```
### Interaction Plot
Simple plot to visualize an SHAP Interaction value of two features depending on their values.
```{r plot_interaction}
plot_interaction(treeshap_interactions, "height_cm", "overall")
```
## How to use the unifying functions?
For your convenience, you can now simply use the `unify()` function by specifying your model and reference dataset. Behind the scenes, it uses one of the six functions from the `.unify()` family (`xgboost.unify()`, `lightgbm.unify()`, `gbm.unify()`, `catboost.unify()`, `randomForest.unify()`, `ranger.unify()`). Even though the objects produced by these functions are identical when it comes to the structure, due to different possibilities of saving and representing the trees among the packages, the usage of these model-specific functions may be slightly different. Therefore, you can use them independently or pass some additional parameters to `unify()`.
```{r gbm, eval=FALSE}
library(treeshap)
library(gbm)
x <- fifa20$data[colnames(fifa20$data) != 'work_rate']
x['value_eur'] <- fifa20$target
gbm_model <- gbm::gbm(
formula = value_eur ~ .,
data = x,
distribution = "laplace",
n.trees = 200,
cv.folds = 2,
interaction.depth = 2
)
unified_gbm <- unify(gbm_model, x)
unified_gbm2 <- gbm.unify(gbm_model, x) # legacy API
```
## Setting reference dataset
Dataset used as a reference for calculating SHAP values is stored in unified model representation object. It can be set any time using `set_reference_dataset()` function.
```{r set_reference_dataset, eval=FALSE}
library(treeshap)
library(ranger)
data_fifa <- fifa20$data[!colnames(fifa20$data) %in%
c('work_rate', 'value_eur', 'gk_diving', 'gk_handling',
'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')]
data <- na.omit(cbind(data_fifa, target = fifa20$target))
rf <- ranger::ranger(target~., data = data, max.depth = 10, num.trees = 10)
unified_ranger_model <- unify(rf, data)
unified_ranger_model2 <- set_reference_dataset(unified_ranger_model, data[c(1000:2000), ])
```
## Other functionalities
Package also implements `predict()` function for calculating model's predictions using unified representation.
## How fast does it work?
The complexity of TreeSHAP is $\mathcal{O}(TLD^2)$, where $T$ is the number of trees, $L$ is the number of leaves in a tree, and $D$ is the depth of a tree.
Our implementation works at a speed comparable to the original Lundberg's Python package `shap` implementation using C and Python.
The complexity of SHAP interaction values computation is $\mathcal{O}(MTLD^2)$, where $M$ is the number of explanatory variables used by the explained model, $T$ is the number of trees, $L$ is the number of leaves in a tree, and $D$ is the depth of a tree.
## CatBoost
Originally, `treeshap` also supported the CatBoost models from the `catboost` package but due to the lack of this package on CRAN or R-universe (see `catboost` issues issues [#439](https://github.com/catboost/catboost/issues/439), [#1846](https://github.com/catboost/catboost/issues/1846)), we decided to remove support from the main version of our package.
However, you can still use the `treeshap` implementation for `catboost` by installing our package from [`catboost` branch](https://github.com/ModelOriented/treeshap/tree/catboost).
This branch can be installed with:
``` r
devtools::install_github('ModelOriented/treeshap@catboost')
```
## References
- Lundberg, S.M., Erion, G., Chen, H. et al. "From local explanations to global understanding with explainable AI for trees", Nature Machine Intelligence 2, 56–67 (2020).