Skip to content

Commit

Permalink
Merge branch 'main' into survival_example
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC authored Jul 12, 2024
2 parents 1f00886 + 9e00603 commit 95cabbb
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 10 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Changelog
=========

0.6.1 (2024-07-xx)
0.7.0 (2024-07-12)
------------------

**New features**
Expand All @@ -16,7 +16,12 @@ Changelog

**Other changes**

* Changed the index columns order in ``MetaLearnerGridSearch.results_``.
* Change the index columns order in ``MetaLearnerGridSearch.results_``.

* Raise a custom error if only one class is present in a classification outcome.

* Raise a custom error if there are some treatment variants which have seen classification outcomes which have not appeared for some other treatment variant.


0.6.0 (2024-07-08)
------------------
Expand Down
2 changes: 1 addition & 1 deletion metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def fit(
n_jobs_base_learners: int | None = None,
) -> Self:
self._validate_treatment(w)
self._validate_outcome(y)
self._validate_outcome(y, w)

self._treatment_variants_indices = []

Expand Down
13 changes: 12 additions & 1 deletion metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _validate_treatment(self, w: Vector) -> None:
f"Yet we found the values {set(np.unique(w))}."
)

def _validate_outcome(self, y: Vector) -> None:
def _validate_outcome(self, y: Vector, w: Vector) -> None:
if (
self.is_classification
and not self._supports_multi_class()
Expand All @@ -327,6 +327,17 @@ def _validate_outcome(self, y: Vector) -> None:
f"{self.__class__.__name__} does not support multiclass classification."
f" Yet we found {len(np.unique(y))} classes."
)
if self.is_classification:
classes_0 = set(np.unique(y[w == 0]))
for tv in range(self.n_variants):
if set(np.unique(y[w == tv])) != classes_0:
raise ValueError(
f"Variants 0 and {tv} have seen different sets of classification outcomes. Please check your data."
)
if len(classes_0) == 1:
raise ValueError(
f"There is only one class present in the classification outcome: {classes_0}. Please check your data."
)

def _validate_models(self) -> None:
"""Validate that the base models are appropriate.
Expand Down
2 changes: 1 addition & 1 deletion metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def fit(
) -> Self:

self._validate_treatment(w)
self._validate_outcome(y)
self._validate_outcome(y, w)

self._variants_indices = []

Expand Down
2 changes: 1 addition & 1 deletion metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def fit(
n_jobs_base_learners: int | None = None,
) -> Self:
self._validate_treatment(w)
self._validate_outcome(y)
self._validate_outcome(y, w)
self._fitted_treatments = convert_treatment(w)

mock_model = self.nuisance_model_factory[_BASE_MODEL](
Expand Down
2 changes: 1 addition & 1 deletion metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def fit(
n_jobs_base_learners: int | None = None,
) -> Self:
self._validate_treatment(w)
self._validate_outcome(y)
self._validate_outcome(y, w)

self._treatment_variants_indices = []

Expand Down
2 changes: 1 addition & 1 deletion metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def fit(
n_jobs_base_learners: int | None = None,
) -> Self:
self._validate_treatment(w)
self._validate_outcome(y)
self._validate_outcome(y, w)

self._treatment_variants_indices = []

Expand Down
4 changes: 2 additions & 2 deletions tests/test_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,8 @@ def test_validate_treatment_error_different_instantiation(metalearner_prefix):
)
def test_validate_outcome_multi_class(metalearner_prefix, success):
covariates = np.zeros((20, 1))
w = np.array([0, 1] * 10)
y = np.array([0, 1] * 8 + [2] * 4)
w = np.array([0] * 10 + [1] * 10)
y = np.array([0, 1, 2, 3, 4] * 4)

factory = metalearner_factory(metalearner_prefix)
learner = factory(
Expand Down
55 changes: 55 additions & 0 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,3 +1052,58 @@ def test_n_jobs_base_learners(implementation, rng):

np.testing.assert_allclose(ml.predict(X, False), ml_2.predict(X, False))
np.testing.assert_allclose(ml.predict(X, True), ml_2.predict(X, True))


@pytest.mark.parametrize(
"implementation",
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
@pytest.mark.parametrize("use_pandas", [False, True])
def test_validate_outcome_one_class(implementation, use_pandas, rng):
X = rng.standard_normal((10, 2))
y = np.zeros(10)
w = rng.integers(0, 2, 10)
if use_pandas:
X = pd.DataFrame(X)
y = pd.Series(y)
w = pd.Series(w)

ml = implementation(
True,
2,
LogisticRegression,
LinearRegression,
LogisticRegression,
)
with pytest.raises(
ValueError,
match="There is only one class present in the classification outcome",
):
ml.fit(X, y, w)


@pytest.mark.parametrize(
"implementation",
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
@pytest.mark.parametrize("use_pandas", [False, True])
def test_validate_outcome_different_classes(implementation, use_pandas, rng):
X = rng.standard_normal((4, 2))
y = np.array([0, 1, 0, 0])
w = np.array([0, 0, 1, 1])
if use_pandas:
X = pd.DataFrame(X)
y = pd.Series(y)
w = pd.Series(w)

ml = implementation(
True,
2,
LogisticRegression,
LinearRegression,
LogisticRegression,
)
with pytest.raises(
ValueError, match="have seen different sets of classification outcomes."
):
ml.fit(X, y, w)

0 comments on commit 95cabbb

Please sign in to comment.