Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

193 add isolation forrest conversion to treeexplainer #289

Merged
merged 18 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions shapiq/explainer/tree/conversion/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import numpy as np
from sklearn.ensemble._iforest import _average_path_length

from shapiq.utils import safe_isinstance
from shapiq.utils.types import Model
Expand Down Expand Up @@ -74,3 +75,97 @@ def convert_sklearn_tree(
empty_prediction=None, # compute empty prediction later
original_output_type=output_type,
)


def average_path_length(isolation_forest):
max_samples = isolation_forest._max_samples
average_path_length = _average_path_length(
[max_samples]
) # NOTE: _average_path_length func is equivalent to equation 1 in Isolation Forest paper Lui2008
return average_path_length


def convert_sklearn_isolation_forest(
tree_model: Model,
) -> list[TreeModel]:
"""Transforms a scikit-learn isolation forest to the format used by shapiq.

Args:
tree_model: The scikit-learn isolation forest model to convert.

Returns:
The converted isolation forest model.
"""
scaling = 1.0 / len(tree_model.estimators_)

return [
# convert_isolation_tree_shap_isotree(tree, features, scaling=scaling)
convert_isolation_tree(tree, features, scaling=scaling)
for tree, features in zip(tree_model.estimators_, tree_model.estimators_features_)
]


def convert_isolation_tree(
tree_model: Model,
tree_features,
class_label: Optional[int] = None,
scaling: float = 1.0,
average_path_length: float = 1.0, # TODO fix default value
) -> TreeModel:
"""Convert a scikit-learn decision tree to the format used by shapiq.

Args:
tree_model: The scikit-learn decision tree model to convert.
class_label: The class label of the model to explain. Only used for classification models.
Defaults to ``1``.
scaling: The scaling factor for the tree values.

Returns:
The converted decision tree model.
"""
output_type = "raw"
tree_values = tree_model.tree_.value.copy()
tree_values = tree_values.flatten()
features_updated, values_updated = isotree_value_traversal(
tree_model.tree_, tree_features, normalize=False, scaling=1.0
)
values_updated = values_updated * scaling
values_updated = values_updated.flatten()

return TreeModel(
children_left=tree_model.tree_.children_left,
children_right=tree_model.tree_.children_right,
features=features_updated,
thresholds=tree_model.tree_.threshold,
values=values_updated,
node_sample_weight=tree_model.tree_.weighted_n_node_samples,
empty_prediction=None, # compute empty prediction later
original_output_type=output_type,
)


def isotree_value_traversal(
tree, tree_features, normalize=False, scaling=1.0, data=None, data_missing=None
):
features = tree.feature.copy()
corrected_values = tree.value.copy()
if safe_isinstance(tree, "sklearn.tree._tree.Tree"):

def _recalculate_value(tree, i, level):
if tree.children_left[i] == -1 and tree.children_right[i] == -1:
value = level + _average_path_length(np.array([tree.n_node_samples[i]]))[0]
corrected_values[i, 0] = value
return value * tree.n_node_samples[i]
else:
value_left = _recalculate_value(tree, tree.children_left[i], level + 1)
value_right = _recalculate_value(tree, tree.children_right[i], level + 1)
corrected_values[i, 0] = (value_left + value_right) / tree.n_node_samples[i]
return value_left + value_right

_recalculate_value(tree, 0, 0)
if normalize:
corrected_values = (corrected_values.T / corrected_values.sum(1)).T
corrected_values = corrected_values * scaling
# re-number the features if each tree gets a different set of features
features = np.where(features >= 0, tree_features[features], features)
return features, corrected_values
3 changes: 2 additions & 1 deletion shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class TreeExplainer(Explainer):

def __init__(
self,
model: Union[dict, TreeModel, Any],
model: Union[dict, TreeModel, list, Any],
max_order: int = 2,
min_order: int = 1,
index: str = "k-SII",
Expand All @@ -61,6 +61,7 @@ def __init__(
# validate and parse model
validated_model = validate_tree_model(model, class_label=class_index)
self._trees: list[TreeModel] = copy.deepcopy(validated_model)
# TODO trees are made instance of list here, but in validation they are also but then converted back into single element if list is length 1
if not isinstance(self._trees, list):
self._trees = [self._trees]
self._n_trees = len(self._trees)
Expand Down
19 changes: 16 additions & 3 deletions shapiq/explainer/tree/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from .base import TreeModel
from .conversion.lightgbm import convert_lightgbm_booster
from .conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree
from .conversion.sklearn import (
convert_sklearn_forest,
convert_sklearn_isolation_forest,
convert_sklearn_tree,
)
from .conversion.xgboost import convert_xgboost_booster

SUPPORTED_MODELS = {
Expand All @@ -20,6 +24,8 @@
"sklearn.ensemble._forest.ExtraTreesClassifier",
"sklearn.ensemble.RandomForestRegressor",
"sklearn.ensemble._forest.RandomForestRegressor",
"sklearn.ensemble.IsolationForest",
"sklearn.ensemble._iforest.IsolationForest",
"lightgbm.sklearn.LGBMRegressor",
"lightgbm.sklearn.LGBMClassifier",
"lightgbm.basic.Booster",
Expand All @@ -42,8 +48,11 @@ def validate_tree_model(
# tree model (is already in the correct format)
if type(model).__name__ == "TreeModel":
tree_model = model
elif isinstance(model, list) and all([type(m).__name__ == "TreeModel" for m in model]):
tree_model = model
# direct return if list of tree models
elif type(model).__name__ == "list":
# check if all elements are TreeModel
if all([type(tree).__name__ == "TreeModel" for tree in model]):
tree_model = model
# dict as model is parsed to TreeModel (the dict needs to have the correct format and names)
elif type(model).__name__ == "dict":
tree_model = TreeModel(**model)
Expand All @@ -66,6 +75,10 @@ def validate_tree_model(
or safe_isinstance(model, "sklearn.ensemble._forest.ExtraTreesClassifier")
):
tree_model = convert_sklearn_forest(model, class_label=class_label)
elif safe_isinstance(model, "sklearn.ensemble.IsolationForest") or safe_isinstance(
model, "sklearn.ensemble._iforest.IsolationForest"
):
tree_model = convert_sklearn_isolation_forest(model)
elif safe_isinstance(model, "lightgbm.sklearn.LGBMRegressor") or safe_isinstance(
model, "lightgbm.sklearn.LGBMClassifier"
):
Expand Down
2 changes: 1 addition & 1 deletion shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def plot_force(
Returns:
The force plot as a matplotlib figure (if show is ``False``).
"""
from shapiq import force_plot
from .plot import force_plot

return force_plot(
self,
Expand Down
21 changes: 20 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from PIL import Image
from sklearn.datasets import make_classification, make_regression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.ensemble import IsolationForest, RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

Expand Down Expand Up @@ -142,6 +142,25 @@ def rf_clf_model() -> RandomForestClassifier:
return model


# Isolationforest model
@pytest.fixture
def if_clf_model() -> IsolationForest:
n_samples, n_outliers = 120, 40
rng = np.random.RandomState(0)
covariance = np.array([[0.5, -0.1], [0.7, 0.4]])
cluster_1 = 0.4 * rng.randn(n_samples, 2) @ covariance + np.array([2, 2]) # general
cluster_2 = 0.3 * rng.randn(n_samples, 2) + np.array([-2, -2]) # spherical
outliers = rng.uniform(low=-4, high=4, size=(n_outliers, 2))

X = np.concatenate([cluster_1, cluster_2, outliers])
y = np.concatenate([np.ones((2 * n_samples), dtype=int), -np.ones((n_outliers), dtype=int)])

# X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
model = IsolationForest(random_state=42, n_estimators=3)
model.fit(X, y)
return model


@pytest.fixture
def xgb_reg_model():
"""Return a simple xgboost regression model."""
Expand Down
26 changes: 26 additions & 0 deletions tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,29 @@ def test_xgboost_shap_error(xgb_clf_model, background_clf_data):

# now the values surprisingly are the same
assert np.allclose(sv_shap, sv_shapiq_rounded_values, rtol=1e-5)


def test_iso_forest_shap(if_clf_model):
"""Tests the shapiq implementation of TreeSHAP vs. SHAP's implementation for Isolation Forest."""

x_explain = np.array([0.125, 0.05])

# the following code is used to get the shap values from the SHAP implementation
# import shap
# model_copy = copy.deepcopy(if_clf_model)
# explainer_shap = shap.TreeExplainer(model=model_copy)
# baseline_shap = float(explainer_shap.expected_value)
# sv_shap = explainer_shap.shap_values(x_explain)
# print(sv_shap)
# print(baseline_shap)
sv_shap = np.array([-2.34951688, -4.55545493])
baseline_shap = 12.238305148044713

# compute with shapiq
explainer_shapiq = TreeExplainer(model=if_clf_model, max_order=1, index="SV")
sv_shapiq = explainer_shapiq.explain(x=x_explain)
sv_shapiq_values = sv_shapiq.get_n_order_values(1)
baseline_shapiq = sv_shapiq.baseline_value

assert baseline_shap == pytest.approx(baseline_shapiq, rel=1e-6)
assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from shapiq.explainer.tree.base import TreeModel
from shapiq.explainer.tree.conversion.edges import create_edge_tree
from shapiq.explainer.tree.conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree
from shapiq.explainer.tree.conversion.sklearn import (
convert_sklearn_forest,
convert_sklearn_isolation_forest,
convert_sklearn_tree,
)
from shapiq.utils import safe_isinstance


Expand Down Expand Up @@ -123,3 +127,14 @@ def test_skleanr_rf_conversion(rf_clf_model, rf_reg_model):
assert isinstance(tree_model, list)
assert safe_isinstance(tree_model[0], tree_model_class_path_str)
assert tree_model[0].empty_prediction is not None


def test_sklearn_if_conversion(if_clf_model):
"""Test the conversion of a scikit-learn isolation forest model."""
tree_model_class_path_str = ["shapiq.explainer.tree.base.TreeModel"]

# test the isolation forest model
tree_model = convert_sklearn_isolation_forest(if_clf_model)
assert isinstance(tree_model, list)
assert safe_isinstance(tree_model[0], tree_model_class_path_str)
assert tree_model[0].empty_prediction is not None
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from shapiq.explainer.tree.validation import validate_tree_model


def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model):
def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model, if_clf_model):
"""Test the validation of the model."""
class_path_str = ["shapiq.explainer.tree.base.TreeModel"]
# sklearn dt models are supported
Expand All @@ -20,6 +20,10 @@ def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model):
for tree in tree_model:
assert safe_isinstance(tree, class_path_str)
tree_model = validate_tree_model(rf_reg_model)
for tree in tree_model:
assert safe_isinstance(tree, class_path_str)
# sklearn isolation forest is supported
tree_model = validate_tree_model(if_clf_model)
for tree in tree_model:
assert safe_isinstance(tree, class_path_str)

Expand Down
Loading