Skip to content

Commit

Permalink
Use functions/methods instead of lambdas for ModelSpecifications (#…
Browse files Browse the repository at this point in the history
…140)

Co-authored-by: Francesc Martí Escofet <[email protected]>
Co-authored-by: Francesc Martí Escofet <[email protected]>
  • Loading branch information
3 people authored May 31, 2024
1 parent 336058b commit 6942c93
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 32 deletions.
12 changes: 12 additions & 0 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,15 @@ def load_twins_data(
categorical_feature_columns,
true_cate_column,
)


def get_one(*args, **kwargs) -> int:
return 1


def get_predict(*args, **kwargs) -> PredictMethod:
return "predict"


def get_predict_proba(*args, **kwargs) -> PredictMethod:
return "predict_proba"
17 changes: 10 additions & 7 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
Matrix,
Vector,
clip_element_absolute_value_to_epsilon,
get_one,
get_predict,
get_predict_proba,
index_matrix,
validate_valid_treatment_variant_not_control,
)
Expand All @@ -19,6 +22,7 @@
TREATMENT,
TREATMENT_MODEL,
VARIANT_OUTCOME_MODEL,
MetaLearner,
_ConditionalAverageOutcomeMetaLearner,
_ModelSpecifications,
)
Expand Down Expand Up @@ -49,22 +53,21 @@ class DRLearner(_ConditionalAverageOutcomeMetaLearner):
def nuisance_model_specifications(cls) -> dict[str, _ModelSpecifications]:
return {
PROPENSITY_MODEL: _ModelSpecifications(
cardinality=lambda _: 1, predict_method=lambda _: "predict_proba"
cardinality=get_one,
predict_method=get_predict_proba,
),
VARIANT_OUTCOME_MODEL: _ModelSpecifications(
cardinality=lambda ml: ml.n_variants,
predict_method=lambda ml: (
"predict_proba" if ml.is_classification else "predict"
),
cardinality=MetaLearner._get_n_variants,
predict_method=MetaLearner._outcome_predict_method,
),
}

@classmethod
def treatment_model_specifications(cls) -> dict[str, _ModelSpecifications]:
return {
TREATMENT_MODEL: _ModelSpecifications(
cardinality=lambda ml: ml.n_variants - 1,
predict_method=lambda _: "predict",
cardinality=MetaLearner._get_n_variants_minus_one,
predict_method=get_predict,
)
}

Expand Down
9 changes: 9 additions & 0 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ def _supports_multi_treatment(cls) -> bool: ...
@abstractmethod
def _supports_multi_class(cls) -> bool: ...

def _outcome_predict_method(self):
return "predict_proba" if self.is_classification else "predict"

def _get_n_variants(self):
return self.n_variants

def _get_n_variants_minus_one(self):
return self.n_variants - 1

@classmethod
def _validate_n_variants(cls, n_variants: int) -> None:
if not isinstance(n_variants, int) or n_variants < 2:
Expand Down
16 changes: 9 additions & 7 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
Vector,
clip_element_absolute_value_to_epsilon,
function_has_argument,
get_one,
get_predict,
get_predict_proba,
index_matrix,
validate_all_vectors_same_index,
validate_valid_treatment_variant_not_control,
Expand Down Expand Up @@ -124,22 +127,21 @@ def _validate_fit_params(cls, fit_params: dict[str, dict[str, dict]]) -> None:
def nuisance_model_specifications(cls) -> dict[str, _ModelSpecifications]:
return {
PROPENSITY_MODEL: _ModelSpecifications(
cardinality=lambda _: 1, predict_method=lambda _: "predict_proba"
cardinality=get_one,
predict_method=get_predict_proba,
),
OUTCOME_MODEL: _ModelSpecifications(
cardinality=lambda _: 1,
predict_method=lambda ml: (
"predict_proba" if ml.is_classification else "predict"
),
cardinality=get_one,
predict_method=MetaLearner._outcome_predict_method,
),
}

@classmethod
def treatment_model_specifications(cls) -> dict[str, _ModelSpecifications]:
return {
TREATMENT_MODEL: _ModelSpecifications(
cardinality=lambda ml: ml.n_variants - 1,
predict_method=lambda _: "predict",
cardinality=MetaLearner._get_n_variants_minus_one,
predict_method=get_predict,
)
}

Expand Down
7 changes: 3 additions & 4 deletions metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Matrix,
Vector,
convert_treatment,
get_one,
supports_categoricals,
)
from metalearners.cross_fit_estimator import OVERALL
Expand Down Expand Up @@ -73,10 +74,8 @@ class SLearner(MetaLearner):
def nuisance_model_specifications(cls) -> dict[str, _ModelSpecifications]:
return {
_BASE_MODEL: _ModelSpecifications(
cardinality=lambda _: 1,
predict_method=lambda ml: (
"predict_proba" if ml.is_classification else "predict"
),
cardinality=get_one,
predict_method=MetaLearner._outcome_predict_method,
)
}

Expand Down
7 changes: 3 additions & 4 deletions metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from metalearners.metalearner import (
NUISANCE,
VARIANT_OUTCOME_MODEL,
MetaLearner,
_ConditionalAverageOutcomeMetaLearner,
_ModelSpecifications,
)
Expand All @@ -28,10 +29,8 @@ class TLearner(_ConditionalAverageOutcomeMetaLearner):
def nuisance_model_specifications(cls) -> dict[str, _ModelSpecifications]:
return {
VARIANT_OUTCOME_MODEL: _ModelSpecifications(
cardinality=lambda ml: ml.n_variants,
predict_method=lambda ml: (
"predict_proba" if ml.is_classification else "predict"
),
cardinality=MetaLearner._get_n_variants,
predict_method=MetaLearner._outcome_predict_method,
),
}

Expand Down
22 changes: 12 additions & 10 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from metalearners._utils import (
Matrix,
Vector,
get_one,
get_predict,
get_predict_proba,
index_matrix,
validate_valid_treatment_variant_not_control,
)
Expand All @@ -18,6 +21,7 @@
PROPENSITY_MODEL,
TREATMENT,
VARIANT_OUTCOME_MODEL,
MetaLearner,
_ConditionalAverageOutcomeMetaLearner,
_ModelSpecifications,
)
Expand All @@ -38,27 +42,25 @@ class XLearner(_ConditionalAverageOutcomeMetaLearner):
def nuisance_model_specifications(cls) -> dict[str, _ModelSpecifications]:
return {
VARIANT_OUTCOME_MODEL: _ModelSpecifications(
cardinality=lambda ml: ml.n_variants,
predict_method=lambda ml: (
"predict_proba" if ml.is_classification else "predict"
),
cardinality=MetaLearner._get_n_variants,
predict_method=MetaLearner._outcome_predict_method,
),
PROPENSITY_MODEL: _ModelSpecifications(
cardinality=lambda _: 1,
predict_method=lambda _: "predict_proba",
cardinality=get_one,
predict_method=get_predict_proba,
),
}

@classmethod
def treatment_model_specifications(cls) -> dict[str, _ModelSpecifications]:
return {
CONTROL_EFFECT_MODEL: _ModelSpecifications(
cardinality=lambda ml: ml.n_variants - 1,
predict_method=lambda _: "predict",
cardinality=MetaLearner._get_n_variants_minus_one,
predict_method=get_predict,
),
TREATMENT_EFFECT_MODEL: _ModelSpecifications(
cardinality=lambda ml: ml.n_variants - 1,
predict_method=lambda _: "predict",
cardinality=MetaLearner._get_n_variants_minus_one,
predict_method=get_predict,
),
}

Expand Down

0 comments on commit 6942c93

Please sign in to comment.