-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Example gridsearch * CSS file * Fix mypy using tags * Change tag for removing when converting to avoid not running the cell * Change tag for removing when converting to avoid not running the cell * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> * Update docs/examples/example_gridsearch.ipynb Co-authored-by: Kevin Klein <[email protected]> --------- Co-authored-by: Kevin Klein <[email protected]>
- Loading branch information
1 parent
58c3741
commit d8d84fc
Showing
6 changed files
with
385 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ jobs: | |
uses: prefix-dev/[email protected] | ||
- name: Run mypy | ||
run: | | ||
pixi run jupyter nbconvert --to script docs/examples/*.ipynb | ||
pixi run jupyter nbconvert --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags no-convert --to script docs/examples/*.ipynb | ||
for file in docs/examples/*.txt; do mv -- "$file" "${file%.txt}.py"; done | ||
pixi run mypy docs/examples/*.py | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/* Copied from https://github.com/executablebooks/MyST-NB/issues/453 */ | ||
div.cell.tag_scroll-output div.cell_output { | ||
max-height: 24em; | ||
overflow-y: auto; | ||
max-width: 100%; | ||
overflow-x: auto; | ||
} | ||
|
||
div.cell.tag_scroll-output div.cell_output::-webkit-scrollbar { | ||
width: 0.3rem; | ||
height: 0.3rem; | ||
} | ||
|
||
div.cell.tag_scroll-output div.cell_output::-webkit-scrollbar-thumb { | ||
background: #c1c1c1; | ||
border-radius: 0.25rem; | ||
} | ||
|
||
div.cell.tag_scroll-output div.cell_output::-webkit-scrollbar-thumb:hover { | ||
background: #a0a0a0; | ||
} | ||
|
||
@media print { | ||
div.cell.tag_scroll-output div.cell_output { | ||
max-height: unset; | ||
overflow-y: visible; | ||
max-width: unset; | ||
overflow-x: visible; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,349 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [ | ||
"hide-cell", | ||
"no-convert" | ||
], | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%%html\n", | ||
"<style>\n", | ||
"/* Any CSS style can go in here. */\n", | ||
".dataframe th {\n", | ||
" font-size: 12px;\n", | ||
"}\n", | ||
".dataframe td {\n", | ||
" font-size: 12px;\n", | ||
"}\n", | ||
"</style>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"(example-grid-search)=\n", | ||
"\n", | ||
"# Tuning hyperparameters of a MetaLearner with ``MetaLearnerGridSearch``\n", | ||
"\n", | ||
"Motivation\n", | ||
"----------\n", | ||
"\n", | ||
"We know that model selection and/or hyperparameter optimization (HPO) can\n", | ||
"have massive impacts on the prediction quality in regular Machine\n", | ||
"Learning. Yet, it seems that model selection and hyperparameter\n", | ||
"optimization are of substantial importance for CATE estimation with\n", | ||
"MetaLearners, too, see e.g. [Machlanski et. al](https://arxiv.org/abs/2303.01412>).\n", | ||
"\n", | ||
"However, model selection and HPO for MetaLearners look quite different from what we're used to from e.g. simple supervised learning problems. Concretely,\n", | ||
"\n", | ||
"* In terms of a MetaLearners's option space, there are several levels\n", | ||
" to optimize for:\n", | ||
"\n", | ||
" 1. The MetaLearner architecture, e.g. R-Learner vs DR-Learner\n", | ||
" 2. The model to choose per base estimator of said MetaLearner architecture, e.g. ``LogisticRegression`` vs ``LGBMClassifier``\n", | ||
" 3. The model hyperparameters per base model\n", | ||
"\n", | ||
"* On a conceptual level, it's not clear how to measure model quality\n", | ||
" for MetaLearners. As a proxy for the underlying quantity of\n", | ||
" interest one might look into base model performance, the R-Loss of\n", | ||
" the CATE estimates or some more elaborate approaches alluded to by\n", | ||
" [Machlanski et. al](https://arxiv.org/abs/2303.01412).\n", | ||
"\n", | ||
"We think that HPO can be divided into two camps:\n", | ||
"\n", | ||
"* Exploration of (hyperparameter, metric evaluation) pairs where the\n", | ||
" pairs do not influence each other (e.g. grid search, random search)\n", | ||
"\n", | ||
"* Exploration of (hyperparameter, metric evaluation) pairs where the\n", | ||
" pairs do influence each other (e.g. Bayesian optimization,\n", | ||
" evolutionary algorithms); in other words, there is a feedback-loop between\n", | ||
" sample result and sample\n", | ||
"\n", | ||
"In this example, we will illustrate the former and how one can make use of\n", | ||
"{class}`~metalearners.grid_search.MetaLearnerGridSearch` for it. For the latter please\n", | ||
"refer to the {ref}`example on model selection with optuna<example-optuna>`.\n", | ||
"\n", | ||
"Loading the data\n", | ||
"----------------\n", | ||
"\n", | ||
"Just like in our {ref}`example on estimating CATEs with a MetaLearner\n", | ||
"<example-basic>`, we will first load some experiment data:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"from pathlib import Path\n", | ||
"from git_root import git_root\n", | ||
"\n", | ||
"df = pd.read_csv(git_root(\"data/learning_mindset.zip\"))\n", | ||
"outcome_column = \"achievement_score\"\n", | ||
"treatment_column = \"intervention\"\n", | ||
"feature_columns = [\n", | ||
" column for column in df.columns if column not in [outcome_column, treatment_column]\n", | ||
"]\n", | ||
"categorical_feature_columns = [\n", | ||
" \"ethnicity\",\n", | ||
" \"gender\",\n", | ||
" \"frst_in_family\",\n", | ||
" \"school_urbanicity\",\n", | ||
" \"schoolid\",\n", | ||
"]\n", | ||
"# Note that explicitly setting the dtype of these features to category\n", | ||
"# allows both lightgbm as well as shap plots to\n", | ||
"# 1. Operate on features which are not of type int, bool or float\n", | ||
"# 2. Correctly interpret categoricals with int values to be\n", | ||
"# interpreted as categoricals, as compared to ordinals/numericals.\n", | ||
"for categorical_feature_column in categorical_feature_columns:\n", | ||
" df[categorical_feature_column] = df[categorical_feature_column].astype(\"category\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now that we've loaded the experiment data, we can split it up into\n", | ||
"train and validation data:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn.model_selection import train_test_split\n", | ||
"\n", | ||
"X_train, X_validation, y_train, y_validation, w_train, w_validation = train_test_split(\n", | ||
" df[feature_columns], df[outcome_column], df[treatment_column], test_size=0.25\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Performing the grid search\n", | ||
"--------------------------\n", | ||
"\n", | ||
"We can run a grid search by using the {class}`~metalearners.grid_search.MetaLearnerGridSearch`\n", | ||
"class. However, it's important to note that this class only supports a single MetaLearner\n", | ||
"architecture at a time. If you're interested in conducting a grid search across multiple architectures,\n", | ||
"it will require several grid searches.\n", | ||
"\n", | ||
"Let's say we want to work with a {class}`~metalearners.DRLearner`. We can check the names of\n", | ||
"the base models for this architecture with the following code:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from metalearners import DRLearner\n", | ||
"\n", | ||
"print(DRLearner.nuisance_model_specifications().keys())\n", | ||
"print(DRLearner.treatment_model_specifications().keys())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"source": [ | ||
"We see that this MetaLearner contains three base models: ``\"variant_outcome_model\"``,\n", | ||
"``\"propensity_model\"`` and ``\"treatment_model\"``.\n", | ||
"\n", | ||
"Since our problem has a regression outcome, the ``\"variant_outcome_model\"`` should be a regressor.\n", | ||
"The ``\"propensity_model\"`` and ``\"treatment_model\"`` are always a classifier and a regressor\n", | ||
"respectively.\n", | ||
"\n", | ||
"To instantiate the {class}`~metalearners.grid_search.MetaLearnerGridSearch` object we need to\n", | ||
"specify the different base models to be used. Moreover, if we'd like to use non-default hyperparameters for a given base model, we need to specify those, too.\n", | ||
"\n", | ||
"In this tutorial we test a ``LinearRegression`` and ``LGBMRegressor`` for the outcome model,\n", | ||
"a ``LGBMClassifier`` and ``QuadraticDiscriminantAnalysis`` for the propensity model and a\n", | ||
"``LGBMRegressor`` for the treatment model.\n", | ||
"\n", | ||
"Finally we can define the hyperparameters to test for the base models using the ``param_grid``\n", | ||
"parameter." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from metalearners.grid_search import MetaLearnerGridSearch\n", | ||
"from lightgbm import LGBMClassifier, LGBMRegressor\n", | ||
"from sklearn.linear_model import LinearRegression\n", | ||
"from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis\n", | ||
"\n", | ||
"gs = MetaLearnerGridSearch(\n", | ||
" metalearner_factory=DRLearner,\n", | ||
" metalearner_params={\"is_classification\": False, \"n_variants\": 2},\n", | ||
" base_learner_grid={\n", | ||
" \"variant_outcome_model\": [LinearRegression, LGBMRegressor],\n", | ||
" \"propensity_model\": [LGBMClassifier, QuadraticDiscriminantAnalysis],\n", | ||
" \"treatment_model\": [LGBMRegressor],\n", | ||
" },\n", | ||
" param_grid={\n", | ||
" \"variant_outcome_model\": {\n", | ||
" \"LGBMRegressor\": {\"n_estimators\": [3, 5], \"verbose\": [-1]}\n", | ||
" },\n", | ||
" \"treatment_model\": {\"LGBMRegressor\": {\"n_estimators\": [1, 2], \"verbose\": [-1]}},\n", | ||
" \"propensity_model\": {\n", | ||
" \"LGBMClassifier\": {\"n_estimators\": [1, 2, 3], \"verbose\": [-1]}\n", | ||
" },\n", | ||
" },\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now we can call {meth}`~metalearners.grid_search.MetaLearnerGridSearch.fit` with the train\n", | ||
"and validation data and can inspect the results ``DataFrame`` in ``results_``." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [ | ||
"scroll-output" | ||
], | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"gs.fit(X_train, y_train, w_train, X_validation, y_validation, w_validation)\n", | ||
"gs.results_" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Reusing base models\n", | ||
"--------------------\n", | ||
"In order to decrease the grid search runtime, it may sometimes be desirable to reuse some nuisance models.\n", | ||
"We refer to our {ref}`example of model reusage <example-reuse>` for a more in depth explanation\n", | ||
"on how this can be achieved, but here we'll show an example for the integration of model\n", | ||
"reusage with {class}`~metalearners.grid_search.MetaLearnerGridSearch`.\n", | ||
"\n", | ||
"We will reuse the ``\"variant_outcome_model\"`` of a {class}`~metalearners.TLearner` for\n", | ||
"a grid search over the {class}`~metalearners.XLearner`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [ | ||
"scroll-output" | ||
], | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from metalearners import TLearner, XLearner\n", | ||
"\n", | ||
"tl = TLearner(\n", | ||
" False,\n", | ||
" 2,\n", | ||
" LGBMRegressor,\n", | ||
" nuisance_model_params={\"verbose\": -1, \"n_estimators\": 20, \"learning_rate\": 0.05},\n", | ||
" n_folds=2,\n", | ||
")\n", | ||
"tl.fit(X_train, y_train, w_train)\n", | ||
"\n", | ||
"gs = MetaLearnerGridSearch(\n", | ||
" metalearner_factory=XLearner,\n", | ||
" metalearner_params={\n", | ||
" \"is_classification\": False,\n", | ||
" \"n_variants\": 2,\n", | ||
" \"n_folds\": 5, # The number of folds does not need to be the same as in the TLearner\n", | ||
" \"fitted_nuisance_models\": {\n", | ||
" \"variant_outcome_model\": tl._nuisance_models[\"variant_outcome_model\"]\n", | ||
" },\n", | ||
" },\n", | ||
" base_learner_grid={\n", | ||
" \"propensity_model\": [LGBMClassifier],\n", | ||
" \"control_effect_model\": [LGBMRegressor, LinearRegression],\n", | ||
" \"treatment_effect_model\": [LGBMRegressor, LinearRegression],\n", | ||
" },\n", | ||
" param_grid={\n", | ||
" \"propensity_model\": {\"LGBMClassifier\": {\"n_estimators\": [5], \"verbose\": [-1]}},\n", | ||
" \"treatment_effect_model\": {\n", | ||
" \"LGBMRegressor\": {\"n_estimators\": [5, 10], \"verbose\": [-1]}\n", | ||
" },\n", | ||
" \"control_effect_model\": {\n", | ||
" \"LGBMRegressor\": {\"n_estimators\": [1, 3], \"verbose\": [-1]}\n", | ||
" },\n", | ||
" },\n", | ||
")\n", | ||
"\n", | ||
"gs.fit(X_train, y_train, w_train, X_validation, y_validation, w_validation)\n", | ||
"gs.results_" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Further comments\n", | ||
"-------------------\n", | ||
"* We strongly recommend only reusing base models if they have been trained on\n", | ||
" exactly the same data. If this is not the case, some functionalities\n", | ||
" will probably not work as hoped for." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.