Skip to content

Commit

Permalink
Add predict_conditional_average_outcomes method to R-Learner (#78)
Browse files Browse the repository at this point in the history
* Enable R-Learner for tests.

* Add method to MetaLearner class.

* Draft predict_conditional_average_outcomes method for R-Learner.

* Fix PR-unrelated typo in changelog.

* Add changelog entry.

* Fix shape issue.

* Mention estimation approach in doc string.

* Fix typo.

* Remove duplicated docstrings.
  • Loading branch information
kklein authored Aug 13, 2024
1 parent 4409cc5 commit 9c6e0a5
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 38 deletions.
16 changes: 13 additions & 3 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,18 @@
Changelog
=========

0.9.1 (2024-08-xx)
------------------
0.10.0 (2024-08-xx)
-------------------

**New features**

* Add abstract method
:meth:`~metalearners.metalearner.MetaLearner.predict_conditional_average_outcomes`
to :class:`~metalearners.metalearner.MetaLearner`.

* Implement
:meth:`~metalearners.rlearner.RLearner.predict_conditional_average_outcomes`
for :class:`~metalearners.rlearner.RLearner`.

**Bug fixes**

Expand All @@ -20,7 +30,7 @@ Changelog

**New features**

* Add :meth:`metalearners.metalearner.MetaLearner.init_params`.
* Add :meth:`metalearners.metalearner.MetaLearner.init_args`.

* Add :class:`metalearners.utils.FixedBinaryPropensity`.

Expand Down
40 changes: 23 additions & 17 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,29 @@ def predict(
"""
...

@abstractmethod
def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
r"""Predict the vectors of conditional average outcomes.
These are defined as :math:`\mathbb{E}[Y_i(w) | X]` for each treatment variant
:math:`w`.
If ``is_oos``, an acronym for 'is out of sample' is ``False``,
the estimates will stem from cross-fitting. Otherwise,
various approaches exist, specified via ``oos_method``.
The returned ndarray is of shape:
* :math:`(n_{obs}, n_{variants}, 1)` if the outcome is a scalar, i.e. in case
of a regression problem.
* :math:`(n_{obs}, n_{variants}, n_{classes})` if the outcome is a class,
i.e. in case of a classification problem.
"""
...

@abstractmethod
def evaluate(
self,
Expand Down Expand Up @@ -1317,23 +1340,6 @@ def __init__(
def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
r"""Predict the vectors of conditional average outcomes.
These are defined as :math:`\mathbb{E}[Y_i(w) | X]` for each treatment variant
:math:`w`.
If ``is_oos``, an acronym for 'is out of sample' is ``False``,
the estimates will stem from cross-fitting. Otherwise,
various approaches exist, specified via ``oos_method``.
The returned ndarray is of shape:
* :math:`(n_{obs}, n_{variants}, 1)` if the outcome is a scalar, i.e. in case
of a regression problem.
* :math:`(n_{obs}, n_{variants}, n_{classes})` if the outcome is a class,
i.e. in case of a classification problem.
"""
if self._treatment_variants_indices is None:
raise ValueError(
"The metalearner needs to be fitted before predicting."
Expand Down
83 changes: 83 additions & 0 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,89 @@ def _necessary_onnx_models(self) -> dict[str, list[_ScikitModel]]:
)
}

def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
r"""Predict the vectors of conditional average outcomes.
These are defined as :math:`\mathbb{E}[Y_i(w) | X]` for each treatment variant
:math:`w`.
If ``is_oos``, an acronym for 'is out of sample' is ``False``,
the estimates will stem from cross-fitting. Otherwise,
various approaches exist, specified via ``oos_method``.
The returned ndarray is of shape:
* :math:`(n_{obs}, n_{variants}, 1)` if the outcome is a scalar, i.e. in case
of a regression problem.
* :math:`(n_{obs}, n_{variants}, n_{classes})` if the outcome is a class,
i.e. in case of a classification problem.
The conditional average outcomes are estimated as follows:
* :math:`Y_i(0) = \hat{\mu}(X_i) - \sum_{k=1}^{K} \hat{e}_k(X_i) \hat{\tau_k}(X_i)`
* :math:`Y_i(k) = Y_i(0) + \hat{\tau_k}(X_i)` for :math:`k \in \{1, \dots, K\}`
where :math:`K` is the number of treatment variants.
"""
n_obs = len(X)

cate_estimates = self.predict(
X=X,
is_oos=is_oos,
oos_method=oos_method,
)
propensity_estimates = self.predict_nuisance(
X=X,
model_kind=PROPENSITY_MODEL,
model_ord=0,
is_oos=is_oos,
oos_method=oos_method,
)
outcome_estimates = self.predict_nuisance(
X=X,
model_kind=OUTCOME_MODEL,
model_ord=0,
is_oos=is_oos,
oos_method=oos_method,
)

conditional_average_outcomes_list = []

control_outcomes = outcome_estimates

# TODO: Consider whether the readability vs efficiency trade-off should be dealt with differently here.
# One could use matrix/tensor operations instead.
for treatment_variant in range(1, self.n_variants):
if (n_outputs := cate_estimates.shape[2]) > 1:
for outcome_channel in range(0, n_outputs):
control_outcomes[:, outcome_channel] -= (
propensity_estimates[:, treatment_variant]
* cate_estimates[:, treatment_variant - 1, outcome_channel]
)
else:
control_outcomes -= (
propensity_estimates[:, treatment_variant]
* cate_estimates[:, treatment_variant - 1, 0]
)

conditional_average_outcomes_list.append(control_outcomes)

for treatment_variant in range(1, self.n_variants):
conditional_average_outcomes_list.append(
control_outcomes
+ np.reshape(
cate_estimates[:, treatment_variant - 1, :],
(control_outcomes.shape),
)
)

return np.stack(conditional_average_outcomes_list, axis=1).reshape(
n_obs, self.n_variants, -1
)

@copydoc(MetaLearner._build_onnx, sep="")
def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""In the RLearner case, the necessary models are:
Expand Down
17 changes: 0 additions & 17 deletions metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,23 +231,6 @@ def evaluate(
def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
r"""Predict the vectors of conditional average outcomes.
These are defined as :math:`\mathbb{E}[Y_i(w) | X]` for each treatment variant
:math:`w`.
If ``is_oos``, an acronym for 'is out of sample' is ``False``,
the estimates will stem from cross-fitting. Otherwise,
various approaches exist, specified via ``oos_method``.
The returned ndarray is of shape:
* :math:`(n_{obs}, n_{variants}, 1)` if the outcome is a scalar, i.e. in case
of a regression problem.
* :math:`(n_{obs}, n_{variants}, n_{classes})` if the outcome is a class,
i.e. in case of a classification problem.
"""
n_obs = len(X)
conditional_average_outcomes_list = []

Expand Down
2 changes: 1 addition & 1 deletion tests/test_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def test_validate_outcome_multi_class(metalearner_prefix, success):


@pytest.mark.parametrize("is_classification", [True, False])
@pytest.mark.parametrize("metalearner_prefix", ["S", "T", "X", "DR"])
@pytest.mark.parametrize("metalearner_prefix", ["S", "T", "R", "X", "DR"])
def test_conditional_average_outcomes_smoke(
metalearner_prefix, is_classification, request
):
Expand Down

0 comments on commit 9c6e0a5

Please sign in to comment.