diff --git a/docs/examples/model.onnx b/docs/examples/model.onnx new file mode 100644 index 0000000..99cecad Binary files /dev/null and b/docs/examples/model.onnx differ diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index 6c03ab3..f2c67e1 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -150,12 +150,12 @@ def fit_all_nuisance( self._validate_treatment(w) self._validate_outcome(y, w) - self._treatment_variants_indices = [] + self._treatment_variants_mask = [] qualified_fit_params = self._qualified_fit_params(fit_params) for treatment_variant in range(self.n_variants): - self._treatment_variants_indices.append(w == treatment_variant) + self._treatment_variants_mask.append(w == treatment_variant) self._cv_split_indices: SplitIndices | None @@ -168,10 +168,8 @@ def fit_all_nuisance( for treatment_variant in range(self.n_variants): nuisance_jobs.append( self._nuisance_joblib_specifications( - X=index_matrix( - X, self._treatment_variants_indices[treatment_variant] - ), - y=y[self._treatment_variants_indices[treatment_variant]], + X=index_matrix(X, self._treatment_variants_mask[treatment_variant]), + y=y[self._treatment_variants_mask[treatment_variant]], model_kind=VARIANT_OUTCOME_MODEL, model_ord=treatment_variant, n_jobs_cross_fitting=n_jobs_cross_fitting, diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 1ec53c9..4024341 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -1336,15 +1336,15 @@ def __init__( n_folds=n_folds, random_state=random_state, ) - self._treatment_variants_indices: list[np.ndarray] | None = None + self._treatment_variants_mask: list[np.ndarray] | None = None def predict_conditional_average_outcomes( self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL ) -> np.ndarray: - if self._treatment_variants_indices is None: + if self._treatment_variants_mask is None: raise ValueError( "The metalearner needs to be fitted before predicting." - "In particular, the MetaLearner's attribute _treatment_variant_indices, " + "In particular, the MetaLearner's attribute _treatment_variant_mask, " "typically set during fitting, is None." ) # TODO: Consider multiprocessing @@ -1363,17 +1363,17 @@ def predict_conditional_average_outcomes( ) else: conditional_average_outcomes_list[tv][ - self._treatment_variants_indices[tv] + self._treatment_variants_mask[tv] ] = self.predict_nuisance( - X=index_matrix(X, self._treatment_variants_indices[tv]), + X=index_matrix(X, self._treatment_variants_mask[tv]), model_kind=VARIANT_OUTCOME_MODEL, model_ord=tv, is_oos=False, ) conditional_average_outcomes_list[tv][ - ~self._treatment_variants_indices[tv] + ~self._treatment_variants_mask[tv] ] = self.predict_nuisance( - X=index_matrix(X, ~self._treatment_variants_indices[tv]), + X=index_matrix(X, ~self._treatment_variants_mask[tv]), model_kind=VARIANT_OUTCOME_MODEL, model_ord=tv, is_oos=True, diff --git a/metalearners/tlearner.py b/metalearners/tlearner.py index c009e59..946d833 100644 --- a/metalearners/tlearner.py +++ b/metalearners/tlearner.py @@ -71,10 +71,10 @@ def fit_all_nuisance( self._validate_treatment(w) self._validate_outcome(y, w) - self._treatment_variants_indices = [] + self._treatment_variants_mask = [] for v in range(self.n_variants): - self._treatment_variants_indices.append(w == v) + self._treatment_variants_mask.append(w == v) qualified_fit_params = self._qualified_fit_params(fit_params) @@ -82,10 +82,8 @@ def fit_all_nuisance( for treatment_variant in range(self.n_variants): nuisance_jobs.append( self._nuisance_joblib_specifications( - X=index_matrix( - X, self._treatment_variants_indices[treatment_variant] - ), - y=y[self._treatment_variants_indices[treatment_variant]], + X=index_matrix(X, self._treatment_variants_mask[treatment_variant]), + y=y[self._treatment_variants_mask[treatment_variant]], model_kind=VARIANT_OUTCOME_MODEL, model_ord=treatment_variant, n_jobs_cross_fitting=n_jobs_cross_fitting, diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index 016a726..f9ff549 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -96,17 +96,17 @@ def fit_all_nuisance( self._validate_treatment(w) self._validate_outcome(y, w) - self._treatment_variants_indices = [] + self._treatment_variants_mask = [] qualified_fit_params = self._qualified_fit_params(fit_params) self._cvs: list = [] for treatment_variant in range(self.n_variants): - self._treatment_variants_indices.append(w == treatment_variant) + self._treatment_variants_mask.append(w == treatment_variant) if synchronize_cross_fitting: cv_split_indices = self._split( - index_matrix(X, self._treatment_variants_indices[treatment_variant]) + index_matrix(X, self._treatment_variants_mask[treatment_variant]) ) else: cv_split_indices = None @@ -116,10 +116,8 @@ def fit_all_nuisance( for treatment_variant in range(self.n_variants): nuisance_jobs.append( self._nuisance_joblib_specifications( - X=index_matrix( - X, self._treatment_variants_indices[treatment_variant] - ), - y=y[self._treatment_variants_indices[treatment_variant]], + X=index_matrix(X, self._treatment_variants_mask[treatment_variant]), + y=y[self._treatment_variants_mask[treatment_variant]], model_kind=VARIANT_OUTCOME_MODEL, model_ord=treatment_variant, n_jobs_cross_fitting=n_jobs_cross_fitting, @@ -159,10 +157,10 @@ def fit_all_treatment( synchronize_cross_fitting: bool = True, n_jobs_base_learners: int | None = None, ) -> Self: - if self._treatment_variants_indices is None: + if self._treatment_variants_mask is None: raise ValueError( "The nuisance models need to be fitted before fitting the treatment models." - "In particular, the MetaLearner's attribute _treatment_variant_indices, " + "In particular, the MetaLearner's attribute _treatment_variant_mask, " "typically set during nuisance fitting, is None." ) if not hasattr(self, "_cvs"): @@ -188,9 +186,7 @@ def fit_all_treatment( ) treatment_jobs.append( self._treatment_joblib_specifications( - X=index_matrix( - X, self._treatment_variants_indices[treatment_variant] - ), + X=index_matrix(X, self._treatment_variants_mask[treatment_variant]), y=imputed_te_treatment, model_kind=TREATMENT_EFFECT_MODEL, model_ord=treatment_variant - 1, @@ -202,7 +198,7 @@ def fit_all_treatment( treatment_jobs.append( self._treatment_joblib_specifications( - X=index_matrix(X, self._treatment_variants_indices[0]), + X=index_matrix(X, self._treatment_variants_mask[0]), y=imputed_te_control, model_kind=CONTROL_EFFECT_MODEL, model_ord=treatment_variant - 1, @@ -225,10 +221,10 @@ def predict( is_oos: bool, oos_method: OosMethod = OVERALL, ) -> np.ndarray: - if self._treatment_variants_indices is None: + if self._treatment_variants_mask is None: raise ValueError( "The MetaLearner needs to be fitted before predicting. " - "In particular, the X-Learner's attribute _treatment_variant_indices, " + "In particular, the X-Learner's attribute _treatment_variant_mask, " "typically set during fitting, is None." ) n_outputs = 2 if self.is_classification else 1 @@ -243,14 +239,12 @@ def predict( oos_method=propensity_score_oos, ) - control_indices = self._treatment_variants_indices[0] + control_indices = self._treatment_variants_mask[0] non_control_indices = ~control_indices for treatment_variant in range(1, self.n_variants): - treatment_variant_indices = self._treatment_variants_indices[ - treatment_variant - ] - non_treatment_variant_indices = ~treatment_variant_indices + treatment_variant_mask = self._treatment_variants_mask[treatment_variant] + non_treatment_variant_mask = ~treatment_variant_mask if is_oos: tau_hat_treatment = self.predict_treatment( X=X, @@ -270,18 +264,16 @@ def predict( tau_hat_treatment = np.zeros(safe_len(X)) tau_hat_control = np.zeros(safe_len(X)) - tau_hat_treatment[non_treatment_variant_indices] = ( - self.predict_treatment( - X=index_matrix(X, non_treatment_variant_indices), - model_kind=TREATMENT_EFFECT_MODEL, - model_ord=treatment_variant - 1, - is_oos=True, - oos_method=oos_method, - ) + tau_hat_treatment[non_treatment_variant_mask] = self.predict_treatment( + X=index_matrix(X, non_treatment_variant_mask), + model_kind=TREATMENT_EFFECT_MODEL, + model_ord=treatment_variant - 1, + is_oos=True, + oos_method=oos_method, ) - tau_hat_treatment[treatment_variant_indices] = self.predict_treatment( - X=index_matrix(X, treatment_variant_indices), + tau_hat_treatment[treatment_variant_mask] = self.predict_treatment( + X=index_matrix(X, treatment_variant_mask), model_kind=TREATMENT_EFFECT_MODEL, model_ord=treatment_variant - 1, is_oos=False, diff --git a/tests/test_learner.py b/tests/test_learner.py index afe2634..6b8f343 100644 --- a/tests/test_learner.py +++ b/tests/test_learner.py @@ -880,10 +880,10 @@ def test_model_reusage(outcome_kind, request): VARIANT_OUTCOME_MODEL: tlearner._nuisance_models[VARIANT_OUTCOME_MODEL] }, ) - # We need to manually copy _treatment_variants_indices for the xlearner as it's needed + # We need to manually copy _treatment_variants_mask for the xlearner as it's needed # for predict, the user should not have to do this as they should call fit before predict. # This is just for testing. - xlearner._treatment_variants_indices = tlearner._treatment_variants_indices + xlearner._treatment_variants_mask = tlearner._treatment_variants_mask np.testing.assert_allclose( tlearner.predict_conditional_average_outcomes(covariates, False), xlearner.predict_conditional_average_outcomes(covariates, False),