Skip to content

Commit

Permalink
Reuse conditional average outcome estimates for X-Learner pseudo outc…
Browse files Browse the repository at this point in the history
…ome.
  • Loading branch information
kklein committed Aug 14, 2024
1 parent d863df1 commit 7941aef
Showing 1 changed file with 37 additions and 28 deletions.
65 changes: 37 additions & 28 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,17 @@ def fit_all_treatment(
qualified_fit_params = self._qualified_fit_params(fit_params)

treatment_jobs: list[_ParallelJoblibSpecification] = []

conditional_average_outcome_estimates = (
self.predict_conditional_average_outcomes(
X=X,
is_oos=False,
)
)

for treatment_variant in range(1, self.n_variants):
imputed_te_control, imputed_te_treatment = self._pseudo_outcome(
X, y, w, treatment_variant
y, w, treatment_variant, conditional_average_outcome_estimates
)
treatment_jobs.append(
self._treatment_joblib_specifications(
Expand Down Expand Up @@ -270,6 +278,7 @@ def predict(
oos_method=oos_method,
)
)

tau_hat_treatment[treatment_variant_indices] = self.predict_treatment(
X=index_matrix(X, treatment_variant_indices),
model_kind=TREATMENT_EFFECT_MODEL,
Expand Down Expand Up @@ -350,11 +359,19 @@ def evaluate(
feature_set=self.feature_set[PROPENSITY_MODEL],
)

conditional_average_outcome_estimates = (
self.predict_conditional_average_outcomes(
X=X,
is_oos=is_oos,
oos_method=oos_method,
)
)

imputed_te_control: list[np.ndarray] = []
imputed_te_treatment: list[np.ndarray] = []
for treatment_variant in range(1, self.n_variants):
tv_imputed_te_control, tv_imputed_te_treatment = self._pseudo_outcome(
X, y, w, treatment_variant
y, w, treatment_variant, conditional_average_outcome_estimates
)
imputed_te_control.append(tv_imputed_te_control)
imputed_te_treatment.append(tv_imputed_te_treatment)
Expand Down Expand Up @@ -391,7 +408,11 @@ def evaluate(
)

def _pseudo_outcome(
self, X: Matrix, y: Vector, w: Vector, treatment_variant: int
self,
y: Vector,
w: Vector,
treatment_variant: int,
conditional_average_outcome_estimates: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute the X-Learner pseudo outcome.
Expand All @@ -404,38 +425,26 @@ def _pseudo_outcome(
"""
validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants)

mask = (w == treatment_variant) | (w == 0)
treatment_indices = w == treatment_variant
control_indices = w == 0

X_filt = X[mask]
y_filt = y[mask]
w_filt = (w[mask] == treatment_variant).astype(int)
treatment_outcome = index_matrix(
conditional_average_outcome_estimates, control_indices
)[:, treatment_variant]
control_outcome = index_matrix(
conditional_average_outcome_estimates, treatment_indices
)[:, 0]

treatment_indices = w_filt == 1
control_indices = w_filt == 0

# This is always oos because the VARIANT_OUTCOME_MODEL[0] is used to predict the
# control outcomes of the treated observations and vice versa.
control_outcome = self.predict_nuisance(
X=index_matrix(X_filt, treatment_indices),
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=0,
is_oos=True,
oos_method=OVERALL,
)
treatment_outcome = self.predict_nuisance(
X=index_matrix(X_filt, control_indices),
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
is_oos=True,
oos_method=OVERALL,
)
if self.is_classification:
# Get the probability of positive class, multiclass is currently not supported.
control_outcome = control_outcome[:, 1]
treatment_outcome = treatment_outcome[:, 1]
else:
control_outcome = control_outcome[:, 0]
treatment_outcome = treatment_outcome[:, 0]

imputed_te_treatment = y_filt[treatment_indices] - control_outcome
imputed_te_control = treatment_outcome - y_filt[control_indices]
imputed_te_treatment = y[treatment_indices] - control_outcome
imputed_te_control = treatment_outcome - y[control_indices]

return imputed_te_control, imputed_te_treatment

Expand Down

0 comments on commit 7941aef

Please sign in to comment.