From 5758c2261d96c5e7f2419f11fe5afc528f024a27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Biecek?= Date: Sat, 21 May 2022 00:46:59 +0200 Subject: [PATCH] candidate fix for https://github.com/ModelOriented/DALEX/issues/495 --- DESCRIPTION | 5 +-- NAMESPACE | 1 + NEWS.md | 3 +- R/explain.R | 19 +++++++---- R/misc_loss_functions.R | 73 +++++++++++++++++++++++++++++++++++++++++ man/loss_yardstick.Rd | 66 +++++++++++++++++++++++++++++++++++++ 6 files changed, 157 insertions(+), 10 deletions(-) create mode 100644 man/loss_yardstick.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 290b77129..02637a96d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: DALEX Title: moDel Agnostic Language for Exploration and eXplanation -Version: 2.4.0.9001 +Version: 2.4.1 Authors@R: c(person("Przemyslaw", "Biecek", email = "przemyslaw.biecek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-8423-1823")), person("Szymon", "Maksymiuk", role = "aut", @@ -22,7 +22,7 @@ Description: Any unverified black box model is the path to failure. Opaqueness l License: GPL Encoding: UTF-8 LazyData: true -RoxygenNote: 7.1.2 +RoxygenNote: 7.2.0 Depends: R (>= 3.5) Imports: ggplot2, @@ -31,6 +31,7 @@ Imports: Suggests: gower, ranger, + yardstick, testthat, methods URL: https://modeloriented.github.io/DALEX, https://dalex.drwhy.ai diff --git a/NAMESPACE b/NAMESPACE index f0a1ec528..b78509d78 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -59,6 +59,7 @@ export(loss_default) export(loss_one_minus_auc) export(loss_root_mean_square) export(loss_sum_of_squares) +export(loss_yardstick) export(model_diagnostics) export(model_info) export(model_parts) diff --git a/NEWS.md b/NEWS.md index 86b5d161e..658334ed0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,8 @@ -DALEX (development) +DALEX 2.4.1 --------------------------------------------------------------- * changed URLs in the DESCRIPTION as requested in ([#484](https://github.com/ModelOriented/DALEX/issues/484)) * Fix model_info documentation ([#498](https://github.com/ModelOriented/DALEX/issues/498)) +* Support for yardstic metrics ([#495](https://github.com/ModelOriented/DALEX/issues/495)) DALEX 2.4.0 --------------------------------------------------------------- diff --git a/R/explain.R b/R/explain.R index 46ee77699..bb8e41413 100644 --- a/R/explain.R +++ b/R/explain.R @@ -330,13 +330,18 @@ explain.default <- function(model, data = NULL, y = NULL, predict_function = NUL # REPORT: checks for residual_function if (is.null(residual_function)) { # residual_function not specified - # try the default - if (!is.null(predict_function) & model_info$type != "multiclass") { - residual_function <- residual_function_default - verbose_cat(" -> residual function : difference between y and yhat (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose) - } else if (!is.null(predict_function) & model_info$type == "multiclass") { - residual_function <- residual_function_multiclass - verbose_cat(" -> residual function : difference between 1 and probability of true class (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose) + + # if y_hat is not numeric, then do not calculate residuals + # calculate only if y_hat is NULL or numeric + if (is.null(y_hat) | is.numeric(y_hat)) { + # try the default + if (!is.null(predict_function) & model_info$type != "multiclass") { + residual_function <- residual_function_default + verbose_cat(" -> residual function : difference between y and yhat (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose) + } else if (!is.null(predict_function) & model_info$type == "multiclass") { + residual_function <- residual_function_multiclass + verbose_cat(" -> residual function : difference between 1 and probability of true class (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose) + } } } else { if (!"function" %in% class(residual_function)) { diff --git a/R/misc_loss_functions.R b/R/misc_loss_functions.R index cbe566bc9..dcb41ed36 100644 --- a/R/misc_loss_functions.R +++ b/R/misc_loss_functions.R @@ -77,3 +77,76 @@ loss_default <- function(x) { ) } + + + +#' Wrapper for Loss Functions from the Yarstick Package +#' +#' The yardstick package provides many auxiliary functions for calculating +#' the predictive performance of the model. However, they have an interface +#' that is consistent with the tidyverse philosophy. The loss_yardstick +#' function adapts loss functions from the yardstick package to functions +#' understood by DALEX. Type compatibility for y-values and for predictions +#' must be guaranteed by the user. +#' +#' @param loss loss function from the yardstick package +#' +#' @return loss function that can be used in the model_parts function +#' +#' @export +#' @examples +#' \donttest{ +#' # Classification Metrics +#' # y and y_hat are factors!!! +#' library("yardstick") +#' +#' titanic_glm_model <- glm(survived~., data = titanic_imputed, family = "binomial") +#' explainer_glm <- DALEX::explain(titanic_glm_model, +#' data = titanic_imputed[,-8], +#' y = factor(titanic_imputed$survived), +#' predict_function = function(m, x) { +#' factor((predict(m, x, type = "response") > 0.5) + 0) +#' }) +#' +#' model_parts_accuracy <- model_parts(explainer_glm, type = "raw", +#' loss_function = loss_yardstick(accuracy)) +#' plot(model_parts_accuracy) +#' +#' # Class Probability Metrics +#' # y is a factor while y_hat is a numeric!!! +#' +#' titanic_glm_model <- glm(survived~., data = titanic_imputed, family = "binomial") +#' explainer_glm <- DALEX::explain(titanic_glm_model, +#' data = titanic_imputed[,-8], +#' y = factor(titanic_imputed$survived)) +#' +#' model_parts_accuracy <- model_parts(explainer_glm, type = "raw", +#' loss_function = loss_yardstick(roc_auc)) +#' plot(model_parts_accuracy) +#' +#' # Regression Metrics +#' # y and y_hat are numeric!!! +#' +#' library("ranger") +#' apartments_ranger <- ranger(m2.price~., data = apartments, num.trees = 50) +#' explainer_ranger <- DALEX::explain(apartments_ranger, data = apartments[,-1], +#' y = apartments$m2.price, label = "Ranger Apartments") +#' model_parts_ranger <- model_parts(explainer_ranger, type = "raw", +#' loss_function = loss_yardstick(rsq)) +#' plot(model_parts_ranger) +#' +#' } +#' +#' @rdname loss_yardstick +#' @export +loss_yardstick <- function(loss) { + # wrapper for yardstick loss functions + custom_loss <- function(observed, predicted) { + df <- data.frame(observed, predicted) + loss(df, observed, predicted)$.estimate + } + attr(custom_loss, "loss_name") <- deparse(substitute(loss)) + custom_loss +} + + diff --git a/man/loss_yardstick.Rd b/man/loss_yardstick.Rd new file mode 100644 index 000000000..262d60711 --- /dev/null +++ b/man/loss_yardstick.Rd @@ -0,0 +1,66 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc_loss_functions.R +\name{loss_yardstick} +\alias{loss_yardstick} +\title{Wrapper for Loss Functions from the Yarstick Package} +\usage{ +loss_yardstick(loss) +} +\arguments{ +\item{loss}{loss function from the yardstick package} +} +\value{ +loss function that can be used in the model_parts function +} +\description{ +The yardstick package provides many auxiliary functions for calculating +the predictive performance of the model. However, they have an interface +that is consistent with the tidyverse philosophy. The loss_yardstick +function adapts loss functions from the yardstick package to functions +understood by DALEX. Type compatibility for y-values and for predictions +must be guaranteed by the user. +} +\examples{ + \donttest{ + # Classification Metrics + # y and y_hat are factors!!! + library("yardstick") + + titanic_glm_model <- glm(survived~., data = titanic_imputed, family = "binomial") + explainer_glm <- DALEX::explain(titanic_glm_model, + data = titanic_imputed[,-8], + y = factor(titanic_imputed$survived), + predict_function = function(m, x) { + factor((predict(m, x, type = "response") > 0.5) + 0) + }) + +model_parts_accuracy <- model_parts(explainer_glm, type = "raw", + loss_function = loss_yardstick(accuracy)) +plot(model_parts_accuracy) + +# Class Probability Metrics +# y is a factor while y_hat is a numeric!!! + + titanic_glm_model <- glm(survived~., data = titanic_imputed, family = "binomial") + explainer_glm <- DALEX::explain(titanic_glm_model, + data = titanic_imputed[,-8], + y = factor(titanic_imputed$survived)) + +model_parts_accuracy <- model_parts(explainer_glm, type = "raw", + loss_function = loss_yardstick(roc_auc)) +plot(model_parts_accuracy) + +# Regression Metrics +# y and y_hat are numeric!!! + +library("ranger") +apartments_ranger <- ranger(m2.price~., data = apartments, num.trees = 50) +explainer_ranger <- DALEX::explain(apartments_ranger, data = apartments[,-1], + y = apartments$m2.price, label = "Ranger Apartments") +model_parts_ranger <- model_parts(explainer_ranger, type = "raw", + loss_function = loss_yardstick(rsq)) +plot(model_parts_ranger) + +} + +}