diff --git a/docs/background.rst b/docs/background.rst index 7eee9a2..969eeaf 100644 --- a/docs/background.rst +++ b/docs/background.rst @@ -358,15 +358,15 @@ It is an extension of the T-Learner and consists of three stages: \widetilde{D}_1^i &:= Y^i_1 - \hat{\mu}_0(X^i_1) \\ \widetilde{D}_0^i &:= \hat{\mu}_1(X^i_0) - Y^i_0 - Then estimate :math:`\tau_1(x) := \mathbb{E}[\widetilde{D}^i_1 | X]` and - :math:`\tau_0(x) := \mathbb{E}[\widetilde{D}^i_0 | X]` using the observations in the + Then estimate :math:`\tau_1(x) := \mathbb{E}[\widetilde{D}^i_1 | X=x]` and + :math:`\tau_0(x) := \mathbb{E}[\widetilde{D}^i_0 | X=x]` using the observations in the treatment group and the ones in the control group respectively. #. Define the CATE estimate by a weighted average of the two estimates in stage 2: .. math:: \hat{\tau}^X(x) := g(x)\hat{\tau}_0(x) + (1-g(x))\hat{\tau}_1(x) - where :math:`g(x) \in [0,1]`. We take :math:`g(x) := \mathbb{E}[W = 1 | X]` to be + where :math:`g(x) \in [0,1]`. We take :math:`g(x) := \mathbb{E}[W = 1 | X=x]` to be the propensity score. More than binary treatment @@ -388,8 +388,8 @@ In the case of multiple discrete treatments the stages are similar to the binary \widetilde{D}_k^i &:= Y^i_k - \hat{\mu}_0(X^i_k) \\ \widetilde{D}_{0,k}^i &:= \hat{\mu}_k(X^i_0) - Y^i_0 - Then :math:`\tau_k(x) := \mathbb{E}[\widetilde{D}^i_k | X]` is estimated using the - observations which received treatment :math:`k` and :math:`\tau_{0,k}(x) := \mathbb{E}[\widetilde{D}^i_{0,k} | X]` + Then :math:`\tau_k(x) := \mathbb{E}[\widetilde{D}^i_k | X=x]` is estimated using the + observations which received treatment :math:`k` and :math:`\tau_{0,k}(x) := \mathbb{E}[\widetilde{D}^i_{0,k} | X=x]` using the observations in the control group. #. Finally the CATE for each variant is estimated as a weighted average: @@ -419,9 +419,15 @@ It consists of two stages: .. math:: \DeclareMathOperator*{\argmin}{arg\,min} - \hat{\tau}^R (x) &:= \argmin_{\tau}\Bigg\{\mathbb{E}\Bigg[\bigg(\left\{Y^i - \hat{m}(X^i)\right\} - \left\{W^i - \hat{e}(X^i)\right\}\tau(X^i)\bigg)^2\Bigg]\Bigg\} \\ + \hat{\tau}^R (\cdot) &:= \argmin_{\tau}\Bigg\{\mathbb{E}\Bigg[\bigg(\left\{Y^i - \hat{m}(X^i)\right\} - \left\{W^i - \hat{e}(X^i)\right\}\tau(X^i)\bigg)^2\Bigg]\Bigg\} \\ &=\argmin_{\tau}\left\{\mathbb{E}\left[\left\{W^i - \hat{e}(X^i)\right\}^2\bigg(\frac{\left\{Y^i - \hat{m}(X^i)\right\}}{\left\{W^i - \hat{e}(X^i)\right\}} - \tau(X^i)\bigg)^2\right]\right\} \\ - &= \argmin_{\tau}\left\{\mathbb{E}\left[{\tilde{W}^i}^2\bigg(\frac{\tilde{Y}^i}{\tilde{W}^i} - \tau(X^i)\bigg)^2\right]\right\} + &= \argmin_{\tau}\left\{\mathbb{E}\left[{\widetilde{W}^i}^2\bigg(\frac{\widetilde{Y}^i}{\widetilde{W}^i} - \tau(X^i)\bigg)^2\right]\right\} + + Where + + .. math:: + \widetilde{W}^i &= W^i - \hat{e}(X^i) \\ + \widetilde{Y}^i &= Y^i - \hat{m}(X^i) And therefore any ML model which supports weighting each observation differently can be used for the final model. @@ -484,7 +490,7 @@ It consists of two stages: #. Estimate the CATE by regressing :math:`\varphi` on :math:`X`: .. math:: - \hat{\tau}^{DR}(x) := \mathbb{E}[\varphi(X^i, W^i, Y^i) | X^i] + \hat{\tau}^{DR}(x) := \mathbb{E}[\varphi(X^i, W^i, Y^i) | X^i=x] More than binary treatment ************************** @@ -508,4 +514,4 @@ In the case of multiple discrete treatments the stages are similar to the binary treatment variant, :math:`\forall k \in \{1,\dots, K-1\}`: .. math:: - \hat{\tau}_k^{DR}(x) := \mathbb{E}[\varphi_k(X^i, W^i, Y^i) | X^i] + \hat{\tau}_k^{DR}(x) := \mathbb{E}[\varphi_k(X^i, W^i, Y^i) | X^i=x] diff --git a/docs/examples/example_basic.ipynb b/docs/examples/example_basic.ipynb index eb9177a..3780e48 100644 --- a/docs/examples/example_basic.ipynb +++ b/docs/examples/example_basic.ipynb @@ -115,7 +115,9 @@ "* We need to specify the observed treatment assignment ``w`` in the call to the\n", " ``fit`` method.\n", "* We need to specify whether we want in-sample or out-of-sample\n", - " estimates in the {meth}`~metalearners.TLearner.predict` call via ``is_oos``." + " CATE estimates in the {meth}`~metalearners.TLearner.predict` call via ``is_oos``. In the\n", + " case of in-sample predictions, the data passed to {meth}`~metalearners.TLearner.predict`\n", + " must be exactly the same as the data that was used to call {meth}`~metalearners.TLearner.fit`." ] }, { @@ -176,7 +178,7 @@ "Using a MetaLearner with two stages\n", "-----------------------------------\n", "\n", - "Instead of using a T-Learner, we can of course also some other\n", + "Instead of using a T-Learner, we can of course also use some other\n", "MetaLearner, such as the {class}`~metalearners.RLearner`.\n", "The R-Learner's documentation tells us that two more instantiation\n", "parameters are necessary: ``propensity_model_factory`` and\n", @@ -209,7 +211,22 @@ "metadata": {}, "source": [ "where we choose a classifier class to serve as a blueprint for our\n", - "eventual propensity model.\n", + "eventual propensity model. It is important to notice that although we consider the propensity\n", + "model a nuisance model, the initialization parameters for it are separated from the other\n", + "nuisance parameters to allow a more understandable user interface, see the next code prompt.\n", + "\n", + "In general, when initializing a MetaLearner, the ``nuisance_model_factory`` parameter will\n", + "be used to create all the nuisance models which are not a propensity model, the\n", + "``propensity_model_factory`` will be used for the propensity model if the MetaLearner\n", + "contains one, and the ``treatment_model_factory`` will be used for the models predicting\n", + "the CATE. To see the models present in each MetaLearner type see\n", + "{meth}`~metalearners.metalearner.MetaLearner.nuisance_model_specifications` and\n", + "{meth}`~metalearners.metalearner.MetaLearner.treatment_model_specifications`.\n", + "\n", + "In the {class}`~metalearners.RLearner` case, the ``nuisance_model_factory`` parameter will\n", + "be used to create the outcome model, the ``propensity_model_factory`` will be used for the\n", + "propensity model and the ``treatment_model_factory`` will be used for the model predicting\n", + "the CATE.\n", "\n", "If we want to make sure these models are initialized in a specific\n", "way, e.g. with a specific value for the hyperparameter ``n_estimators``, we can do that\n", diff --git a/docs/examples/example_feature_importance_shap.ipynb b/docs/examples/example_feature_importance_shap.ipynb index 6e2a662..a42d8af 100644 --- a/docs/examples/example_feature_importance_shap.ipynb +++ b/docs/examples/example_feature_importance_shap.ipynb @@ -326,7 +326,8 @@ "source": [ "Note that the method {meth}`~metalearners.explainer.Explainer.feature_importances`\n", "returns a list of length {math}`n_{variats} -1` that indicates the feature importance for\n", - "each variant against control.\n", + "each variant against control. Remember that a higher value means that the corresponding\n", + "feature is more important for the CATE prediction.\n", "\n", "### Computing and plotting the SHAP values\n", "\n", @@ -367,7 +368,23 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "For guidelines on how to interpret such SHAP plots please see the [SHAP documentation](https://github.com/shap/shap).\n", + "In these SHAP summary plots, the color and orientation of the plotted values help us to understand\n", + "their impact on model predictions.\n", + "\n", + "Each dot in the plot represents a single instance of the given feature present in the data set.\n", + "The x-axis displays the Shapley value, signifying the strength and directionality of the\n", + "feature's impact. The y-axis displays a subset of the features in the model.\n", + "\n", + "The Shapley value, exhibited on the horizontal axis, is oriented such that values on the\n", + "right of the center line (0 mark) contribute to a positive shift in the predicted outcome,\n", + "while those on the left indicate a negative impact.\n", + "\n", + "The color coding implemented in these plots is straightforward: red implies a high feature value,\n", + "while blue denotes a low feature value. This color scheme assists in identifying whether\n", + "high or low values of a certain feature influence the model's output positively or negatively.\n", + "The categorical variables are colored in grey.\n", + "\n", + "For more guidelines on how to interpret such SHAP plots please see the [SHAP documentation](https://github.com/shap/shap).\n", "\n", "Note that the method {meth}`~metalearners.explainer.Explainer.shap_values`\n", "returns a list of length {math}`n_{variats} -1` that indicates the SHAP values for\n", diff --git a/docs/examples/example_lime.ipynb b/docs/examples/example_lime.ipynb index cf22a6f..d3c97ee 100644 --- a/docs/examples/example_lime.ipynb +++ b/docs/examples/example_lime.ipynb @@ -54,7 +54,7 @@ "* {math}`f`, the original model -- in our case the MetaLearner\n", "* {math}`G`, the class of possible, interpretable surrogate models\n", "* {math}`\\Omega(g)`, a measure of complexity for {math}`g \\in G`\n", - "* {math}`\\pi_x(z)` a proximity measure of {math}`z` with respect to data point {math}`x`\n", + "* {math}`\\pi_x(z)` a proximity measure of an instance {math}`z` with respect to data point {math}`x`\n", "* {math}`\\mathcal{L}(f, g, \\pi_x)` a measure of how unfaithful a {math}`g \\in G` is to {math}`f` in the locality defined by {math}`\\pi_x`\n", "\n", "Given all of these objects as well as a to be explained data point {math}`x`, the authors suggest that the most appropriate surrogate {math}`g`, also referred to as explanation for {math}`x`, {math}`\\xi(x)`, can be expressed as follows:\n", @@ -73,11 +73,13 @@ "* have little redundancy between each other\n", "* showcase the features with highest global importance\n", "\n", - "In line with this ambition, they define a notion of 'coverage' -- to\n", - "be maximized --as follows:\n", + "In line with this ambition, they define a notion of 'coverage' which specifies how well a set\n", + "of candidate datapoints {math}`V` are explained by features that are relevant for\n", + "many observed datapoints. The goal is to find {math}`V` that is not larger than some\n", + "pre-specified size such that this coverage is maximal.\n", "\n", "```{math}\n", - " c(V, W, \\mathcal{I}) = \\sum_{j=1}^{d} I[\\exists i \\in V: W_{i,j} > 0] \\mathcal{I}_j\n", + " c(V, W, \\mathcal{I}) = \\sum_{j=1}^{d} \\mathbb{I}\\{\\exists i \\in V: W_{i,j} > 0\\} \\mathcal{I}_j\n", "````\n", "\n", "where\n", @@ -85,7 +87,8 @@ "* {math}`d` is the number of features\n", "* {math}`V` is the candidate set of explanations to be shown to\n", " humans, within a fixed budget -- this is the variable to be optimized\n", - "* {math}`W` is a {math}`n \\times d` local feature importance matrix and\n", + "* {math}`W` is a {math}`n \\times d` local feature importance matrix that represents\n", + " the local importance of each feature for each instance, and\n", "* {math}`\\mathcal{I}` is a {math}`d`-dimensional vector of global\n", " feature importances\n", "\n", @@ -359,7 +362,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "For guidelines on how to interpret such lime plots please see the [lime documentation](https://github.com/marcotcr/lime)." + "In these plots, the green bars signify that the presence of the corresponding feature\n", + "referenced on the y-axis, increases the CATE estimate for that observation, whereas, the\n", + "red bars represent that the feature presence in the observation reduces the CATE.\n", + "Furthermore, the length of these colored bars corresponds to the magnitude of each feature's\n", + "contribution towards the model prediction. Therefore, the longer the bar, the more\n", + "significant the impact of that feature on the model prediction.\n", + "\n", + "For more guidelines on how to interpret such lime plots please see the [lime documentation](https://github.com/marcotcr/lime)." ] } ], diff --git a/docs/faq.rst b/docs/faq.rst index 13d1de2..8296780 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -29,7 +29,8 @@ FAQ Double machine learning is an ATE estimation technique, pioneered by `Chernozhukov et al. (2016) `_. It is 'double' in the sense that it relies on two preliminary models: one for the probability of - receiving treatment given covariates (the propensity score), and one for the outcome given treatment and covariates. + receiving treatment given covariates (the propensity score), and one for the outcome given covariates and + optionally the (discrete) treatment. Double ML is also referred to as 'debiased' ML, since the propensity score model is used to 'debias' a naive estimator that uses the outcome model to predict the expected outcome under treatment, and under no treatment, diff --git a/docs/glossary.rst b/docs/glossary.rst index d9232cd..cd3bf77 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -24,8 +24,8 @@ Glossary Similar to the R-Learner, the Double Machine Learning blueprint relies on estimating two nuisance models in its first stage: a propensity model as well as an outcome model. Unlike the - R-Learner, the last-stage or treatment effect model might not - be any estimator. + R-Learner, the last-stage or treatment effect model might need to be a + specific type of estimator. See `Chernozhukov et al. (2016) `_. Heterogeneous Treatment Effect (HTE) diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 15831f0..6964547 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -746,7 +746,7 @@ def fit( pattern, propensity models are considered a nuisance model. ``synchronize_cross_fitting`` indicates whether the learning of different base models should use exactly - the same data splits where possible. Note that if there are several to be synchronize models which are + the same data splits where possible. Note that if there are several models to be synchronized which are classifiers, these cannot be split via stratification. """ ...