Skip to content

Commit

Permalink
Fix order
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 10, 2024
1 parent 5a08a60 commit f79658a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions metalearners/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame:
for result in results:
row: dict[str, str | int | float] = {}
row["metalearner"] = result.metalearner.__class__.__name__
nuisance_models = (
nuisance_models = sorted(
set(result.metalearner.nuisance_model_specifications().keys())
- result.metalearner._prefitted_nuisance_models
)
treatment_models = set(
result.metalearner.treatment_model_specifications().keys()
treatment_models = sorted(
set(result.metalearner.treatment_model_specifications().keys())
)
for model_kind in nuisance_models:
row[model_kind] = result.metalearner.nuisance_model_factory[
Expand Down
4 changes: 2 additions & 2 deletions tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@
6,
[
"metalearner",
"variant_outcome_model", # The order of the nuisance models depends on the order of nuisance_model_specifications()
"propensity_model",
"propensity_model_n_estimators",
"variant_outcome_model",
"control_effect_model",
"control_effect_model_n_estimators",
"treatment_effect_model",
Expand All @@ -90,9 +90,9 @@
9,
[
"metalearner",
"outcome_model",
"propensity_model",
"propensity_model_n_estimators",
"outcome_model",
"treatment_model",
"treatment_model_learning_rate",
"treatment_model_n_estimators",
Expand Down

0 comments on commit f79658a

Please sign in to comment.