diff --git a/metalearners/cross_fit_estimator.py b/metalearners/cross_fit_estimator.py index aa112c0..66d86a0 100644 --- a/metalearners/cross_fit_estimator.py +++ b/metalearners/cross_fit_estimator.py @@ -242,12 +242,12 @@ def _predict_in_sample( ) -> np.ndarray: if not self._test_indices: raise ValueError() - if len(X) != sum(len(fold) for fold in self._test_indices): - raise ValueError( - "Trying to predict in-sample on data that is unlike data encountered in training. " - f"Training data included {sum(len(fold) for fold in self._test_indices)} " - f"observations while prediction data includes {len(X)} observations." - ) + # if len(X) != sum(len(fold) for fold in self._test_indices): + # raise ValueError( + # "Trying to predict in-sample on data that is unlike data encountered in training. " + # f"Training data included {sum(len(fold) for fold in self._test_indices)} " + # f"observations while prediction data includes {len(X)} observations." + # ) n_outputs = self._n_outputs(method) predictions = self._initialize_prediction_tensor( n_observations=len(X), diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index 28bee89..082917d 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -99,31 +99,36 @@ def fit_all_nuisance( qualified_fit_params = self._qualified_fit_params(fit_params) - self._cvs: list = [] + if not synchronize_cross_fitting: + raise ValueError() + + self._cv_split_indices = self._split(X) + self._treatment_cv_split_indices = {} for treatment_variant in range(self.n_variants): self._treatment_variants_indices.append(w == treatment_variant) - if synchronize_cross_fitting: - cv_split_indices = self._split( - index_matrix(X, self._treatment_variants_indices[treatment_variant]) + treatment_indices = np.where( + self._treatment_variants_indices[treatment_variant] + )[0] + self._treatment_cv_split_indices[treatment_variant] = [ + ( + np.intersect1d(train_indices, treatment_indices), + np.intersect1d(test_indices, treatment_indices), ) - else: - cv_split_indices = None - self._cvs.append(cv_split_indices) + for train_indices, test_indices in self._cv_split_indices + ] nuisance_jobs: list[_ParallelJoblibSpecification | None] = [] for treatment_variant in range(self.n_variants): nuisance_jobs.append( self._nuisance_joblib_specifications( - X=index_matrix( - X, self._treatment_variants_indices[treatment_variant] - ), - y=y[self._treatment_variants_indices[treatment_variant]], + X=X, + y=y, model_kind=VARIANT_OUTCOME_MODEL, model_ord=treatment_variant, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL], - cv=self._cvs[treatment_variant], + cv=self._treatment_cv_split_indices[treatment_variant], ) ) @@ -160,13 +165,13 @@ def fit_all_treatment( ) -> Self: if self._treatment_variants_indices is None: raise ValueError( - "The nuisance models need to be fitted before fitting the treatment models." + "The nuisance models need to be fitted before fitting the treatment models. " "In particular, the MetaLearner's attribute _treatment_variant_indices, " "typically set during nuisance fitting, is None." ) - if not hasattr(self, "_cvs"): + if not hasattr(self, "_treatment_cv_split_indices"): raise ValueError( - "The nuisance models need to be fitted before fitting the treatment models." + "The nuisance models need to be fitted before fitting the treatment models. " "In particular, the MetaLearner's attribute _cvs, " "typically set during nuisance fitting, does not exist." ) @@ -180,34 +185,31 @@ def fit_all_treatment( is_oos=False, ) ) - for treatment_variant in range(1, self.n_variants): imputed_te_control, imputed_te_treatment = self._pseudo_outcome( y, w, treatment_variant, conditional_average_outcome_estimates ) treatment_jobs.append( self._treatment_joblib_specifications( - X=index_matrix( - X, self._treatment_variants_indices[treatment_variant] - ), + X=X, y=imputed_te_treatment, model_kind=TREATMENT_EFFECT_MODEL, model_ord=treatment_variant - 1, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL], - cv=self._cvs[treatment_variant], + cv=self._treatment_cv_split_indices[treatment_variant], ) ) treatment_jobs.append( self._treatment_joblib_specifications( - X=index_matrix(X, self._treatment_variants_indices[0]), + X=X, y=imputed_te_control, model_kind=CONTROL_EFFECT_MODEL, model_ord=treatment_variant - 1, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL], - cv=self._cvs[0], + cv=self._treatment_cv_split_indices[0], ) ) @@ -278,19 +280,18 @@ def predict( oos_method=oos_method, ) ) - tau_hat_treatment[treatment_variant_indices] = self.predict_treatment( - X=index_matrix(X, treatment_variant_indices), + X=X, model_kind=TREATMENT_EFFECT_MODEL, model_ord=treatment_variant - 1, is_oos=False, - ) + )[treatment_variant_indices] tau_hat_control[control_indices] = self.predict_treatment( - X=index_matrix(X, control_indices), + X=X, model_kind=CONTROL_EFFECT_MODEL, model_ord=treatment_variant - 1, is_oos=False, - ) + )[control_indices] tau_hat_control[non_control_indices] = self.predict_treatment( X=index_matrix(X, non_control_indices), model_kind=CONTROL_EFFECT_MODEL, @@ -424,16 +425,8 @@ def _pseudo_outcome( This function can be used with both in-sample or out-of-sample data. """ validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants) - - treatment_indices = w == treatment_variant - control_indices = w == 0 - - treatment_outcome = index_matrix( - conditional_average_outcome_estimates, control_indices - )[:, treatment_variant] - control_outcome = index_matrix( - conditional_average_outcome_estimates, treatment_indices - )[:, 0] + treatment_outcome = conditional_average_outcome_estimates[:, treatment_variant] + control_outcome = conditional_average_outcome_estimates[:, 0] if self.is_classification: # Get the probability of positive class, multiclass is currently not supported. @@ -443,8 +436,8 @@ def _pseudo_outcome( control_outcome = control_outcome[:, 0] treatment_outcome = treatment_outcome[:, 0] - imputed_te_treatment = y[treatment_indices] - control_outcome - imputed_te_control = treatment_outcome - y[control_indices] + imputed_te_treatment = y - control_outcome + imputed_te_control = treatment_outcome - y return imputed_te_control, imputed_te_treatment @@ -534,3 +527,46 @@ def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): final_model = build(input_dict, {output_name: cate}) check_model(final_model, full_check=True) return final_model + + def predict_conditional_average_outcomes( + self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL + ) -> np.ndarray: + if self._treatment_variants_indices is None: + raise ValueError( + "The metalearner needs to be fitted before predicting." + "In particular, the MetaLearner's attribute _treatment_variant_indices, " + "typically set during fitting, is None." + ) + # TODO: Consider multiprocessing + n_obs = len(X) + nuisance_tensors = self._nuisance_tensors(n_obs) + conditional_average_outcomes_list = [] + + for tv in range(self.n_variants): + if is_oos: + conditional_average_outcomes_list.append( + self.predict_nuisance( + X=X, + model_kind=VARIANT_OUTCOME_MODEL, + model_ord=tv, + is_oos=True, + oos_method=oos_method, + ) + ) + else: + cfe = self._nuisance_models[VARIANT_OUTCOME_MODEL][tv] + conditional_average_outcomes_list.append( + nuisance_tensors[VARIANT_OUTCOME_MODEL][0].copy() + ) + for split_index, test_indices in enumerate(cfe._test_indices): # type: ignore[arg-type] + model = cfe._estimators[split_index] + predict_method_name = self.nuisance_model_specifications()[ + VARIANT_OUTCOME_MODEL + ]["predict_method"](self) + predict_method = getattr(model, predict_method_name) + conditional_average_outcomes_list[tv][test_indices] = ( + predict_method(X[test_indices]) + ) + return np.stack(conditional_average_outcomes_list, axis=1).reshape( + n_obs, self.n_variants, -1 + )