Skip to content

Commit

Permalink
Add tests for isolation forest conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
r-visser committed Dec 16, 2024
1 parent 4d83df4 commit 9bae0d1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
21 changes: 20 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from PIL import Image
from sklearn.datasets import make_classification, make_regression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.ensemble import IsolationForest, RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

Expand Down Expand Up @@ -142,6 +142,25 @@ def rf_clf_model() -> RandomForestClassifier:
return model


# Isolationforest model
@pytest.fixture
def if_clf_model() -> IsolationForest:
n_samples, n_outliers = 120, 40
rng = np.random.RandomState(0)
covariance = np.array([[0.5, -0.1], [0.7, 0.4]])
cluster_1 = 0.4 * rng.randn(n_samples, 2) @ covariance + np.array([2, 2]) # general
cluster_2 = 0.3 * rng.randn(n_samples, 2) + np.array([-2, -2]) # spherical
outliers = rng.uniform(low=-4, high=4, size=(n_outliers, 2))

X = np.concatenate([cluster_1, cluster_2, outliers])
y = np.concatenate([np.ones((2 * n_samples), dtype=int), -np.ones((n_outliers), dtype=int)])

# X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
model = IsolationForest(random_state=42, n_estimators=3)
model.fit(X, y)
return model


@pytest.fixture
def xgb_reg_model():
"""Return a simple xgboost regression model."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from shapiq.explainer.tree.base import TreeModel
from shapiq.explainer.tree.conversion.edges import create_edge_tree
from shapiq.explainer.tree.conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree
from shapiq.explainer.tree.conversion.sklearn import (
convert_sklearn_forest,
convert_sklearn_isolation_forest,
convert_sklearn_tree,
)
from shapiq.utils import safe_isinstance


Expand Down Expand Up @@ -123,3 +127,14 @@ def test_skleanr_rf_conversion(rf_clf_model, rf_reg_model):
assert isinstance(tree_model, list)
assert safe_isinstance(tree_model[0], tree_model_class_path_str)
assert tree_model[0].empty_prediction is not None


def test_sklearn_if_conversion(if_clf_model):
"""Test the conversion of a scikit-learn isolation forest model."""
tree_model_class_path_str = ["shapiq.explainer.tree.base.TreeModel"]

# test the isolation forest model
tree_model = convert_sklearn_isolation_forest(if_clf_model)
assert isinstance(tree_model, list)
assert safe_isinstance(tree_model[0], tree_model_class_path_str)
assert tree_model[0].empty_prediction is not None
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from shapiq.explainer.tree.validation import validate_tree_model


def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model):
def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model, if_clf_model):
"""Test the validation of the model."""
class_path_str = ["shapiq.explainer.tree.base.TreeModel"]
# sklearn dt models are supported
Expand All @@ -20,6 +20,10 @@ def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model):
for tree in tree_model:
assert safe_isinstance(tree, class_path_str)
tree_model = validate_tree_model(rf_reg_model)
for tree in tree_model:
assert safe_isinstance(tree, class_path_str)
# sklearn isolation forest is supported
tree_model = validate_tree_model(if_clf_model)
for tree in tree_model:
assert safe_isinstance(tree, class_path_str)

Expand Down

0 comments on commit 9bae0d1

Please sign in to comment.