Skip to content

Commit

Permalink
Add netron visualization.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Jul 30, 2024
1 parent ce5888d commit c9e62c4
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 15 deletions.
142 changes: 127 additions & 15 deletions docs/examples/example_onnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -113,9 +113,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"<metalearners.xlearner.XLearner at 0x10e5b9520>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from metalearners import XLearner\n",
"from lightgbm import LGBMRegressor, LGBMClassifier\n",
Expand Down Expand Up @@ -161,9 +172,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'propensity_model': [LGBMClassifier(n_estimators=5, verbose=-1)],\n",
" 'control_effect_model': [LGBMRegressor(n_estimators=5, verbose=-1)],\n",
" 'treatment_effect_model': [LGBMRegressor(n_estimators=5, verbose=-1)]}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"necessary_models = xlearner._necessary_onnx_models()\n",
"necessary_models"
Expand All @@ -187,9 +211,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The maximum opset needed by this model is only 9.\n",
"The maximum opset needed by this model is only 8.\n",
"The maximum opset needed by this model is only 8.\n"
]
}
],
"source": [
"import onnx\n",
"from onnxmltools import convert_lightgbm\n",
Expand Down Expand Up @@ -221,9 +255,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/kevinklein/Code/metalearners/metalearners/xlearner.py:463: UserWarning: _build_onnx is an experimental feature. Use it at your own risk!\n",
" warning_experimental_feature(\"_build_onnx\")\n"
]
}
],
"source": [
"onnx_model = xlearner._build_onnx(onnx_models)"
]
Expand All @@ -239,14 +282,60 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ONNX model input: [name: \"X\"\n",
"type {\n",
" tensor_type {\n",
" elem_type: 1\n",
" shape {\n",
" dim {\n",
" }\n",
" dim {\n",
" dim_value: 11\n",
" }\n",
" }\n",
" }\n",
"}\n",
"]\n",
"ONNX model output: [name: \"tau\"\n",
"type {\n",
" tensor_type {\n",
" elem_type: 1\n",
" shape {\n",
" dim {\n",
" }\n",
" dim {\n",
" dim_value: 1\n",
" }\n",
" dim {\n",
" dim_value: 1\n",
" }\n",
" }\n",
" }\n",
"}\n",
"]\n"
]
}
],
"source": [
"print(\"ONNX model input: \", onnx_model.graph.input)\n",
"print(\"ONNX model output: \", onnx_model.graph.output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also visualize the ONNX model with, e.g. [netron](https://netron.app/):\n",
"<img src=\"imgs/onnx_netron.png\" alt=\"Drawing\" style=\"width: 400px;\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -261,7 +350,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -282,7 +371,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -298,6 +387,15 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"onnx.save_model(onnx_model, \"model.onnx\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -340,10 +438,24 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
Binary file added docs/examples/imgs/onnx_netron.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c9e62c4

Please sign in to comment.