Skip to content

Commit

Permalink
rename additional variables in tlearner.py and update CHANGELOG.rst
Browse files Browse the repository at this point in the history
  • Loading branch information
kyracho committed Sep 4, 2024
1 parent 91baf35 commit 00fb48c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Changelog

**New features**

* Rename ``_treatment_variants_indices`` to ``_treatment_variants_mask``in ``metalearner``, ``xlearner``, and ``test_learner``.
* Rename ``_treatment_variants_indices`` to ``_treatment_variants_mask``in ``metalearner``, ``xlearner``, ``tlearner``, and ``test_learner``.


0.11.0 (2024-09-xx)
Expand Down
2 changes: 1 addition & 1 deletion metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def predict_conditional_average_outcomes(
if self._treatment_variants_mask is None:
raise ValueError(
"The metalearner needs to be fitted before predicting."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"In particular, the MetaLearner's attribute _treatment_variant_mask, "
"typically set during fitting, is None."
)
# TODO: Consider multiprocessing
Expand Down
10 changes: 4 additions & 6 deletions metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,19 @@ def fit_all_nuisance(
self._validate_treatment(w)
self._validate_outcome(y, w)

self._treatment_variants_indices = []
self._treatment_variants_mask = []

for v in range(self.n_variants):
self._treatment_variants_indices.append(w == v)
self._treatment_variants_mask.append(w == v)

qualified_fit_params = self._qualified_fit_params(fit_params)

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
nuisance_jobs.append(
self._nuisance_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
y=y[self._treatment_variants_indices[treatment_variant]],
X=index_matrix(X, self._treatment_variants_mask[treatment_variant]),
y=y[self._treatment_variants_mask[treatment_variant]],
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
Expand Down
20 changes: 9 additions & 11 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def fit_all_treatment(
if self._treatment_variants_mask is None:
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"In particular, the MetaLearner's attribute _treatment_variant_mask, "
"typically set during nuisance fitting, is None."
)
if not hasattr(self, "_cvs"):
Expand Down Expand Up @@ -224,7 +224,7 @@ def predict(
if self._treatment_variants_mask is None:
raise ValueError(
"The MetaLearner needs to be fitted before predicting. "
"In particular, the X-Learner's attribute _treatment_variant_indices, "
"In particular, the X-Learner's attribute _treatment_variant_mask, "
"typically set during fitting, is None."
)
n_outputs = 2 if self.is_classification else 1
Expand All @@ -244,7 +244,7 @@ def predict(

for treatment_variant in range(1, self.n_variants):
treatment_variant_indices = self._treatment_variants_mask[treatment_variant]
non_treatment_variant_indices = ~treatment_variant_indices
non_treatment_variant_mask = ~treatment_variant_indices
if is_oos:
tau_hat_treatment = self.predict_treatment(
X=X,
Expand All @@ -264,14 +264,12 @@ def predict(
tau_hat_treatment = np.zeros(safe_len(X))
tau_hat_control = np.zeros(safe_len(X))

tau_hat_treatment[non_treatment_variant_indices] = (
self.predict_treatment(
X=index_matrix(X, non_treatment_variant_indices),
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=True,
oos_method=oos_method,
)
tau_hat_treatment[non_treatment_variant_mask] = self.predict_treatment(
X=index_matrix(X, non_treatment_variant_mask),
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=True,
oos_method=oos_method,
)

tau_hat_treatment[treatment_variant_indices] = self.predict_treatment(
Expand Down

0 comments on commit 00fb48c

Please sign in to comment.