diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d2f6e00..354540c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,13 @@ Changelog ========= +0.6.1 (2024-07-xx) +------------------ + +**Other changes** + +* Changed the index columns order in ``MetaLearnerGridSearch.results_``. + 0.6.0 (2024-07-08) ------------------ diff --git a/metalearners/grid_search.py b/metalearners/grid_search.py index 1d85f63..cc9c732 100644 --- a/metalearners/grid_search.py +++ b/metalearners/grid_search.py @@ -83,12 +83,12 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame: for result in results: row: dict[str, str | int | float] = {} row["metalearner"] = result.metalearner.__class__.__name__ - nuisance_models = ( + nuisance_models = sorted( set(result.metalearner.nuisance_model_specifications().keys()) - result.metalearner._prefitted_nuisance_models ) - treatment_models = set( - result.metalearner.treatment_model_specifications().keys() + treatment_models = sorted( + set(result.metalearner.treatment_model_specifications().keys()) ) for model_kind in nuisance_models: row[model_kind] = result.metalearner.nuisance_model_factory[ @@ -115,13 +115,16 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame: row[f"test_{name}"] = value rows.append(row) df = pd.DataFrame(rows) - index_columns = [ - c - for c in df.columns - if not c.endswith("_time") - and not c.startswith("train_") - and not c.startswith("test_") - ] + sorted_cols = sorted(df.columns) + index_columns = ["metalearner"] + for model_kind in nuisance_models: + for c in sorted_cols: + if c.startswith(model_kind): + index_columns.append(c) + for model_kind in treatment_models: + for c in sorted_cols: + if c.startswith(model_kind): + index_columns.append(c) df = df.set_index(index_columns) return df diff --git a/tests/test_grid_search.py b/tests/test_grid_search.py index fd953ff..e29d3d3 100644 --- a/tests/test_grid_search.py +++ b/tests/test_grid_search.py @@ -25,7 +25,7 @@ {"base_model": [LinearRegression, LGBMRegressor]}, {"base_model": {"LGBMRegressor": {"n_estimators": [1, 2]}}}, 3, - 3, + ["metalearner", "base_model", "base_model_n_estimators"], ), ( SLearner, @@ -33,7 +33,7 @@ {"base_model": [LogisticRegression, LGBMClassifier]}, {"base_model": {"LGBMClassifier": {"n_estimators": [1, 2]}}}, 3, - 3, + ["metalearner", "base_model", "base_model_n_estimators"], ), ( TLearner, @@ -41,7 +41,11 @@ {"variant_outcome_model": [LinearRegression, LGBMRegressor]}, {"variant_outcome_model": {"LGBMRegressor": {"n_estimators": [1, 2, 3]}}}, 4, - 3, + [ + "metalearner", + "variant_outcome_model", + "variant_outcome_model_n_estimators", + ], ), ( XLearner, @@ -58,7 +62,16 @@ "treatment_effect_model": {"LGBMRegressor": {"n_estimators": [1]}}, }, 6, - 8, + [ + "metalearner", + "propensity_model", + "propensity_model_n_estimators", + "variant_outcome_model", + "control_effect_model", + "control_effect_model_n_estimators", + "treatment_effect_model", + "treatment_effect_model_n_estimators", + ], ), ( RLearner, @@ -75,7 +88,15 @@ }, }, 9, - 7, + [ + "metalearner", + "outcome_model", + "propensity_model", + "propensity_model_n_estimators", + "treatment_model", + "treatment_model_learning_rate", + "treatment_model_n_estimators", + ], ), ( DRLearner, @@ -89,7 +110,13 @@ "propensity_model": {"LGBMClassifier": {"n_estimators": [1, 2, 3, 4]}}, }, 4, - 5, + [ + "metalearner", + "propensity_model", + "propensity_model_n_estimators", + "variant_outcome_model", + "treatment_model", + ], ), ], ) @@ -125,7 +152,7 @@ def test_metalearnergridsearch_smoke( gs.fit(X, y, w, X_test, y_test, w_test) assert gs.results_ is not None assert gs.results_.shape[0] == expected_n_configs - assert len(gs.results_.index.names) == expected_index_cols + assert gs.results_.index.names == expected_index_cols train_scores_cols = set( c[6:] for c in list(gs.results_.columns) if c.startswith("train_")