Skip to content

Commit

Permalink
Check output is binary (#72)
Browse files Browse the repository at this point in the history
Co-authored-by: Kevin Klein <[email protected]>
  • Loading branch information
FrancescMartiEscofetQC and kklein authored May 10, 2024
1 parent 860473c commit 305557f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
15 changes: 15 additions & 0 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def _validate_params(self, **kwargs): ...
@abstractmethod
def _supports_multi_treatment(cls) -> bool: ...

@classmethod
@abstractmethod
def _supports_multi_class(cls) -> bool: ...

@classmethod
def _check_treatment(cls, w: Vector) -> None:
if (
Expand All @@ -87,6 +91,17 @@ def _check_treatment(cls, w: Vector) -> None:
f"Yet we found the values {set(np.unique(w))}."
)

def _check_outcome(self, y: Vector) -> None:
if (
self.is_classification
and not self._supports_multi_class()
and len(np.unique(y)) > 2
):
raise ValueError(
f"{self.__class__.__name__} does not support multiclass classification."
f" Yet we found {len(np.unique(y))} classes."
)

@abstractmethod
def _validate_models(self) -> None:
"""Validate that the models are of the correct type (classifier or regressor)"""
Expand Down
13 changes: 6 additions & 7 deletions metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def treatment_model_names(cls) -> set[str]:
def _supports_multi_treatment(cls) -> bool:
return True

@classmethod
def _supports_multi_class(cls) -> bool:
return True

def _validate_params(self, feature_set, **kwargs):
if feature_set is not None:
# For SLearner it does not make sense to allow feature set as we only have one model
Expand Down Expand Up @@ -112,15 +116,10 @@ def _nuisance_predict_methods(

def fit(self, X: Matrix, y: Vector, w: Vector) -> Self:
"""Fit all models of the S-Learner."""
self._check_treatment(w)
self._check_outcome(y)
self._n_variants = len(np.unique(w))
self._fitted_treatments = convert_treatment(w)
# TODO: add support for different encoding of treatment variants (str, not consecutive ints...)
if set(np.unique(w)) != set(range(self._n_variants)):
raise ValueError(
"Treatment variant should be encoded with values "
f"{{0...{self._n_variants -1}}} and all variants should be present. "
f"Yet we found the values {set(np.unique(w))}."
)

mock_model = self.nuisance_model_factory[_BASE_MODEL](
**self.nuisance_model_params[_BASE_MODEL]
Expand Down
5 changes: 5 additions & 0 deletions metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def treatment_model_names(cls) -> set[str]:
def _supports_multi_treatment(cls) -> bool:
return False

@classmethod
def _supports_multi_class(cls) -> bool:
return True

def _validate_models(self) -> None:
if self.is_classification and not is_classifier(
self.nuisance_model_factory[_TREATMENT_MODEL]
Expand Down Expand Up @@ -79,6 +83,7 @@ def _nuisance_predict_methods(self) -> dict[str, PredictMethod]:
def fit(self, X: Matrix, y: Vector, w: Vector) -> Self:
"""Fit all models of the T-Learner."""
self._check_treatment(w)
self._check_outcome(y)
self._treatment_indices = w == 1
self._control_indices = w == 0
# TODO: Consider multiprocessing
Expand Down
4 changes: 4 additions & 0 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def treatment_model_names(cls):
def _supports_multi_treatment(cls) -> bool:
return False

@classmethod
def _supports_multi_class(cls) -> bool:
return False

def _validate_models(self) -> None: ...

@property
Expand Down

0 comments on commit 305557f

Please sign in to comment.