Skip to content

Commit

Permalink
Fix S-Learner's leakage (#79)
Browse files Browse the repository at this point in the history
* Update benchmark values.

* Fix S-Learner's leakage.

* Add changelog entry.

* Fix date in changelog.
  • Loading branch information
kklein authored Aug 12, 2024
1 parent d00947a commit 4409cc5
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 38 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
Changelog
=========

0.9.1 (2024-08-xx)
------------------

**Bug fixes**

* Fix bug in which the :class:`~metalearners.slearner.SLearner`'s
inference step would have some leakage in the in-sample scenario.

0.9.0 (2024-08-02)
------------------

Expand Down
12 changes: 6 additions & 6 deletions benchmarks/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ on ground truth CATEs:

| S-learner | causalml_in_sample | causalml_oos | econml_in_sample | econml_oos | metalearners_in_sample | metalearners_oos |
| :------------------------------------------------------------ | -----------------: | -----------: | ---------------: | ---------: | ---------------------: | ---------------: |
| synthetic_data_continuous_outcome_binary_treatment_linear_te | 14.5706 | 14.6248 | 14.5706 | 14.6248 | 14.5729 | 14.6248 |
| synthetic_data_binary_outcome_binary_treatment_linear_te | 0.229101 | 0.228616 | nan | nan | 0.229231 | 0.2286 |
| twins_pandas | 0.314253 | 0.318554 | nan | nan | 0.371613 | 0.319028 |
| twins_numpy | 0.314253 | 0.318554 | nan | nan | 0.361345 | 0.318554 |
| synthetic_data_continuous_outcome_multi_treatment_linear_te | nan | nan | 14.1468 | 14.185 | 14.1478 | 14.1853 |
| synthetic_data_continuous_outcome_multi_treatment_constant_te | nan | nan | 0.0110779 | 0.0110778 | 0.0104649 | 0.00897915 |
| synthetic_data_continuous_outcome_binary_treatment_linear_te | 14.5706 | 14.6248 | 14.5706 | 14.6248 | 14.5707 | 14.6248 |
| synthetic_data_binary_outcome_binary_treatment_linear_te | 0.229101 | 0.228616 | nan | nan | 0.229201 | 0.2286 |
| twins_pandas | 0.314253 | 0.318554 | nan | nan | 0.322171 | 0.319028 |
| twins_numpy | 0.314253 | 0.318554 | nan | nan | 0.322132 | 0.318554 |
| synthetic_data_continuous_outcome_multi_treatment_linear_te | nan | nan | 14.1468 | 14.185 | 14.147 | 14.1853 |
| synthetic_data_continuous_outcome_multi_treatment_constant_te | nan | nan | 0.0110779 | 0.0110778 | 0.0101122 | 0.00897915 |

| X-learner | causalml_in_sample | causalml_oos | econml_in_sample | econml_oos | metalearners_in_sample | metalearners_oos |
| :------------------------------------------------------------ | -----------------: | -----------: | ---------------: | ---------: | ---------------------: | ---------------: |
Expand Down
35 changes: 3 additions & 32 deletions metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,47 +251,18 @@ def predict_conditional_average_outcomes(
n_obs = len(X)
conditional_average_outcomes_list = []

# The idea behind using is_oos = True for in sample predictions is the following:
# Assuming observation i has received variant v then the model has been trained
# on row (X_i, v), therefore when predicting the conditional average outcome for
# variant v we have to use cross fitting to avoid prediciting on an identical row
# which the model has been trained on. (This happens either with overall, mean
# or median as some of the models would be trained with this row). On the other
# hand, when predicting the conditional average outcome for variant v' != v,
# the model has never seen the row (X_i, v'), so we can use it as it was out of
# sample.
# This can bring some issues where the cross fitted predictions are based on models
# which have been trained with a smaller dataset (K-1 folds) than the overall
# model and this may produce some different distributions in the outputs, for this
# it may make sense to restrict the oos_method to mean or median when is_oos = False,
# although further investigation is needed.
if not is_oos:
X_with_w = _append_treatment_to_covariates(
X,
self._fitted_treatments,
self._supports_categoricals,
self.n_variants,
)
in_sample_pred = self.predict_nuisance(
X=X_with_w, model_kind=_BASE_MODEL, model_ord=0, is_oos=False
)

for v in range(self.n_variants):
w = np.array([v] * n_obs)
for treatment_variant in range(self.n_variants):
w = np.array([treatment_variant] * n_obs)
X_with_w = _append_treatment_to_covariates(
X, w, self._supports_categoricals, self.n_variants
)
variant_predictions = self.predict_nuisance(
X=X_with_w,
model_kind=_BASE_MODEL,
model_ord=0,
is_oos=True,
is_oos=is_oos,
oos_method=oos_method,
)
if not is_oos:
variant_predictions[self._fitted_treatments == v] = in_sample_pred[
self._fitted_treatments == v
]

conditional_average_outcomes_list.append(variant_predictions)

Expand Down

0 comments on commit 4409cc5

Please sign in to comment.