Skip to content

Commit

Permalink
candidate fix for #495
Browse files Browse the repository at this point in the history
  • Loading branch information
pbiecek committed May 20, 2022
1 parent f41ad09 commit 5758c22
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 10 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: DALEX
Title: moDel Agnostic Language for Exploration and eXplanation
Version: 2.4.0.9001
Version: 2.4.1
Authors@R: c(person("Przemyslaw", "Biecek", email = "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-8423-1823")),
person("Szymon", "Maksymiuk", role = "aut",
Expand All @@ -22,7 +22,7 @@ Description: Any unverified black box model is the path to failure. Opaqueness l
License: GPL
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.1.2
RoxygenNote: 7.2.0
Depends: R (>= 3.5)
Imports:
ggplot2,
Expand All @@ -31,6 +31,7 @@ Imports:
Suggests:
gower,
ranger,
yardstick,
testthat,
methods
URL: https://modeloriented.github.io/DALEX, https://dalex.drwhy.ai
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export(loss_default)
export(loss_one_minus_auc)
export(loss_root_mean_square)
export(loss_sum_of_squares)
export(loss_yardstick)
export(model_diagnostics)
export(model_info)
export(model_parts)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
DALEX (development)
DALEX 2.4.1
---------------------------------------------------------------
* changed URLs in the DESCRIPTION as requested in ([#484](https://github.com/ModelOriented/DALEX/issues/484))
* Fix model_info documentation ([#498](https://github.com/ModelOriented/DALEX/issues/498))
* Support for yardstic metrics ([#495](https://github.com/ModelOriented/DALEX/issues/495))

DALEX 2.4.0
---------------------------------------------------------------
Expand Down
19 changes: 12 additions & 7 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,18 @@ explain.default <- function(model, data = NULL, y = NULL, predict_function = NUL
# REPORT: checks for residual_function
if (is.null(residual_function)) {
# residual_function not specified
# try the default
if (!is.null(predict_function) & model_info$type != "multiclass") {
residual_function <- residual_function_default
verbose_cat(" -> residual function : difference between y and yhat (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose)
} else if (!is.null(predict_function) & model_info$type == "multiclass") {
residual_function <- residual_function_multiclass
verbose_cat(" -> residual function : difference between 1 and probability of true class (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose)

# if y_hat is not numeric, then do not calculate residuals
# calculate only if y_hat is NULL or numeric
if (is.null(y_hat) | is.numeric(y_hat)) {
# try the default
if (!is.null(predict_function) & model_info$type != "multiclass") {
residual_function <- residual_function_default
verbose_cat(" -> residual function : difference between y and yhat (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose)
} else if (!is.null(predict_function) & model_info$type == "multiclass") {
residual_function <- residual_function_multiclass
verbose_cat(" -> residual function : difference between 1 and probability of true class (",color_codes$yellow_start,"default",color_codes$yellow_end,")\n", verbose = verbose)
}
}
} else {
if (!"function" %in% class(residual_function)) {
Expand Down
73 changes: 73 additions & 0 deletions R/misc_loss_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,76 @@ loss_default <- function(x) {
)
}




#' Wrapper for Loss Functions from the Yarstick Package
#'
#' The yardstick package provides many auxiliary functions for calculating
#' the predictive performance of the model. However, they have an interface
#' that is consistent with the tidyverse philosophy. The loss_yardstick
#' function adapts loss functions from the yardstick package to functions
#' understood by DALEX. Type compatibility for y-values and for predictions
#' must be guaranteed by the user.
#'
#' @param loss loss function from the yardstick package
#'
#' @return loss function that can be used in the model_parts function
#'
#' @export
#' @examples
#' \donttest{
#' # Classification Metrics
#' # y and y_hat are factors!!!
#' library("yardstick")
#'
#' titanic_glm_model <- glm(survived~., data = titanic_imputed, family = "binomial")
#' explainer_glm <- DALEX::explain(titanic_glm_model,
#' data = titanic_imputed[,-8],
#' y = factor(titanic_imputed$survived),
#' predict_function = function(m, x) {
#' factor((predict(m, x, type = "response") > 0.5) + 0)
#' })
#'
#' model_parts_accuracy <- model_parts(explainer_glm, type = "raw",
#' loss_function = loss_yardstick(accuracy))
#' plot(model_parts_accuracy)
#'
#' # Class Probability Metrics
#' # y is a factor while y_hat is a numeric!!!
#'
#' titanic_glm_model <- glm(survived~., data = titanic_imputed, family = "binomial")
#' explainer_glm <- DALEX::explain(titanic_glm_model,
#' data = titanic_imputed[,-8],
#' y = factor(titanic_imputed$survived))
#'
#' model_parts_accuracy <- model_parts(explainer_glm, type = "raw",
#' loss_function = loss_yardstick(roc_auc))
#' plot(model_parts_accuracy)
#'
#' # Regression Metrics
#' # y and y_hat are numeric!!!
#'
#' library("ranger")
#' apartments_ranger <- ranger(m2.price~., data = apartments, num.trees = 50)
#' explainer_ranger <- DALEX::explain(apartments_ranger, data = apartments[,-1],
#' y = apartments$m2.price, label = "Ranger Apartments")
#' model_parts_ranger <- model_parts(explainer_ranger, type = "raw",
#' loss_function = loss_yardstick(rsq))
#' plot(model_parts_ranger)
#'
#' }
#'
#' @rdname loss_yardstick
#' @export
loss_yardstick <- function(loss) {
# wrapper for yardstick loss functions
custom_loss <- function(observed, predicted) {
df <- data.frame(observed, predicted)
loss(df, observed, predicted)$.estimate
}
attr(custom_loss, "loss_name") <- deparse(substitute(loss))
custom_loss
}


66 changes: 66 additions & 0 deletions man/loss_yardstick.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5758c22

Please sign in to comment.