-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0ce4c73
commit 4e9b975
Showing
3 changed files
with
340 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,337 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Converting a MetaLearner to ONNX\n", | ||
"\n", | ||
"```{warning}\n", | ||
"This is a experimental feature which is not subject to deprecation cycles. Use\n", | ||
"it at your own risk!\n", | ||
"```\n", | ||
"\n", | ||
"ONNX is an open standard for representing machine learning models, enabling models to be\n", | ||
"easily transferred between various deep learning frameworks. By converting a Metalearner\n", | ||
"into an ONNX model, it becomes easier to leverage the model in different environments without\n", | ||
"needing to worry about compatibility or performance issues.\n", | ||
"\n", | ||
"This conversion also allows models to be run on a variety of hardware setups. Also, ONNX\n", | ||
"models are optimized for efficient computation, enabling faster inference compared to\n", | ||
"the python interface.\n", | ||
"\n", | ||
"For more information about ONNX, you can check their [website](https://onnx.ai/).\n", | ||
"\n", | ||
"In this example we will show how most MetaLearners can be converted to ONNX.\n", | ||
"\n", | ||
"## Installation\n", | ||
"\n", | ||
"In order to convert a MetaLearner to ONNX, we first need to install the following packages:\n", | ||
"\n", | ||
"* [onnx](https://github.com/onnx/onnx)\n", | ||
"* [onnxmltools](https://github.com/onnx/onnxmltools)\n", | ||
"* [onnxruntime](https://github.com/microsoft/onnxruntime)\n", | ||
"* [spox](https://github.com/Quantco/spox)\n", | ||
"\n", | ||
"We can do so either via conda and conda-forge:\n", | ||
"\n", | ||
"```console\n", | ||
"$ conda install onnx onnxmltools onnxruntime spox -c conda-forge\n", | ||
"```\n", | ||
"\n", | ||
"or via pip and PyPI\n", | ||
"\n", | ||
"```console\n", | ||
"$ pip install onnx onnxmltools onnxruntime spox\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Usage\n", | ||
"\n", | ||
"```{warning}\n", | ||
"It is important to notice that this method only works for {class}`~metalearners.TLearner`,\n", | ||
"{class}`~metalearners.XLearner`, {class}`~metalearners.RLearner` and {class}`~metalearners.DRLearner`.\n", | ||
"Converting an {class}`~metalearners.SLearner` is highly dependent on the fact that the base\n", | ||
"model supports categorical variables or not and it is not implemented yet. \n", | ||
"```\n", | ||
"\n", | ||
"### Loading the data\n", | ||
"\n", | ||
"Just like in our {ref}`example on estimating CATEs with a MetaLearner <example-basic>`,\n", | ||
"we will first load some experiment data:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"from pathlib import Path\n", | ||
"from git_root import git_root\n", | ||
"\n", | ||
"df = pd.read_csv(git_root(\"data/learning_mindset.zip\"))\n", | ||
"outcome_column = \"achievement_score\"\n", | ||
"treatment_column = \"intervention\"\n", | ||
"feature_columns = [\n", | ||
" column for column in df.columns if column not in [outcome_column, treatment_column]\n", | ||
"]\n", | ||
"categorical_feature_columns = [\n", | ||
" \"ethnicity\",\n", | ||
" \"gender\",\n", | ||
" \"frst_in_family\",\n", | ||
" \"school_urbanicity\",\n", | ||
" \"schoolid\",\n", | ||
"]\n", | ||
"# Note that explicitly setting the dtype of these features to category\n", | ||
"# allows both lightgbm as well as shap plots to\n", | ||
"# 1. Operate on features which are not of type int, bool or float\n", | ||
"# 2. Correctly interpret categoricals with int values to be\n", | ||
"# interpreted as categoricals, as compared to ordinals/numericals.\n", | ||
"for categorical_feature_column in categorical_feature_columns:\n", | ||
" df[categorical_feature_column] = df[categorical_feature_column].astype(\"category\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now that we've loaded the experiment data, we can train a MetaLearner.\n", | ||
"\n", | ||
"\n", | ||
"### Training a MetaLearner\n", | ||
"\n", | ||
"Again, mirroring our {ref}`example on estimating CATEs with a MetaLearner\n", | ||
"<example-basic>`, we can train an\n", | ||
"{class}`~metalearners.XLearner` as follows:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from metalearners import XLearner\n", | ||
"from lightgbm import LGBMRegressor, LGBMClassifier\n", | ||
"\n", | ||
"xlearner = XLearner(\n", | ||
" nuisance_model_factory=LGBMRegressor,\n", | ||
" propensity_model_factory=LGBMClassifier,\n", | ||
" treatment_model_factory=LGBMRegressor,\n", | ||
" is_classification=False,\n", | ||
" n_variants=2,\n", | ||
" nuisance_model_params={\"n_estimators\": 5, \"verbose\": -1},\n", | ||
" propensity_model_params={\"n_estimators\": 5, \"verbose\": -1},\n", | ||
" treatment_model_params={\"n_estimators\": 5, \"verbose\": -1},\n", | ||
" n_folds=2,\n", | ||
")\n", | ||
"\n", | ||
"xlearner.fit(\n", | ||
" X=df[feature_columns],\n", | ||
" y=df[outcome_column],\n", | ||
" w=df[treatment_column],\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Converting the base models to ONNX\n", | ||
"\n", | ||
"Before being able to convert the MetaLearner to ONXX we need to manually convert the necessary\n", | ||
"base models for the prediction. To get a list of the necessary base models that need to be\n", | ||
"converted we can use :meth:`~metalearners.MetaLearner.necessary_onnx_models`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"xlearner.necessary_onnx_models()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We see that we need to convert the ``\"propensity_model\"``, the ``\"control_effect_model\"``\n", | ||
"and the ``\"treatment_effect_model\"``. We can do this with the following code where we\n", | ||
"use the ``convert_lightgbm`` function from the ``onnxmltools`` package.\n", | ||
"\n", | ||
"```{note}\n", | ||
"It is important to know that for classifiers we need to pass the ``zipmap=False`` option. This\n", | ||
"is required so the output probabilities are a Matrix and not a list of dictionaries.\n", | ||
"In the case of using a ``sklearn`` model and using the ``convert_sklearn`` function, this\n", | ||
"option needs to be specified with the ``options={\"zipmap\": False}`` parameter.\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import onnx\n", | ||
"from onnxmltools import convert_lightgbm\n", | ||
"from onnxconverter_common.data_types import FloatTensorType\n", | ||
"\n", | ||
"onnx_models: dict[str, list[onnx.ModelProto]] = {\n", | ||
" \"control_effect_model\": [],\n", | ||
" \"treatment_effect_model\": [],\n", | ||
" \"propensity_model\": [],\n", | ||
"}\n", | ||
"\n", | ||
"for model in xlearner._nuisance_models[\"propensity_model\"]:\n", | ||
" onnx_model = convert_lightgbm(\n", | ||
" model._overall_estimator,\n", | ||
" initial_types=[(\"X\", FloatTensorType([None, len(feature_columns)]))],\n", | ||
" zipmap=False,\n", | ||
" )\n", | ||
" onnx_models[\"propensity_model\"].append(onnx_model)\n", | ||
"\n", | ||
"for model in xlearner._treatment_models[\"control_effect_model\"]:\n", | ||
" onnx_model = convert_lightgbm(\n", | ||
" model._overall_estimator,\n", | ||
" initial_types=[(\"X\", FloatTensorType([None, len(feature_columns)]))],\n", | ||
" )\n", | ||
" onnx_models[\"control_effect_model\"].append(onnx_model)\n", | ||
"\n", | ||
"for model in xlearner._treatment_models[\"treatment_effect_model\"]:\n", | ||
" onnx_model = convert_lightgbm(\n", | ||
" model._overall_estimator,\n", | ||
" initial_types=[(\"X\", FloatTensorType([None, len(feature_columns)]))],\n", | ||
" )\n", | ||
" onnx_models[\"treatment_effect_model\"].append(onnx_model)\n", | ||
"\n", | ||
"onnx_model = xlearner.build_onnx(onnx_models)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can explore the input and output of the model and see that it expects a matrix with 11\n", | ||
"columns and returns a three dimensional tensor with shape ``(..., 1, 1)`` which is expected\n", | ||
"as there is only two treatment variants and one outcome as it is a regression problem." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print(\"ONNX model input: \", onnx_model.graph.input)\n", | ||
"print(\"ONNX model output: \", onnx_model.graph.output)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"```{note}\n", | ||
"We noticed that ``convert_lightgbm`` does not support using native pandas categorical variables.\n", | ||
"This is because a numpy array needs to be passed when predicting, for this reason we need to\n", | ||
"use the categories codes in the input matrix. For more context on this issue see\n", | ||
"[here](https://github.com/onnx/onnxmltools/issues/309) and [here](https://github.com/microsoft/LightGBM/issues/5162).\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"X_onnx = df[feature_columns].copy(deep=True)\n", | ||
"for c in categorical_feature_columns:\n", | ||
" X_onnx[c] = df[c].cat.codes\n", | ||
"X_onnx = X_onnx.to_numpy(np.float32)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can finally use ``onnxruntime`` to perform predictions using our model:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import onnxruntime as rt\n", | ||
"\n", | ||
"sess = rt.InferenceSession(\n", | ||
" onnx_model.SerializeToString(), providers=rt.get_available_providers()\n", | ||
")\n", | ||
"\n", | ||
"pred_onnx = sess.run(\n", | ||
" [\"tau\"],\n", | ||
" {\"X\": X_onnx},\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We recommend always doing a final check with some data that the CATEs predicted by the python\n", | ||
"implementation and the ONNX model are the same (up to some tolerance). This can be done with\n", | ||
"the following code:\n", | ||
"\n", | ||
"```{note}\n", | ||
"We have to use the data as if it was out-of-sample with ``oos_method = True`` as when we\n", | ||
"converted the base models we used the ``_overall_estimtor``.\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"np.testing.assert_allclose(\n", | ||
" xlearner.predict(df[feature_columns], True, \"overall\"), pred_onnx[0], atol=1e-6\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Further comments\n", | ||
"\n", | ||
"* It would be desirable to work with ``DoubleTensorType`` instead of ``FloatTensorType``\n", | ||
" but we have noted that some converters have issues with it. We recommend try using\n", | ||
" ``DoubleTensorType`` but switching to ``FloatTensorType`` in case the converter fails.\n", | ||
"* In the case the final assertion fails we recommend first testing that the different\n", | ||
" base models have the same base outputs as we discovered some issues with some converters,\n", | ||
" see [this issue](https://github.com/onnx/sklearn-onnx/issues/1117) and\n", | ||
" [this issue](https://github.com/onnx/sklearn-onnx/issues/1116)." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters