Skip to content

Commit

Permalink
Add support for kernelshap and update aggregated shap code (#553)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianstando authored Mar 22, 2023
1 parent a6fb717 commit 142c94d
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 54 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: DALEX
Title: moDel Agnostic Language for Exploration and eXplanation
Version: 2.5.0
Version: 2.5.1
Authors@R: c(person("Przemyslaw", "Biecek", email = "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-8423-1823")),
person("Szymon", "Maksymiuk", role = "aut",
Expand All @@ -27,7 +27,8 @@ Depends: R (>= 3.5)
Imports:
ggplot2,
iBreakDown (>= 1.3.1),
ingredients (>= 2.0)
ingredients (>= 2.0),
kernelshap
Suggests:
gower,
ranger,
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ export(predict_diagnostics)
export(predict_parts)
export(predict_parts_break_down)
export(predict_parts_break_down_interactions)
export(predict_parts_kernel_shap)
export(predict_parts_kernel_shap_aggreagted)
export(predict_parts_kernel_shap_break_down)
export(predict_parts_oscillations)
export(predict_parts_oscillations_emp)
export(predict_parts_oscillations_uni)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
DALEX 2.5.1
---------------------------------------------------------------
* adding the support for calculating kernel SHAP values via `predict_parts()` function

DALEX 2.5.0
---------------------------------------------------------------
* breaking change: change the name of `loss_yardstick()` to `get_loss_yardstick()` and `loss_default()` to `get_loss_default()`
Expand Down
55 changes: 51 additions & 4 deletions R/predict_parts.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
#' Model prediction is decomposed into parts that are attributed for particular variables.
#' From DALEX version 1.0 this function calls the \code{\link[iBreakDown]{break_down}} or
#' \code{\link[iBreakDown:break_down_uncertainty]{shap}} functions from the \code{iBreakDown} package or
#' \code{\link[ingredients:ceteris_paribus]{ceteris_paribus}} from the \code{ingredients} package.
#' \code{\link[ingredients:ceteris_paribus]{ceteris_paribus}} from the \code{ingredients} package or
#' \code{\link[kernelshap:kernelshap]{kernelshap}} from the \code{kernelshap} package.
#' Find information how to use the \code{break_down} method here: \url{https://ema.drwhy.ai/breakDown.html}.
#' Find information how to use the \code{shap} method here: \url{https://ema.drwhy.ai/shapley.html}.
#' Find information how to use the \code{oscillations} method here: \url{https://ema.drwhy.ai/ceterisParibusOscillations.html}.
#' Find information how to use the \code{kernelshap} method here: \url{https://modeloriented.github.io/kernelshap/}
#' aSHAP method provides explanations for a set of observations based on SHAP.
#'
#' @param explainer a model to be explained, preprocessed by the \code{explain} function
Expand All @@ -18,7 +20,7 @@
#' @param N the maximum number of observations used for calculation of attributions. By default NULL (use all) or 500 (for oscillations).
#' @param variable_splits_type how variable grids shall be calculated? Will be passed to \code{\link[ingredients]{ceteris_paribus}}.
#' @param type the type of variable attributions. Either \code{shap}, \code{aggregated_shap}, \code{oscillations}, \code{oscillations_uni},
#' \code{oscillations_emp}, \code{break_down} or \code{break_down_interactions}.
#' \code{oscillations_emp}, \code{break_down}, \code{break_down_interactions}, \code{kernel_shap}, \code{kernel_shap_break_down} or \code{kernel_shap_aggregated}.
#'
#' @return Depending on the \code{type} there are different classes of the resulting object.
#' It's a data frame with calculated average response.
Expand Down Expand Up @@ -84,7 +86,11 @@ predict_parts <- function(explainer, new_observation, ..., N = if(substr(type, 1
"oscillations_uni" = predict_parts_oscillations_uni(explainer, new_observation, ...),
"oscillations_emp" = predict_parts_oscillations_emp(explainer, new_observation, ...),
"shap_aggregated" = predict_parts_shap_aggregated(explainer, new_observation, ...),
stop("The type argument shall be either 'shap' or 'break_down' or 'break_down_interactions' or 'oscillations' or 'oscillations_uni' or 'oscillations_emp' or 'shap_aggregated'")
"kernel_shap" = predict_parts_kernel_shap(explainer, new_observation, ...),
"kernel_shap_break_down" = predict_parts_kernel_shap_break_down(explainer, new_observation, ...),
"kernel_shap_aggregated" = predict_parts_kernel_shap_aggreagted(explainer, new_observation, ...),
stop("The type argument shall be either 'shap' or 'break_down' or 'break_down_interactions' or 'oscillations' or 'oscillations_uni' or 'oscillations_emp' or 'shap_aggregated'
or 'kernel_shap' or 'kernel_shap_break_down' or 'kernel_shap_aggregated'")
)
}

Expand Down Expand Up @@ -193,10 +199,51 @@ predict_parts_shap_aggregated <- function(explainer, new_observation, ...) {

res <- shap_aggregated(explainer,
new_observations = new_observation,
kernelshap = FALSE,
...)

class(res) <- c('predict_parts', class(res))
res
}

#' @name predict_parts
#' @export
predict_parts_kernel_shap <- function(explainer, new_observation, ...) {
test_explainer(explainer, has_data = TRUE, function_name = "kernel_shap")
ks <- kernelshap::kernelshap(
object = explainer$model,
X = new_observation,
bg_X = explainer$data,
pred_fun = explainer$predict_function,
verbose = FALSE,
...,
)

res <- kernelshap_to_shap(ks, new_observation, explainer, agg=FALSE)
class(res) <- c('predict_parts', 'kernel_shap', 'shap', 'break_down_uncertainty', 'data.frame')
res
}

#' @name predict_parts
#' @export
predict_parts_kernel_shap_break_down <- function(explainer, new_observation, ...){
test_explainer(explainer, has_data = TRUE, function_name = "kernel_shap_break_down")
shaps <- predict_parts_kernel_shap(explainer, new_observation, ...)
ret <- transform_shap_to_break_down(shaps, explainer$label,
attr(shaps, 'intercept'), attr(shaps, 'prediction'),
order_of_variables=NULL, agg=FALSE)
class(ret) <- c('predict_parts', 'break_down', 'break_down_kernel_shap', 'data.frame')
ret
}

#' @name predict_parts
#' @export
predict_parts_kernel_shap_aggreagted <- function(explainer, new_observation, ...){
test_explainer(explainer, has_data = TRUE, function_name = "kernel_shap_break_down")
res <- shap_aggregated(explainer,
new_observations = new_observation,
kernelshap = TRUE,
...)
class(res) <- c('predict_parts', class(res))
res
}

Expand Down
66 changes: 23 additions & 43 deletions R/shap_aggregated.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#' @param new_observations a set of new observations with columns that correspond to variables used in the model.
#' @param order if not \code{NULL}, then it will be a fixed order of variables. It can be a numeric vector or vector with names of variables.
#' @param ... other parameters like \code{label}, \code{predict_function}, \code{data}, \code{x}
#' @param B number of random paths
#' @param B number of random paths; works only if kernelshap=FALSE
#' @param kernelshap indicates whether the kernelshap method should be used
#'
#' @return an object of the \code{shap_aggregated} class.
#'
Expand All @@ -29,14 +30,27 @@
#' plot(bd_glm, max_features = 3)
#' }
#' @export
shap_aggregated <- function(explainer, new_observations, order = NULL, B = 25, ...) {
shap_aggregated <- function(explainer, new_observations, order = NULL, B = 25, kernelshap = FALSE, ...) {
ret_raw <- data.frame(contribution = c(), variable_name = c(), label = c())

for(i in 1:nrow(new_observations)){
new_obs <- new_observations[i,]
shap_vals <- iBreakDown::shap(explainer, new_observation = new_obs, B = B, ...)
shap_vals <- shap_vals[shap_vals$B != 0, c('contribution', 'variable_name', 'label')]
ret_raw <- rbind(ret_raw, shap_vals)
if(kernelshap) {
ks <- kernelshap::kernelshap(
object = explainer$model,
X = new_observations,
bg_X = explainer$data,
pred_fun = explainer$predict_function,
verbose = FALSE,
...,
)
res <- kernelshap_to_shap(ks, new_observations, explainer, agg=TRUE)
ret_raw <- res[ ,c('contribution', 'variable_name', 'label')]
} else {
for(i in 1:nrow(new_observations)){
new_obs <- new_observations[i,]
shap_vals <- iBreakDown::shap(explainer, new_observation = new_obs, B = B, ...)
shap_vals <- shap_vals[shap_vals$B != 0, c('contribution', 'variable_name', 'label')]
ret_raw <- rbind(ret_raw, shap_vals)
}
}

data_preds <- predict(explainer, explainer$data)
Expand Down Expand Up @@ -64,46 +78,12 @@ shap_aggregated <- function(explainer, new_observations, order = NULL, B = 25, .
raw_to_aggregated <- function(ret_raw, mean_prediction, mean_subset, order, label){
ret <- aggregate(ret_raw$contribution, list(ret_raw$variable_name, ret_raw$label), FUN=mean)
colnames(ret) <- c('variable', 'label', 'contribution')
ret$variable <- as.character(ret$variable)
rownames(ret) <- ret$variable

ret <- ret[order,]

ret$position <- (nrow(ret) + 1):2
ret$sign <- ifelse(ret$contribution >= 0, "1", "-1")

ret <- rbind(ret, data.frame(variable = "intercept",
label = label,
contribution = mean_prediction,
position = max(ret$position) + 1,
sign = "X"),
make.row.names=FALSE)

ret <- rbind(ret, data.frame(variable = "prediction",
label = label,
contribution = mean_subset,
position = 1,
sign = "X"),
make.row.names=FALSE)

ret <- ret[call_order_func(ret$position, decreasing = TRUE), ]

ret$cumulative <- cumsum(ret$contribution)
ret$cumulative[nrow(ret)] <- ret$contribution[nrow(ret)]

ret$variable_name <- ret$variable
ret$variable_name <- factor(ret$variable_name, levels=c(ret$variable_name, ''))
ret$variable_name[nrow(ret)] <- ''

ret$variable_value <- '' # column for consistency

ret <- transform_shap_to_break_down(ret, label, mean_subset, mean_prediction, order, agg=TRUE)
class(ret) <- c('data.frame')
ret
}

call_order_func <- function(...) {
order(...)
}

calculate_1d_changes <- function(model, new_observation, data, predict_function) {
average_yhats <- list()
j <- 1
Expand Down
83 changes: 83 additions & 0 deletions R/shap_utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
transform_shap_to_break_down <- function(shaps, label, intercept, prediction, order_of_variables=NULL, agg=FALSE) {
ret <- as.data.frame(shaps)
ret <- ret[, c('variable', 'label', 'contribution')]
ret$variable <- as.character(ret$variable)
rownames(ret) <- ret$variable

if(!is.null(order_of_variables) && is.vector(order_of_variables)){
ret <- ret[order_of_variables,]
}

ret$position <- (nrow(ret) + 1):2
ret$sign <- ifelse(ret$contribution >= 0, "1", "-1")

ret <- rbind(ret, data.frame(variable = "intercept",
label = label,
contribution = intercept,
position = max(ret$position) + 1,
sign = "X"),
make.row.names=FALSE)

ret <- rbind(ret, data.frame(variable = "prediction",
label = label,
contribution = prediction,
position = 1,
sign = "X"),
make.row.names=FALSE)

ret <- ret[order(ret$position, decreasing = TRUE), ]

ret$cumulative <- cumsum(ret$contribution)
ret$cumulative[nrow(ret)] <- ret$contribution[nrow(ret)]

ret$variable_name <- c('intercept', as.data.frame(shaps)[, c('variable_name')], '')
ret$variable_name <- factor(ret$variable_name, levels=ret$variable_name)
ret$variable_name[nrow(ret)] <- ''

if(agg) {
ret$variable_value <- ''
} else {
ret$variable_value <- c(1, as.data.frame(shaps)[, c('variable_name')], '')
}


class(ret) <- c('predict_parts', 'break_down', 'data.frame')

ret
}

kernelshap_to_shap <- function(ks, new_observation, explainer, agg=FALSE) {
res <- as.data.frame(t(ks$S))

colnames(res) <- c('contribution')
res$variable_name <- rownames(res)

if(agg) {
res$variable_value <- ''
} else {
res$variable_value <- unname(unlist(new_observation))
}

res$variable <- paste0(res$variable_name, ' = ', nice_format(res$variable_value))
res$sign <- ifelse(res$contribution > 0, 1, -1)
res$label <- explainer$label
res$B <- 0

attr(res, "prediction") <- as.numeric(
explainer$predict_function(explainer$model, new_observation)
)
attr(res, "intercept") <- as.numeric(ks$baseline)

class(res) <- c('predict_parts', 'shap', 'break_down_uncertainty', 'data.frame')
res
}

nice_format <- function(x) {
if (is.numeric(x)) {
as.character(signif(x, 4))
} else if ("tbl" %in% class(x)) {
as.character(x[[1]])
} else {
as.character(x)
}
}
15 changes: 13 additions & 2 deletions man/predict_parts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions man/shap_aggregated.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 142c94d

Please sign in to comment.