-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dfa95d8
commit 92dbe5a
Showing
23 changed files
with
2,360 additions
and
1,573 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,311 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"(example-basic)=\n", | ||
"\n", | ||
"Example: Estimating CATEs with a MetaLearner\n", | ||
"==============================================\n", | ||
"\n", | ||
"Loading the data\n", | ||
"----------------\n", | ||
"\n", | ||
"First, we will load and prepare some data for this example. In this\n", | ||
"particular case we rely on the so-called mindset data set, taken from\n", | ||
"[here](https://github.com/matheusfacure/python-causality-handbook/blob/master/causal-inference-for-the-brave-and-true/data/learning_mindset.csv)\n", | ||
"and under MIT License. It stems from an experimental setup where\n", | ||
"\n", | ||
"* The outcome was the achievement of a student in scalar form, found\n", | ||
" in column ``\"achievement_score\".``\n", | ||
"* The mindset intervention is a binary variable found in the column\n", | ||
" ``\"intervention\"``.\n", | ||
"* Both numerical and categorical covariates/features are present." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"from pathlib import Path\n", | ||
"from git_root import git_root\n", | ||
"\n", | ||
"df = pd.read_csv(git_root(\"data/learning_mindset.zip\"))\n", | ||
"outcome_column = \"achievement_score\"\n", | ||
"treatment_column = \"intervention\"\n", | ||
"feature_columns = [\n", | ||
" column\n", | ||
" for column in df.columns\n", | ||
" if column not in [outcome_column, treatment_column]\n", | ||
"]\n", | ||
"categorical_feature_columns = [\n", | ||
" \"ethnicity\",\n", | ||
" \"gender\",\n", | ||
" \"frst_in_family\",\n", | ||
" \"school_urbanicity\",\n", | ||
" \"schoolid\",\n", | ||
"]\n", | ||
"# Note that explicitly setting the dtype of these features to category\n", | ||
"# allows both lightgbm as well as shap plots to\n", | ||
"# 1. Operate on features which are not of type int, bool or float\n", | ||
"# 2. Correctly interpret categoricals with int values to be\n", | ||
"# interpreted as categoricals, as compared to ordinals/numericals.\n", | ||
"for categorical_feature_column in categorical_feature_columns:\n", | ||
" df[categorical_feature_column] = df[categorical_feature_column].astype(\n", | ||
" \"category\"\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Using a first, simple MetaLearner\n", | ||
"---------------------------------\n", | ||
"\n", | ||
"Now that the data has been loaded, we can get to actually using\n", | ||
"MetaLearners. Let's start with the\n", | ||
"{class}`metalearners.TLearner`.\n", | ||
"Investigating its documentation, we realize that only three initialization parameters\n", | ||
"are necessary in the case we do not want to reuse nuisance models: ``nuisance_model_factory``, ``is_classification`` and\n", | ||
"``n_variants``. Given that our outcome is a scalar, we want to set\n", | ||
"``is_classification=False`` and use a regressor as the\n", | ||
"``nuisance_model_factory``. In this case we arbitrarily choose a\n", | ||
"regressor from ``lightgbm``. Since we know that the intervention was\n", | ||
"binary, we set ``n_variants=2``." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from metalearners import TLearner\n", | ||
"from lightgbm import LGBMRegressor\n", | ||
"\n", | ||
"tlearner = TLearner(\n", | ||
" nuisance_model_factory=LGBMRegressor,\n", | ||
" is_classification=False,\n", | ||
" n_variants=2,\n", | ||
" nuisance_model_params={\"verbose\": -1}\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Once our T-Learner has been instantiated, we can use it\n", | ||
"in a fashion akin to scikit-learn's Estimator protocol. The subtle differences\n", | ||
"to aforementioned scikit-learn protocol are that\n", | ||
"\n", | ||
"* We need to specify the observed treatment assignment ``w`` in the call to the\n", | ||
" ``fit`` method.\n", | ||
"* We need to specify whether we want in-sample or out-of-sample\n", | ||
" estimates in the ``predict`` call via ``is_oos``." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"tlearner.fit(\n", | ||
" X=df[feature_columns],\n", | ||
" y=df[outcome_column],\n", | ||
" w=df[treatment_column],\n", | ||
")\n", | ||
"\n", | ||
"cate_estimates_tlearner = tlearner.predict(\n", | ||
" X=df[feature_columns],\n", | ||
" is_oos=False,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can now notice that ``cate_estimates_tlearner`` is of shape\n", | ||
"{math}`(n_{obs}, n_{variants} - 1, n_{outputs})`. This is meant to\n", | ||
"cater to a general case, where there are more than two variants and/or\n", | ||
"classification problems with many class probabilities. Given that we\n", | ||
"care about the simple case of binary variant regression, we can make use of\n", | ||
"{func}`metalearners.utils.simplify_output` to simplify this shape as such:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from metalearners.utils import simplify_output\n", | ||
"one_d_estimates = simplify_output(cate_estimates_tlearner)\n", | ||
"\n", | ||
"print(cate_estimates_tlearner.shape)\n", | ||
"print(one_d_estimates.shape)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Using a MetaLearner with two stages\n", | ||
"-----------------------------------\n", | ||
"\n", | ||
"Instead of using a T-Learner, we can of course also some other\n", | ||
"MetaLearner, such as the {class}`metalearners.RLearner`.\n", | ||
"The R-Learner's documentation tells us that two more instantiation\n", | ||
"parameters are necessary: ``propensity_model_factory`` and\n", | ||
"``treatment_model_factory``. Hence we can instantiate an R-Learner as follows" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from metalearners import RLearner\n", | ||
"from lightgbm import LGBMClassifier\n", | ||
"rlearner = RLearner(\n", | ||
" nuisance_model_factory=LGBMRegressor,\n", | ||
" propensity_model_factory=LGBMClassifier,\n", | ||
" treatment_model_factory=LGBMRegressor,\n", | ||
" is_classification=False,\n", | ||
" n_variants=2,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"where we choose a classifier class to serve as a blueprint for our\n", | ||
"eventual propensity model.\n", | ||
"\n", | ||
"If we want to make sure these models are initialized in a specific\n", | ||
"way, e.g. with a specific value for the hyperparameter ``n_estimators``, we can do that\n", | ||
"as follows:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"rlearner = RLearner(\n", | ||
" nuisance_model_factory=LGBMRegressor,\n", | ||
" propensity_model_factory=LGBMClassifier,\n", | ||
" treatment_model_factory=LGBMRegressor,\n", | ||
" is_classification=False,\n", | ||
" n_variants=2,\n", | ||
" nuisance_model_params={\"n_estimators\": 10, \"verbose\": -1},\n", | ||
" propensity_model_params={\"n_estimators\": 8, \"verbose\": -1},\n", | ||
" treatment_model_params={\"n_estimators\": 3, \"verbose\": -1},\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The estimation steps look identical to those of the T-Learner:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"rlearner.fit(\n", | ||
" X=df[feature_columns],\n", | ||
" y=df[outcome_column],\n", | ||
" w=df[treatment_column],\n", | ||
")\n", | ||
"\n", | ||
"cate_estimates_rlearner = rlearner.predict(\n", | ||
" X=df[feature_columns],\n", | ||
" is_oos=False,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Comparing estimates\n", | ||
"-------------------\n", | ||
"\n", | ||
"We can now compare the CATE estimates produced by both MetaLearners on\n", | ||
"a histogram:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"fig, ax = plt.subplots()\n", | ||
"\n", | ||
"ax.hist(simplify_output(cate_estimates_tlearner), density=True, alpha=.5, label=\"T-Learner\")\n", | ||
"ax.hist(simplify_output(cate_estimates_rlearner), density=True, alpha=.5, label=\"R-Learner\")\n", | ||
"ax.legend()\n", | ||
"ax.set_xlabel(\"CATE estimate\")\n", | ||
"ax.set_ylabel(\"relative frequency\")\n", | ||
"plt.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.