diff --git a/docs/examples/example_estimating_ates.ipynb b/docs/examples/example_estimating_ates.ipynb
index 9ba10c2..84d14d0 100644
--- a/docs/examples/example_estimating_ates.ipynb
+++ b/docs/examples/example_estimating_ates.ipynb
@@ -4,9 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "(example-ate)=\n",
- "\n",
- " Example: Estimating Average Treatment Effects\n",
+ "Example: Estimating Average Treatment Effects\n",
"=============================\n",
"\n",
"Motivation\n",
@@ -22,7 +20,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -42,7 +40,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -101,20 +99,9 @@
},
{
"cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(2.083595103597918, 0.06526671583747883)"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"naive_lm = smf.ols(f\"{outcome_column} ~ {treatment_column}\", df) .fit(cov_type=\"HC1\")\n",
"naive_est = naive_lm.params.iloc[1], naive_lm.bse.iloc[1]\n",
@@ -123,20 +110,9 @@
},
{
"cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(2.1433722387308025, 0.06345124983351998)"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"covaradjust_lm = smf.ols(f\"{outcome_column} ~ {treatment_column}+{'+'.join(feature_columns)}\",\n",
" df) .fit(cov_type=\"HC1\")\n",
@@ -162,42 +138,9 @@
},
{
"cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"from metalearners import DRLearner\n",
"from lightgbm import LGBMRegressor, LGBMClassifier\n",
@@ -206,20 +149,9 @@
},
{
"cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array([[1.02931589, 0.06679633]])"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"metalearners_dr = DRLearner(\n",
" nuisance_model_factory=LGBMRegressor,\n",
@@ -254,17 +186,9 @@
},
{
"cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "est: 1.0293158917468608, se: 0.06679966900982737\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"gamma_i = metalearners_dr._pseudo_outcome(\n",
" X=df[feature_columns],\n",
@@ -289,7 +213,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -314,17 +238,9 @@
},
{
"cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[1.08356716 0.05786543]\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(doubleml_est := aipw_mod.summary.values[0, :2])"
]
@@ -338,7 +254,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -348,17 +264,9 @@
},
{
"cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Y ~ 0 + X0+X1+X2+X3\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(ff := f\"{outcome_column} ~ 0 + {'+'.join(feature_columns)}\")\n",
"y, X = fm.Formula(ff).get_model_matrix(df, output=\"numpy\")\n",
@@ -367,7 +275,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -378,17 +286,9 @@
},
{
"cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[1.001 0.109]\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(econml_est := econml_dr.intercept__inference(1).summary_frame().iloc[0, :2].values)"
]
@@ -404,69 +304,9 @@
},
{
"cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " naive | \n",
- " linreg | \n",
- " metalearners | \n",
- " doubleml | \n",
- " econml | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " est | \n",
- " 2.083595 | \n",
- " 2.143372 | \n",
- " 1.029316 | \n",
- " 1.083567 | \n",
- " 1.001 | \n",
- "
\n",
- " \n",
- " se | \n",
- " 0.065267 | \n",
- " 0.063451 | \n",
- " 0.066796 | \n",
- " 0.057865 | \n",
- " 0.109 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " naive linreg metalearners doubleml econml\n",
- "est 2.083595 2.143372 1.029316 1.083567 1.001\n",
- "se 0.065267 0.063451 0.066796 0.057865 0.109"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"pd.DataFrame(\n",
" np.c_[\n",
@@ -491,102 +331,9 @@
},
{
"cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " Y | \n",
- " X0 | \n",
- " X1 | \n",
- " X2 | \n",
- " X3 | \n",
- " W | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 8.514662 | \n",
- " -0.104441 | \n",
- " 0.833485 | \n",
- " 1.802766 | \n",
- " 1.0 | \n",
- " 2 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 0.355338 | \n",
- " -0.690561 | \n",
- " -0.180011 | \n",
- " -1.715710 | \n",
- " 0.0 | \n",
- " 2 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 2.033627 | \n",
- " 0.685568 | \n",
- " -0.004838 | \n",
- " 0.671343 | \n",
- " 0.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 2.791453 | \n",
- " 1.984499 | \n",
- " -0.433412 | \n",
- " 0.921716 | \n",
- " 0.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 2.080426 | \n",
- " 0.189912 | \n",
- " -0.769235 | \n",
- " -0.450760 | \n",
- " 1.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Y X0 X1 X2 X3 W\n",
- "0 8.514662 -0.104441 0.833485 1.802766 1.0 2\n",
- "1 0.355338 -0.690561 -0.180011 -1.715710 0.0 2\n",
- "2 2.033627 0.685568 -0.004838 0.671343 0.0 1\n",
- "3 2.791453 1.984499 -0.433412 0.921716 0.0 1\n",
- "4 2.080426 0.189912 -0.769235 -0.450760 1.0 1"
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"np.random.seed(123)\n",
"\n",
@@ -652,23 +399,9 @@
},
{
"cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Intercept 0.645659\n",
- "C(W)[T.1] 1.777798\n",
- "C(W)[T.2] 2.387949\n",
- "dtype: float64"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"lm_multi = smf.ols(\"Y ~ C(W)\", df_multi).fit()\n",
"lm_multi.params"
@@ -683,21 +416,9 @@
},
{
"cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array([[1.04071166, 0.14139908],\n",
- " [2.25658364, 0.21346275]])"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"metalearners_dr_2 = DRLearner(\n",
" nuisance_model_factory=LGBMRegressor,\n",
@@ -741,102 +462,9 @@
},
{
"cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " Y | \n",
- " W | \n",
- " X0 | \n",
- " X1 | \n",
- " X2 | \n",
- " X3 | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 0 | \n",
- " 1.0 | \n",
- " 1.138203 | \n",
- " -1.439008 | \n",
- " -1.981465 | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 1 | \n",
- " 1.0 | \n",
- " 0.297749 | \n",
- " -0.092953 | \n",
- " 1.455780 | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 1 | \n",
- " 1.0 | \n",
- " 1.150090 | \n",
- " 0.545041 | \n",
- " 0.799303 | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 1 | \n",
- " 0.0 | \n",
- " 0.832256 | \n",
- " -0.591107 | \n",
- " -1.076526 | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 0 | \n",
- " 1.0 | \n",
- " 1.559507 | \n",
- " -1.018031 | \n",
- " -1.137247 | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Y W X0 X1 X2 X3\n",
- "0 0 1.0 1.138203 -1.439008 -1.981465 0.0\n",
- "1 1 1.0 0.297749 -0.092953 1.455780 0.0\n",
- "2 1 1.0 1.150090 0.545041 0.799303 0.0\n",
- "3 1 0.0 0.832256 -0.591107 -1.076526 0.0\n",
- "4 0 1.0 1.559507 -1.018031 -1.137247 0.0"
- ]
- },
- "execution_count": 18,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"def classification_dgp(n, k, pscore_fn, tau_fn, outcome_fn, k_cat=1):\n",
" \"\"\"DGP for a confounded treatment assignment with binary outcome\"\"\"\n",
@@ -884,24 +512,9 @@
},
{
"cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAANIAAAAQCAYAAABjuSH9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAABJ0AAASdAHeZh94AAAItUlEQVR4nO2be7BXVRXHPyiJhkoGImUmwkgCiZcyHqnIDYPygkFhNQ0kzgA5xgDyUKNs+XXGhAoUs4egA2ZMZYqk8UhAJkJNZxRHHVAoHkolCoSDAhGP/lj7yLnnnnN/55zfb8Z/7pr5zb5nn/Xae52912Pv2+rYsWO0QAu0QHXQOtkh6RPAbcCXgPbAv4ElgMzsP0WYF+UlqQGYBPSI4T8PzDGzZ1LwRwKXA3XARcBpwCIzG9WMTrOAi4FuQAfgALA96HWPme1O4G8Dzs1gt9PMOmXIKTSWGN0gYALQHzgD2A28DMw1s2XV6lbEJpLGAAuydA1w1MxObA5B0ijgwfA4zszuy8DLNWeS2gMjgAbgQuBs4BA+TwuABWZ2NMG7DE0rYGz49QRaARuB+4B5cfwTEoRdg+LXAs8BdwJbwuCeCcrkgqK8wgf+J+AzwApgLvAC8BXgqWCMJPwA/+jqgH/mVO0GoC2wMshYBBwGbgVeknROCs07gFJ+P00TUHIsSPoxsApf6I8Bs4GlwJnAwIzx5NathH1fzOAt4MmAszxDr0jmOcA9wLsV8IrM2dXAfKAv8CxwF/AI8Gn8I38oLAKqpPkNMA/oDPw24H0Y+CWwMI6Y9Ei/ADoCE83sZ7FBzsE/wNuB65qZj1K8JHUCpgE7gV5m9lYMvx432m1hYHG4AdgB/B33TGty6HW6mR1Mdkq6HZgBfA+4PvF6r5ndmoN36bFIGgdMBx4AxpvZocT7D2WIzK0bBe1rZi/ii6kJSIo8xLwsYeHDXIB71cX4vKThFZ2zTcBVwNK4V5A0A98gvgZ8FV8olKGRNAL4FrAV6GNmu0L/SQFntKQlZrYYYh4p7FaDgW3AzxNjNeC9QNw2fdoaTUxRXucGXZ6NTyKAma0B9uG7Msl3ZrbZzHInemmLKMBDoT0/L68MKDwWSW3wj/h1UhZRoP1fNUrV2L4XAv3wKGBpM6gTgS/gHvC9ZvAKzZmZPWlmjydDMTN7E/hVeByYeFeUZkRoZ0eLKOAfAm4JjxOi/nhoVx/aJ1KE7QOewt1aPypDUV6b8Xi1j6QOcXxJA/DcZ1UOudXAsNC+lPKujaRRkmZImiSpXlJWXlBmLF/EP5TFwFFJDZJuCrL6V9A7r261tO/40N5vZkfSECR1B2biud3aCvxqaf9owzmcEz+LJsovt6TgR32XBQ/VKLT7VGg3ZQjbjO9o3YDVFRQrxMvM9ki6CZgDbJC0BA8HuuLueCXwnQoyC4GkacCpQDs8J7kUX0QzU9A7cTxZjmCrpGvN7C/xzpJj+VxoDwLr8bg9rutaYKSZvV2FbjWxr6RTgFHAETxnSMNpHXR6HQ+Xm4Va2T/I/XZ4XFEJvwJN5IXOSyHrEtrW4e9X4x6pXWjfyZAZ9X8kh36FeZnZXXiM2hoYB9yMJ4hvAAuTLr8GMA0PaSbji2gFMDjlY10ADMI/2LZ4xedePAFdLumiJOMSY+kY2unAMeAyfBfuBTwBDAD+kDKGIrrVyr5fDzgrzOyNDJwfAr2BMWZ2oAI/oGb2n4lvQsvM7M955DZDE4WsUyR9NOoMuapieGdASvn7gwJJNwI/Au7GqzxvAhcAdwCLJNWZ2Y21kheVhiWdBXwen9D1koaa2QsxPCVIXwGuk/QuMBWv9o2II5QYS7ShHQauMrNt4fnlkPS+BlwuqX+8DFxGtxpAFNbdm/ZSUl/cC81ursyfQleV/SVNxMf8KjA6p8zmaH4X+obgXvKPeMRwBfAx3Nt+EjgKjXOkaEdqRzpE/Xtz6FiIl6SBwCzgMTObYmZbzGx/+KBH4EntVEldUnhVBWa208wexcOa9sCvc5JGCeqAeGfJsewN7frYIor02w9EO2WfKnSr2r6SeuKbzg5gWcr71vj8beJ4Ql4RqrW/pAl4uXwDUG9me3LIbJYm5H7DcM/4NnBN+G3G52BfQH0LGnuk10LbLUN2VM3KirHjUJTX0NA2KV+b2X5Jz+ET2pv05K9qMLPtkjYAdZI6xCs1GRCFgMkqV5mxRPO1N0NWdFB6SgWdmtOtFvatVGQ4Ncb/oJR0mADMlzQfL0JMDn2l7S9pMn4e9gowKE8ImJcmVEpnhV+c/mR8vnaZ2VZo7JGiQQyWlDyoPQ24BNgP/K2SoiV4tQltkxJ3or9JWbjG8PHQplaiEhBVt5ILu8xYVuO5UY/kfAWIig9bc+iVpVtV9g0fz2h8bu7PkPvf8C7ttz7grAvP8bCvlP1DgeJO/KyrPuciKkyTAt8ETsIPaYHYQjKzf+CJbWfgu0n5+O72oJm9fx4gqaukC5KHhSV4/TW04yWd3QhZ+jJu5IPA0zkHmgqSuklqEtpIOiEcyHYEno6uykjqnnauIqkzHsdD00PiwmMxs+3A43jMPSlBMxiP0/cSqyoV1a2MfRNwNZ5YL88qMpjZATMbm/bDb2oAPBD6fh8jLTxnkm7B89rnca9SKYIoTCPp9JS+OuAneJTwfoU3WWy4Pih7t/zO10b8SkU97vK/n8BfjR+mnYcf9JXl9TB+TnAFsFHSo3iy2R13+62Am63pPbjhwPDwGNX9+0taGP7eZWbx0/QrgTskrcN3993AWfitiC5B5rgY/jfw2Hwtfh9vH16SbQBOxvOE5FWcUmPBP+7ewBz5nbP1+LwOx73AWDOLV9zK6FbUvnGIwrrMmwxVQKE5k3QNftPhCL4IJ6aEkdvMbGH0UIYGWCnpAB4C7gv6NOD3M4eZ2b8ixEYuPuxaF+P3iPriFY2ueFLWL8X4mVCEVzggvBK/prIBj4en4iHKMmCImc1NEVPH8SRwSOjrEusbmcBfhYcVZ+Kl1un41ZA9+K7c08w2xPDX4Pe/uuLXRabgi25d4D80eQuh7FjMbAfwWdybnI97poG4p7rEzB5JkJTRrZR9w+HqpWQUGaqFEnMWne2ciB9fWMpvTEJMGZqH8WOIUfj89sI3kh7J88NWLf9G0QItUD38H35O5CrMQtFfAAAAAElFTkSuQmCC",
- "text/latex": [
- "$\\displaystyle 0.081358650748229$"
- ],
- "text/plain": [
- "0.08135865074822901"
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"# naive TE\n",
"df_class.groupby(\"W\")[\"Y\"].mean().diff()[1]"
@@ -909,20 +522,9 @@
},
{
"cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array([[0.04373356, 0.04191263]])"
- ]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"metalearners_dr_3 = DRLearner(\n",
" nuisance_model_factory=LGBMClassifier,\n",
@@ -939,10 +541,11 @@
" y=df_class[outcome_column],\n",
" w=df_class[treatment_column],\n",
")\n",
- "metalearners_est_3 = metalearners_dr_3.treatment_effect( # still need to pass data objects since DRLearner does not retain any data\n",
+ "metalearners_est_3 = metalearners_dr_3.average_treatment_effect( # still need to pass data objects since DRLearner does not retain any data\n",
" X=df_class[feature_columns],\n",
" y=df_class[outcome_column],\n",
" w=df_class[treatment_column],\n",
+ " is_oos=False,\n",
")\n",
"metalearners_est_3"
]
@@ -957,7 +560,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -971,9 +574,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.7"
+ "version": "3.11.9"
}
},
"nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 4
}