From dec01bc79421b1d90628afb63435f7b6cdde2b9d Mon Sep 17 00:00:00 2001 From: kklein Date: Wed, 7 Aug 2024 10:06:03 +0200 Subject: [PATCH] Fix notebook. --- docs/examples/example_estimating_ates.ipynb | 509 +++----------------- 1 file changed, 56 insertions(+), 453 deletions(-) diff --git a/docs/examples/example_estimating_ates.ipynb b/docs/examples/example_estimating_ates.ipynb index 9ba10c2..84d14d0 100644 --- a/docs/examples/example_estimating_ates.ipynb +++ b/docs/examples/example_estimating_ates.ipynb @@ -4,9 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "(example-ate)=\n", - "\n", - " Example: Estimating Average Treatment Effects\n", + "Example: Estimating Average Treatment Effects\n", "=============================\n", "\n", "Motivation\n", @@ -22,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -42,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -101,20 +99,9 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(2.083595103597918, 0.06526671583747883)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "naive_lm = smf.ols(f\"{outcome_column} ~ {treatment_column}\", df) .fit(cov_type=\"HC1\")\n", "naive_est = naive_lm.params.iloc[1], naive_lm.bse.iloc[1]\n", @@ -123,20 +110,9 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(2.1433722387308025, 0.06345124983351998)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "covaradjust_lm = smf.ols(f\"{outcome_column} ~ {treatment_column}+{'+'.join(feature_columns)}\",\n", " df) .fit(cov_type=\"HC1\")\n", @@ -162,42 +138,9 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from metalearners import DRLearner\n", "from lightgbm import LGBMRegressor, LGBMClassifier\n", @@ -206,20 +149,9 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1.02931589, 0.06679633]])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "metalearners_dr = DRLearner(\n", " nuisance_model_factory=LGBMRegressor,\n", @@ -254,17 +186,9 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "est: 1.0293158917468608, se: 0.06679966900982737\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "gamma_i = metalearners_dr._pseudo_outcome(\n", " X=df[feature_columns],\n", @@ -289,7 +213,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -314,17 +238,9 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1.08356716 0.05786543]\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "print(doubleml_est := aipw_mod.summary.values[0, :2])" ] @@ -338,7 +254,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -348,17 +264,9 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Y ~ 0 + X0+X1+X2+X3\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "print(ff := f\"{outcome_column} ~ 0 + {'+'.join(feature_columns)}\")\n", "y, X = fm.Formula(ff).get_model_matrix(df, output=\"numpy\")\n", @@ -367,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -378,17 +286,9 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1.001 0.109]\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "print(econml_est := econml_dr.intercept__inference(1).summary_frame().iloc[0, :2].values)" ] @@ -404,69 +304,9 @@ }, { "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
naivelinregmetalearnersdoublemleconml
est2.0835952.1433721.0293161.0835671.001
se0.0652670.0634510.0667960.0578650.109
\n", - "
" - ], - "text/plain": [ - " naive linreg metalearners doubleml econml\n", - "est 2.083595 2.143372 1.029316 1.083567 1.001\n", - "se 0.065267 0.063451 0.066796 0.057865 0.109" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "pd.DataFrame(\n", " np.c_[\n", @@ -491,102 +331,9 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
YX0X1X2X3W
08.514662-0.1044410.8334851.8027661.02
10.355338-0.690561-0.180011-1.7157100.02
22.0336270.685568-0.0048380.6713430.01
32.7914531.984499-0.4334120.9217160.01
42.0804260.189912-0.769235-0.4507601.01
\n", - "
" - ], - "text/plain": [ - " Y X0 X1 X2 X3 W\n", - "0 8.514662 -0.104441 0.833485 1.802766 1.0 2\n", - "1 0.355338 -0.690561 -0.180011 -1.715710 0.0 2\n", - "2 2.033627 0.685568 -0.004838 0.671343 0.0 1\n", - "3 2.791453 1.984499 -0.433412 0.921716 0.0 1\n", - "4 2.080426 0.189912 -0.769235 -0.450760 1.0 1" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "np.random.seed(123)\n", "\n", @@ -652,23 +399,9 @@ }, { "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Intercept 0.645659\n", - "C(W)[T.1] 1.777798\n", - "C(W)[T.2] 2.387949\n", - "dtype: float64" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "lm_multi = smf.ols(\"Y ~ C(W)\", df_multi).fit()\n", "lm_multi.params" @@ -683,21 +416,9 @@ }, { "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1.04071166, 0.14139908],\n", - " [2.25658364, 0.21346275]])" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "metalearners_dr_2 = DRLearner(\n", " nuisance_model_factory=LGBMRegressor,\n", @@ -741,102 +462,9 @@ }, { "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
YWX0X1X2X3
001.01.138203-1.439008-1.9814650.0
111.00.297749-0.0929531.4557800.0
211.01.1500900.5450410.7993030.0
310.00.832256-0.591107-1.0765260.0
401.01.559507-1.018031-1.1372470.0
\n", - "
" - ], - "text/plain": [ - " Y W X0 X1 X2 X3\n", - "0 0 1.0 1.138203 -1.439008 -1.981465 0.0\n", - "1 1 1.0 0.297749 -0.092953 1.455780 0.0\n", - "2 1 1.0 1.150090 0.545041 0.799303 0.0\n", - "3 1 0.0 0.832256 -0.591107 -1.076526 0.0\n", - "4 0 1.0 1.559507 -1.018031 -1.137247 0.0" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "def classification_dgp(n, k, pscore_fn, tau_fn, outcome_fn, k_cat=1):\n", " \"\"\"DGP for a confounded treatment assignment with binary outcome\"\"\"\n", @@ -884,24 +512,9 @@ }, { "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAANIAAAAQCAYAAABjuSH9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAABJ0AAASdAHeZh94AAAItUlEQVR4nO2be7BXVRXHPyiJhkoGImUmwkgCiZcyHqnIDYPygkFhNQ0kzgA5xgDyUKNs+XXGhAoUs4egA2ZMZYqk8UhAJkJNZxRHHVAoHkolCoSDAhGP/lj7yLnnnnN/55zfb8Z/7pr5zb5nn/Xae52912Pv2+rYsWO0QAu0QHXQOtkh6RPAbcCXgPbAv4ElgMzsP0WYF+UlqQGYBPSI4T8PzDGzZ1LwRwKXA3XARcBpwCIzG9WMTrOAi4FuQAfgALA96HWPme1O4G8Dzs1gt9PMOmXIKTSWGN0gYALQHzgD2A28DMw1s2XV6lbEJpLGAAuydA1w1MxObA5B0ijgwfA4zszuy8DLNWeS2gMjgAbgQuBs4BA+TwuABWZ2NMG7DE0rYGz49QRaARuB+4B5cfwTEoRdg+LXAs8BdwJbwuCeCcrkgqK8wgf+J+AzwApgLvAC8BXgqWCMJPwA/+jqgH/mVO0GoC2wMshYBBwGbgVeknROCs07gFJ+P00TUHIsSPoxsApf6I8Bs4GlwJnAwIzx5NathH1fzOAt4MmAszxDr0jmOcA9wLsV8IrM2dXAfKAv8CxwF/AI8Gn8I38oLAKqpPkNMA/oDPw24H0Y+CWwMI6Y9Ei/ADoCE83sZ7FBzsE/wNuB65qZj1K8JHUCpgE7gV5m9lYMvx432m1hYHG4AdgB/B33TGty6HW6mR1Mdkq6HZgBfA+4PvF6r5ndmoN36bFIGgdMBx4AxpvZocT7D2WIzK0bBe1rZi/ii6kJSIo8xLwsYeHDXIB71cX4vKThFZ2zTcBVwNK4V5A0A98gvgZ8FV8olKGRNAL4FrAV6GNmu0L/SQFntKQlZrYYYh4p7FaDgW3AzxNjNeC9QNw2fdoaTUxRXucGXZ6NTyKAma0B9uG7Msl3ZrbZzHInemmLKMBDoT0/L68MKDwWSW3wj/h1UhZRoP1fNUrV2L4XAv3wKGBpM6gTgS/gHvC9ZvAKzZmZPWlmjydDMTN7E/hVeByYeFeUZkRoZ0eLKOAfAm4JjxOi/nhoVx/aJ1KE7QOewt1aPypDUV6b8Xi1j6QOcXxJA/DcZ1UOudXAsNC+lPKujaRRkmZImiSpXlJWXlBmLF/EP5TFwFFJDZJuCrL6V9A7r261tO/40N5vZkfSECR1B2biud3aCvxqaf9owzmcEz+LJsovt6TgR32XBQ/VKLT7VGg3ZQjbjO9o3YDVFRQrxMvM9ki6CZgDbJC0BA8HuuLueCXwnQoyC4GkacCpQDs8J7kUX0QzU9A7cTxZjmCrpGvN7C/xzpJj+VxoDwLr8bg9rutaYKSZvV2FbjWxr6RTgFHAETxnSMNpHXR6HQ+Xm4Va2T/I/XZ4XFEJvwJN5IXOSyHrEtrW4e9X4x6pXWjfyZAZ9X8kh36FeZnZXXiM2hoYB9yMJ4hvAAuTLr8GMA0PaSbji2gFMDjlY10ADMI/2LZ4xedePAFdLumiJOMSY+kY2unAMeAyfBfuBTwBDAD+kDKGIrrVyr5fDzgrzOyNDJwfAr2BMWZ2oAI/oGb2n4lvQsvM7M955DZDE4WsUyR9NOoMuapieGdASvn7gwJJNwI/Au7GqzxvAhcAdwCLJNWZ2Y21kheVhiWdBXwen9D1koaa2QsxPCVIXwGuk/QuMBWv9o2II5QYS7ShHQauMrNt4fnlkPS+BlwuqX+8DFxGtxpAFNbdm/ZSUl/cC81ursyfQleV/SVNxMf8KjA6p8zmaH4X+obgXvKPeMRwBfAx3Nt+EjgKjXOkaEdqRzpE/Xtz6FiIl6SBwCzgMTObYmZbzGx/+KBH4EntVEldUnhVBWa208wexcOa9sCvc5JGCeqAeGfJsewN7frYIor02w9EO2WfKnSr2r6SeuKbzg5gWcr71vj8beJ4Ql4RqrW/pAl4uXwDUG9me3LIbJYm5H7DcM/4NnBN+G3G52BfQH0LGnuk10LbLUN2VM3KirHjUJTX0NA2KV+b2X5Jz+ET2pv05K9qMLPtkjYAdZI6xCs1GRCFgMkqV5mxRPO1N0NWdFB6SgWdmtOtFvatVGQ4Ncb/oJR0mADMlzQfL0JMDn2l7S9pMn4e9gowKE8ImJcmVEpnhV+c/mR8vnaZ2VZo7JGiQQyWlDyoPQ24BNgP/K2SoiV4tQltkxJ3or9JWbjG8PHQplaiEhBVt5ILu8xYVuO5UY/kfAWIig9bc+iVpVtV9g0fz2h8bu7PkPvf8C7ttz7grAvP8bCvlP1DgeJO/KyrPuciKkyTAt8ETsIPaYHYQjKzf+CJbWfgu0n5+O72oJm9fx4gqaukC5KHhSV4/TW04yWd3QhZ+jJu5IPA0zkHmgqSuklqEtpIOiEcyHYEno6uykjqnnauIqkzHsdD00PiwmMxs+3A43jMPSlBMxiP0/cSqyoV1a2MfRNwNZ5YL88qMpjZATMbm/bDb2oAPBD6fh8jLTxnkm7B89rnca9SKYIoTCPp9JS+OuAneJTwfoU3WWy4Pih7t/zO10b8SkU97vK/n8BfjR+mnYcf9JXl9TB+TnAFsFHSo3iy2R13+62Am63pPbjhwPDwGNX9+0taGP7eZWbx0/QrgTskrcN3993AWfitiC5B5rgY/jfw2Hwtfh9vH16SbQBOxvOE5FWcUmPBP+7ewBz5nbP1+LwOx73AWDOLV9zK6FbUvnGIwrrMmwxVQKE5k3QNftPhCL4IJ6aEkdvMbGH0UIYGWCnpAB4C7gv6NOD3M4eZ2b8ixEYuPuxaF+P3iPriFY2ueFLWL8X4mVCEVzggvBK/prIBj4en4iHKMmCImc1NEVPH8SRwSOjrEusbmcBfhYcVZ+Kl1un41ZA9+K7c08w2xPDX4Pe/uuLXRabgi25d4D80eQuh7FjMbAfwWdybnI97poG4p7rEzB5JkJTRrZR9w+HqpWQUGaqFEnMWne2ciB9fWMpvTEJMGZqH8WOIUfj89sI3kh7J88NWLf9G0QItUD38H35O5CrMQtFfAAAAAElFTkSuQmCC", - "text/latex": [ - "$\\displaystyle 0.081358650748229$" - ], - "text/plain": [ - "0.08135865074822901" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# naive TE\n", "df_class.groupby(\"W\")[\"Y\"].mean().diff()[1]" @@ -909,20 +522,9 @@ }, { "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[0.04373356, 0.04191263]])" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "metalearners_dr_3 = DRLearner(\n", " nuisance_model_factory=LGBMClassifier,\n", @@ -939,10 +541,11 @@ " y=df_class[outcome_column],\n", " w=df_class[treatment_column],\n", ")\n", - "metalearners_est_3 = metalearners_dr_3.treatment_effect( # still need to pass data objects since DRLearner does not retain any data\n", + "metalearners_est_3 = metalearners_dr_3.average_treatment_effect( # still need to pass data objects since DRLearner does not retain any data\n", " X=df_class[feature_columns],\n", " y=df_class[outcome_column],\n", " w=df_class[treatment_column],\n", + " is_oos=False,\n", ")\n", "metalearners_est_3" ] @@ -957,7 +560,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -971,9 +574,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.9" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }