From f79658a5bbff3e894ccf151719596e80abc4b6fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Wed, 10 Jul 2024 09:53:10 +0200 Subject: [PATCH] Fix order --- metalearners/grid_search.py | 6 +++--- tests/test_grid_search.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/metalearners/grid_search.py b/metalearners/grid_search.py index cfbc480c..cc9c732e 100644 --- a/metalearners/grid_search.py +++ b/metalearners/grid_search.py @@ -83,12 +83,12 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame: for result in results: row: dict[str, str | int | float] = {} row["metalearner"] = result.metalearner.__class__.__name__ - nuisance_models = ( + nuisance_models = sorted( set(result.metalearner.nuisance_model_specifications().keys()) - result.metalearner._prefitted_nuisance_models ) - treatment_models = set( - result.metalearner.treatment_model_specifications().keys() + treatment_models = sorted( + set(result.metalearner.treatment_model_specifications().keys()) ) for model_kind in nuisance_models: row[model_kind] = result.metalearner.nuisance_model_factory[ diff --git a/tests/test_grid_search.py b/tests/test_grid_search.py index 387775c8..e29d3d3b 100644 --- a/tests/test_grid_search.py +++ b/tests/test_grid_search.py @@ -64,9 +64,9 @@ 6, [ "metalearner", - "variant_outcome_model", # The order of the nuisance models depends on the order of nuisance_model_specifications() "propensity_model", "propensity_model_n_estimators", + "variant_outcome_model", "control_effect_model", "control_effect_model_n_estimators", "treatment_effect_model", @@ -90,9 +90,9 @@ 9, [ "metalearner", + "outcome_model", "propensity_model", "propensity_model_n_estimators", - "outcome_model", "treatment_model", "treatment_model_learning_rate", "treatment_model_n_estimators",