Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

X-Learner: Use the same sample splits in all base models. #84

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 6 additions & 6 deletions metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,12 @@ def _predict_in_sample(
) -> np.ndarray:
if not self._test_indices:
raise ValueError()
if len(X) != sum(len(fold) for fold in self._test_indices):
raise ValueError(
"Trying to predict in-sample on data that is unlike data encountered in training. "
f"Training data included {sum(len(fold) for fold in self._test_indices)} "
f"observations while prediction data includes {len(X)} observations."
)
# if len(X) != sum(len(fold) for fold in self._test_indices):
# raise ValueError(
# "Trying to predict in-sample on data that is unlike data encountered in training. "
# f"Training data included {sum(len(fold) for fold in self._test_indices)} "
# f"observations while prediction data includes {len(X)} observations."
# )
n_outputs = self._n_outputs(method)
predictions = self._initialize_prediction_tensor(
n_observations=len(X),
Expand Down
146 changes: 101 additions & 45 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from joblib import Parallel, delayed
from typing_extensions import Self

from metalearners._typing import Matrix, OosMethod, Scoring, Vector, _ScikitModel
from metalearners._typing import (
Matrix,
OosMethod,
Scoring,
SplitIndices,
Vector,
_ScikitModel,
)
from metalearners._utils import (
check_spox_installed,
copydoc,
Expand Down Expand Up @@ -99,31 +106,39 @@ def fit_all_nuisance(

qualified_fit_params = self._qualified_fit_params(fit_params)

self._cvs: list = []
# TODO: Move this to object initialization.
if not synchronize_cross_fitting:
raise ValueError(
"The X-Learner does not support synchronize_cross_fitting=False."
)

self._cv_split_indices: SplitIndices = self._split(X)
self._treatment_cv_split_indices: dict[int, SplitIndices] = {}

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
if synchronize_cross_fitting:
cv_split_indices = self._split(
index_matrix(X, self._treatment_variants_indices[treatment_variant])
treatment_indices = np.where(
Copy link
Collaborator Author

@kklein kklein Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an opaque way of turning an array [True, True, False, False, True] into an array [0, 1, 4]. Not sure if there's a neater way of doing that.

Copy link
Contributor

@MatthiasLoefflerQC MatthiasLoefflerQC Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[index for index, value in enumerate(vector) if value] would work too, I guess, and is more verbose, but I like the np.where :)

self._treatment_variants_indices[treatment_variant]
)[0]
self._treatment_cv_split_indices[treatment_variant] = [
(
np.intersect1d(train_indices, treatment_indices),
np.intersect1d(test_indices, treatment_indices),
)
else:
cv_split_indices = None
self._cvs.append(cv_split_indices)
for train_indices, test_indices in self._cv_split_indices
MatthiasLoefflerQC marked this conversation as resolved.
Show resolved Hide resolved
]

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
nuisance_jobs.append(
self._nuisance_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
y=y[self._treatment_variants_indices[treatment_variant]],
X=X,
y=y,
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL],
cv=self._cvs[treatment_variant],
cv=self._treatment_cv_split_indices[treatment_variant],
)
)

Expand Down Expand Up @@ -160,14 +175,14 @@ def fit_all_treatment(
) -> Self:
if self._treatment_variants_indices is None:
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"The nuisance models need to be fitted before fitting the treatment models. "
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during nuisance fitting, is None."
)
if not hasattr(self, "_cvs"):
if not hasattr(self, "_treatment_cv_split_indices"):
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _cvs, "
"The nuisance models need to be fitted before fitting the treatment models. "
"In particular, the MetaLearner's attribute _treatment_cv_split_indices, "
"typically set during nuisance fitting, does not exist."
)
qualified_fit_params = self._qualified_fit_params(fit_params)
Expand All @@ -180,34 +195,32 @@ def fit_all_treatment(
is_oos=False,
)
)

for treatment_variant in range(1, self.n_variants):
imputed_te_control, imputed_te_treatment = self._pseudo_outcome(
y, w, treatment_variant, conditional_average_outcome_estimates
)

treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
X=X,
y=imputed_te_treatment,
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL],
cv=self._cvs[treatment_variant],
cv=self._treatment_cv_split_indices[treatment_variant],
)
)

treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(X, self._treatment_variants_indices[0]),
X=X,
y=imputed_te_control,
model_kind=CONTROL_EFFECT_MODEL,
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL],
cv=self._cvs[0],
cv=self._treatment_cv_split_indices[0],
)
)

Expand All @@ -216,6 +229,7 @@ def fit_all_treatment(
delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
)
self._assign_joblib_treatment_results(results)

return self

def predict(
Expand Down Expand Up @@ -278,19 +292,18 @@ def predict(
oos_method=oos_method,
)
)

tau_hat_treatment[treatment_variant_indices] = self.predict_treatment(
X=index_matrix(X, treatment_variant_indices),
X=X,
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
)
)[treatment_variant_indices]
tau_hat_control[control_indices] = self.predict_treatment(
X=index_matrix(X, control_indices),
X=X,
model_kind=CONTROL_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
)
)[control_indices]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need is_oos=False below (and likewise for tau_hat_treatment)? Might be worth a try.

tau_hat_control[non_control_indices] = self.predict_treatment(
X=index_matrix(X, non_control_indices),
model_kind=CONTROL_EFFECT_MODEL,
Expand Down Expand Up @@ -337,8 +350,8 @@ def evaluate(

variant_outcome_evaluation = _evaluate_model_kind(
cfes=self._nuisance_models[VARIANT_OUTCOME_MODEL],
Xs=[X[w == tv] for tv in range(self.n_variants)],
ys=[y[w == tv] for tv in range(self.n_variants)],
Xs=[X] * self.n_variants,
ys=[y] * self.n_variants,
scorers=safe_scoring[VARIANT_OUTCOME_MODEL],
model_kind=VARIANT_OUTCOME_MODEL,
is_oos=is_oos,
Expand Down Expand Up @@ -378,7 +391,7 @@ def evaluate(

te_treatment_evaluation = _evaluate_model_kind(
self._treatment_models[TREATMENT_EFFECT_MODEL],
Xs=[X[w == tv] for tv in range(1, self.n_variants)],
Xs=[X] * self.n_variants,
ys=imputed_te_treatment,
scorers=safe_scoring[TREATMENT_EFFECT_MODEL],
model_kind=TREATMENT_EFFECT_MODEL,
Expand All @@ -390,7 +403,7 @@ def evaluate(

te_control_evaluation = _evaluate_model_kind(
self._treatment_models[CONTROL_EFFECT_MODEL],
Xs=[X[w == 0] for _ in range(1, self.n_variants)],
Xs=[X] * self.n_variants,
ys=imputed_te_control,
scorers=safe_scoring[CONTROL_EFFECT_MODEL],
model_kind=CONTROL_EFFECT_MODEL,
Expand Down Expand Up @@ -424,16 +437,8 @@ def _pseudo_outcome(
This function can be used with both in-sample or out-of-sample data.
"""
validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants)

treatment_indices = w == treatment_variant
control_indices = w == 0

treatment_outcome = index_matrix(
conditional_average_outcome_estimates, control_indices
)[:, treatment_variant]
control_outcome = index_matrix(
conditional_average_outcome_estimates, treatment_indices
)[:, 0]
treatment_outcome = conditional_average_outcome_estimates[:, treatment_variant]
control_outcome = conditional_average_outcome_estimates[:, 0]

if self.is_classification:
# Get the probability of positive class, multiclass is currently not supported.
Expand All @@ -443,8 +448,8 @@ def _pseudo_outcome(
control_outcome = control_outcome[:, 0]
treatment_outcome = treatment_outcome[:, 0]

imputed_te_treatment = y[treatment_indices] - control_outcome
imputed_te_control = treatment_outcome - y[control_indices]
imputed_te_treatment = y - control_outcome
imputed_te_control = treatment_outcome - y

return imputed_te_control, imputed_te_treatment

Expand Down Expand Up @@ -534,3 +539,54 @@ def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
final_model = build(input_dict, {output_name: cate})
check_model(final_model, full_check=True)
return final_model

def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
if self._treatment_variants_indices is None:
raise ValueError(
"The metalearner needs to be fitted before predicting."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during fitting, is None."
)
# TODO: Consider multiprocessing
n_obs = len(X)
cao_tensor = self._nuisance_tensors(n_obs)[VARIANT_OUTCOME_MODEL][0]
predict_method_name = self.nuisance_model_specifications()[
VARIANT_OUTCOME_MODEL
]["predict_method"](self)
conditional_average_outcomes_list = []

for tv in range(self.n_variants):
if is_oos:
conditional_average_outcomes_list.append(
self.predict_nuisance(
X=X,
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=tv,
is_oos=is_oos,
oos_method=oos_method,
)
)
else:
# TODO: Consider moving this logic to CrossFitEstimator.predict.
cfe = self._nuisance_models[VARIANT_OUTCOME_MODEL][tv]
conditional_average_outcome_estimates = cao_tensor.copy()

for fold_index, (train_indices, prediction_indices) in enumerate(
self._cv_split_indices
):
fold_model = cfe._estimators[fold_index]
predict_method = getattr(fold_model, predict_method_name)
fold_estimates = predict_method(index_matrix(X, prediction_indices))
conditional_average_outcome_estimates[prediction_indices] = (
fold_estimates
)

conditional_average_outcomes_list.append(
conditional_average_outcome_estimates
)

return np.stack(conditional_average_outcomes_list, axis=1).reshape(
n_obs, self.n_variants, -1
)
14 changes: 11 additions & 3 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,9 +727,17 @@ def test_fit_params(metalearner_factory, fit_params, expected_keys, dummy_datase
is_classification=False,
n_folds=1,
)
# Using cross-fitting is not possible with a single fold.
if metalearner_factory == XLearner:
# TODO: The X-Learner doesn't support using synchronize_cross_fitting=False.
# As a consequence, it doesn't support n_folds=1 either.
# We should find an alternative to testing this property for the X-Learner.
pytest.skip()
metalearner.fit(
X=X, y=y, w=w, fit_params=fit_params, synchronize_cross_fitting=False
X=X,
y=y,
w=w,
fit_params=fit_params,
synchronize_cross_fitting=False,
)


Expand Down Expand Up @@ -994,9 +1002,9 @@ def test_shap_values_smoke(
[
TLearner,
SLearner,
XLearner,
RLearner,
DRLearner,
# The X-Learner does not support synchronze_cross_fitting = False.
],
)
@pytest.mark.parametrize("n_variants", [2, 5])
Expand Down
Loading