Skip to content

Commit

Permalink
add statsmodels comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrenodet committed Feb 21, 2025
1 parent ddcaaf7 commit f7fe2d1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ extra-args = [
extra-dependencies = [
"mislabeled[examples]",
"pytest-benchmark",
"scipy>=1.15.0"
"scipy>=1.15.0",
"statsmodels"
]

[tool.coverage.run]
Expand Down
30 changes: 29 additions & 1 deletion tests/probe/test_self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
from sklearn.metrics import log_loss, mean_squared_error
from sklearn.model_selection import LeaveOneOut
from sklearn.preprocessing import LabelBinarizer, StandardScaler
from statsmodels.genmod import families
from statsmodels.genmod.generalized_linear_model import GLM

from mislabeled.probe._influence import ALOO, SelfInfluence
from mislabeled.probe import ALOO, SelfInfluence, linearize


@pytest.mark.parametrize(
Expand Down Expand Up @@ -124,3 +126,29 @@ def eval(model, X, y, train, test):
1,
abs_tol=0.005 if close_form else 0.2,
)


@pytest.mark.parametrize(
"model", [LogisticRegression(fit_intercept=False, penalty=None)]
)
@pytest.mark.parametrize("num_classes", [2])
def test_aloo_against_statmodels(model, num_classes):
X, y = make_blobs(n_samples=30, random_state=1, centers=num_classes)

X = StandardScaler().fit_transform(X)

model.fit(X, y)

aloo = ALOO()

res = GLM(y, X, family=families.Binomial()).fit()
model.coef_ = res.params.reshape(1, -1)

aloo_scores = aloo(model, X, y)

np.testing.assert_allclose(
aloo_scores, -2 * res.get_influence(observed=True).cooks_distance[0]
)
np.testing.assert_allclose(
linearize(model, X, y)[0].hessian(X, y), -res.model.hessian(res.params)
)

0 comments on commit f7fe2d1

Please sign in to comment.