diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 02d971c..f9d98ef 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -36,7 +36,7 @@ jobs:
uses: prefix-dev/setup-pixi@v0.8.1
- name: Run mypy
run: |
- pixi run jupyter nbconvert --to script docs/examples/*.ipynb
+ pixi run jupyter nbconvert --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags no-convert --to script docs/examples/*.ipynb
for file in docs/examples/*.txt; do mv -- "$file" "${file%.txt}.py"; done
pixi run mypy docs/examples/*.py
diff --git a/docs/_static/custom.css b/docs/_static/custom.css
new file mode 100644
index 0000000..b07ad9c
--- /dev/null
+++ b/docs/_static/custom.css
@@ -0,0 +1,30 @@
+/* Copied from https://github.com/executablebooks/MyST-NB/issues/453 */
+div.cell.tag_scroll-output div.cell_output {
+ max-height: 24em;
+ overflow-y: auto;
+ max-width: 100%;
+ overflow-x: auto;
+}
+
+div.cell.tag_scroll-output div.cell_output::-webkit-scrollbar {
+ width: 0.3rem;
+ height: 0.3rem;
+}
+
+div.cell.tag_scroll-output div.cell_output::-webkit-scrollbar-thumb {
+ background: #c1c1c1;
+ border-radius: 0.25rem;
+}
+
+div.cell.tag_scroll-output div.cell_output::-webkit-scrollbar-thumb:hover {
+ background: #a0a0a0;
+}
+
+@media print {
+ div.cell.tag_scroll-output div.cell_output {
+ max-height: unset;
+ overflow-y: visible;
+ max-width: unset;
+ overflow-x: visible;
+ }
+}
diff --git a/docs/conf.py b/docs/conf.py
index c6cdcf1..bb8effc 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -68,6 +68,9 @@
numpydoc_show_class_members = False
+html_css_files = ["custom.css"]
+
+
# Copied and adapted from
# https://github.com/pandas-dev/pandas/blob/4a14d064187367cacab3ff4652a12a0e45d0711b/doc/source/conf.py#L613-L659
# Required configuration function to use sphinx.ext.linkcode
diff --git a/docs/examples/example_gridsearch.ipynb b/docs/examples/example_gridsearch.ipynb
new file mode 100644
index 0000000..586c5d0
--- /dev/null
+++ b/docs/examples/example_gridsearch.ipynb
@@ -0,0 +1,349 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": [
+ "hide-cell",
+ "no-convert"
+ ],
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "%%html\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "(example-grid-search)=\n",
+ "\n",
+ "# Tuning hyperparameters of a MetaLearner with ``MetaLearnerGridSearch``\n",
+ "\n",
+ "Motivation\n",
+ "----------\n",
+ "\n",
+ "We know that model selection and/or hyperparameter optimization (HPO) can\n",
+ "have massive impacts on the prediction quality in regular Machine\n",
+ "Learning. Yet, it seems that model selection and hyperparameter\n",
+ "optimization are of substantial importance for CATE estimation with\n",
+ "MetaLearners, too, see e.g. [Machlanski et. al](https://arxiv.org/abs/2303.01412>).\n",
+ "\n",
+ "However, model selection and HPO for MetaLearners look quite different from what we're used to from e.g. simple supervised learning problems. Concretely,\n",
+ "\n",
+ "* In terms of a MetaLearners's option space, there are several levels\n",
+ " to optimize for:\n",
+ "\n",
+ " 1. The MetaLearner architecture, e.g. R-Learner vs DR-Learner\n",
+ " 2. The model to choose per base estimator of said MetaLearner architecture, e.g. ``LogisticRegression`` vs ``LGBMClassifier``\n",
+ " 3. The model hyperparameters per base model\n",
+ "\n",
+ "* On a conceptual level, it's not clear how to measure model quality\n",
+ " for MetaLearners. As a proxy for the underlying quantity of\n",
+ " interest one might look into base model performance, the R-Loss of\n",
+ " the CATE estimates or some more elaborate approaches alluded to by\n",
+ " [Machlanski et. al](https://arxiv.org/abs/2303.01412).\n",
+ "\n",
+ "We think that HPO can be divided into two camps:\n",
+ "\n",
+ "* Exploration of (hyperparameter, metric evaluation) pairs where the\n",
+ " pairs do not influence each other (e.g. grid search, random search)\n",
+ "\n",
+ "* Exploration of (hyperparameter, metric evaluation) pairs where the\n",
+ " pairs do influence each other (e.g. Bayesian optimization,\n",
+ " evolutionary algorithms); in other words, there is a feedback-loop between\n",
+ " sample result and sample\n",
+ "\n",
+ "In this example, we will illustrate the former and how one can make use of\n",
+ "{class}`~metalearners.grid_search.MetaLearnerGridSearch` for it. For the latter please\n",
+ "refer to the {ref}`example on model selection with optuna`.\n",
+ "\n",
+ "Loading the data\n",
+ "----------------\n",
+ "\n",
+ "Just like in our {ref}`example on estimating CATEs with a MetaLearner\n",
+ "`, we will first load some experiment data:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "from pathlib import Path\n",
+ "from git_root import git_root\n",
+ "\n",
+ "df = pd.read_csv(git_root(\"data/learning_mindset.zip\"))\n",
+ "outcome_column = \"achievement_score\"\n",
+ "treatment_column = \"intervention\"\n",
+ "feature_columns = [\n",
+ " column for column in df.columns if column not in [outcome_column, treatment_column]\n",
+ "]\n",
+ "categorical_feature_columns = [\n",
+ " \"ethnicity\",\n",
+ " \"gender\",\n",
+ " \"frst_in_family\",\n",
+ " \"school_urbanicity\",\n",
+ " \"schoolid\",\n",
+ "]\n",
+ "# Note that explicitly setting the dtype of these features to category\n",
+ "# allows both lightgbm as well as shap plots to\n",
+ "# 1. Operate on features which are not of type int, bool or float\n",
+ "# 2. Correctly interpret categoricals with int values to be\n",
+ "# interpreted as categoricals, as compared to ordinals/numericals.\n",
+ "for categorical_feature_column in categorical_feature_columns:\n",
+ " df[categorical_feature_column] = df[categorical_feature_column].astype(\"category\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now that we've loaded the experiment data, we can split it up into\n",
+ "train and validation data:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import train_test_split\n",
+ "\n",
+ "X_train, X_validation, y_train, y_validation, w_train, w_validation = train_test_split(\n",
+ " df[feature_columns], df[outcome_column], df[treatment_column], test_size=0.25\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Performing the grid search\n",
+ "--------------------------\n",
+ "\n",
+ "We can run a grid search by using the {class}`~metalearners.grid_search.MetaLearnerGridSearch`\n",
+ "class. However, it's important to note that this class only supports a single MetaLearner\n",
+ "architecture at a time. If you're interested in conducting a grid search across multiple architectures,\n",
+ "it will require several grid searches.\n",
+ "\n",
+ "Let's say we want to work with a {class}`~metalearners.DRLearner`. We can check the names of\n",
+ "the base models for this architecture with the following code:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from metalearners import DRLearner\n",
+ "\n",
+ "print(DRLearner.nuisance_model_specifications().keys())\n",
+ "print(DRLearner.treatment_model_specifications().keys())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "source": [
+ "We see that this MetaLearner contains three base models: ``\"variant_outcome_model\"``,\n",
+ "``\"propensity_model\"`` and ``\"treatment_model\"``.\n",
+ "\n",
+ "Since our problem has a regression outcome, the ``\"variant_outcome_model\"`` should be a regressor.\n",
+ "The ``\"propensity_model\"`` and ``\"treatment_model\"`` are always a classifier and a regressor\n",
+ "respectively.\n",
+ "\n",
+ "To instantiate the {class}`~metalearners.grid_search.MetaLearnerGridSearch` object we need to\n",
+ "specify the different base models to be used. Moreover, if we'd like to use non-default hyperparameters for a given base model, we need to specify those, too.\n",
+ "\n",
+ "In this tutorial we test a ``LinearRegression`` and ``LGBMRegressor`` for the outcome model,\n",
+ "a ``LGBMClassifier`` and ``QuadraticDiscriminantAnalysis`` for the propensity model and a\n",
+ "``LGBMRegressor`` for the treatment model.\n",
+ "\n",
+ "Finally we can define the hyperparameters to test for the base models using the ``param_grid``\n",
+ "parameter."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from metalearners.grid_search import MetaLearnerGridSearch\n",
+ "from lightgbm import LGBMClassifier, LGBMRegressor\n",
+ "from sklearn.linear_model import LinearRegression\n",
+ "from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis\n",
+ "\n",
+ "gs = MetaLearnerGridSearch(\n",
+ " metalearner_factory=DRLearner,\n",
+ " metalearner_params={\"is_classification\": False, \"n_variants\": 2},\n",
+ " base_learner_grid={\n",
+ " \"variant_outcome_model\": [LinearRegression, LGBMRegressor],\n",
+ " \"propensity_model\": [LGBMClassifier, QuadraticDiscriminantAnalysis],\n",
+ " \"treatment_model\": [LGBMRegressor],\n",
+ " },\n",
+ " param_grid={\n",
+ " \"variant_outcome_model\": {\n",
+ " \"LGBMRegressor\": {\"n_estimators\": [3, 5], \"verbose\": [-1]}\n",
+ " },\n",
+ " \"treatment_model\": {\"LGBMRegressor\": {\"n_estimators\": [1, 2], \"verbose\": [-1]}},\n",
+ " \"propensity_model\": {\n",
+ " \"LGBMClassifier\": {\"n_estimators\": [1, 2, 3], \"verbose\": [-1]}\n",
+ " },\n",
+ " },\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we can call {meth}`~metalearners.grid_search.MetaLearnerGridSearch.fit` with the train\n",
+ "and validation data and can inspect the results ``DataFrame`` in ``results_``."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": [
+ "scroll-output"
+ ],
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "gs.fit(X_train, y_train, w_train, X_validation, y_validation, w_validation)\n",
+ "gs.results_"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Reusing base models\n",
+ "--------------------\n",
+ "In order to decrease the grid search runtime, it may sometimes be desirable to reuse some nuisance models.\n",
+ "We refer to our {ref}`example of model reusage ` for a more in depth explanation\n",
+ "on how this can be achieved, but here we'll show an example for the integration of model\n",
+ "reusage with {class}`~metalearners.grid_search.MetaLearnerGridSearch`.\n",
+ "\n",
+ "We will reuse the ``\"variant_outcome_model\"`` of a {class}`~metalearners.TLearner` for\n",
+ "a grid search over the {class}`~metalearners.XLearner`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": [
+ "scroll-output"
+ ],
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from metalearners import TLearner, XLearner\n",
+ "\n",
+ "tl = TLearner(\n",
+ " False,\n",
+ " 2,\n",
+ " LGBMRegressor,\n",
+ " nuisance_model_params={\"verbose\": -1, \"n_estimators\": 20, \"learning_rate\": 0.05},\n",
+ " n_folds=2,\n",
+ ")\n",
+ "tl.fit(X_train, y_train, w_train)\n",
+ "\n",
+ "gs = MetaLearnerGridSearch(\n",
+ " metalearner_factory=XLearner,\n",
+ " metalearner_params={\n",
+ " \"is_classification\": False,\n",
+ " \"n_variants\": 2,\n",
+ " \"n_folds\": 5, # The number of folds does not need to be the same as in the TLearner\n",
+ " \"fitted_nuisance_models\": {\n",
+ " \"variant_outcome_model\": tl._nuisance_models[\"variant_outcome_model\"]\n",
+ " },\n",
+ " },\n",
+ " base_learner_grid={\n",
+ " \"propensity_model\": [LGBMClassifier],\n",
+ " \"control_effect_model\": [LGBMRegressor, LinearRegression],\n",
+ " \"treatment_effect_model\": [LGBMRegressor, LinearRegression],\n",
+ " },\n",
+ " param_grid={\n",
+ " \"propensity_model\": {\"LGBMClassifier\": {\"n_estimators\": [5], \"verbose\": [-1]}},\n",
+ " \"treatment_effect_model\": {\n",
+ " \"LGBMRegressor\": {\"n_estimators\": [5, 10], \"verbose\": [-1]}\n",
+ " },\n",
+ " \"control_effect_model\": {\n",
+ " \"LGBMRegressor\": {\"n_estimators\": [1, 3], \"verbose\": [-1]}\n",
+ " },\n",
+ " },\n",
+ ")\n",
+ "\n",
+ "gs.fit(X_train, y_train, w_train, X_validation, y_validation, w_validation)\n",
+ "gs.results_"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Further comments\n",
+ "-------------------\n",
+ "* We strongly recommend only reusing base models if they have been trained on\n",
+ " exactly the same data. If this is not the case, some functionalities\n",
+ " will probably not work as hoped for."
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/examples/example_optuna.ipynb b/docs/examples/example_optuna.ipynb
index 1d87f94..2479cf2 100644
--- a/docs/examples/example_optuna.ipynb
+++ b/docs/examples/example_optuna.ipynb
@@ -46,6 +46,7 @@
"In this example, we will illustrate the latter camp based on an\n",
"application of [optuna](https://github.com/optuna/optuna) -- a\n",
"popular framework for HPO -- in interplay with ``metalearners``.\n",
+ "For the former please refer to the {ref}`example on hyperparameter tuning with MetaLearnerGridSearch`.\n",
"\n",
"Installation\n",
"------------\n",
diff --git a/docs/examples/index.rst b/docs/examples/index.rst
index 629a177..d825a99 100644
--- a/docs/examples/index.rst
+++ b/docs/examples/index.rst
@@ -10,4 +10,5 @@ Examples
Explainability: Lime plots of MetaLearners
Explainability: Feature importance and SHAP values
Model selection with optuna
+ Tuning hyperparameters of a MetaLearner with MetaLearnerGridSearch
Generating data