Skip to content

Commit

Permalink
final touches
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvalal committed Aug 28, 2024
1 parent 8282381 commit e4501b4
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 158 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
Changelog
=========

0.11.0 (2024-09-xx)
-------------------

**New features**

* Add support for using ``scipy.sparse.csr_matrix`` as datastructure for covariates ``X``.


0.10.0 (2024-08-13)
-------------------

Expand Down
100 changes: 20 additions & 80 deletions docs/examples/example_estimating_ates.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -40,7 +40,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -99,20 +99,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(2.083595103597918, 0.06526671583747883)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"naive_lm = smf.ols(f\"{outcome_column} ~ {treatment_column}\", df) .fit(cov_type=\"HC1\")\n",
"naive_est = naive_lm.params.iloc[1], naive_lm.bse.iloc[1]\n",
Expand All @@ -121,20 +110,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(2.1433722387308025, 0.06345124983351998)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"covaradjust_lm = smf.ols(f\"{outcome_column} ~ {treatment_column}+{'+'.join(feature_columns)}\",\n",
" df) .fit(cov_type=\"HC1\")\n",
Expand All @@ -160,42 +138,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type='text/css'>\n",
".datatable table.frame { margin-bottom: 0; }\n",
".datatable table.frame thead { border-bottom: none; }\n",
".datatable table.frame tr.coltypes td { color: #FFFFFF; line-height: 6px; padding: 0 0.5em;}\n",
".datatable .bool { background: #DDDD99; }\n",
".datatable .object { background: #565656; }\n",
".datatable .int { background: #5D9E5D; }\n",
".datatable .float { background: #4040CC; }\n",
".datatable .str { background: #CC4040; }\n",
".datatable .time { background: #40CC40; }\n",
".datatable .row_index { background: var(--jp-border-color3); border-right: 1px solid var(--jp-border-color0); color: var(--jp-ui-font-color3); font-size: 9px;}\n",
".datatable .frame tbody td { text-align: left; }\n",
".datatable .frame tr.coltypes .row_index { background: var(--jp-border-color0);}\n",
".datatable th:nth-child(2) { padding-left: 12px; }\n",
".datatable .hellipsis { color: var(--jp-cell-editor-border-color);}\n",
".datatable .vellipsis { background: var(--jp-layout-color0); color: var(--jp-cell-editor-border-color);}\n",
".datatable .na { color: var(--jp-cell-editor-border-color); font-size: 80%;}\n",
".datatable .sp { opacity: 0.25;}\n",
".datatable .footer { font-size: 9px; }\n",
".datatable .frame_dimensions { background: var(--jp-border-color3); border-top: 1px solid var(--jp-border-color0); color: var(--jp-ui-font-color3); display: inline-block; opacity: 0.6; padding: 1px 10px 1px 5px;}\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from metalearners import DRLearner\n",
"from lightgbm import LGBMRegressor, LGBMClassifier\n",
Expand All @@ -204,26 +149,15 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"(array([1.02931589]), array([0.06679633]))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"metalearners_dr = DRLearner(\n",
" nuisance_model_factory=LGBMRegressor,\n",
Expand Down Expand Up @@ -629,8 +563,14 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "py311",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"name": "python",
"version": "3.11.7"
},
"mystnb": {
"execution_timeout": 120
Expand Down
93 changes: 16 additions & 77 deletions docs/examples/example_sparse_inputs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type='text/css'>\n",
".datatable table.frame { margin-bottom: 0; }\n",
".datatable table.frame thead { border-bottom: none; }\n",
".datatable table.frame tr.coltypes td { color: #FFFFFF; line-height: 6px; padding: 0 0.5em;}\n",
".datatable .bool { background: #DDDD99; }\n",
".datatable .object { background: #565656; }\n",
".datatable .int { background: #5D9E5D; }\n",
".datatable .float { background: #4040CC; }\n",
".datatable .str { background: #CC4040; }\n",
".datatable .time { background: #40CC40; }\n",
".datatable .row_index { background: var(--jp-border-color3); border-right: 1px solid var(--jp-border-color0); color: var(--jp-ui-font-color3); font-size: 9px;}\n",
".datatable .frame tbody td { text-align: left; }\n",
".datatable .frame tr.coltypes .row_index { background: var(--jp-border-color0);}\n",
".datatable th:nth-child(2) { padding-left: 12px; }\n",
".datatable .hellipsis { color: var(--jp-cell-editor-border-color);}\n",
".datatable .vellipsis { background: var(--jp-layout-color0); color: var(--jp-cell-editor-border-color);}\n",
".datatable .na { color: var(--jp-cell-editor-border-color); font-size: 80%;}\n",
".datatable .sp { opacity: 0.25;}\n",
".datatable .footer { font-size: 9px; }\n",
".datatable .frame_dimensions { background: var(--jp-border-color3); border-top: 1px solid var(--jp-border-color0); color: var(--jp-ui-font-color3); display: inline-block; opacity: 0.6; padding: 1px 10px 1px 5px;}\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"import time, psutil, os, gc\n",
"import numpy as np\n",
Expand All @@ -76,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -104,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -175,7 +142,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -192,19 +159,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Sparse data memory: 0.76MB\n",
"Dense data memory: 41.28MB\n"
]
}
],
"outputs": [],
"source": [
"print(f\"\\nSparse data memory: {X_csr.data.nbytes / 1024 / 1024:.2f}MB\")\n",
"print(f\"Dense data memory: {X_np.nbytes / 1024 / 1024:.2f}MB\")"
Expand All @@ -219,11 +176,11 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def fit_drlearner_wrapper(X):\n",
"def fit_drlearner_wrapper(X, name):\n",
" start_memory = get_memory_usage()\n",
" start_time = time.time()\n",
" metalearners_dr = DRLearner(\n",
Expand Down Expand Up @@ -251,7 +208,7 @@
" end_memory = get_memory_usage()\n",
" runtime = end_time - start_time\n",
" memory_used = end_memory - start_memory\n",
" print(f\"Sparse data - Runtime: {runtime:.2f}s, Memory used: {memory_used:.2f}MB\")\n",
" print(f\"{name} data - Runtime: {runtime:.2f}s, Memory used: {memory_used:.2f}MB\")\n",
" print(metalearners_est)"
]
},
Expand All @@ -264,20 +221,11 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sparse data - Runtime: 3.06s, Memory used: 115.93MB\n",
"(array([1.0161235]), array([0.06374022]))\n"
]
}
],
"outputs": [],
"source": [
"fit_drlearner_wrapper(X_csr)"
"fit_drlearner_wrapper(X_csr, \"Sparse\")"
]
},
{
Expand All @@ -289,20 +237,11 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sparse data - Runtime: 6.91s, Memory used: 131.66MB\n",
"(array([1.01609547]), array([0.06384197]))\n"
]
}
],
"outputs": [],
"source": [
"fit_drlearner_wrapper(X_np)"
"fit_drlearner_wrapper(X_np, \"Dense\")"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
default_rng = np.random.default_rng()


def safe_len(X):
def safe_len(X: Matrix) -> int:
if scipy.sparse.issparse(X):
return X.shape[0]
return len(X)
Expand Down

0 comments on commit e4501b4

Please sign in to comment.