From 9bae0d1542ef417676f419546228f4fde62fe9a9 Mon Sep 17 00:00:00 2001 From: Roel Visser Date: Mon, 16 Dec 2024 12:52:52 +0100 Subject: [PATCH] Add tests for isolation forest conversion --- tests/conftest.py | 21 ++++++++++++++++++- .../test_tree_explainer_conversion.py | 17 ++++++++++++++- .../test_tree_explainer_validate.py | 6 +++++- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6bef04ef..d1daad68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ import pytest from PIL import Image from sklearn.datasets import make_classification, make_regression -from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.ensemble import IsolationForest, RandomForestClassifier, RandomForestRegressor from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor @@ -142,6 +142,25 @@ def rf_clf_model() -> RandomForestClassifier: return model +# Isolationforest model +@pytest.fixture +def if_clf_model() -> IsolationForest: + n_samples, n_outliers = 120, 40 + rng = np.random.RandomState(0) + covariance = np.array([[0.5, -0.1], [0.7, 0.4]]) + cluster_1 = 0.4 * rng.randn(n_samples, 2) @ covariance + np.array([2, 2]) # general + cluster_2 = 0.3 * rng.randn(n_samples, 2) + np.array([-2, -2]) # spherical + outliers = rng.uniform(low=-4, high=4, size=(n_outliers, 2)) + + X = np.concatenate([cluster_1, cluster_2, outliers]) + y = np.concatenate([np.ones((2 * n_samples), dtype=int), -np.ones((n_outliers), dtype=int)]) + + # X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42) + model = IsolationForest(random_state=42, n_estimators=3) + model.fit(X, y) + return model + + @pytest.fixture def xgb_reg_model(): """Return a simple xgboost regression model.""" diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py index 2119b57e..cfb1b07b 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py @@ -5,7 +5,11 @@ from shapiq.explainer.tree.base import TreeModel from shapiq.explainer.tree.conversion.edges import create_edge_tree -from shapiq.explainer.tree.conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree +from shapiq.explainer.tree.conversion.sklearn import ( + convert_sklearn_forest, + convert_sklearn_isolation_forest, + convert_sklearn_tree, +) from shapiq.utils import safe_isinstance @@ -123,3 +127,14 @@ def test_skleanr_rf_conversion(rf_clf_model, rf_reg_model): assert isinstance(tree_model, list) assert safe_isinstance(tree_model[0], tree_model_class_path_str) assert tree_model[0].empty_prediction is not None + + +def test_sklearn_if_conversion(if_clf_model): + """Test the conversion of a scikit-learn isolation forest model.""" + tree_model_class_path_str = ["shapiq.explainer.tree.base.TreeModel"] + + # test the isolation forest model + tree_model = convert_sklearn_isolation_forest(if_clf_model) + assert isinstance(tree_model, list) + assert safe_isinstance(tree_model[0], tree_model_class_path_str) + assert tree_model[0].empty_prediction is not None diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py index e23ea81d..cc453044 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py @@ -7,7 +7,7 @@ from shapiq.explainer.tree.validation import validate_tree_model -def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model): +def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model, if_clf_model): """Test the validation of the model.""" class_path_str = ["shapiq.explainer.tree.base.TreeModel"] # sklearn dt models are supported @@ -20,6 +20,10 @@ def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model): for tree in tree_model: assert safe_isinstance(tree, class_path_str) tree_model = validate_tree_model(rf_reg_model) + for tree in tree_model: + assert safe_isinstance(tree, class_path_str) + # sklearn isolation forest is supported + tree_model = validate_tree_model(if_clf_model) for tree in tree_model: assert safe_isinstance(tree, class_path_str)