From c32bf0ca46a1d3ebfdd7c4dda162626e3a889a90 Mon Sep 17 00:00:00 2001 From: mayer79 Date: Wed, 8 Jun 2022 09:50:54 +0200 Subject: [PATCH 1/3] potential new LGB predict() interface --- R/SHAP_funcs.R | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/R/SHAP_funcs.R b/R/SHAP_funcs.R index eb156df..5c1e3bb 100644 --- a/R/SHAP_funcs.R +++ b/R/SHAP_funcs.R @@ -38,9 +38,12 @@ shap.values <- function(xgb_model, X_train){ - shap_contrib <- predict(xgb_model, - (X_train), - predcontrib = TRUE) + # New predict() interface for LGB 4 + if (inherits(xgb_model, "lgb.Booster") && utils::packageVersion("lightgbm") >= 4) { + shap_contrib <- predict(xgb_model, X_train, type = "contrib") + } else { + shap_contrib <- predict(xgb_model, X_train, predcontrib = TRUE) + } # Add colnames if not already there (required for LightGBM) if (is.null(colnames(shap_contrib))) { From e253528fccc577c25aec0bbc47b21628e404100c Mon Sep 17 00:00:00 2001 From: mayer79 Date: Wed, 8 Jun 2022 09:56:13 +0200 Subject: [PATCH 2/3] updated description and news --- DESCRIPTION | 1 + NEWS.md | 1 + 2 files changed, 2 insertions(+) diff --git a/DESCRIPTION b/DESCRIPTION index 3b6ac0a..6c5c220 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,6 +25,7 @@ VignetteBuilder: knitr Imports: stats, + utils, ggplot2 (>= 3.0.0), xgboost (>= 0.81.0.0), data.table (>= 1.12.0), diff --git a/NEWS.md b/NEWS.md index 9fa942f..5b49e36 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # SHAPforxgboost 0.1.2 +* 08/06/2022 Comply to LightGBM's 4.0 `predict()` interface. * 17/05/2022 Added option `kind = "bar"` to `shap.plot.summary()`. # SHAPforxgboost 0.1.1 From 56d96f608173120ef80654cb1a780f641c0388b3 Mon Sep 17 00:00:00 2001 From: mayer79 Date: Tue, 19 Jul 2022 19:48:51 +0200 Subject: [PATCH 3/3] Better condition --- R/SHAP_funcs.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/SHAP_funcs.R b/R/SHAP_funcs.R index 5c1e3bb..89bc662 100644 --- a/R/SHAP_funcs.R +++ b/R/SHAP_funcs.R @@ -38,8 +38,9 @@ shap.values <- function(xgb_model, X_train){ - # New predict() interface for LGB 4 - if (inherits(xgb_model, "lgb.Booster") && utils::packageVersion("lightgbm") >= 4) { + # New predict() interface for LGB > 3.3.2 + new_lgb <- utils::packageVersion("lightgbm") > package_version("3.3.2") + if (inherits(xgb_model, "lgb.Booster") && new_lgb) { shap_contrib <- predict(xgb_model, X_train, type = "contrib") } else { shap_contrib <- predict(xgb_model, X_train, predcontrib = TRUE)