Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for sparse X #86

Merged
merged 11 commits into from
Aug 28, 2024
119 changes: 90 additions & 29 deletions docs/examples/example_estimating_ates.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
apoorvalal marked this conversation as resolved.
Show resolved Hide resolved
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -40,7 +40,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -99,9 +99,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(2.083595103597918, 0.06526671583747883)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"naive_lm = smf.ols(f\"{outcome_column} ~ {treatment_column}\", df) .fit(cov_type=\"HC1\")\n",
"naive_est = naive_lm.params.iloc[1], naive_lm.bse.iloc[1]\n",
Expand All @@ -110,9 +121,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(2.1433722387308025, 0.06345124983351998)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"covaradjust_lm = smf.ols(f\"{outcome_column} ~ {treatment_column}+{'+'.join(feature_columns)}\",\n",
" df) .fit(cov_type=\"HC1\")\n",
Expand All @@ -138,9 +160,42 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type='text/css'>\n",
".datatable table.frame { margin-bottom: 0; }\n",
".datatable table.frame thead { border-bottom: none; }\n",
".datatable table.frame tr.coltypes td { color: #FFFFFF; line-height: 6px; padding: 0 0.5em;}\n",
".datatable .bool { background: #DDDD99; }\n",
".datatable .object { background: #565656; }\n",
".datatable .int { background: #5D9E5D; }\n",
".datatable .float { background: #4040CC; }\n",
".datatable .str { background: #CC4040; }\n",
".datatable .time { background: #40CC40; }\n",
".datatable .row_index { background: var(--jp-border-color3); border-right: 1px solid var(--jp-border-color0); color: var(--jp-ui-font-color3); font-size: 9px;}\n",
".datatable .frame tbody td { text-align: left; }\n",
".datatable .frame tr.coltypes .row_index { background: var(--jp-border-color0);}\n",
".datatable th:nth-child(2) { padding-left: 12px; }\n",
".datatable .hellipsis { color: var(--jp-cell-editor-border-color);}\n",
".datatable .vellipsis { background: var(--jp-layout-color0); color: var(--jp-cell-editor-border-color);}\n",
".datatable .na { color: var(--jp-cell-editor-border-color); font-size: 80%;}\n",
".datatable .sp { opacity: 0.25;}\n",
".datatable .footer { font-size: 9px; }\n",
".datatable .frame_dimensions { background: var(--jp-border-color3); border-top: 1px solid var(--jp-border-color0); color: var(--jp-ui-font-color3); display: inline-block; opacity: 0.6; padding: 1px 10px 1px 5px;}\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from metalearners import DRLearner\n",
"from lightgbm import LGBMRegressor, LGBMClassifier\n",
Expand All @@ -149,9 +204,26 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 6,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"(array([1.02931589]), array([0.06679633]))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"metalearners_dr = DRLearner(\n",
" nuisance_model_factory=LGBMRegressor,\n",
Expand Down Expand Up @@ -557,22 +629,11 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"name": "python"
},
"mystnb": {
"execution_timeout": 120
}
},
"nbformat": 4,
Expand Down
Loading
Loading