Skip to content

Commit

Permalink
Merge branch 'master' into check-regul-scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrenodet authored Feb 24, 2025
2 parents f7fe2d1 + fc8781c commit a8c6b16
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 9 deletions.
6 changes: 3 additions & 3 deletions mislabeled/detect/detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ProgressiveEnsemble,
)
from mislabeled.probe import (
ALOO,
ApproximateLOO,
Confidence,
CrossEntropy,
FiniteDiffSensitivity,
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, base_model):
)


class ALOODetector(ModelProbingDetector):
class ApproximateLOODetector(ModelProbingDetector):
"""Detector based on Approximate LeaveOneOut for GLM models.
References
Expand All @@ -85,7 +85,7 @@ def __init__(self, base_model):
super().__init__(
base_model=base_model,
ensemble=NoEnsemble(),
probe=ALOO(),
probe=ApproximateLOO(),
aggregate="sum",
)

Expand Down
4 changes: 2 additions & 2 deletions mislabeled/probe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ._complexity import ParameterCount, ParamNorm2
from ._grads import GradSimilarity
from ._influence import (
ALOO,
ApproximateLOO,
GradNorm2,
Representer,
SelfInfluence,
Expand Down Expand Up @@ -47,7 +47,7 @@
"ParameterCount",
"ParamNorm2",
"SelfInfluence",
"ALOO",
"ApproximateLOO",
"Representer",
"GradNorm2",
"GradSimilarity",
Expand Down
2 changes: 1 addition & 1 deletion mislabeled/probe/_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __call__(self, estimator, X, y):
return self_influence


class ALOO(Maximize):
class ApproximateLOO(Maximize):
def __init__(self):
pass

Expand Down
1 change: 1 addition & 0 deletions mislabeled/probe/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def pseudo(self, X):
return X

def variance(self, p):
# variance of the GLM link function
if self.loss == "l2":
return np.eye(self.out_dim)[None, :, :] * np.ones(p.shape[0])[:, None, None]
elif self.loss == "log_loss":
Expand Down
2 changes: 1 addition & 1 deletion tests/probe/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def f(prc):
print(np.round(jacobian(vectorized_objective, packed_raveled_coef).df, 2))

# I dont know why the gradient should not take into account the regul
# to compute ALOO and SelfInfluence ...
# to compute ApproximateLOO and SelfInfluence ...
# np.testing.assert_allclose(
# linearized.grad_p(X, y).sum(axis=0),
# jacobian(vectorized_objective, packed_raveled_coef).df,
Expand Down
4 changes: 2 additions & 2 deletions tests/probe/test_self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from statsmodels.genmod import families
from statsmodels.genmod.generalized_linear_model import GLM

from mislabeled.probe import ALOO, SelfInfluence, linearize
from mislabeled.probe import ApproximateLOO, SelfInfluence


@pytest.mark.parametrize(
Expand Down Expand Up @@ -92,7 +92,7 @@ def loss_fn(model, X, y):
model.fit(X, y)

si = SelfInfluence()
aloo = ALOO()
aloo = ApproximateLOO()

si_scores = si(model, X, y)
aloo_scores = aloo(model, X, y)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from mislabeled.aggregate import oob, sum
from mislabeled.detect import ModelProbingDetector
from mislabeled.detect.detectors import (
ApproximateLOODetector,
AreaUnderMargin,
Classifier,
ConfidentLearning,
Expand Down Expand Up @@ -165,6 +166,12 @@ def simple_detect_test(n_classes, detector):
)
),
VoLG(MLPClassifier(random_state=seed)),
ApproximateLOODetector(
make_pipeline(
Nystroem(gamma=0.1, n_components=100, random_state=seed),
LogisticRegression(random_state=seed, C=10),
)
),
]


Expand Down

0 comments on commit a8c6b16

Please sign in to comment.