Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split fit in fit_all_nuisance and fit_all_treatment #64

Merged
merged 6 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Changelog
0.8.0 (2024-07-xx)
------------------

**New features**

* Added :meth:`metalearners.metalearner.MetaLearner.fit_all_nuisance` and
:meth:`metalearners.metalearner.MetaLearner.fit_all_treatment`.

* Implement :meth:`metalearners.cross_fit_estimator.CrossFitEstimator.score`.

**Bug fixes**
Expand Down
33 changes: 28 additions & 5 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
OosMethod,
Params,
Scoring,
SplitIndices,
Vector,
_ScikitModel,
)
Expand Down Expand Up @@ -128,7 +129,7 @@
)
self.adaptive_clipping = adaptive_clipping

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand All @@ -148,10 +149,12 @@
for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)

self._cv_split_indices: SplitIndices | None

if synchronize_cross_fitting:
cv_split_indices = self._split(X)
self._cv_split_indices = self._split(X)
else:
cv_split_indices = None
self._cv_split_indices = None

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
Expand All @@ -176,7 +179,7 @@
model_ord=0,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL],
cv=cv_split_indices,
cv=self._cv_split_indices,
)
)

Expand All @@ -189,6 +192,25 @@

self._assign_joblib_nuisance_results(results)

return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
if not hasattr(self, "_cv_split_indices"):
raise ValueError(

Check warning on line 208 in metalearners/drlearner.py

View check run for this annotation

Codecov / codecov/patch

metalearners/drlearner.py#L208

Added line #L208 was not covered by tests
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _cv_split_indices, "
"typically set during nuisance fitting, does not exist."
)
qualified_fit_params = self._qualified_fit_params(fit_params)
treatment_jobs: list[_ParallelJoblibSpecification] = []
for treatment_variant in range(1, self.n_variants):
pseudo_outcomes = self._pseudo_outcome(
Expand All @@ -207,9 +229,10 @@
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][TREATMENT_MODEL],
cv=cv_split_indices,
cv=self._cv_split_indices,
)
)
parallel = Parallel(n_jobs=n_jobs_base_learners)
results = parallel(
delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
)
Expand Down
57 changes: 56 additions & 1 deletion metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,41 @@
] = result.cross_fit_estimator

@abstractmethod
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
"""Fit all nuisance models of the MetaLearner.

If pre-fitted models were passed at instantiation, these are never refitted.

For the parameters check :meth:`metalearners.metalearner.MetaLearner.fit`.
"""
...

Check warning on line 765 in metalearners/metalearner.py

View check run for this annotation

Codecov / codecov/patch

metalearners/metalearner.py#L765

Added line #L765 was not covered by tests

@abstractmethod
def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
"""Fit all treatment models of the MetaLearner.

For the parameters check :meth:`metalearners.metalearner.MetaLearner.fit`.
"""
...

Check warning on line 782 in metalearners/metalearner.py

View check run for this annotation

Codecov / codecov/patch

metalearners/metalearner.py#L782

Added line #L782 was not covered by tests

def fit(
self,
X: Matrix,
Expand Down Expand Up @@ -793,7 +828,27 @@
the same data splits where possible. Note that if there are several models to be synchronized which are
classifiers, these cannot be split via stratification.
"""
...
self.fit_all_nuisance(
X=X,
y=y,
w=w,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=fit_params,
synchronize_cross_fitting=synchronize_cross_fitting,
n_jobs_base_learners=n_jobs_base_learners,
)

self.fit_all_treatment(
X=X,
y=y,
w=w,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=fit_params,
synchronize_cross_fitting=synchronize_cross_fitting,
n_jobs_base_learners=n_jobs_base_learners,
)

return self

def predict_nuisance(
self,
Expand Down
22 changes: 17 additions & 5 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _supports_multi_treatment(cls) -> bool:
def _supports_multi_class(cls) -> bool:
return False

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand All @@ -165,14 +165,10 @@ def fit(
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
epsilon: float = _EPSILON,
) -> Self:

self._validate_treatment(w)
self._validate_outcome(y, w)

self._variants_indices = []

qualified_fit_params = self._qualified_fit_params(fit_params)
self._validate_fit_params(qualified_fit_params)

Expand Down Expand Up @@ -214,7 +210,22 @@ def fit(
)
self._assign_joblib_nuisance_results(results)

return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
epsilon: float = _EPSILON,
) -> Self:
qualified_fit_params = self._qualified_fit_params(fit_params)
treatment_jobs: list[_ParallelJoblibSpecification] = []
self._variants_indices = []
for treatment_variant in range(1, self.n_variants):

is_treatment = w == treatment_variant
Expand Down Expand Up @@ -246,6 +257,7 @@ def fit(
n_jobs_cross_fitting=n_jobs_cross_fitting,
)
)
parallel = Parallel(n_jobs=n_jobs_base_learners)
results = parallel(
delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
)
Expand Down
14 changes: 13 additions & 1 deletion metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
random_state=random_state,
)

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand Down Expand Up @@ -175,6 +175,18 @@ def fit(
)
return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
return self

def predict(
self,
X: Matrix,
Expand Down
14 changes: 13 additions & 1 deletion metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _supports_multi_treatment(cls) -> bool:
def _supports_multi_class(cls) -> bool:
return True

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand Down Expand Up @@ -92,6 +92,18 @@ def fit(
self._assign_joblib_nuisance_results(results)
return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
return self

def predict(
self,
X,
Expand Down
39 changes: 33 additions & 6 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
def _supports_multi_class(cls) -> bool:
return False

def fit(
def fit_all_nuisance(
self,
X: Matrix,
y: Vector,
Expand All @@ -91,7 +91,7 @@

qualified_fit_params = self._qualified_fit_params(fit_params)

cvs: list = []
self._cvs: list = []

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
Expand All @@ -101,7 +101,7 @@
)
else:
cv_split_indices = None
cvs.append(cv_split_indices)
self._cvs.append(cv_split_indices)

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
Expand All @@ -115,7 +115,7 @@
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL],
cv=cvs[treatment_variant],
cv=self._cvs[treatment_variant],
)
)

Expand All @@ -138,6 +138,32 @@
)
self._assign_joblib_nuisance_results(results)

return self

def fit_all_treatment(
self,
X: Matrix,
y: Vector,
w: Vector,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
if self._treatment_variants_indices is None:
raise ValueError(

Check warning on line 154 in metalearners/xlearner.py

View check run for this annotation

Codecov / codecov/patch

metalearners/xlearner.py#L154

Added line #L154 was not covered by tests
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during nuisance fitting, is None."
)
if not hasattr(self, "_cvs"):
raise ValueError(

Check warning on line 160 in metalearners/xlearner.py

View check run for this annotation

Codecov / codecov/patch

metalearners/xlearner.py#L160

Added line #L160 was not covered by tests
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _cvs, "
"typically set during nuisance fitting, does not exist."
)
qualified_fit_params = self._qualified_fit_params(fit_params)

treatment_jobs: list[_ParallelJoblibSpecification] = []
for treatment_variant in range(1, self.n_variants):
imputed_te_control, imputed_te_treatment = self._pseudo_outcome(
Expand All @@ -153,7 +179,7 @@
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL],
cv=cvs[treatment_variant],
cv=self._cvs[treatment_variant],
)
)

Expand All @@ -165,10 +191,11 @@
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL],
cv=cvs[0],
cv=self._cvs[0],
)
)

parallel = Parallel(n_jobs=n_jobs_base_learners)
results = parallel(
delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
)
Expand Down
14 changes: 13 additions & 1 deletion tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _supports_multi_class(cls) -> bool:

def _validate_models(self) -> None: ...

def fit(
def fit_all_nuisance(
self,
X,
y,
Expand All @@ -83,6 +83,18 @@ def fit(
self.nuisance_model_specifications()[model_kind]["cardinality"](self)
):
self.fit_nuisance(X, y, model_kind, model_ord)
return self

def fit_all_treatment(
self,
X,
y,
w,
n_jobs_cross_fitting: int | None = None,
fit_params: dict | None = None,
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
):
for model_kind in self.__class__.treatment_model_specifications():
for model_ord in range(
self.treatment_model_specifications()[model_kind]["cardinality"](self)
Expand Down
Loading