Skip to content

Commit

Permalink
Lint and format
Browse files Browse the repository at this point in the history
  • Loading branch information
r-visser committed Dec 16, 2024
1 parent 9eb6ba4 commit 4d83df4
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 27 deletions.
35 changes: 24 additions & 11 deletions check_shap.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""This script checks if the SV computed for an IsoForest model are the same for shap and shapiq."""

import copy
from dataclasses import dataclass

import numpy as np
import shap
import shapiq
from sklearn.model_selection import train_test_split
from sklearn.ensemble import IsolationForest
from sklearn.model_selection import train_test_split

import shapiq


def generate_random_perturb_features(settings):
Expand Down Expand Up @@ -46,8 +48,9 @@ def generate_random_perturb_features(settings):
data = np.vstack(clusters)

# Function to perturb a subset of features to create outliers
def create_outliers(data, n_outliers, n_features, n_perturb_features,
always_perturb_same_features):
def create_outliers(
data, n_outliers, n_features, n_perturb_features, always_perturb_same_features
):
original_indices = rng.choice(len(data), n_outliers, replace=False)
outliers = data[original_indices].copy()

Expand All @@ -62,8 +65,9 @@ def create_outliers(data, n_outliers, n_features, n_perturb_features,
return outliers, original_indices

# Create outliers and get the indices of the original points
outliers, original_indices = create_outliers(data, n_outliers, n_features, n_perturb_features,
always_perturb_same_features)
outliers, original_indices = create_outliers(
data, n_outliers, n_features, n_perturb_features, always_perturb_same_features
)

# Store original samples and perturbed samples separately
original_samples = data[original_indices].copy()
Expand All @@ -84,12 +88,21 @@ def create_outliers(data, n_outliers, n_features, n_perturb_features,
# print("Final Data:\n", final_data)

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(final_data, labels, test_size=0.33,
random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
final_data, labels, test_size=0.33, random_state=42
)

# Create a mapping between original points and their perturbed versions
original_to_perturbed = {i: i for i in range(n_outliers)}
return X_train, X_test, y_train, y_test, original_samples, perturbed_samples, original_to_perturbed
return (
X_train,
X_test,
y_train,
y_test,
original_samples,
perturbed_samples,
original_to_perturbed,
)


@dataclass
Expand All @@ -104,7 +117,7 @@ class SyntheticOutlierInlierSettings:
random_state: int = None


if __name__ == '__main__':
if __name__ == "__main__":

# create data
settings = SyntheticOutlierInlierSettings(
Expand All @@ -114,7 +127,7 @@ class SyntheticOutlierInlierSettings:
n_features=12,
n_perturb_features=2,
always_perturb_same_features=True,
random_state=0
random_state=0,
)
d = generate_random_perturb_features(settings)
X_train, X_test, y_train, y_test, original_samples, perturbed_samples, original_to_perturbed = d
Expand Down
36 changes: 25 additions & 11 deletions shapiq/explainer/tree/conversion/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Optional

import numpy as np

from sklearn.ensemble._iforest import _average_path_length

from shapiq.utils import safe_isinstance
Expand Down Expand Up @@ -33,6 +32,7 @@ def convert_sklearn_forest(
for tree in tree_model.estimators_
]


def convert_sklearn_tree(
tree_model: Model, class_label: Optional[int] = None, scaling: float = 1.0
) -> TreeModel:
Expand Down Expand Up @@ -76,11 +76,15 @@ def convert_sklearn_tree(
original_output_type=output_type,
)


def average_path_length(isolation_forest):
max_samples = isolation_forest._max_samples
average_path_length = _average_path_length([max_samples]) # NOTE: _average_path_length func is equivalent to equation 1 in Isolation Forest paper Lui2008
average_path_length = _average_path_length(
[max_samples]
) # NOTE: _average_path_length func is equivalent to equation 1 in Isolation Forest paper Lui2008
return average_path_length


def convert_sklearn_isolation_forest(
tree_model: Model,
) -> list[TreeModel]:
Expand All @@ -100,8 +104,13 @@ def convert_sklearn_isolation_forest(
for tree, features in zip(tree_model.estimators_, tree_model.estimators_features_)
]


def convert_isolation_tree(
tree_model: Model, tree_features, class_label: Optional[int] = None, scaling: float = 1.0, average_path_length: float = 1.0 # TODO fix default value
tree_model: Model,
tree_features,
class_label: Optional[int] = None,
scaling: float = 1.0,
average_path_length: float = 1.0, # TODO fix default value
) -> TreeModel:
"""Convert a scikit-learn decision tree to the format used by shapiq.
Expand All @@ -117,7 +126,9 @@ def convert_isolation_tree(
output_type = "raw"
tree_values = tree_model.tree_.value.copy()
tree_values = tree_values.flatten()
features_updated, values_updated = isotree_value_traversal(tree_model.tree_, tree_features, normalize=False, scaling=1.0)
features_updated, values_updated = isotree_value_traversal(
tree_model.tree_, tree_features, normalize=False, scaling=1.0
)
values_updated = values_updated * scaling
values_updated = values_updated.flatten()

Expand All @@ -132,20 +143,23 @@ def convert_isolation_tree(
original_output_type=output_type,
)

def isotree_value_traversal(tree, tree_features, normalize=False, scaling=1.0, data=None, data_missing=None):

def isotree_value_traversal(
tree, tree_features, normalize=False, scaling=1.0, data=None, data_missing=None
):
features = tree.feature.copy()
corrected_values = tree.value.copy()
if safe_isinstance(tree, "sklearn.tree._tree.Tree"):

def _recalculate_value(tree, i , level):
def _recalculate_value(tree, i, level):
if tree.children_left[i] == -1 and tree.children_right[i] == -1:
value = level + _average_path_length(np.array([tree.n_node_samples[i]]))[0]
corrected_values[i, 0] = value
corrected_values[i, 0] = value
return value * tree.n_node_samples[i]
else:
value_left = _recalculate_value(tree, tree.children_left[i] , level + 1)
value_right = _recalculate_value(tree, tree.children_right[i] , level + 1)
corrected_values[i, 0] = (value_left + value_right) / tree.n_node_samples[i]
value_left = _recalculate_value(tree, tree.children_left[i], level + 1)
value_right = _recalculate_value(tree, tree.children_right[i], level + 1)
corrected_values[i, 0] = (value_left + value_right) / tree.n_node_samples[i]
return value_left + value_right

_recalculate_value(tree, 0, 0)
Expand All @@ -154,4 +168,4 @@ def _recalculate_value(tree, i , level):
corrected_values = corrected_values * scaling
# re-number the features if each tree gets a different set of features
features = np.where(features >= 0, tree_features[features], features)
return features, corrected_values
return features, corrected_values
11 changes: 7 additions & 4 deletions shapiq/explainer/tree/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from .base import TreeModel
from .conversion.lightgbm import convert_lightgbm_booster
from .conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree, convert_sklearn_isolation_forest
from .conversion.sklearn import (
convert_sklearn_forest,
convert_sklearn_isolation_forest,
convert_sklearn_tree,
)
from .conversion.xgboost import convert_xgboost_booster

SUPPORTED_MODELS = {
Expand Down Expand Up @@ -71,9 +75,8 @@ def validate_tree_model(
or safe_isinstance(model, "sklearn.ensemble._forest.ExtraTreesClassifier")
):
tree_model = convert_sklearn_forest(model, class_label=class_label)
elif (
safe_isinstance(model, "sklearn.ensemble.IsolationForest")
or safe_isinstance(model, "sklearn.ensemble._iforest.IsolationForest")
elif safe_isinstance(model, "sklearn.ensemble.IsolationForest") or safe_isinstance(
model, "sklearn.ensemble._iforest.IsolationForest"
):
tree_model = convert_sklearn_isolation_forest(model)
elif safe_isinstance(model, "lightgbm.sklearn.LGBMRegressor") or safe_isinstance(
Expand Down
1 change: 0 additions & 1 deletion shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,6 @@ def plot_force(
Returns:
The force plot as a matplotlib figure (if show is ``False``).
"""
# from shapiq import force_plot
from .plot import force_plot

return force_plot(
Expand Down

0 comments on commit 4d83df4

Please sign in to comment.