Skip to content

Commit

Permalink
Split fit
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 19, 2024
1 parent 9e00603 commit 120885a
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 19 deletions.
33 changes: 28 additions & 5 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
OosMethod,
Params,
Scoring,
SplitIndices,
Vector,
_ScikitModel,
)
Expand Down Expand Up @@ -128,7 +129,7 @@ def __init__(
)
self.adaptive_clipping = adaptive_clipping

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand All @@ -148,10 +149,12 @@ def fit(
for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)

self._cv_split_indices: SplitIndices | None

if synchronize_cross_fitting:
cv_split_indices = self._split(X)
self._cv_split_indices = self._split(X)
else:
cv_split_indices = None
self._cv_split_indices = None

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
Expand All @@ -176,7 +179,7 @@ def fit(
model_ord=0,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL],
cv=cv_split_indices,
cv=self._cv_split_indices,
)
)

Expand All @@ -189,6 +192,25 @@ def fit(

self._assign_joblib_nuisance_results(results)

return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
if not hasattr(self, "_cv_split_indices"):
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _cv_split_indices, "
"typically set during nuisance fitting, does not exist."
)
qualified_fit_params = self._qualified_fit_params(fit_params)
treatment_jobs: list[_ParallelJoblibSpecification] = []
for treatment_variant in range(1, self.n_variants):
pseudo_outcomes = self._pseudo_outcome(
Expand All @@ -207,9 +229,10 @@ def fit(
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][TREATMENT_MODEL],
cv=cv_split_indices,
cv=self._cv_split_indices,
)
)
parallel = Parallel(n_jobs=n_jobs_base_learners)
results = parallel(
delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
)
Expand Down
57 changes: 56 additions & 1 deletion metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,41 @@ def _assign_joblib_treatment_results(
] = result.cross_fit_estimator

@abstractmethod
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
"""Fit all nuisance models of the MetaLearner.
If pre-fitted models were passed at instantiation, these are never refitted.
For the parameters check :meth:`metalearners.metalearner.MetaLearner.fit`.
"""
...

@abstractmethod
def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
"""Fit all treatment models of the MetaLearner.
For the parameters check :meth:`metalearners.metalearner.MetaLearner.fit`.
"""
...

def fit(
self,
X: Matrix,
Expand Down Expand Up @@ -791,7 +826,27 @@ def fit(
the same data splits where possible. Note that if there are several models to be synchronized which are
classifiers, these cannot be split via stratification.
"""
...
self.fit_all_nuisance(
X=X,
y=y,
w=w,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=fit_params,
synchronize_cross_fitting=synchronize_cross_fitting,
n_jobs_base_learners=n_jobs_base_learners,
)

self.fit_all_treatment(
X=X,
y=y,
w=w,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=fit_params,
synchronize_cross_fitting=synchronize_cross_fitting,
n_jobs_base_learners=n_jobs_base_learners,
)

return self

def predict_nuisance(
self,
Expand Down
22 changes: 17 additions & 5 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _supports_multi_treatment(cls) -> bool:
def _supports_multi_class(cls) -> bool:
return False

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand All @@ -165,14 +165,10 @@ def fit(
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
epsilon: float = _EPSILON,
) -> Self:

self._validate_treatment(w)
self._validate_outcome(y, w)

self._variants_indices = []

qualified_fit_params = self._qualified_fit_params(fit_params)
self._validate_fit_params(qualified_fit_params)

Expand Down Expand Up @@ -214,7 +210,22 @@ def fit(
)
self._assign_joblib_nuisance_results(results)

return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
epsilon: float = _EPSILON,
) -> Self:
qualified_fit_params = self._qualified_fit_params(fit_params)
treatment_jobs: list[_ParallelJoblibSpecification] = []
self._variants_indices = []
for treatment_variant in range(1, self.n_variants):

is_treatment = w == treatment_variant
Expand Down Expand Up @@ -246,6 +257,7 @@ def fit(
n_jobs_cross_fitting=n_jobs_cross_fitting,
)
)
parallel = Parallel(n_jobs=n_jobs_base_learners)
results = parallel(
delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
)
Expand Down
14 changes: 13 additions & 1 deletion metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
random_state=random_state,
)

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand Down Expand Up @@ -175,6 +175,18 @@ def fit(
)
return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
return self

def predict(
self,
X: Matrix,
Expand Down
14 changes: 13 additions & 1 deletion metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _supports_multi_treatment(cls) -> bool:
def _supports_multi_class(cls) -> bool:
return True

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand Down Expand Up @@ -92,6 +92,18 @@ def fit(
self._assign_joblib_nuisance_results(results)
return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
return self

def predict(
self,
X,
Expand Down
39 changes: 33 additions & 6 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _supports_multi_treatment(cls) -> bool:
def _supports_multi_class(cls) -> bool:
return False

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand All @@ -91,7 +91,7 @@ def fit(

qualified_fit_params = self._qualified_fit_params(fit_params)

cvs: list = []
self._cvs: list = []

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
Expand All @@ -101,7 +101,7 @@ def fit(
)
else:
cv_split_indices = None
cvs.append(cv_split_indices)
self._cvs.append(cv_split_indices)

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
Expand All @@ -115,7 +115,7 @@ def fit(
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL],
cv=cvs[treatment_variant],
cv=self._cvs[treatment_variant],
)
)

Expand All @@ -138,6 +138,32 @@ def fit(
)
self._assign_joblib_nuisance_results(results)

return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
if self._treatment_variants_indices is None:
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during nuisance fitting, is None."
)
if not hasattr(self, "_cvs"):
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _cvs, "
"typically set during nuisance fitting, does not exist."
)
qualified_fit_params = self._qualified_fit_params(fit_params)

treatment_jobs: list[_ParallelJoblibSpecification] = []
for treatment_variant in range(1, self.n_variants):
imputed_te_control, imputed_te_treatment = self._pseudo_outcome(
Expand All @@ -153,7 +179,7 @@ def fit(
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL],
cv=cvs[treatment_variant],
cv=self._cvs[treatment_variant],
)
)

Expand All @@ -165,10 +191,11 @@ def fit(
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL],
cv=cvs[0],
cv=self._cvs[0],
)
)

parallel = Parallel(n_jobs=n_jobs_base_learners)
results = parallel(
delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
)
Expand Down

0 comments on commit 120885a

Please sign in to comment.