diff --git a/metalearners/cross_fit_estimator.py b/metalearners/cross_fit_estimator.py index aa112c0..66d86a0 100644 --- a/metalearners/cross_fit_estimator.py +++ b/metalearners/cross_fit_estimator.py @@ -242,12 +242,12 @@ def _predict_in_sample( ) -> np.ndarray: if not self._test_indices: raise ValueError() - if len(X) != sum(len(fold) for fold in self._test_indices): - raise ValueError( - "Trying to predict in-sample on data that is unlike data encountered in training. " - f"Training data included {sum(len(fold) for fold in self._test_indices)} " - f"observations while prediction data includes {len(X)} observations." - ) + # if len(X) != sum(len(fold) for fold in self._test_indices): + # raise ValueError( + # "Trying to predict in-sample on data that is unlike data encountered in training. " + # f"Training data included {sum(len(fold) for fold in self._test_indices)} " + # f"observations while prediction data includes {len(X)} observations." + # ) n_outputs = self._n_outputs(method) predictions = self._initialize_prediction_tensor( n_observations=len(X), diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index 28bee89..321cd94 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -8,7 +8,14 @@ from joblib import Parallel, delayed from typing_extensions import Self -from metalearners._typing import Matrix, OosMethod, Scoring, Vector, _ScikitModel +from metalearners._typing import ( + Matrix, + OosMethod, + Scoring, + SplitIndices, + Vector, + _ScikitModel, +) from metalearners._utils import ( check_spox_installed, copydoc, @@ -99,31 +106,39 @@ def fit_all_nuisance( qualified_fit_params = self._qualified_fit_params(fit_params) - self._cvs: list = [] + # TODO: Move this to object initialization. + if not synchronize_cross_fitting: + raise ValueError( + "The X-Learner does not support synchronize_cross_fitting=False." + ) + + self._cv_split_indices: SplitIndices = self._split(X) + self._treatment_cv_split_indices: dict[int, SplitIndices] = {} for treatment_variant in range(self.n_variants): self._treatment_variants_indices.append(w == treatment_variant) - if synchronize_cross_fitting: - cv_split_indices = self._split( - index_matrix(X, self._treatment_variants_indices[treatment_variant]) + treatment_indices = np.where( + self._treatment_variants_indices[treatment_variant] + )[0] + self._treatment_cv_split_indices[treatment_variant] = [ + ( + np.intersect1d(train_indices, treatment_indices), + np.intersect1d(test_indices, treatment_indices), ) - else: - cv_split_indices = None - self._cvs.append(cv_split_indices) + for train_indices, test_indices in self._cv_split_indices + ] nuisance_jobs: list[_ParallelJoblibSpecification | None] = [] for treatment_variant in range(self.n_variants): nuisance_jobs.append( self._nuisance_joblib_specifications( - X=index_matrix( - X, self._treatment_variants_indices[treatment_variant] - ), - y=y[self._treatment_variants_indices[treatment_variant]], + X=X, + y=y, model_kind=VARIANT_OUTCOME_MODEL, model_ord=treatment_variant, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL], - cv=self._cvs[treatment_variant], + cv=self._treatment_cv_split_indices[treatment_variant], ) ) @@ -160,14 +175,14 @@ def fit_all_treatment( ) -> Self: if self._treatment_variants_indices is None: raise ValueError( - "The nuisance models need to be fitted before fitting the treatment models." + "The nuisance models need to be fitted before fitting the treatment models. " "In particular, the MetaLearner's attribute _treatment_variant_indices, " "typically set during nuisance fitting, is None." ) - if not hasattr(self, "_cvs"): + if not hasattr(self, "_treatment_cv_split_indices"): raise ValueError( - "The nuisance models need to be fitted before fitting the treatment models." - "In particular, the MetaLearner's attribute _cvs, " + "The nuisance models need to be fitted before fitting the treatment models. " + "In particular, the MetaLearner's attribute _treatment_cv_split_indices, " "typically set during nuisance fitting, does not exist." ) qualified_fit_params = self._qualified_fit_params(fit_params) @@ -180,34 +195,32 @@ def fit_all_treatment( is_oos=False, ) ) - for treatment_variant in range(1, self.n_variants): imputed_te_control, imputed_te_treatment = self._pseudo_outcome( y, w, treatment_variant, conditional_average_outcome_estimates ) + treatment_jobs.append( self._treatment_joblib_specifications( - X=index_matrix( - X, self._treatment_variants_indices[treatment_variant] - ), + X=X, y=imputed_te_treatment, model_kind=TREATMENT_EFFECT_MODEL, model_ord=treatment_variant - 1, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL], - cv=self._cvs[treatment_variant], + cv=self._treatment_cv_split_indices[treatment_variant], ) ) treatment_jobs.append( self._treatment_joblib_specifications( - X=index_matrix(X, self._treatment_variants_indices[0]), + X=X, y=imputed_te_control, model_kind=CONTROL_EFFECT_MODEL, model_ord=treatment_variant - 1, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL], - cv=self._cvs[0], + cv=self._treatment_cv_split_indices[0], ) ) @@ -216,6 +229,7 @@ def fit_all_treatment( delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs ) self._assign_joblib_treatment_results(results) + return self def predict( @@ -278,19 +292,18 @@ def predict( oos_method=oos_method, ) ) - tau_hat_treatment[treatment_variant_indices] = self.predict_treatment( - X=index_matrix(X, treatment_variant_indices), + X=X, model_kind=TREATMENT_EFFECT_MODEL, model_ord=treatment_variant - 1, is_oos=False, - ) + )[treatment_variant_indices] tau_hat_control[control_indices] = self.predict_treatment( - X=index_matrix(X, control_indices), + X=X, model_kind=CONTROL_EFFECT_MODEL, model_ord=treatment_variant - 1, is_oos=False, - ) + )[control_indices] tau_hat_control[non_control_indices] = self.predict_treatment( X=index_matrix(X, non_control_indices), model_kind=CONTROL_EFFECT_MODEL, @@ -337,8 +350,8 @@ def evaluate( variant_outcome_evaluation = _evaluate_model_kind( cfes=self._nuisance_models[VARIANT_OUTCOME_MODEL], - Xs=[X[w == tv] for tv in range(self.n_variants)], - ys=[y[w == tv] for tv in range(self.n_variants)], + Xs=[X] * self.n_variants, + ys=[y] * self.n_variants, scorers=safe_scoring[VARIANT_OUTCOME_MODEL], model_kind=VARIANT_OUTCOME_MODEL, is_oos=is_oos, @@ -378,7 +391,7 @@ def evaluate( te_treatment_evaluation = _evaluate_model_kind( self._treatment_models[TREATMENT_EFFECT_MODEL], - Xs=[X[w == tv] for tv in range(1, self.n_variants)], + Xs=[X] * self.n_variants, ys=imputed_te_treatment, scorers=safe_scoring[TREATMENT_EFFECT_MODEL], model_kind=TREATMENT_EFFECT_MODEL, @@ -390,7 +403,7 @@ def evaluate( te_control_evaluation = _evaluate_model_kind( self._treatment_models[CONTROL_EFFECT_MODEL], - Xs=[X[w == 0] for _ in range(1, self.n_variants)], + Xs=[X] * self.n_variants, ys=imputed_te_control, scorers=safe_scoring[CONTROL_EFFECT_MODEL], model_kind=CONTROL_EFFECT_MODEL, @@ -424,16 +437,8 @@ def _pseudo_outcome( This function can be used with both in-sample or out-of-sample data. """ validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants) - - treatment_indices = w == treatment_variant - control_indices = w == 0 - - treatment_outcome = index_matrix( - conditional_average_outcome_estimates, control_indices - )[:, treatment_variant] - control_outcome = index_matrix( - conditional_average_outcome_estimates, treatment_indices - )[:, 0] + treatment_outcome = conditional_average_outcome_estimates[:, treatment_variant] + control_outcome = conditional_average_outcome_estimates[:, 0] if self.is_classification: # Get the probability of positive class, multiclass is currently not supported. @@ -443,8 +448,8 @@ def _pseudo_outcome( control_outcome = control_outcome[:, 0] treatment_outcome = treatment_outcome[:, 0] - imputed_te_treatment = y[treatment_indices] - control_outcome - imputed_te_control = treatment_outcome - y[control_indices] + imputed_te_treatment = y - control_outcome + imputed_te_control = treatment_outcome - y return imputed_te_control, imputed_te_treatment @@ -534,3 +539,54 @@ def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): final_model = build(input_dict, {output_name: cate}) check_model(final_model, full_check=True) return final_model + + def predict_conditional_average_outcomes( + self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL + ) -> np.ndarray: + if self._treatment_variants_indices is None: + raise ValueError( + "The metalearner needs to be fitted before predicting." + "In particular, the MetaLearner's attribute _treatment_variant_indices, " + "typically set during fitting, is None." + ) + # TODO: Consider multiprocessing + n_obs = len(X) + cao_tensor = self._nuisance_tensors(n_obs)[VARIANT_OUTCOME_MODEL][0] + predict_method_name = self.nuisance_model_specifications()[ + VARIANT_OUTCOME_MODEL + ]["predict_method"](self) + conditional_average_outcomes_list = [] + + for tv in range(self.n_variants): + if is_oos: + conditional_average_outcomes_list.append( + self.predict_nuisance( + X=X, + model_kind=VARIANT_OUTCOME_MODEL, + model_ord=tv, + is_oos=is_oos, + oos_method=oos_method, + ) + ) + else: + # TODO: Consider moving this logic to CrossFitEstimator.predict. + cfe = self._nuisance_models[VARIANT_OUTCOME_MODEL][tv] + conditional_average_outcome_estimates = cao_tensor.copy() + + for fold_index, (train_indices, prediction_indices) in enumerate( + self._cv_split_indices + ): + fold_model = cfe._estimators[fold_index] + predict_method = getattr(fold_model, predict_method_name) + fold_estimates = predict_method(index_matrix(X, prediction_indices)) + conditional_average_outcome_estimates[prediction_indices] = ( + fold_estimates + ) + + conditional_average_outcomes_list.append( + conditional_average_outcome_estimates + ) + + return np.stack(conditional_average_outcomes_list, axis=1).reshape( + n_obs, self.n_variants, -1 + ) diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index d9ac1f6..e0907d1 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -727,9 +727,17 @@ def test_fit_params(metalearner_factory, fit_params, expected_keys, dummy_datase is_classification=False, n_folds=1, ) - # Using cross-fitting is not possible with a single fold. + if metalearner_factory == XLearner: + # TODO: The X-Learner doesn't support using synchronize_cross_fitting=False. + # As a consequence, it doesn't support n_folds=1 either. + # We should find an alternative to testing this property for the X-Learner. + pytest.skip() metalearner.fit( - X=X, y=y, w=w, fit_params=fit_params, synchronize_cross_fitting=False + X=X, + y=y, + w=w, + fit_params=fit_params, + synchronize_cross_fitting=False, ) @@ -994,9 +1002,9 @@ def test_shap_values_smoke( [ TLearner, SLearner, - XLearner, RLearner, DRLearner, + # The X-Learner does not support synchronze_cross_fitting = False. ], ) @pytest.mark.parametrize("n_variants", [2, 5])