diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index c0bc92b..e87be41 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -13,6 +13,7 @@ OosMethod, Params, Scoring, + SplitIndices, Vector, _ScikitModel, ) @@ -128,7 +129,7 @@ def __init__( ) self.adaptive_clipping = adaptive_clipping - def fit( + def fit_all_nuisance( self, X: Matrix, y: Vector, @@ -148,10 +149,12 @@ def fit( for treatment_variant in range(self.n_variants): self._treatment_variants_indices.append(w == treatment_variant) + self._cv_split_indices: SplitIndices | None + if synchronize_cross_fitting: - cv_split_indices = self._split(X) + self._cv_split_indices = self._split(X) else: - cv_split_indices = None + self._cv_split_indices = None nuisance_jobs: list[_ParallelJoblibSpecification | None] = [] for treatment_variant in range(self.n_variants): @@ -176,7 +179,7 @@ def fit( model_ord=0, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL], - cv=cv_split_indices, + cv=self._cv_split_indices, ) ) @@ -189,6 +192,25 @@ def fit( self._assign_joblib_nuisance_results(results) + return self + + def fit_all_treatment( + self, + X: Matrix, + y: Vector, + w: Vector, + n_jobs_cross_fitting: int | None = None, + fit_params: dict | None = None, + synchronize_cross_fitting: bool = True, + n_jobs_base_learners: int | None = None, + ) -> Self: + if not hasattr(self, "_cv_split_indices"): + raise ValueError( + "The nuisance models need to be fitted before fitting the treatment models." + "In particular, the MetaLearner's attribute _cv_split_indices, " + "typically set during nuisance fitting, does not exist." + ) + qualified_fit_params = self._qualified_fit_params(fit_params) treatment_jobs: list[_ParallelJoblibSpecification] = [] for treatment_variant in range(1, self.n_variants): pseudo_outcomes = self._pseudo_outcome( @@ -207,9 +229,10 @@ def fit( model_ord=treatment_variant - 1, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[TREATMENT][TREATMENT_MODEL], - cv=cv_split_indices, + cv=self._cv_split_indices, ) ) + parallel = Parallel(n_jobs=n_jobs_base_learners) results = parallel( delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs ) diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 6214af9..2fb285b 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -744,6 +744,41 @@ def _assign_joblib_treatment_results( ] = result.cross_fit_estimator @abstractmethod + def fit_all_nuisance( + self, + X: Matrix, + y: Vector, + w: Vector, + n_jobs_cross_fitting: int | None = None, + fit_params: dict | None = None, + synchronize_cross_fitting: bool = True, + n_jobs_base_learners: int | None = None, + ) -> Self: + """Fit all nuisance models of the MetaLearner. + + If pre-fitted models were passed at instantiation, these are never refitted. + + For the parameters check :meth:`metalearners.metalearner.MetaLearner.fit`. + """ + ... + + @abstractmethod + def fit_all_treatment( + self, + X: Matrix, + y: Vector, + w: Vector, + n_jobs_cross_fitting: int | None = None, + fit_params: dict | None = None, + synchronize_cross_fitting: bool = True, + n_jobs_base_learners: int | None = None, + ) -> Self: + """Fit all treatment models of the MetaLearner. + + For the parameters check :meth:`metalearners.metalearner.MetaLearner.fit`. + """ + ... + def fit( self, X: Matrix, @@ -791,7 +826,27 @@ def fit( the same data splits where possible. Note that if there are several models to be synchronized which are classifiers, these cannot be split via stratification. """ - ... + self.fit_all_nuisance( + X=X, + y=y, + w=w, + n_jobs_cross_fitting=n_jobs_cross_fitting, + fit_params=fit_params, + synchronize_cross_fitting=synchronize_cross_fitting, + n_jobs_base_learners=n_jobs_base_learners, + ) + + self.fit_all_treatment( + X=X, + y=y, + w=w, + n_jobs_cross_fitting=n_jobs_cross_fitting, + fit_params=fit_params, + synchronize_cross_fitting=synchronize_cross_fitting, + n_jobs_base_learners=n_jobs_base_learners, + ) + + return self def predict_nuisance( self, diff --git a/metalearners/rlearner.py b/metalearners/rlearner.py index b7e9297..8e85502 100644 --- a/metalearners/rlearner.py +++ b/metalearners/rlearner.py @@ -156,7 +156,7 @@ def _supports_multi_treatment(cls) -> bool: def _supports_multi_class(cls) -> bool: return False - def fit( + def fit_all_nuisance( self, X: Matrix, y: Vector, @@ -165,14 +165,10 @@ def fit( fit_params: dict | None = None, synchronize_cross_fitting: bool = True, n_jobs_base_learners: int | None = None, - epsilon: float = _EPSILON, ) -> Self: - self._validate_treatment(w) self._validate_outcome(y, w) - self._variants_indices = [] - qualified_fit_params = self._qualified_fit_params(fit_params) self._validate_fit_params(qualified_fit_params) @@ -214,7 +210,22 @@ def fit( ) self._assign_joblib_nuisance_results(results) + return self + + def fit_all_treatment( + self, + X: Matrix, + y: Vector, + w: Vector, + n_jobs_cross_fitting: int | None = None, + fit_params: dict | None = None, + synchronize_cross_fitting: bool = True, + n_jobs_base_learners: int | None = None, + epsilon: float = _EPSILON, + ) -> Self: + qualified_fit_params = self._qualified_fit_params(fit_params) treatment_jobs: list[_ParallelJoblibSpecification] = [] + self._variants_indices = [] for treatment_variant in range(1, self.n_variants): is_treatment = w == treatment_variant @@ -246,6 +257,7 @@ def fit( n_jobs_cross_fitting=n_jobs_cross_fitting, ) ) + parallel = Parallel(n_jobs=n_jobs_base_learners) results = parallel( delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs ) diff --git a/metalearners/slearner.py b/metalearners/slearner.py index 7158cca..c6c99a6 100644 --- a/metalearners/slearner.py +++ b/metalearners/slearner.py @@ -141,7 +141,7 @@ def __init__( random_state=random_state, ) - def fit( + def fit_all_nuisance( self, X: Matrix, y: Vector, @@ -175,6 +175,18 @@ def fit( ) return self + def fit_all_treatment( + self, + X: Matrix, + y: Vector, + w: Vector, + n_jobs_cross_fitting: int | None = None, + fit_params: dict | None = None, + synchronize_cross_fitting: bool = True, + n_jobs_base_learners: int | None = None, + ) -> Self: + return self + def predict( self, X: Matrix, diff --git a/metalearners/tlearner.py b/metalearners/tlearner.py index cc61fe5..17442cf 100644 --- a/metalearners/tlearner.py +++ b/metalearners/tlearner.py @@ -48,7 +48,7 @@ def _supports_multi_treatment(cls) -> bool: def _supports_multi_class(cls) -> bool: return True - def fit( + def fit_all_nuisance( self, X: Matrix, y: Vector, @@ -92,6 +92,18 @@ def fit( self._assign_joblib_nuisance_results(results) return self + def fit_all_treatment( + self, + X: Matrix, + y: Vector, + w: Vector, + n_jobs_cross_fitting: int | None = None, + fit_params: dict | None = None, + synchronize_cross_fitting: bool = True, + n_jobs_base_learners: int | None = None, + ) -> Self: + return self + def predict( self, X, diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index 109619d..72d2107 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -74,7 +74,7 @@ def _supports_multi_treatment(cls) -> bool: def _supports_multi_class(cls) -> bool: return False - def fit( + def fit_all_nuisance( self, X: Matrix, y: Vector, @@ -91,7 +91,7 @@ def fit( qualified_fit_params = self._qualified_fit_params(fit_params) - cvs: list = [] + self._cvs: list = [] for treatment_variant in range(self.n_variants): self._treatment_variants_indices.append(w == treatment_variant) @@ -101,7 +101,7 @@ def fit( ) else: cv_split_indices = None - cvs.append(cv_split_indices) + self._cvs.append(cv_split_indices) nuisance_jobs: list[_ParallelJoblibSpecification | None] = [] for treatment_variant in range(self.n_variants): @@ -115,7 +115,7 @@ def fit( model_ord=treatment_variant, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL], - cv=cvs[treatment_variant], + cv=self._cvs[treatment_variant], ) ) @@ -138,6 +138,32 @@ def fit( ) self._assign_joblib_nuisance_results(results) + return self + + def fit_all_treatment( + self, + X: Matrix, + y: Vector, + w: Vector, + n_jobs_cross_fitting: int | None = None, + fit_params: dict | None = None, + synchronize_cross_fitting: bool = True, + n_jobs_base_learners: int | None = None, + ) -> Self: + if self._treatment_variants_indices is None: + raise ValueError( + "The nuisance models need to be fitted before fitting the treatment models." + "In particular, the MetaLearner's attribute _treatment_variant_indices, " + "typically set during nuisance fitting, is None." + ) + if not hasattr(self, "_cvs"): + raise ValueError( + "The nuisance models need to be fitted before fitting the treatment models." + "In particular, the MetaLearner's attribute _cvs, " + "typically set during nuisance fitting, does not exist." + ) + qualified_fit_params = self._qualified_fit_params(fit_params) + treatment_jobs: list[_ParallelJoblibSpecification] = [] for treatment_variant in range(1, self.n_variants): imputed_te_control, imputed_te_treatment = self._pseudo_outcome( @@ -153,7 +179,7 @@ def fit( model_ord=treatment_variant - 1, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL], - cv=cvs[treatment_variant], + cv=self._cvs[treatment_variant], ) ) @@ -165,10 +191,11 @@ def fit( model_ord=treatment_variant - 1, n_jobs_cross_fitting=n_jobs_cross_fitting, fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL], - cv=cvs[0], + cv=self._cvs[0], ) ) + parallel = Parallel(n_jobs=n_jobs_base_learners) results = parallel( delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs )