Skip to content

Commit

Permalink
Add outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasmarwitz committed Dec 13, 2024
1 parent 7fbabd5 commit 2ef89f1
Show file tree
Hide file tree
Showing 9 changed files with 1,104 additions and 148 deletions.
58 changes: 46 additions & 12 deletions docs/examples/example_basic.ipynb

Large diffs are not rendered by default.

506 changes: 453 additions & 53 deletions docs/examples/example_estimating_ates.ipynb

Large diffs are not rendered by default.

141 changes: 127 additions & 14 deletions docs/examples/example_feature_importance_shap.ipynb

Large diffs are not rendered by default.

74 changes: 65 additions & 9 deletions docs/examples/example_lime.ipynb

Large diffs are not rendered by default.

112 changes: 96 additions & 16 deletions docs/examples/example_onnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -113,9 +113,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"<metalearners.xlearner.XLearner at 0x16a753050>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from metalearners import XLearner\n",
"from lightgbm import LGBMRegressor, LGBMClassifier\n",
Expand Down Expand Up @@ -161,9 +172,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'propensity_model': [LGBMClassifier(n_estimators=5, verbose=-1)],\n",
" 'control_effect_model': [LGBMRegressor(n_estimators=5, verbose=-1)],\n",
" 'treatment_effect_model': [LGBMRegressor(n_estimators=5, verbose=-1)]}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"necessary_models = xlearner._necessary_onnx_models()\n",
"necessary_models"
Expand All @@ -187,9 +211,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The maximum opset needed by this model is only 9.\n",
"The maximum opset needed by this model is only 8.\n",
"The maximum opset needed by this model is only 8.\n"
]
}
],
"source": [
"import onnx\n",
"from onnxmltools import convert_lightgbm\n",
Expand Down Expand Up @@ -221,9 +255,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"_build_onnx is an experimental feature. Use it at your own risk!\n"
]
}
],
"source": [
"onnx_model = xlearner._build_onnx(onnx_models)"
]
Expand All @@ -239,9 +281,47 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ONNX model input: [name: \"X\"\n",
"type {\n",
" tensor_type {\n",
" elem_type: 1\n",
" shape {\n",
" dim {\n",
" }\n",
" dim {\n",
" dim_value: 11\n",
" }\n",
" }\n",
" }\n",
"}\n",
"]\n",
"ONNX model output: [name: \"tau\"\n",
"type {\n",
" tensor_type {\n",
" elem_type: 1\n",
" shape {\n",
" dim {\n",
" }\n",
" dim {\n",
" dim_value: 1\n",
" }\n",
" dim {\n",
" dim_value: 1\n",
" }\n",
" }\n",
" }\n",
"}\n",
"]\n"
]
}
],
"source": [
"print(\"ONNX model input: \", onnx_model.graph.input)\n",
"print(\"ONNX model output: \", onnx_model.graph.output)"
Expand Down Expand Up @@ -269,7 +349,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -290,7 +370,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -308,7 +388,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -331,7 +411,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -372,7 +452,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
55 changes: 47 additions & 8 deletions docs/examples/example_propensity.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"vscode": {
"languageId": "plaintext"
Expand Down Expand Up @@ -69,13 +69,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"0.3256664421133673"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[treatment_column].mean()"
]
Expand Down Expand Up @@ -111,13 +122,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"<metalearners.rlearner.RLearner at 0x302d930d0>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from metalearners import RLearner\n",
"from metalearners.utils import FixedBinaryPropensity\n",
Expand Down Expand Up @@ -149,13 +171,30 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([[0.7, 0.3],\n",
" [0.7, 0.3],\n",
" [0.7, 0.3],\n",
" ...,\n",
" [0.7, 0.3],\n",
" [0.7, 0.3],\n",
" [0.7, 0.3]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rlearner.predict_nuisance(\n",
" X=df[feature_columns], model_kind=\"propensity_model\", model_ord=0, is_oos=False\n",
Expand Down Expand Up @@ -192,7 +231,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 2ef89f1

Please sign in to comment.