Skip to content

Commit

Permalink
Fix S-Learner's leakage.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Aug 10, 2024
1 parent 802c2b7 commit 84e8c3d
Showing 1 changed file with 3 additions and 32 deletions.
35 changes: 3 additions & 32 deletions metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,47 +251,18 @@ def predict_conditional_average_outcomes(
n_obs = len(X)
conditional_average_outcomes_list = []

# The idea behind using is_oos = True for in sample predictions is the following:
# Assuming observation i has received variant v then the model has been trained
# on row (X_i, v), therefore when predicting the conditional average outcome for
# variant v we have to use cross fitting to avoid prediciting on an identical row
# which the model has been trained on. (This happens either with overall, mean
# or median as some of the models would be trained with this row). On the other
# hand, when predicting the conditional average outcome for variant v' != v,
# the model has never seen the row (X_i, v'), so we can use it as it was out of
# sample.
# This can bring some issues where the cross fitted predictions are based on models
# which have been trained with a smaller dataset (K-1 folds) than the overall
# model and this may produce some different distributions in the outputs, for this
# it may make sense to restrict the oos_method to mean or median when is_oos = False,
# although further investigation is needed.
if not is_oos:
X_with_w = _append_treatment_to_covariates(
X,
self._fitted_treatments,
self._supports_categoricals,
self.n_variants,
)
in_sample_pred = self.predict_nuisance(
X=X_with_w, model_kind=_BASE_MODEL, model_ord=0, is_oos=False
)

for v in range(self.n_variants):
w = np.array([v] * n_obs)
for treatment_variant in range(self.n_variants):
w = np.array([treatment_variant] * n_obs)
X_with_w = _append_treatment_to_covariates(
X, w, self._supports_categoricals, self.n_variants
)
variant_predictions = self.predict_nuisance(
X=X_with_w,
model_kind=_BASE_MODEL,
model_ord=0,
is_oos=True,
is_oos=is_oos,
oos_method=oos_method,
)
if not is_oos:
variant_predictions[self._fitted_treatments == v] = in_sample_pred[
self._fitted_treatments == v
]

conditional_average_outcomes_list.append(variant_predictions)

Expand Down

0 comments on commit 84e8c3d

Please sign in to comment.