Skip to content

Commit

Permalink
Merge branch 'pixi' of github.com:Quantco/metalearners into pixi
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Jun 25, 2024
2 parents 28ea33b + 1f39681 commit 5ee2860
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 24 deletions.
24 changes: 15 additions & 9 deletions docs/background.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,15 @@ It is an extension of the T-Learner and consists of three stages:
\widetilde{D}_1^i &:= Y^i_1 - \hat{\mu}_0(X^i_1) \\
\widetilde{D}_0^i &:= \hat{\mu}_1(X^i_0) - Y^i_0
Then estimate :math:`\tau_1(x) := \mathbb{E}[\widetilde{D}^i_1 | X]` and
:math:`\tau_0(x) := \mathbb{E}[\widetilde{D}^i_0 | X]` using the observations in the
Then estimate :math:`\tau_1(x) := \mathbb{E}[\widetilde{D}^i_1 | X=x]` and
:math:`\tau_0(x) := \mathbb{E}[\widetilde{D}^i_0 | X=x]` using the observations in the
treatment group and the ones in the control group respectively.
#. Define the CATE estimate by a weighted average of the two estimates in stage 2:

.. math::
\hat{\tau}^X(x) := g(x)\hat{\tau}_0(x) + (1-g(x))\hat{\tau}_1(x)
where :math:`g(x) \in [0,1]`. We take :math:`g(x) := \mathbb{E}[W = 1 | X]` to be
where :math:`g(x) \in [0,1]`. We take :math:`g(x) := \mathbb{E}[W = 1 | X=x]` to be
the propensity score.

More than binary treatment
Expand All @@ -388,8 +388,8 @@ In the case of multiple discrete treatments the stages are similar to the binary
\widetilde{D}_k^i &:= Y^i_k - \hat{\mu}_0(X^i_k) \\
\widetilde{D}_{0,k}^i &:= \hat{\mu}_k(X^i_0) - Y^i_0
Then :math:`\tau_k(x) := \mathbb{E}[\widetilde{D}^i_k | X]` is estimated using the
observations which received treatment :math:`k` and :math:`\tau_{0,k}(x) := \mathbb{E}[\widetilde{D}^i_{0,k} | X]`
Then :math:`\tau_k(x) := \mathbb{E}[\widetilde{D}^i_k | X=x]` is estimated using the
observations which received treatment :math:`k` and :math:`\tau_{0,k}(x) := \mathbb{E}[\widetilde{D}^i_{0,k} | X=x]`
using the observations in the control group.

#. Finally the CATE for each variant is estimated as a weighted average:
Expand Down Expand Up @@ -419,9 +419,15 @@ It consists of two stages:

.. math::
\DeclareMathOperator*{\argmin}{arg\,min}
\hat{\tau}^R (x) &:= \argmin_{\tau}\Bigg\{\mathbb{E}\Bigg[\bigg(\left\{Y^i - \hat{m}(X^i)\right\} - \left\{W^i - \hat{e}(X^i)\right\}\tau(X^i)\bigg)^2\Bigg]\Bigg\} \\
\hat{\tau}^R (\cdot) &:= \argmin_{\tau}\Bigg\{\mathbb{E}\Bigg[\bigg(\left\{Y^i - \hat{m}(X^i)\right\} - \left\{W^i - \hat{e}(X^i)\right\}\tau(X^i)\bigg)^2\Bigg]\Bigg\} \\
&=\argmin_{\tau}\left\{\mathbb{E}\left[\left\{W^i - \hat{e}(X^i)\right\}^2\bigg(\frac{\left\{Y^i - \hat{m}(X^i)\right\}}{\left\{W^i - \hat{e}(X^i)\right\}} - \tau(X^i)\bigg)^2\right]\right\} \\
&= \argmin_{\tau}\left\{\mathbb{E}\left[{\tilde{W}^i}^2\bigg(\frac{\tilde{Y}^i}{\tilde{W}^i} - \tau(X^i)\bigg)^2\right]\right\}
&= \argmin_{\tau}\left\{\mathbb{E}\left[{\widetilde{W}^i}^2\bigg(\frac{\widetilde{Y}^i}{\widetilde{W}^i} - \tau(X^i)\bigg)^2\right]\right\}
Where

.. math::
\widetilde{W}^i &= W^i - \hat{e}(X^i) \\
\widetilde{Y}^i &= Y^i - \hat{m}(X^i)
And therefore any ML model which supports weighting each observation differently can be used for the final model.

Expand Down Expand Up @@ -484,7 +490,7 @@ It consists of two stages:
#. Estimate the CATE by regressing :math:`\varphi` on :math:`X`:

.. math::
\hat{\tau}^{DR}(x) := \mathbb{E}[\varphi(X^i, W^i, Y^i) | X^i]
\hat{\tau}^{DR}(x) := \mathbb{E}[\varphi(X^i, W^i, Y^i) | X^i=x]
More than binary treatment
**************************
Expand All @@ -508,4 +514,4 @@ In the case of multiple discrete treatments the stages are similar to the binary
treatment variant, :math:`\forall k \in \{1,\dots, K-1\}`:

.. math::
\hat{\tau}_k^{DR}(x) := \mathbb{E}[\varphi_k(X^i, W^i, Y^i) | X^i]
\hat{\tau}_k^{DR}(x) := \mathbb{E}[\varphi_k(X^i, W^i, Y^i) | X^i=x]
23 changes: 20 additions & 3 deletions docs/examples/example_basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@
"* We need to specify the observed treatment assignment ``w`` in the call to the\n",
" ``fit`` method.\n",
"* We need to specify whether we want in-sample or out-of-sample\n",
" estimates in the {meth}`~metalearners.TLearner.predict` call via ``is_oos``."
" CATE estimates in the {meth}`~metalearners.TLearner.predict` call via ``is_oos``. In the\n",
" case of in-sample predictions, the data passed to {meth}`~metalearners.TLearner.predict`\n",
" must be exactly the same as the data that was used to call {meth}`~metalearners.TLearner.fit`."
]
},
{
Expand Down Expand Up @@ -176,7 +178,7 @@
"Using a MetaLearner with two stages\n",
"-----------------------------------\n",
"\n",
"Instead of using a T-Learner, we can of course also some other\n",
"Instead of using a T-Learner, we can of course also use some other\n",
"MetaLearner, such as the {class}`~metalearners.RLearner`.\n",
"The R-Learner's documentation tells us that two more instantiation\n",
"parameters are necessary: ``propensity_model_factory`` and\n",
Expand Down Expand Up @@ -209,7 +211,22 @@
"metadata": {},
"source": [
"where we choose a classifier class to serve as a blueprint for our\n",
"eventual propensity model.\n",
"eventual propensity model. It is important to notice that although we consider the propensity\n",
"model a nuisance model, the initialization parameters for it are separated from the other\n",
"nuisance parameters to allow a more understandable user interface, see the next code prompt.\n",
"\n",
"In general, when initializing a MetaLearner, the ``nuisance_model_factory`` parameter will\n",
"be used to create all the nuisance models which are not a propensity model, the\n",
"``propensity_model_factory`` will be used for the propensity model if the MetaLearner\n",
"contains one, and the ``treatment_model_factory`` will be used for the models predicting\n",
"the CATE. To see the models present in each MetaLearner type see\n",
"{meth}`~metalearners.metalearner.MetaLearner.nuisance_model_specifications` and\n",
"{meth}`~metalearners.metalearner.MetaLearner.treatment_model_specifications`.\n",
"\n",
"In the {class}`~metalearners.RLearner` case, the ``nuisance_model_factory`` parameter will\n",
"be used to create the outcome model, the ``propensity_model_factory`` will be used for the\n",
"propensity model and the ``treatment_model_factory`` will be used for the model predicting\n",
"the CATE.\n",
"\n",
"If we want to make sure these models are initialized in a specific\n",
"way, e.g. with a specific value for the hyperparameter ``n_estimators``, we can do that\n",
Expand Down
21 changes: 19 additions & 2 deletions docs/examples/example_feature_importance_shap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@
"source": [
"Note that the method {meth}`~metalearners.explainer.Explainer.feature_importances`\n",
"returns a list of length {math}`n_{variats} -1` that indicates the feature importance for\n",
"each variant against control.\n",
"each variant against control. Remember that a higher value means that the corresponding\n",
"feature is more important for the CATE prediction.\n",
"\n",
"### Computing and plotting the SHAP values\n",
"\n",
Expand Down Expand Up @@ -367,7 +368,23 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"For guidelines on how to interpret such SHAP plots please see the [SHAP documentation](https://github.com/shap/shap).\n",
"In these SHAP summary plots, the color and orientation of the plotted values help us to understand\n",
"their impact on model predictions.\n",
"\n",
"Each dot in the plot represents a single instance of the given feature present in the data set.\n",
"The x-axis displays the Shapley value, signifying the strength and directionality of the\n",
"feature's impact. The y-axis displays a subset of the features in the model.\n",
"\n",
"The Shapley value, exhibited on the horizontal axis, is oriented such that values on the\n",
"right of the center line (0 mark) contribute to a positive shift in the predicted outcome,\n",
"while those on the left indicate a negative impact.\n",
"\n",
"The color coding implemented in these plots is straightforward: red implies a high feature value,\n",
"while blue denotes a low feature value. This color scheme assists in identifying whether\n",
"high or low values of a certain feature influence the model's output positively or negatively.\n",
"The categorical variables are colored in grey.\n",
"\n",
"For more guidelines on how to interpret such SHAP plots please see the [SHAP documentation](https://github.com/shap/shap).\n",
"\n",
"Note that the method {meth}`~metalearners.explainer.Explainer.shap_values`\n",
"returns a list of length {math}`n_{variats} -1` that indicates the SHAP values for\n",
Expand Down
22 changes: 16 additions & 6 deletions docs/examples/example_lime.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"* {math}`f`, the original model -- in our case the MetaLearner\n",
"* {math}`G`, the class of possible, interpretable surrogate models\n",
"* {math}`\\Omega(g)`, a measure of complexity for {math}`g \\in G`\n",
"* {math}`\\pi_x(z)` a proximity measure of {math}`z` with respect to data point {math}`x`\n",
"* {math}`\\pi_x(z)` a proximity measure of an instance {math}`z` with respect to data point {math}`x`\n",
"* {math}`\\mathcal{L}(f, g, \\pi_x)` a measure of how unfaithful a {math}`g \\in G` is to {math}`f` in the locality defined by {math}`\\pi_x`\n",
"\n",
"Given all of these objects as well as a to be explained data point {math}`x`, the authors suggest that the most appropriate surrogate {math}`g`, also referred to as explanation for {math}`x`, {math}`\\xi(x)`, can be expressed as follows:\n",
Expand All @@ -73,19 +73,22 @@
"* have little redundancy between each other\n",
"* showcase the features with highest global importance\n",
"\n",
"In line with this ambition, they define a notion of 'coverage' -- to\n",
"be maximized --as follows:\n",
"In line with this ambition, they define a notion of 'coverage' which specifies how well a set\n",
"of candidate datapoints {math}`V` are explained by features that are relevant for\n",
"many observed datapoints. The goal is to find {math}`V` that is not larger than some\n",
"pre-specified size such that this coverage is maximal.\n",
"\n",
"```{math}\n",
" c(V, W, \\mathcal{I}) = \\sum_{j=1}^{d} I[\\exists i \\in V: W_{i,j} > 0] \\mathcal{I}_j\n",
" c(V, W, \\mathcal{I}) = \\sum_{j=1}^{d} \\mathbb{I}\\{\\exists i \\in V: W_{i,j} > 0\\} \\mathcal{I}_j\n",
"````\n",
"\n",
"where\n",
"\n",
"* {math}`d` is the number of features\n",
"* {math}`V` is the candidate set of explanations to be shown to\n",
" humans, within a fixed budget -- this is the variable to be optimized\n",
"* {math}`W` is a {math}`n \\times d` local feature importance matrix and\n",
"* {math}`W` is a {math}`n \\times d` local feature importance matrix that represents\n",
" the local importance of each feature for each instance, and\n",
"* {math}`\\mathcal{I}` is a {math}`d`-dimensional vector of global\n",
" feature importances\n",
"\n",
Expand Down Expand Up @@ -359,7 +362,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"For guidelines on how to interpret such lime plots please see the [lime documentation](https://github.com/marcotcr/lime)."
"In these plots, the green bars signify that the presence of the corresponding feature\n",
"referenced on the y-axis, increases the CATE estimate for that observation, whereas, the\n",
"red bars represent that the feature presence in the observation reduces the CATE.\n",
"Furthermore, the length of these colored bars corresponds to the magnitude of each feature's\n",
"contribution towards the model prediction. Therefore, the longer the bar, the more\n",
"significant the impact of that feature on the model prediction.\n",
"\n",
"For more guidelines on how to interpret such lime plots please see the [lime documentation](https://github.com/marcotcr/lime)."
]
}
],
Expand Down
3 changes: 2 additions & 1 deletion docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ FAQ
Double machine learning is an ATE estimation technique, pioneered by
`Chernozhukov et al. (2016) <https://arxiv.org/abs/1608.00060>`_.
It is 'double' in the sense that it relies on two preliminary models: one for the probability of
receiving treatment given covariates (the propensity score), and one for the outcome given treatment and covariates.
receiving treatment given covariates (the propensity score), and one for the outcome given covariates and
optionally the (discrete) treatment.

Double ML is also referred to as 'debiased' ML, since the propensity score model is used to 'debias'
a naive estimator that uses the outcome model to predict the expected outcome under treatment, and under no treatment,
Expand Down
4 changes: 2 additions & 2 deletions docs/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Glossary
Similar to the R-Learner, the Double Machine Learning blueprint
relies on estimating two nuisance models in its first stage: a
propensity model as well as an outcome model. Unlike the
R-Learner, the last-stage or treatment effect model might not
be any estimator.
R-Learner, the last-stage or treatment effect model might need to be a
specific type of estimator.
See `Chernozhukov et al. (2016) <https://arxiv.org/abs/1608.00060>`_.

Heterogeneous Treatment Effect (HTE)
Expand Down
2 changes: 1 addition & 1 deletion metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def fit(
pattern, propensity models are considered a nuisance model.
``synchronize_cross_fitting`` indicates whether the learning of different base models should use exactly
the same data splits where possible. Note that if there are several to be synchronize models which are
the same data splits where possible. Note that if there are several models to be synchronized which are
classifiers, these cannot be split via stratification.
"""
...
Expand Down

0 comments on commit 5ee2860

Please sign in to comment.