From 142c94d0aab51584a81e31c38d687a897b11cb60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20Sta=C5=84do?= <56126191+adrianstando@users.noreply.github.com> Date: Wed, 22 Mar 2023 22:16:11 +0100 Subject: [PATCH] Add support for kernelshap and update aggregated shap code (#553) --- DESCRIPTION | 5 +- NAMESPACE | 3 + NEWS.md | 4 ++ R/predict_parts.R | 55 ++++++++++++++-- R/shap_aggregated.R | 66 +++++++------------ R/shap_utils.R | 83 ++++++++++++++++++++++++ man/predict_parts.Rd | 15 ++++- man/shap_aggregated.Rd | 13 +++- tests/testthat/test_predict_parts.R | 98 ++++++++++++++++++++++++++++- 9 files changed, 288 insertions(+), 54 deletions(-) create mode 100644 R/shap_utils.R diff --git a/DESCRIPTION b/DESCRIPTION index 3b83c8333..2f4611518 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: DALEX Title: moDel Agnostic Language for Exploration and eXplanation -Version: 2.5.0 +Version: 2.5.1 Authors@R: c(person("Przemyslaw", "Biecek", email = "przemyslaw.biecek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-8423-1823")), person("Szymon", "Maksymiuk", role = "aut", @@ -27,7 +27,8 @@ Depends: R (>= 3.5) Imports: ggplot2, iBreakDown (>= 1.3.1), - ingredients (>= 2.0) + ingredients (>= 2.0), + kernelshap Suggests: gower, ranger, diff --git a/NAMESPACE b/NAMESPACE index 2400cb9f9..afd19ed02 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -75,6 +75,9 @@ export(predict_diagnostics) export(predict_parts) export(predict_parts_break_down) export(predict_parts_break_down_interactions) +export(predict_parts_kernel_shap) +export(predict_parts_kernel_shap_aggreagted) +export(predict_parts_kernel_shap_break_down) export(predict_parts_oscillations) export(predict_parts_oscillations_emp) export(predict_parts_oscillations_uni) diff --git a/NEWS.md b/NEWS.md index a8e9e05d5..f0cca4776 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +DALEX 2.5.1 +--------------------------------------------------------------- +* adding the support for calculating kernel SHAP values via `predict_parts()` function + DALEX 2.5.0 --------------------------------------------------------------- * breaking change: change the name of `loss_yardstick()` to `get_loss_yardstick()` and `loss_default()` to `get_loss_default()` diff --git a/R/predict_parts.R b/R/predict_parts.R index 08e123d90..23d67eacc 100644 --- a/R/predict_parts.R +++ b/R/predict_parts.R @@ -4,10 +4,12 @@ #' Model prediction is decomposed into parts that are attributed for particular variables. #' From DALEX version 1.0 this function calls the \code{\link[iBreakDown]{break_down}} or #' \code{\link[iBreakDown:break_down_uncertainty]{shap}} functions from the \code{iBreakDown} package or -#' \code{\link[ingredients:ceteris_paribus]{ceteris_paribus}} from the \code{ingredients} package. +#' \code{\link[ingredients:ceteris_paribus]{ceteris_paribus}} from the \code{ingredients} package or +#' \code{\link[kernelshap:kernelshap]{kernelshap}} from the \code{kernelshap} package. #' Find information how to use the \code{break_down} method here: \url{https://ema.drwhy.ai/breakDown.html}. #' Find information how to use the \code{shap} method here: \url{https://ema.drwhy.ai/shapley.html}. #' Find information how to use the \code{oscillations} method here: \url{https://ema.drwhy.ai/ceterisParibusOscillations.html}. +#' Find information how to use the \code{kernelshap} method here: \url{https://modeloriented.github.io/kernelshap/} #' aSHAP method provides explanations for a set of observations based on SHAP. #' #' @param explainer a model to be explained, preprocessed by the \code{explain} function @@ -18,7 +20,7 @@ #' @param N the maximum number of observations used for calculation of attributions. By default NULL (use all) or 500 (for oscillations). #' @param variable_splits_type how variable grids shall be calculated? Will be passed to \code{\link[ingredients]{ceteris_paribus}}. #' @param type the type of variable attributions. Either \code{shap}, \code{aggregated_shap}, \code{oscillations}, \code{oscillations_uni}, -#' \code{oscillations_emp}, \code{break_down} or \code{break_down_interactions}. +#' \code{oscillations_emp}, \code{break_down}, \code{break_down_interactions}, \code{kernel_shap}, \code{kernel_shap_break_down} or \code{kernel_shap_aggregated}. #' #' @return Depending on the \code{type} there are different classes of the resulting object. #' It's a data frame with calculated average response. @@ -84,7 +86,11 @@ predict_parts <- function(explainer, new_observation, ..., N = if(substr(type, 1 "oscillations_uni" = predict_parts_oscillations_uni(explainer, new_observation, ...), "oscillations_emp" = predict_parts_oscillations_emp(explainer, new_observation, ...), "shap_aggregated" = predict_parts_shap_aggregated(explainer, new_observation, ...), - stop("The type argument shall be either 'shap' or 'break_down' or 'break_down_interactions' or 'oscillations' or 'oscillations_uni' or 'oscillations_emp' or 'shap_aggregated'") + "kernel_shap" = predict_parts_kernel_shap(explainer, new_observation, ...), + "kernel_shap_break_down" = predict_parts_kernel_shap_break_down(explainer, new_observation, ...), + "kernel_shap_aggregated" = predict_parts_kernel_shap_aggreagted(explainer, new_observation, ...), + stop("The type argument shall be either 'shap' or 'break_down' or 'break_down_interactions' or 'oscillations' or 'oscillations_uni' or 'oscillations_emp' or 'shap_aggregated' + or 'kernel_shap' or 'kernel_shap_break_down' or 'kernel_shap_aggregated'") ) } @@ -193,10 +199,51 @@ predict_parts_shap_aggregated <- function(explainer, new_observation, ...) { res <- shap_aggregated(explainer, new_observations = new_observation, + kernelshap = FALSE, ...) - class(res) <- c('predict_parts', class(res)) + res +} + +#' @name predict_parts +#' @export +predict_parts_kernel_shap <- function(explainer, new_observation, ...) { + test_explainer(explainer, has_data = TRUE, function_name = "kernel_shap") + ks <- kernelshap::kernelshap( + object = explainer$model, + X = new_observation, + bg_X = explainer$data, + pred_fun = explainer$predict_function, + verbose = FALSE, + ..., + ) + + res <- kernelshap_to_shap(ks, new_observation, explainer, agg=FALSE) + class(res) <- c('predict_parts', 'kernel_shap', 'shap', 'break_down_uncertainty', 'data.frame') + res +} +#' @name predict_parts +#' @export +predict_parts_kernel_shap_break_down <- function(explainer, new_observation, ...){ + test_explainer(explainer, has_data = TRUE, function_name = "kernel_shap_break_down") + shaps <- predict_parts_kernel_shap(explainer, new_observation, ...) + ret <- transform_shap_to_break_down(shaps, explainer$label, + attr(shaps, 'intercept'), attr(shaps, 'prediction'), + order_of_variables=NULL, agg=FALSE) + class(ret) <- c('predict_parts', 'break_down', 'break_down_kernel_shap', 'data.frame') + ret +} + +#' @name predict_parts +#' @export +predict_parts_kernel_shap_aggreagted <- function(explainer, new_observation, ...){ + test_explainer(explainer, has_data = TRUE, function_name = "kernel_shap_break_down") + res <- shap_aggregated(explainer, + new_observations = new_observation, + kernelshap = TRUE, + ...) + class(res) <- c('predict_parts', class(res)) res } diff --git a/R/shap_aggregated.R b/R/shap_aggregated.R index 25c4627da..6b61ff88b 100644 --- a/R/shap_aggregated.R +++ b/R/shap_aggregated.R @@ -6,7 +6,8 @@ #' @param new_observations a set of new observations with columns that correspond to variables used in the model. #' @param order if not \code{NULL}, then it will be a fixed order of variables. It can be a numeric vector or vector with names of variables. #' @param ... other parameters like \code{label}, \code{predict_function}, \code{data}, \code{x} -#' @param B number of random paths +#' @param B number of random paths; works only if kernelshap=FALSE +#' @param kernelshap indicates whether the kernelshap method should be used #' #' @return an object of the \code{shap_aggregated} class. #' @@ -29,14 +30,27 @@ #' plot(bd_glm, max_features = 3) #' } #' @export -shap_aggregated <- function(explainer, new_observations, order = NULL, B = 25, ...) { +shap_aggregated <- function(explainer, new_observations, order = NULL, B = 25, kernelshap = FALSE, ...) { ret_raw <- data.frame(contribution = c(), variable_name = c(), label = c()) - for(i in 1:nrow(new_observations)){ - new_obs <- new_observations[i,] - shap_vals <- iBreakDown::shap(explainer, new_observation = new_obs, B = B, ...) - shap_vals <- shap_vals[shap_vals$B != 0, c('contribution', 'variable_name', 'label')] - ret_raw <- rbind(ret_raw, shap_vals) + if(kernelshap) { + ks <- kernelshap::kernelshap( + object = explainer$model, + X = new_observations, + bg_X = explainer$data, + pred_fun = explainer$predict_function, + verbose = FALSE, + ..., + ) + res <- kernelshap_to_shap(ks, new_observations, explainer, agg=TRUE) + ret_raw <- res[ ,c('contribution', 'variable_name', 'label')] + } else { + for(i in 1:nrow(new_observations)){ + new_obs <- new_observations[i,] + shap_vals <- iBreakDown::shap(explainer, new_observation = new_obs, B = B, ...) + shap_vals <- shap_vals[shap_vals$B != 0, c('contribution', 'variable_name', 'label')] + ret_raw <- rbind(ret_raw, shap_vals) + } } data_preds <- predict(explainer, explainer$data) @@ -64,46 +78,12 @@ shap_aggregated <- function(explainer, new_observations, order = NULL, B = 25, . raw_to_aggregated <- function(ret_raw, mean_prediction, mean_subset, order, label){ ret <- aggregate(ret_raw$contribution, list(ret_raw$variable_name, ret_raw$label), FUN=mean) colnames(ret) <- c('variable', 'label', 'contribution') - ret$variable <- as.character(ret$variable) - rownames(ret) <- ret$variable - - ret <- ret[order,] - - ret$position <- (nrow(ret) + 1):2 - ret$sign <- ifelse(ret$contribution >= 0, "1", "-1") - - ret <- rbind(ret, data.frame(variable = "intercept", - label = label, - contribution = mean_prediction, - position = max(ret$position) + 1, - sign = "X"), - make.row.names=FALSE) - - ret <- rbind(ret, data.frame(variable = "prediction", - label = label, - contribution = mean_subset, - position = 1, - sign = "X"), - make.row.names=FALSE) - - ret <- ret[call_order_func(ret$position, decreasing = TRUE), ] - - ret$cumulative <- cumsum(ret$contribution) - ret$cumulative[nrow(ret)] <- ret$contribution[nrow(ret)] - ret$variable_name <- ret$variable - ret$variable_name <- factor(ret$variable_name, levels=c(ret$variable_name, '')) - ret$variable_name[nrow(ret)] <- '' - - ret$variable_value <- '' # column for consistency - + ret <- transform_shap_to_break_down(ret, label, mean_subset, mean_prediction, order, agg=TRUE) + class(ret) <- c('data.frame') ret } -call_order_func <- function(...) { - order(...) -} - calculate_1d_changes <- function(model, new_observation, data, predict_function) { average_yhats <- list() j <- 1 diff --git a/R/shap_utils.R b/R/shap_utils.R new file mode 100644 index 000000000..23ba4878a --- /dev/null +++ b/R/shap_utils.R @@ -0,0 +1,83 @@ +transform_shap_to_break_down <- function(shaps, label, intercept, prediction, order_of_variables=NULL, agg=FALSE) { + ret <- as.data.frame(shaps) + ret <- ret[, c('variable', 'label', 'contribution')] + ret$variable <- as.character(ret$variable) + rownames(ret) <- ret$variable + + if(!is.null(order_of_variables) && is.vector(order_of_variables)){ + ret <- ret[order_of_variables,] + } + + ret$position <- (nrow(ret) + 1):2 + ret$sign <- ifelse(ret$contribution >= 0, "1", "-1") + + ret <- rbind(ret, data.frame(variable = "intercept", + label = label, + contribution = intercept, + position = max(ret$position) + 1, + sign = "X"), + make.row.names=FALSE) + + ret <- rbind(ret, data.frame(variable = "prediction", + label = label, + contribution = prediction, + position = 1, + sign = "X"), + make.row.names=FALSE) + + ret <- ret[order(ret$position, decreasing = TRUE), ] + + ret$cumulative <- cumsum(ret$contribution) + ret$cumulative[nrow(ret)] <- ret$contribution[nrow(ret)] + + ret$variable_name <- c('intercept', as.data.frame(shaps)[, c('variable_name')], '') + ret$variable_name <- factor(ret$variable_name, levels=ret$variable_name) + ret$variable_name[nrow(ret)] <- '' + + if(agg) { + ret$variable_value <- '' + } else { + ret$variable_value <- c(1, as.data.frame(shaps)[, c('variable_name')], '') + } + + + class(ret) <- c('predict_parts', 'break_down', 'data.frame') + + ret +} + +kernelshap_to_shap <- function(ks, new_observation, explainer, agg=FALSE) { + res <- as.data.frame(t(ks$S)) + + colnames(res) <- c('contribution') + res$variable_name <- rownames(res) + + if(agg) { + res$variable_value <- '' + } else { + res$variable_value <- unname(unlist(new_observation)) + } + + res$variable <- paste0(res$variable_name, ' = ', nice_format(res$variable_value)) + res$sign <- ifelse(res$contribution > 0, 1, -1) + res$label <- explainer$label + res$B <- 0 + + attr(res, "prediction") <- as.numeric( + explainer$predict_function(explainer$model, new_observation) + ) + attr(res, "intercept") <- as.numeric(ks$baseline) + + class(res) <- c('predict_parts', 'shap', 'break_down_uncertainty', 'data.frame') + res +} + +nice_format <- function(x) { + if (is.numeric(x)) { + as.character(signif(x, 4)) + } else if ("tbl" %in% class(x)) { + as.character(x[[1]]) + } else { + as.character(x) + } +} diff --git a/man/predict_parts.Rd b/man/predict_parts.Rd index cd4cda5a9..8a5d26636 100644 --- a/man/predict_parts.Rd +++ b/man/predict_parts.Rd @@ -10,6 +10,9 @@ \alias{predict_parts_oscillations_emp} \alias{predict_parts_break_down_interactions} \alias{predict_parts_shap_aggregated} +\alias{predict_parts_kernel_shap} +\alias{predict_parts_kernel_shap_break_down} +\alias{predict_parts_kernel_shap_aggreagted} \alias{variable_attribution} \title{Instance Level Parts of the Model Predictions} \usage{ @@ -46,6 +49,12 @@ predict_parts_shap(explainer, new_observation, ...) predict_parts_shap_aggregated(explainer, new_observation, ...) +predict_parts_kernel_shap(explainer, new_observation, ...) + +predict_parts_kernel_shap_break_down(explainer, new_observation, ...) + +predict_parts_kernel_shap_aggreagted(explainer, new_observation, ...) + variable_attribution( explainer, new_observation, @@ -64,7 +73,7 @@ variable_attribution( \item{N}{the maximum number of observations used for calculation of attributions. By default NULL (use all) or 500 (for oscillations).} \item{type}{the type of variable attributions. Either \code{shap}, \code{aggregated_shap}, \code{oscillations}, \code{oscillations_uni}, -\code{oscillations_emp}, \code{break_down} or \code{break_down_interactions}.} +\code{oscillations_emp}, \code{break_down}, \code{break_down_interactions}, \code{kernel_shap}, \code{kernel_shap_break_down} or \code{kernel_shap_aggregated}.} \item{variable_splits_type}{how variable grids shall be calculated? Will be passed to \code{\link[ingredients]{ceteris_paribus}}.} @@ -81,10 +90,12 @@ Instance Level Variable Attributions as Break Down, SHAP, aggregated SHAP or Osc Model prediction is decomposed into parts that are attributed for particular variables. From DALEX version 1.0 this function calls the \code{\link[iBreakDown]{break_down}} or \code{\link[iBreakDown:break_down_uncertainty]{shap}} functions from the \code{iBreakDown} package or -\code{\link[ingredients:ceteris_paribus]{ceteris_paribus}} from the \code{ingredients} package. +\code{\link[ingredients:ceteris_paribus]{ceteris_paribus}} from the \code{ingredients} package or +\code{\link[kernelshap:kernelshap]{kernelshap}} from the \code{kernelshap} package. Find information how to use the \code{break_down} method here: \url{https://ema.drwhy.ai/breakDown.html}. Find information how to use the \code{shap} method here: \url{https://ema.drwhy.ai/shapley.html}. Find information how to use the \code{oscillations} method here: \url{https://ema.drwhy.ai/ceterisParibusOscillations.html}. +Find information how to use the \code{kernelshap} method here: \url{https://modeloriented.github.io/kernelshap/} aSHAP method provides explanations for a set of observations based on SHAP. } \examples{ diff --git a/man/shap_aggregated.Rd b/man/shap_aggregated.Rd index cb337c639..ac6f5f822 100644 --- a/man/shap_aggregated.Rd +++ b/man/shap_aggregated.Rd @@ -4,7 +4,14 @@ \alias{shap_aggregated} \title{SHAP aggregated values} \usage{ -shap_aggregated(explainer, new_observations, order = NULL, B = 25, ...) +shap_aggregated( + explainer, + new_observations, + order = NULL, + B = 25, + kernelshap = FALSE, + ... +) } \arguments{ \item{explainer}{a model to be explained, preprocessed by the \code{explain} function} @@ -13,7 +20,9 @@ shap_aggregated(explainer, new_observations, order = NULL, B = 25, ...) \item{order}{if not \code{NULL}, then it will be a fixed order of variables. It can be a numeric vector or vector with names of variables.} -\item{B}{number of random paths} +\item{B}{number of random paths; works only if kernelshap=FALSE} + +\item{kernelshap}{indicates whether the kernelshap method should be used} \item{...}{other parameters like \code{label}, \code{predict_function}, \code{data}, \code{x}} } diff --git a/tests/testthat/test_predict_parts.R b/tests/testthat/test_predict_parts.R index 6686db7e6..560cca6df 100644 --- a/tests/testthat/test_predict_parts.R +++ b/tests/testthat/test_predict_parts.R @@ -8,6 +8,9 @@ test_that("data not provided",{ expect_error(predict_parts(explainer_wo_data, type = "break_down_interactions")) expect_error(predict_parts(explainer_wo_data, type = "shap")) expect_error(predict_parts(explainer_wo_data, type = "shap_aggregated")) + expect_error(predict_parts(explainer_wo_data, type = "kernel_shap")) + expect_error(predict_parts(explainer_wo_data, type = "kernel_shap_break_down")) + expect_error(predict_parts(explainer_wo_data, type = "kernel_shap_aggregated")) }) test_that("wrong type value",{ @@ -18,7 +21,10 @@ test_that("Wrong object class (not explainer)", { expect_error(predict_parts(list(1), type = "break_down")) expect_error(predict_parts(list(1), type = "break_down_interactions")) expect_error(predict_parts(list(1), type = "shap")) - expect_error(predict_parts(explainer_wo_data, type = "shap_aggregated")) + expect_error(predict_parts(list(1), type = "shap_aggregated")) + expect_error(predict_parts(list(1), type = "kernel_shap")) + expect_error(predict_parts(list(1), type = "kernel_shap_break_down")) + expect_error(predict_parts(list(1), type = "kernel_shap_aggregated")) }) test_that("Output format",{ @@ -34,6 +40,15 @@ test_that("Output format",{ pp_lm_agg_shap_set <- predict_parts(explainer_regr_lm, new_observation = new_apartments_set, type = "shap_aggregated") pp_ranger_agg_shap_set <- predict_parts(explainer_regr_ranger, new_observation = new_apartments_set, type = "shap_aggregated") + pp_lm_kernel_shap <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap") + pp_ranger_kernel_shap <- predict_parts(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap") + pp_lm_kernel_shap_break_down <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_break_down") + pp_ranger_kernel_shap_break_down <- predict_parts(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_break_down") + pp_lm_kernel_shap_aggregated <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_aggregated") + pp_ranger_kernel_shap_aggregated <- predict_parts(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_aggregated") + pp_lm_kernel_shap_aggregated_set <- predict_parts(explainer_regr_lm, new_observation = new_apartments_set, type = "kernel_shap_aggregated") + pp_ranger_kernel_shap_aggregated_set <- predict_parts(explainer_regr_ranger, new_observation = new_apartments_set, type = "kernel_shap_aggregated") + pp_lm_osc <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "oscillations") pp_ranger_osc <- predict_parts(explainer_regr_ranger, new_observation = new_apartments, type = "oscillations") pp_lm_osc_uni <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "oscillations_uni") @@ -52,6 +67,14 @@ test_that("Output format",{ expect_is(pp_ranger_agg_shap, c("shap_aggregated", 'predict_parts')) expect_is(pp_lm_agg_shap_set, c("shap_aggregated", 'predict_parts')) expect_is(pp_ranger_agg_shap_set, c("shap_aggregated", 'predict_parts')) + expect_is(pp_lm_kernel_shap, c("kernel_shap", 'predict_parts')) + expect_is(pp_ranger_kernel_shap, c("kernel_shap", 'predict_parts')) + expect_is(pp_lm_kernel_shap_break_down, c("kernel_shap_break_down", 'predict_parts')) + expect_is(pp_ranger_kernel_shap_break_down, c("kernel_shap_break_down", 'predict_parts')) + expect_is(pp_lm_kernel_shap_aggregated, c("kernel_shap_aggregated", 'predict_parts')) + expect_is(pp_ranger_kernel_shap_aggregated, c("kernel_shap_aggregated", 'predict_parts')) + expect_is(pp_lm_kernel_shap_aggregated_set, c("kernel_shap_aggregated", 'predict_parts')) + expect_is(pp_ranger_kernel_shap_aggregated_set, c("kernel_shap_aggregated", 'predict_parts')) expect_is(pp_lm_osc, c("oscillations", 'predict_parts')) expect_is(pp_ranger_osc, c("oscillations", 'predict_parts')) expect_is(pp_lm_osc_uni, c("oscillations_uni", 'predict_parts')) @@ -72,6 +95,15 @@ test_that("Output format - plot",{ pp_lm_agg_shap_set <- predict_parts(explainer_regr_lm, new_observation = new_apartments_set, type = "shap_aggregated") pp_ranger_agg_shap_set <- predict_parts(explainer_regr_ranger, new_observation = new_apartments_set, type = "shap_aggregated") + pp_lm_kernel_shap <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap") + pp_ranger_kernel_shap <- predict_parts(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap") + pp_lm_kernel_shap_break_down <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_break_down") + pp_ranger_kernel_shap_break_down <- predict_parts(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_break_down") + pp_lm_kernel_shap_aggregated <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_aggregated") + pp_ranger_kernel_shap_aggregated <- predict_parts(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_aggregated") + pp_lm_kernel_shap_aggregated_set <- predict_parts(explainer_regr_lm, new_observation = new_apartments_set, type = "kernel_shap_aggregated") + pp_ranger_kernel_shap_aggregated_set <- predict_parts(explainer_regr_ranger, new_observation = new_apartments_set, type = "kernel_shap_aggregated") + pp_lm_osc <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "oscillations") pp_ranger_osc <- predict_parts(explainer_regr_ranger, new_observation = new_apartments, type = "oscillations") pp_lm_osc_uni <- predict_parts(explainer_regr_lm, new_observation = new_apartments, type = "oscillations_uni") @@ -89,6 +121,16 @@ test_that("Output format - plot",{ expect_is(plot(pp_ranger_agg_shap), "gg") expect_is(plot(pp_lm_agg_shap_set), "gg") expect_is(plot(pp_ranger_agg_shap_set), "gg") + + expect_is(plot(pp_lm_kernel_shap), "gg") + expect_is(plot(pp_ranger_kernel_shap), "gg") + expect_is(plot(pp_lm_kernel_shap_break_down), "gg") + expect_is(plot(pp_ranger_kernel_shap_break_down), "gg") + expect_is(plot(pp_lm_kernel_shap_aggregated), "gg") + expect_is(plot(pp_ranger_kernel_shap_aggregated), "gg") + expect_is(plot(pp_lm_kernel_shap_aggregated_set), "gg") + expect_is(plot(pp_ranger_kernel_shap_aggregated_set), "gg") + expect_is(plot(pp_lm_osc), "gg") expect_is(plot(pp_ranger_osc), "gg") expect_is(plot(pp_lm_osc_uni), "gg") @@ -111,6 +153,15 @@ test_that("Output format",{ va_lm_agg_shap_set <- variable_attribution(explainer_regr_lm, new_observation = new_apartments_set, type = "shap_aggregated") va_ranger_agg_shap_set <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments_set, type = "shap_aggregated") + va_lm_kernel_shap <- variable_attribution(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap") + va_ranger_kernel_shap <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap") + va_lm_kernel_shap_break_down <- variable_attribution(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_break_down") + va_ranger_kernel_shap_break_down <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_break_down") + va_lm_kernel_shap_aggregated <- variable_attribution(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_aggregated") + va_ranger_kernel_shap_aggregated <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_aggregated") + va_lm_kernel_shap_aggregated_set <- variable_attribution(explainer_regr_lm, new_observation = new_apartments_set, type = "kernel_shap_aggregated") + va_ranger_kernel_shap_aggregated_set <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments_set, type = "kernel_shap_aggregated") + expect_is(va_lm_break_down, c("break_down", 'predict_parts')) expect_is(va_ranger_break_down, c("break_down", 'predict_parts')) expect_is(va_lm_ibreak_down, c("break_down", 'predict_parts')) @@ -121,6 +172,15 @@ test_that("Output format",{ expect_is(va_ranger_agg_shap, c("shap_aggregated", 'predict_parts')) expect_is(va_lm_agg_shap_set, c("shap_aggregated", 'predict_parts')) expect_is(va_ranger_agg_shap_set, c("shap_aggregated", 'predict_parts')) + + expect_is(va_lm_kernel_shap, c("kernel_shap", 'predict_parts')) + expect_is(va_ranger_kernel_shap, c("kernel_shap", 'predict_parts')) + expect_is(va_lm_kernel_shap_break_down, c("kernel_shap_break_down", 'predict_parts')) + expect_is(va_ranger_kernel_shap_break_down, c("kernel_shap_break_down", 'predict_parts')) + expect_is(va_lm_kernel_shap_aggregated, c("kernel_shap_aggregated", 'predict_parts')) + expect_is(va_ranger_kernel_shap_aggregated, c("kernel_shap_aggregated", 'predict_parts')) + expect_is(va_lm_kernel_shap_aggregated_set, c("kernel_shap_aggregated", 'predict_parts')) + expect_is(va_ranger_kernel_shap_aggregated_set, c("kernel_shap_aggregated", 'predict_parts')) }) test_that("Output format - plot",{ @@ -135,6 +195,15 @@ test_that("Output format - plot",{ va_lm_agg_shap_set <- variable_attribution(explainer_regr_lm, new_observation = new_apartments_set, type = "shap_aggregated") va_ranger_agg_shap_set <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments_set, type = "shap_aggregated") + va_lm_kernel_shap <- variable_attribution(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap") + va_ranger_kernel_shap <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap") + va_lm_kernel_shap_break_down <- variable_attribution(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_break_down") + va_ranger_kernel_shap_break_down <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_break_down") + va_lm_kernel_shap_aggregated <- variable_attribution(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_aggregated") + va_ranger_kernel_shap_aggregated <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_aggregated") + va_lm_kernel_shap_aggregated_set <- variable_attribution(explainer_regr_lm, new_observation = new_apartments_set, type = "kernel_shap_aggregated") + va_ranger_kernel_shap_aggregated_set <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments_set, type = "kernel_shap_aggregated") + expect_is(plot(va_ranger_break_down), "gg") expect_is(plot(va_ranger_break_down, va_lm_break_down), "gg") expect_is(plot(va_ranger_ibreak_down), "gg") @@ -145,6 +214,15 @@ test_that("Output format - plot",{ expect_is(plot(va_ranger_agg_shap), "gg") expect_is(plot(va_lm_agg_shap_set), "gg") expect_is(plot(va_ranger_agg_shap_set), "gg") + + expect_is(plot(va_lm_kernel_shap), "gg") + expect_is(plot(va_ranger_kernel_shap), "gg") + expect_is(plot(va_lm_kernel_shap_break_down), "gg") + expect_is(plot(va_ranger_kernel_shap_break_down), "gg") + expect_is(plot(va_lm_kernel_shap_aggregated), "gg") + expect_is(plot(va_ranger_kernel_shap_aggregated), "gg") + expect_is(plot(va_lm_kernel_shap_aggregated_set), "gg") + expect_is(plot(va_ranger_kernel_shap_aggregated_set), "gg") }) test_that("Output format - plot with subset",{ @@ -159,6 +237,15 @@ test_that("Output format - plot with subset",{ va_lm_agg_shap_set <- variable_attribution(explainer_regr_lm, new_observation = new_apartments_set, type = "shap_aggregated", N=200) va_ranger_agg_shap_set <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments_set, type = "shap_aggregated", N=200) + va_lm_kernel_shap <- variable_attribution(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap", N=200) + va_ranger_kernel_shap <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap", N=200) + va_lm_kernel_shap_break_down <- variable_attribution(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_break_down", N=200) + va_ranger_kernel_shap_break_down <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_break_down", N=200) + va_lm_kernel_shap_aggregated <- variable_attribution(explainer_regr_lm, new_observation = new_apartments, type = "kernel_shap_aggregated", N=200) + va_ranger_kernel_shap_aggregated <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments, type = "kernel_shap_aggregated", N=200) + va_lm_kernel_shap_aggregated_set <- variable_attribution(explainer_regr_lm, new_observation = new_apartments_set, type = "kernel_shap_aggregated", N=200) + va_ranger_kernel_shap_aggregated_set <- variable_attribution(explainer_regr_ranger, new_observation = new_apartments_set, type = "kernel_shap_aggregated", N=200) + expect_is(plot(va_ranger_break_down), "gg") expect_is(plot(va_ranger_break_down, va_lm_break_down), "gg") expect_is(plot(va_ranger_ibreak_down), "gg") @@ -169,4 +256,13 @@ test_that("Output format - plot with subset",{ expect_is(plot(va_ranger_agg_shap), "gg") expect_is(plot(va_lm_agg_shap_set), "gg") expect_is(plot(va_ranger_agg_shap_set), "gg") + + expect_is(plot(va_lm_kernel_shap), "gg") + expect_is(plot(va_ranger_kernel_shap), "gg") + expect_is(plot(va_lm_kernel_shap_break_down), "gg") + expect_is(plot(va_ranger_kernel_shap_break_down), "gg") + expect_is(plot(va_lm_kernel_shap_aggregated), "gg") + expect_is(plot(va_ranger_kernel_shap_aggregated), "gg") + expect_is(plot(va_lm_kernel_shap_aggregated_set), "gg") + expect_is(plot(va_ranger_kernel_shap_aggregated_set), "gg") })