Skip to content

Commit

Permalink
Mention estimation approach in doc string.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Aug 9, 2024
1 parent 486d4f8 commit ce2334b
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,30 @@ def _necessary_onnx_models(self) -> dict[str, list[_ScikitModel]]:
def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
r"""Predict the vectors of conditional average outcomes.
These are defined as :math:`\mathbb{E}[Y_i(w) | X]` for each treatment variant
:math:`w`.
If ``is_oos``, an acronym for 'is out of sample' is ``False``,
the estimates will stem from cross-fitting. Otherwise,
various approaches exist, specified via ``oos_method``.
The returned ndarray is of shape:
* :math:`(n_{obs}, n_{variants}, 1)` if the outcome is a scalar, i.e. in case
of a regression problem.
* :math:`(n_{obs}, n_{variants}, n_{classes})` if the outcome is a class,
i.e. in case of a classification problem.
The conditional average outcomes are estimated as follows:
* :math:`Y_i(0) = \hat{\mu}(X_i) - \sum_{k=1}^{K} \hat{e}_k(X_i) \hat{\tau_k}(X_i)`
* :math:`Y_i(k) = Y_i(0) + \hat{\tau_k}(X_i)` for :math:`k \in \{1, \dots, K\}`
where :math:`K` is the number of treatment variants.
"""
n_obs = len(X)

cate_estimates = self.predict(
Expand Down

0 comments on commit ce2334b

Please sign in to comment.