Skip to content

Commit

Permalink
Draft usage of same splits in all models.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Aug 15, 2024
1 parent b883d62 commit 5afa7af
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 45 deletions.
12 changes: 6 additions & 6 deletions metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,12 @@ def _predict_in_sample(
) -> np.ndarray:
if not self._test_indices:
raise ValueError()
if len(X) != sum(len(fold) for fold in self._test_indices):
raise ValueError(
"Trying to predict in-sample on data that is unlike data encountered in training. "
f"Training data included {sum(len(fold) for fold in self._test_indices)} "
f"observations while prediction data includes {len(X)} observations."
)
# if len(X) != sum(len(fold) for fold in self._test_indices):
# raise ValueError(
# "Trying to predict in-sample on data that is unlike data encountered in training. "
# f"Training data included {sum(len(fold) for fold in self._test_indices)} "
# f"observations while prediction data includes {len(X)} observations."
# )
n_outputs = self._n_outputs(method)
predictions = self._initialize_prediction_tensor(
n_observations=len(X),
Expand Down
114 changes: 75 additions & 39 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,31 +99,36 @@ def fit_all_nuisance(

qualified_fit_params = self._qualified_fit_params(fit_params)

self._cvs: list = []
if not synchronize_cross_fitting:
raise ValueError()

self._cv_split_indices = self._split(X)
self._treatment_cv_split_indices = {}

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
if synchronize_cross_fitting:
cv_split_indices = self._split(
index_matrix(X, self._treatment_variants_indices[treatment_variant])
treatment_indices = np.where(
self._treatment_variants_indices[treatment_variant]
)[0]
self._treatment_cv_split_indices[treatment_variant] = [
(
np.intersect1d(train_indices, treatment_indices),
np.intersect1d(test_indices, treatment_indices),
)
else:
cv_split_indices = None
self._cvs.append(cv_split_indices)
for train_indices, test_indices in self._cv_split_indices
]

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
nuisance_jobs.append(
self._nuisance_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
y=y[self._treatment_variants_indices[treatment_variant]],
X=X,
y=y,
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL],
cv=self._cvs[treatment_variant],
cv=self._treatment_cv_split_indices[treatment_variant],
)
)

Expand Down Expand Up @@ -160,13 +165,13 @@ def fit_all_treatment(
) -> Self:
if self._treatment_variants_indices is None:
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"The nuisance models need to be fitted before fitting the treatment models. "
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during nuisance fitting, is None."
)
if not hasattr(self, "_cvs"):
if not hasattr(self, "_treatment_cv_split_indices"):
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"The nuisance models need to be fitted before fitting the treatment models. "
"In particular, the MetaLearner's attribute _cvs, "
"typically set during nuisance fitting, does not exist."
)
Expand All @@ -180,34 +185,31 @@ def fit_all_treatment(
is_oos=False,
)
)

for treatment_variant in range(1, self.n_variants):
imputed_te_control, imputed_te_treatment = self._pseudo_outcome(
y, w, treatment_variant, conditional_average_outcome_estimates
)
treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
X=X,
y=imputed_te_treatment,
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL],
cv=self._cvs[treatment_variant],
cv=self._treatment_cv_split_indices[treatment_variant],
)
)

treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(X, self._treatment_variants_indices[0]),
X=X,
y=imputed_te_control,
model_kind=CONTROL_EFFECT_MODEL,
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL],
cv=self._cvs[0],
cv=self._treatment_cv_split_indices[0],
)
)

Expand Down Expand Up @@ -278,19 +280,18 @@ def predict(
oos_method=oos_method,
)
)

tau_hat_treatment[treatment_variant_indices] = self.predict_treatment(
X=index_matrix(X, treatment_variant_indices),
X=X,
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
)
)[treatment_variant_indices]
tau_hat_control[control_indices] = self.predict_treatment(
X=index_matrix(X, control_indices),
X=X,
model_kind=CONTROL_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
)
)[control_indices]
tau_hat_control[non_control_indices] = self.predict_treatment(
X=index_matrix(X, non_control_indices),
model_kind=CONTROL_EFFECT_MODEL,
Expand Down Expand Up @@ -424,16 +425,8 @@ def _pseudo_outcome(
This function can be used with both in-sample or out-of-sample data.
"""
validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants)

treatment_indices = w == treatment_variant
control_indices = w == 0

treatment_outcome = index_matrix(
conditional_average_outcome_estimates, control_indices
)[:, treatment_variant]
control_outcome = index_matrix(
conditional_average_outcome_estimates, treatment_indices
)[:, 0]
treatment_outcome = conditional_average_outcome_estimates[:, treatment_variant]
control_outcome = conditional_average_outcome_estimates[:, 0]

if self.is_classification:
# Get the probability of positive class, multiclass is currently not supported.
Expand All @@ -443,8 +436,8 @@ def _pseudo_outcome(
control_outcome = control_outcome[:, 0]
treatment_outcome = treatment_outcome[:, 0]

imputed_te_treatment = y[treatment_indices] - control_outcome
imputed_te_control = treatment_outcome - y[control_indices]
imputed_te_treatment = y - control_outcome
imputed_te_control = treatment_outcome - y

return imputed_te_control, imputed_te_treatment

Expand Down Expand Up @@ -534,3 +527,46 @@ def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
final_model = build(input_dict, {output_name: cate})
check_model(final_model, full_check=True)
return final_model

def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
if self._treatment_variants_indices is None:
raise ValueError(
"The metalearner needs to be fitted before predicting."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during fitting, is None."
)
# TODO: Consider multiprocessing
n_obs = len(X)
nuisance_tensors = self._nuisance_tensors(n_obs)
conditional_average_outcomes_list = []

for tv in range(self.n_variants):
if is_oos:
conditional_average_outcomes_list.append(
self.predict_nuisance(
X=X,
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=tv,
is_oos=True,
oos_method=oos_method,
)
)
else:
cfe = self._nuisance_models[VARIANT_OUTCOME_MODEL][tv]
conditional_average_outcomes_list.append(
nuisance_tensors[VARIANT_OUTCOME_MODEL][0].copy()
)
for split_index, test_indices in enumerate(cfe._test_indices): # type: ignore[arg-type]
model = cfe._estimators[split_index]
predict_method_name = self.nuisance_model_specifications()[
VARIANT_OUTCOME_MODEL
]["predict_method"](self)
predict_method = getattr(model, predict_method_name)
conditional_average_outcomes_list[tv][test_indices] = (
predict_method(X[test_indices])
)
return np.stack(conditional_average_outcomes_list, axis=1).reshape(
n_obs, self.n_variants, -1
)

0 comments on commit 5afa7af

Please sign in to comment.