From a5a6fc3637fa5b8d6c435f937c25020ffd1e7a6d Mon Sep 17 00:00:00 2001 From: Thomas Marwitz Date: Tue, 3 Dec 2024 14:06:12 +0100 Subject: [PATCH] --wip-- [skip ci] --- docs/examples/example_data_generation.ipynb | 182 ++- docs/examples/example_estimating_ates.ipynb | 2 +- docs/examples/example_gridsearch.ipynb | 1262 ++++++++++++++++++- docs/examples/example_lime.ipynb | 2 +- docs/examples/example_propensity.ipynb | 12 +- docs/glossary.md | 2 +- docs/motivation.md | 3 +- metalearners/rlearner.py | 58 +- mkdocs.yml | 12 +- pixi.lock | 85 ++ pixi.toml | 1 + 11 files changed, 1550 insertions(+), 71 deletions(-) diff --git a/docs/examples/example_data_generation.ipynb b/docs/examples/example_data_generation.ipynb index eeafcd5b..def9b1fc 100644 --- a/docs/examples/example_data_generation.ipynb +++ b/docs/examples/example_data_generation.ipynb @@ -53,13 +53,124 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { + "execution": { + "iopub.execute_input": "2024-11-26T14:13:33.556121Z", + "iopub.status.busy": "2024-11-26T14:13:33.556025Z", + "iopub.status.idle": "2024-11-26T14:13:34.685096Z", + "shell.execute_reply": "2024-11-26T14:13:34.684830Z" + }, "vscode": { "languageId": "plaintext" } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
01234567
0-0.432580-1.956691-2.724410-4.051359-4.275785023
10.213119-1.7942463.3352720.596448-8.053070023
2-0.333022-1.8553242.567406-0.507977-7.255018243
3-1.036547-1.3799201.721547-2.817249-4.626411231
4-1.514100-3.060547-4.077247-5.819707-4.468868503
\n", + "
" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 7\n", + "0 -0.432580 -1.956691 -2.724410 -4.051359 -4.275785 0 2 3\n", + "1 0.213119 -1.794246 3.335272 0.596448 -8.053070 0 2 3\n", + "2 -0.333022 -1.855324 2.567406 -0.507977 -7.255018 2 4 3\n", + "3 -1.036547 -1.379920 1.721547 -2.817249 -4.626411 2 3 1\n", + "4 -1.514100 -3.060547 -4.077247 -5.819707 -4.468868 5 0 3" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from metalearners.data_generation import generate_covariates\n", "\n", @@ -89,13 +200,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { + "execution": { + "iopub.execute_input": "2024-11-26T14:13:34.704946Z", + "iopub.status.busy": "2024-11-26T14:13:34.704791Z", + "iopub.status.idle": "2024-11-26T14:13:34.707624Z", + "shell.execute_reply": "2024-11-26T14:13:34.707391Z" + }, "vscode": { "languageId": "plaintext" } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(numpy.ndarray, array([0, 1]), 0.514)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import numpy as np\n", "from metalearners.data_generation import generate_treatment\n", @@ -125,13 +253,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { + "execution": { + "iopub.execute_input": "2024-11-26T14:13:34.708926Z", + "iopub.status.busy": "2024-11-26T14:13:34.708830Z", + "iopub.status.idle": "2024-11-26T14:13:34.713643Z", + "shell.execute_reply": "2024-11-26T14:13:34.713435Z" + }, "vscode": { "languageId": "plaintext" } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-4.6390948 , -6.99101697],\n", + " [-4.5927874 , -1.43775422],\n", + " [-5.6179741 , -3.62754599],\n", + " ...,\n", + " [-5.81369594, -2.16523526],\n", + " [ 0.89106589, 0.44998321],\n", + " [-6.62191898, -7.66198481]])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from metalearners._utils import get_linear_dimension\n", "from metalearners.outcome_functions import linear_treatment_effect\n", @@ -161,8 +312,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { + "execution": { + "iopub.execute_input": "2024-11-26T14:13:34.714879Z", + "iopub.status.busy": "2024-11-26T14:13:34.714797Z", + "iopub.status.idle": "2024-11-26T14:13:34.716540Z", + "shell.execute_reply": "2024-11-26T14:13:34.716345Z" + }, "vscode": { "languageId": "plaintext" } @@ -180,7 +337,16 @@ ], "metadata": { "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" } }, "nbformat": 4, diff --git a/docs/examples/example_estimating_ates.ipynb b/docs/examples/example_estimating_ates.ipynb index ec88ec48..7d06cdfd 100644 --- a/docs/examples/example_estimating_ates.ipynb +++ b/docs/examples/example_estimating_ates.ipynb @@ -254,7 +254,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### `econML`: `LinearDRLearner`" + "### econML: `LinearDRLearner`" ] }, { diff --git a/docs/examples/example_gridsearch.ipynb b/docs/examples/example_gridsearch.ipynb index 9c0b3c53..c9f3b93e 100644 --- a/docs/examples/example_gridsearch.ipynb +++ b/docs/examples/example_gridsearch.ipynb @@ -2,8 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { + "execution": { + "iopub.execute_input": "2024-11-26T14:12:26.656377Z", + "iopub.status.busy": "2024-11-26T14:12:26.656172Z", + "iopub.status.idle": "2024-11-26T14:12:26.664227Z", + "shell.execute_reply": "2024-11-26T14:12:26.663900Z" + }, "tags": [ "hide-cell", "no-convert" @@ -12,7 +18,28 @@ "languageId": "plaintext" } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "%%html\n", "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fit_timescore_timetrain_variant_outcome_model_0_neg_root_mean_squared_errortrain_variant_outcome_model_1_neg_root_mean_squared_errortrain_propensity_model_neg_log_losstrain_treatment_model_1_vs_0_neg_root_mean_squared_errortest_variant_outcome_model_0_neg_root_mean_squared_errortest_variant_outcome_model_1_neg_root_mean_squared_errortest_propensity_model_neg_log_losstest_treatment_model_1_vs_0_neg_root_mean_squared_error
metalearnerpropensity_modelpropensity_model_n_estimatorspropensity_model_verbosevariant_outcome_modelvariant_outcome_model_n_estimatorsvariant_outcome_model_verbosetreatment_modeltreatment_model_n_estimatorstreatment_model_verbose
DRLearnerLGBMClassifier1.0-1.0LinearRegressionNaNNaNLGBMRegressor1-10.3626320.068002-0.852074-0.848146-0.632090-1.813472-0.840509-0.832314-0.628208-1.771340
2-10.3882840.067350-0.852355-0.848704-0.631710-1.811203-0.840509-0.832314-0.628208-1.769949
2.0-1.0LinearRegressionNaNNaNLGBMRegressor1-10.3926400.066799-0.852027-0.847422-0.631746-1.812832-0.840509-0.832314-0.628263-1.773597
2-10.4546540.067791-0.852033-0.847687-0.632071-1.816253-0.840509-0.832314-0.628263-1.773111
3.0-1.0LinearRegressionNaNNaNLGBMRegressor1-10.4516040.069173-0.851851-0.847961-0.632294-1.815351-0.840509-0.832314-0.628397-1.777481
2-10.5125990.068227-0.852593-0.848230-0.632181-1.817798-0.840509-0.832314-0.628397-1.777205
1.0-1.0LGBMRegressor3.0-1.0LGBMRegressor1-10.7521400.090002-0.897820-0.914654-0.631841-1.937299-0.904254-0.883362-0.628208-1.893821
5.0-1.0LGBMRegressor1-11.0300020.092309-0.868651-0.881616-0.632241-1.867413-0.874428-0.851044-0.628208-1.824367
3.0-1.0LGBMRegressor2-10.8132040.088926-0.897875-0.916016-0.632136-1.938607-0.904254-0.883362-0.628208-1.893318
5.0-1.0LGBMRegressor2-11.0752020.097479-0.868589-0.883810-0.632147-1.869441-0.874428-0.851044-0.628208-1.823845
2.0-1.0LGBMRegressor3.0-1.0LGBMRegressor1-10.8638830.093149-0.898204-0.916213-0.631750-1.939523-0.904254-0.883362-0.628263-1.897660
5.0-1.0LGBMRegressor1-11.0659250.091348-0.869530-0.881849-0.631758-1.869640-0.874428-0.851044-0.628263-1.828047
3.0-1.0LGBMRegressor2-10.8747740.089929-0.896739-0.915254-0.632360-1.940699-0.904254-0.883362-0.628263-1.896607
5.0-1.0LGBMRegressor2-11.1156290.090655-0.869139-0.880639-0.631567-1.865421-0.874428-0.851044-0.628263-1.827504
3.0-1.0LGBMRegressor3.0-1.0LGBMRegressor1-10.8768340.088035-0.897744-0.914163-0.631423-1.937204-0.904254-0.883362-0.628397-1.901964
5.0-1.0LGBMRegressor1-11.1448600.089213-0.868147-0.882217-0.631853-1.871375-0.874428-0.851044-0.628397-1.831622
3.0-1.0LGBMRegressor2-10.9360020.092514-0.896970-0.916339-0.632114-1.946211-0.904254-0.883362-0.628397-1.901835
5.0-1.0LGBMRegressor2-11.1968360.098911-0.868453-0.880995-0.632492-1.872229-0.874428-0.851044-0.628397-1.831039
QuadraticDiscriminantAnalysisNaNNaNLinearRegressionNaNNaNLGBMRegressor1-10.2528700.058438-0.852213-0.849771-0.640512-2.276244-0.840509-0.832314-0.638479-4.386575
2-10.3256310.059346-0.851522-0.849607-0.640054-2.272417-0.840509-0.832314-0.638479-4.374702
LGBMRegressor3.0-1.0LGBMRegressor1-10.6781140.084764-0.897237-0.916300-0.640197-2.219056-0.904254-0.883362-0.638479-2.311704
5.0-1.0LGBMRegressor1-10.9129640.083096-0.869783-0.883112-0.641877-2.757328-0.874428-0.851044-0.638479-2.196190
3.0-1.0LGBMRegressor2-10.7477910.082716-0.897858-0.914723-0.640128-2.231605-0.904254-0.883362-0.638479-2.313465
5.0-1.0LGBMRegressor2-10.9775370.084649-0.868732-0.879986-0.639841-2.204121-0.874428-0.851044-0.638479-2.202680
\n", + "" + ], + "text/plain": [ + " fit_time \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 0.362632 \n", + " 2 -1 0.388284 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 0.392640 \n", + " 2 -1 0.454654 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 0.451604 \n", + " 2 -1 0.512599 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 0.752140 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 1.030002 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 0.813204 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 1.075202 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 0.863883 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 1.065925 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 0.874774 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 1.115629 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 0.876834 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 1.144860 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 0.936002 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 1.196836 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 0.252870 \n", + " 2 -1 0.325631 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 0.678114 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 0.912964 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 0.747791 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 0.977537 \n", + "\n", + " score_time \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 0.068002 \n", + " 2 -1 0.067350 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 0.066799 \n", + " 2 -1 0.067791 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 0.069173 \n", + " 2 -1 0.068227 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 0.090002 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 0.092309 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 0.088926 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 0.097479 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 0.093149 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 0.091348 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 0.089929 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 0.090655 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 0.088035 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 0.089213 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 0.092514 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 0.098911 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 0.058438 \n", + " 2 -1 0.059346 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 0.084764 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 0.083096 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 0.082716 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 0.084649 \n", + "\n", + " train_variant_outcome_model_0_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.852074 \n", + " 2 -1 -0.852355 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.852027 \n", + " 2 -1 -0.852033 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.851851 \n", + " 2 -1 -0.852593 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.897820 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.868651 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.897875 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.868589 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.898204 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.869530 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.896739 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.869139 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.897744 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.868147 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.896970 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.868453 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 -0.852213 \n", + " 2 -1 -0.851522 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.897237 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.869783 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.897858 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.868732 \n", + "\n", + " train_variant_outcome_model_1_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.848146 \n", + " 2 -1 -0.848704 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.847422 \n", + " 2 -1 -0.847687 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.847961 \n", + " 2 -1 -0.848230 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.914654 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.881616 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.916016 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.883810 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.916213 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.881849 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.915254 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.880639 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.914163 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.882217 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.916339 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.880995 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 -0.849771 \n", + " 2 -1 -0.849607 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.916300 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.883112 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.914723 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.879986 \n", + "\n", + " train_propensity_model_neg_log_loss \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.632090 \n", + " 2 -1 -0.631710 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.631746 \n", + " 2 -1 -0.632071 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.632294 \n", + " 2 -1 -0.632181 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.631841 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.632241 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.632136 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.632147 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.631750 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.631758 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.632360 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.631567 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.631423 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.631853 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.632114 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.632492 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 -0.640512 \n", + " 2 -1 -0.640054 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.640197 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.641877 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.640128 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.639841 \n", + "\n", + " train_treatment_model_1_vs_0_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -1.813472 \n", + " 2 -1 -1.811203 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -1.812832 \n", + " 2 -1 -1.816253 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -1.815351 \n", + " 2 -1 -1.817798 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -1.937299 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -1.867413 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -1.938607 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -1.869441 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -1.939523 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -1.869640 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -1.940699 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -1.865421 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -1.937204 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -1.871375 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -1.946211 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -1.872229 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 -2.276244 \n", + " 2 -1 -2.272417 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -2.219056 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -2.757328 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -2.231605 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -2.204121 \n", + "\n", + " test_variant_outcome_model_0_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.840509 \n", + " 2 -1 -0.840509 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.840509 \n", + " 2 -1 -0.840509 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.840509 \n", + " 2 -1 -0.840509 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.904254 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.874428 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.904254 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.874428 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.904254 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.874428 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.904254 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.874428 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.904254 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.874428 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.904254 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.874428 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 -0.840509 \n", + " 2 -1 -0.840509 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.904254 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.874428 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.904254 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.874428 \n", + "\n", + " test_variant_outcome_model_1_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.832314 \n", + " 2 -1 -0.832314 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.832314 \n", + " 2 -1 -0.832314 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.832314 \n", + " 2 -1 -0.832314 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.883362 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.851044 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.883362 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.851044 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.883362 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.851044 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.883362 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.851044 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.883362 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.851044 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.883362 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.851044 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 -0.832314 \n", + " 2 -1 -0.832314 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.883362 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.851044 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.883362 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.851044 \n", + "\n", + " test_propensity_model_neg_log_loss \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.628208 \n", + " 2 -1 -0.628208 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.628263 \n", + " 2 -1 -0.628263 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -0.628397 \n", + " 2 -1 -0.628397 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.628208 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.628208 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.628208 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.628208 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.628263 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.628263 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.628263 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.628263 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.628397 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.628397 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.628397 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.628397 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 -0.638479 \n", + " 2 -1 -0.638479 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -0.638479 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -0.638479 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -0.638479 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -0.638479 \n", + "\n", + " test_treatment_model_1_vs_0_neg_root_mean_squared_error \n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose variant_outcome_model variant_outcome_model_n_estimators variant_outcome_model_verbose treatment_model treatment_model_n_estimators treatment_model_verbose \n", + "DRLearner LGBMClassifier 1.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -1.771340 \n", + " 2 -1 -1.769949 \n", + " 2.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -1.773597 \n", + " 2 -1 -1.773111 \n", + " 3.0 -1.0 LinearRegression NaN NaN LGBMRegressor 1 -1 -1.777481 \n", + " 2 -1 -1.777205 \n", + " 1.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -1.893821 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -1.824367 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -1.893318 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -1.823845 \n", + " 2.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -1.897660 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -1.828047 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -1.896607 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -1.827504 \n", + " 3.0 -1.0 LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -1.901964 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -1.831622 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -1.901835 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -1.831039 \n", + " QuadraticDiscriminantAnalysis NaN NaN LinearRegression NaN NaN LGBMRegressor 1 -1 -4.386575 \n", + " 2 -1 -4.374702 \n", + " LGBMRegressor 3.0 -1.0 LGBMRegressor 1 -1 -2.311704 \n", + " 5.0 -1.0 LGBMRegressor 1 -1 -2.196190 \n", + " 3.0 -1.0 LGBMRegressor 2 -1 -2.313465 \n", + " 5.0 -1.0 LGBMRegressor 2 -1 -2.202680 " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "gs.fit(X_train, y_train, w_train, X_validation, y_validation, w_validation)\n", "gs.results_" @@ -275,8 +1108,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { + "execution": { + "iopub.execute_input": "2024-11-26T14:13:26.279819Z", + "iopub.status.busy": "2024-11-26T14:13:26.279699Z", + "iopub.status.idle": "2024-11-26T14:13:31.603633Z", + "shell.execute_reply": "2024-11-26T14:13:31.603343Z" + }, "tags": [ "scroll-output" ], @@ -284,7 +1123,397 @@ "languageId": "plaintext" } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fit_timescore_timetrain_variant_outcome_model_0_neg_root_mean_squared_errortrain_variant_outcome_model_1_neg_root_mean_squared_errortrain_propensity_model_neg_log_losstrain_treatment_effect_model_1_vs_0_neg_root_mean_squared_errortrain_control_effect_model_1_vs_0_neg_root_mean_squared_errortest_variant_outcome_model_0_neg_root_mean_squared_errortest_variant_outcome_model_1_neg_root_mean_squared_errortest_propensity_model_neg_log_losstest_treatment_effect_model_1_vs_0_neg_root_mean_squared_errortest_control_effect_model_1_vs_0_neg_root_mean_squared_error
metalearnerpropensity_modelpropensity_model_n_estimatorspropensity_model_verbosecontrol_effect_modelcontrol_effect_model_n_estimatorscontrol_effect_model_verbosetreatment_effect_modeltreatment_effect_model_n_estimatorstreatment_effect_model_verbose
XLearnerLGBMClassifier5-1LGBMRegressor1.0-1.0LGBMRegressor5.0-1.00.4685900.046710-0.835861-0.84813-0.631601-0.814415-0.824946-0.8372-0.811075-0.628424-0.789648-0.833076
10.0-1.00.6263670.046410-0.835861-0.84813-0.631980-0.803432-0.825162-0.8372-0.811075-0.628424-0.779851-0.833076
3.0-1.0LGBMRegressor5.0-1.00.5322250.047818-0.835861-0.84813-0.633323-0.813032-0.816535-0.8372-0.811075-0.628424-0.789648-0.825040
10.0-1.00.6914590.046533-0.835861-0.84813-0.633088-0.803686-0.816648-0.8372-0.811075-0.628424-0.779851-0.825040
1.0-1.0LinearRegressionNaNNaN0.2934620.043040-0.835861-0.84813-0.633597-0.808634-0.824883-0.8372-0.811075-0.628424-0.788946-0.833076
3.0-1.0LinearRegressionNaNNaN0.3669400.045290-0.835861-0.84813-0.632851-0.810018-0.815930-0.8372-0.811075-0.628424-0.788946-0.825040
LinearRegressionNaNNaNLGBMRegressor5.0-1.00.4274250.046298-0.835861-0.84813-0.633093-0.813929-0.817702-0.8372-0.811075-0.628424-0.789648-0.816616
10.0-1.00.5747890.045061-0.835861-0.84813-0.633856-0.805584-0.817893-0.8372-0.811075-0.628424-0.779851-0.816616
LinearRegressionNaNNaN0.2526760.040404-0.835861-0.84813-0.633268-0.809481-0.818456-0.8372-0.811075-0.628424-0.788946-0.816616
\n", + "
" + ], + "text/plain": [ + " fit_time \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 0.468590 \n", + " 10.0 -1.0 0.626367 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 0.532225 \n", + " 10.0 -1.0 0.691459 \n", + " 1.0 -1.0 LinearRegression NaN NaN 0.293462 \n", + " 3.0 -1.0 LinearRegression NaN NaN 0.366940 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 0.427425 \n", + " 10.0 -1.0 0.574789 \n", + " LinearRegression NaN NaN 0.252676 \n", + "\n", + " score_time \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 0.046710 \n", + " 10.0 -1.0 0.046410 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 0.047818 \n", + " 10.0 -1.0 0.046533 \n", + " 1.0 -1.0 LinearRegression NaN NaN 0.043040 \n", + " 3.0 -1.0 LinearRegression NaN NaN 0.045290 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 0.046298 \n", + " 10.0 -1.0 0.045061 \n", + " LinearRegression NaN NaN 0.040404 \n", + "\n", + " train_variant_outcome_model_0_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.835861 \n", + " 10.0 -1.0 -0.835861 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.835861 \n", + " 10.0 -1.0 -0.835861 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.835861 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.835861 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.835861 \n", + " 10.0 -1.0 -0.835861 \n", + " LinearRegression NaN NaN -0.835861 \n", + "\n", + " train_variant_outcome_model_1_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.84813 \n", + " 10.0 -1.0 -0.84813 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.84813 \n", + " 10.0 -1.0 -0.84813 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.84813 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.84813 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.84813 \n", + " 10.0 -1.0 -0.84813 \n", + " LinearRegression NaN NaN -0.84813 \n", + "\n", + " train_propensity_model_neg_log_loss \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.631601 \n", + " 10.0 -1.0 -0.631980 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.633323 \n", + " 10.0 -1.0 -0.633088 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.633597 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.632851 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.633093 \n", + " 10.0 -1.0 -0.633856 \n", + " LinearRegression NaN NaN -0.633268 \n", + "\n", + " train_treatment_effect_model_1_vs_0_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.814415 \n", + " 10.0 -1.0 -0.803432 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.813032 \n", + " 10.0 -1.0 -0.803686 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.808634 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.810018 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.813929 \n", + " 10.0 -1.0 -0.805584 \n", + " LinearRegression NaN NaN -0.809481 \n", + "\n", + " train_control_effect_model_1_vs_0_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.824946 \n", + " 10.0 -1.0 -0.825162 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.816535 \n", + " 10.0 -1.0 -0.816648 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.824883 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.815930 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.817702 \n", + " 10.0 -1.0 -0.817893 \n", + " LinearRegression NaN NaN -0.818456 \n", + "\n", + " test_variant_outcome_model_0_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.8372 \n", + " 10.0 -1.0 -0.8372 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.8372 \n", + " 10.0 -1.0 -0.8372 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.8372 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.8372 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.8372 \n", + " 10.0 -1.0 -0.8372 \n", + " LinearRegression NaN NaN -0.8372 \n", + "\n", + " test_variant_outcome_model_1_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.811075 \n", + " 10.0 -1.0 -0.811075 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.811075 \n", + " 10.0 -1.0 -0.811075 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.811075 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.811075 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.811075 \n", + " 10.0 -1.0 -0.811075 \n", + " LinearRegression NaN NaN -0.811075 \n", + "\n", + " test_propensity_model_neg_log_loss \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.628424 \n", + " 10.0 -1.0 -0.628424 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.628424 \n", + " 10.0 -1.0 -0.628424 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.628424 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.628424 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.628424 \n", + " 10.0 -1.0 -0.628424 \n", + " LinearRegression NaN NaN -0.628424 \n", + "\n", + " test_treatment_effect_model_1_vs_0_neg_root_mean_squared_error \\\n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.789648 \n", + " 10.0 -1.0 -0.779851 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.789648 \n", + " 10.0 -1.0 -0.779851 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.788946 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.788946 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.789648 \n", + " 10.0 -1.0 -0.779851 \n", + " LinearRegression NaN NaN -0.788946 \n", + "\n", + " test_control_effect_model_1_vs_0_neg_root_mean_squared_error \n", + "metalearner propensity_model propensity_model_n_estimators propensity_model_verbose control_effect_model control_effect_model_n_estimators control_effect_model_verbose treatment_effect_model treatment_effect_model_n_estimators treatment_effect_model_verbose \n", + "XLearner LGBMClassifier 5 -1 LGBMRegressor 1.0 -1.0 LGBMRegressor 5.0 -1.0 -0.833076 \n", + " 10.0 -1.0 -0.833076 \n", + " 3.0 -1.0 LGBMRegressor 5.0 -1.0 -0.825040 \n", + " 10.0 -1.0 -0.825040 \n", + " 1.0 -1.0 LinearRegression NaN NaN -0.833076 \n", + " 3.0 -1.0 LinearRegression NaN NaN -0.825040 \n", + " LinearRegression NaN NaN LGBMRegressor 5.0 -1.0 -0.816616 \n", + " 10.0 -1.0 -0.816616 \n", + " LinearRegression NaN NaN -0.816616 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from metalearners import TLearner, XLearner\n", "\n", @@ -361,7 +1590,16 @@ ], "metadata": { "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" }, "mystnb": { "execution_timeout": 60 diff --git a/docs/examples/example_lime.ipynb b/docs/examples/example_lime.ipynb index e967df83..0459bf8f 100644 --- a/docs/examples/example_lime.ipynb +++ b/docs/examples/example_lime.ipynb @@ -54,7 +54,7 @@ "* {math}`f`, the original model -- in our case the MetaLearner\n", "* {math}`G`, the class of possible, interpretable surrogate models\n", "* {math}`\\Omega(g)`, a measure of complexity for {math}`g \\in G`\n", - "* {math}`\\pi_x(z)` a proximity measure of an instance {math}`z` with respect to data point {math}`x`\n", + "* $\\pi_x(z)$ a proximity measure of an instance {math}`z` with respect to data point {math}`x`\n", "* {math}`\\mathcal{L}(f, g, \\pi_x)` a measure of how unfaithful a {math}`g \\in G` is to {math}`f` in the locality defined by {math}`\\pi_x`\n", "\n", "Given all of these objects as well as a to be explained data point {math}`x`, the authors suggest that the most appropriate surrogate {math}`g`, also referred to as explanation for {math}`x`, {math}`\\xi(x)`, can be expressed as follows:\n", diff --git a/docs/examples/example_propensity.ipynb b/docs/examples/example_propensity.ipynb index 7fc280f8..ee698258 100644 --- a/docs/examples/example_propensity.ipynb +++ b/docs/examples/example_propensity.ipynb @@ -84,10 +84,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "```{note}\n", - "The fact that we have a fixed propensity score for all observations is not true for this\n", - "dataset, we just use it for illustrational purposes.\n", - "```\n", + "\n", + "
\n", + "

Note

\n", + "

The fact that we have a fixed propensity score for all observations is not true for this dataset, we just use it for illustrational purposes.\n", + "

\n", + "
\n", + "\n", + "\n", "\n", "Now we can use a custom ``sklearn``-like classifier: {class}`~metalearners.utils.FixedBinaryPropensity`.\n", "The latter can be used like any ``sklearn`` classifier but will always return the same propensity,\n", diff --git a/docs/glossary.md b/docs/glossary.md index 9bb76e49..50e70f94 100644 --- a/docs/glossary.md +++ b/docs/glossary.md @@ -12,7 +12,7 @@ $\tau(X) = \mathbb{E}[Y(1) - Y(0)|X]$ in the binary case and $\tau_{i,j}(X) = \m $\mathbb{E}[Y_i(w) | X]$ for each treatment variant $w$. -##### Covariates +#### Covariates The features $X$ based on which a CATE is estimated. diff --git a/docs/motivation.md b/docs/motivation.md index 5e131805..ca1fdf0b 100644 --- a/docs/motivation.md +++ b/docs/motivation.md @@ -33,7 +33,8 @@ While MetaLearners are, in principle, designed in a very modular fashion, we've One reason to access the base models is to evaluate their individual performance. Due to the fundamental problem of Causal Inference, we are not able to evaluate a MetaLearner based on a simple metric measuring the mismatch between estimate and ground truth. Yet, we might want to do this for our base learners which often do have ground truth labels to compare the estimates to. Yet, this is not supported by `econml` and `causalml`. -![Component Evaluation](imgs/component_eval.drawio.svg) +![Component Evaluation](imgs/component_eval.drawio.svg#only-light) +![Component Evaluation](imgs/heterogeneity.svg#only-dark) In the illustration above, we indicate that we'd like to access, predict with, and evaluate a propensity model -- one base model of the MetaLearner at hand -- in isolation. diff --git a/metalearners/rlearner.py b/metalearners/rlearner.py index 95cc237e..95ca5440 100644 --- a/metalearners/rlearner.py +++ b/metalearners/rlearner.py @@ -93,27 +93,27 @@ class RLearner(MetaLearner): The R-Learner contains two nuisance models - * a ``"propensity_model"`` estimating :math:`\Pr[W=k|X]` - * an ``"outcome_model"`` estimating :math:`\mathbb{E}[Y|X]` + * a `"propensity_model"` estimating :math:`\Pr[W=k|X]` + * an `"outcome_model"` estimating :math:`\mathbb{E}[Y|X]` and one treatment model per treatment variant which isn't control - * ``"treatment_model"`` which estimates :math:`\mathbb{E}[Y(k) - Y(0) | X]` + * `"treatment_model"` which estimates :math:`\mathbb{E}[Y(k) - Y(0) | X]` - The ``treatment_model_factory`` provided needs to support the argument - ``sample_weight`` in its ``fit`` method. + The `treatment_model_factory` provided needs to support the argument + `sample_weight` in its `fit` method. """ def _validate_models(self) -> None: """Validate that the base models are appropriate. - In particular, it is validated that a base model to be used with ``"predict"`` is - recognized by ``scikit-learn`` as a regressor via ``sklearn.base.is_regressor`` and - a model to be used with ``"predict_proba"`` is recognized by ``scikit-learn` as - a classifier via ``sklearn.base.is_classifier``. + In particular, it is validated that a base model to be used with `"predict"` is + recognized by `scikit-learn` as a regressor via `sklearn.base.is_regressor` and + a model to be used with `"predict_proba"` is recognized by `scikit-learn` as + a classifier via `sklearn.base.is_classifier`. Additionally, this method ensures that the treatment model "treatment_model" supports - the ``"sample_weight"`` argument in its ``fit`` method. + the `"sample_weight"` argument in its `fit` method. """ if not function_has_argument( self.treatment_model_factory[TREATMENT_MODEL].fit, _SAMPLE_WEIGHT @@ -346,7 +346,6 @@ def predict( tau_hat[variant_indices, treatment_variant - 1] = variant_estimates return tau_hat - @copydoc(MetaLearner.evaluate, sep="\n\t") def evaluate( self, X: Matrix, @@ -356,10 +355,10 @@ def evaluate( oos_method: OosMethod = OVERALL, scoring: Scoring | None = None, ) -> dict[str, float]: - """In the RLearner case, the ``"treatment_model"`` is always evaluated with the + """In the RLearner case, the `"treatment_model"` is always evaluated with the :func:`~metalearners.rlearner.r_loss` besides the scorers in - ``scoring["treatment_model"]``, which should support passing the - ``sample_weight`` keyword argument.""" + `scoring["treatment_model"]`, which should support passing the `sample_weight` + keyword argument.""" safe_scoring = self._scoring(scoring) propensity_evaluation = _evaluate_model_kind( @@ -480,11 +479,11 @@ def _pseudo_outcome_and_weights( ) -> tuple[np.ndarray, np.ndarray]: """Compute the R-Learner pseudo outcome and corresponding weights. - If ``mask`` is provided, the retuned pseudo outcomes and weights are only + If `mask` is provided, the retuned pseudo outcomes and weights are only with respect the observations that the mask selects. Since the pseudo outcome is a fraction of residuals, we add a small - constant ``epsilon`` to the denominator in order to avoid numerical problems. + constant `epsilon` to the denominator in order to avoid numerical problems. """ if mask is None: mask = np.ones(safe_len(X), dtype=bool) @@ -537,29 +536,12 @@ def _necessary_onnx_models(self) -> dict[str, list[_ScikitModel]]: def predict_conditional_average_outcomes( self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL ) -> np.ndarray: - r"""Predict the vectors of conditional average outcomes. + r"""The conditional average outcomes are estimated as follows: - These are defined as :math:`\mathbb{E}[Y_i(w) | X]` for each treatment variant - :math:`w`. + * $Y_i(0) = \hat{\mu}(X_i) - \sum_{k=1}^{K} \hat{e}_k(X_i) \hat{\tau_k}(X_i)$ + * $Y_i(k) = Y_i(0) + \hat{\tau_k}(X_i)$ for $k \in \{1, \dots, K\}$ - If ``is_oos``, an acronym for 'is out of sample' is ``False``, - the estimates will stem from cross-fitting. Otherwise, - various approaches exist, specified via ``oos_method``. - - The returned ndarray is of shape: - - * :math:`(n_{obs}, n_{variants}, 1)` if the outcome is a scalar, i.e. in case - of a regression problem. - - * :math:`(n_{obs}, n_{variants}, n_{classes})` if the outcome is a class, - i.e. in case of a classification problem. - - The conditional average outcomes are estimated as follows: - - * :math:`Y_i(0) = \hat{\mu}(X_i) - \sum_{k=1}^{K} \hat{e}_k(X_i) \hat{\tau_k}(X_i)` - * :math:`Y_i(k) = Y_i(0) + \hat{\tau_k}(X_i)` for :math:`k \in \{1, \dots, K\}` - - where :math:`K` is the number of treatment variants. + where $K$ is the number of treatment variants. """ n_obs = safe_len(X) @@ -621,7 +603,7 @@ def predict_conditional_average_outcomes( def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): """In the RLearner case, the necessary models are: - * ``"treatment_model"`` + * `"treatment_model"` """ warning_experimental_feature("_build_onnx") check_spox_installed() diff --git a/mkdocs.yml b/mkdocs.yml index f1e9ad45..5920277e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -15,14 +15,12 @@ theme: toggle: icon: material/brightness-7 name: Switch to dark mode - primary: deep purple # Palette toggle for dark mode - media: "(prefers-color-scheme: dark)" scheme: slate toggle: icon: material/brightness-4 name: Switch to system preference - primary: deep purple features: - content.action.edit - search.suggest @@ -57,10 +55,13 @@ plugins: extensions: - griffe_inherited_docstrings: merge: true - unwrap_annotated: true + allow_inspection: true + # unwrap_annotated: true + show_signature_annotations: true + show_signature: true show_symbol_type_heading: true - docstring_style: numpy - docstring_section_style: spacy + docstring_style: "google" # null for no tables + docstring_section_style: "table" separate_signature: true merge_init_into_class: true show_submodules: true # show *all* code docu @@ -95,6 +96,7 @@ nav: markdown_extensions: - admonition + - pymdownx.details - pymdownx.highlight - pymdownx.superfences - pymdownx.inlinehilite diff --git a/pixi.lock b/pixi.lock index 59bdce70..42e8d94a 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2443,6 +2443,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-gamma-0.5.0-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/babel-2.14.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/black-24.10.0-py311h38be061_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/bleach-6.1.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/blosc-1.21.6-hef167b5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hd590300_1.conda @@ -2790,6 +2791,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-gamma-0.5.0-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/babel-2.14.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/black-24.10.0-py311h6eed73b_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/bleach-6.1.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/blosc-1.21.6-h7d75f6d_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/brotli-1.1.0-h0dc2134_1.conda @@ -3122,6 +3124,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-gamma-0.5.0-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/babel-2.14.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/black-24.10.0-py311h267d04e_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/bleach-6.1.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/blosc-1.21.6-h5499902_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-1.1.0-hb547adb_1.conda @@ -3453,6 +3456,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-gamma-0.5.0-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/babel-2.14.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/win-64/black-24.10.0-py311h1ea47a8_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/bleach-6.1.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/blosc-1.21.6-h85f69ea_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/brotli-1.1.0-hcfcfb64_1.conda @@ -7383,6 +7387,87 @@ packages: license_family: MIT size: 391969 timestamp: 1714119854151 +- kind: conda + name: black + version: 24.10.0 + build: py311h1ea47a8_0 + subdir: win-64 + url: https://conda.anaconda.org/conda-forge/win-64/black-24.10.0-py311h1ea47a8_0.conda + sha256: 8f818b565127c2846b331addd64ad9e7db59f3b413126a4eba5c08e50eca8d13 + md5: 22b6302f1fd44b83c5d659b40dc4fedc + depends: + - click >=8.0.0 + - mypy_extensions >=0.4.3 + - packaging >=22.0 + - pathspec >=0.9 + - platformdirs >=2 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: MIT + license_family: MIT + size: 423352 + timestamp: 1728504147411 +- kind: conda + name: black + version: 24.10.0 + build: py311h267d04e_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/black-24.10.0-py311h267d04e_0.conda + sha256: 70c41af42699e765acb5a027740b97300bf696af22d0406dd26d66cd1aa7959f + md5: 2780798c556604ad91ddeb01e2e3f2ea + depends: + - click >=8.0.0 + - mypy_extensions >=0.4.3 + - packaging >=22.0 + - pathspec >=0.9 + - platformdirs >=2 + - python >=3.11,<3.12.0a0 + - python >=3.11,<3.12.0a0 *_cpython + - python_abi 3.11.* *_cp311 + license: MIT + license_family: MIT + size: 399611 + timestamp: 1728503954546 +- kind: conda + name: black + version: 24.10.0 + build: py311h38be061_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/black-24.10.0-py311h38be061_0.conda + sha256: fad9bf51c6574af15900a4e1226f45d315fe79a5424f9fc97db35e3f74927e70 + md5: fceab39cec6b08375a81d1406e30ea1e + depends: + - click >=8.0.0 + - mypy_extensions >=0.4.3 + - packaging >=22.0 + - pathspec >=0.9 + - platformdirs >=2 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: MIT + license_family: MIT + size: 398766 + timestamp: 1728503879231 +- kind: conda + name: black + version: 24.10.0 + build: py311h6eed73b_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/black-24.10.0-py311h6eed73b_0.conda + sha256: d242a4f2faeca61e7cc323b8c61b65b08192bc660d80a41e5d8dab000abfd078 + md5: d3f6c057e93e39d9df94ffa3c4693333 + depends: + - click >=8.0.0 + - mypy_extensions >=0.4.3 + - packaging >=22.0 + - pathspec >=0.9 + - platformdirs >=2 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: MIT + license_family: MIT + size: 399950 + timestamp: 1728503926364 - kind: conda name: bleach version: 6.1.0 diff --git a/pixi.toml b/pixi.toml index 6bd693f0..a8e73654 100644 --- a/pixi.toml +++ b/pixi.toml @@ -74,6 +74,7 @@ mkdocs-material = ">=9.5.31,<10" mkdocstrings = ">=0.25.2" mkdocstrings-python = ">=1.12,<2" griffe-inherited-docstrings = ">=1.1.1,<2" +black = ">=24.10.0,<25" # enables formatted (e.g. line-wrapped) function signatures [feature.docs.tasks] # postinstall task needs to be executed in 'docs' environment beforehand to resolve API references