Skip to content

Commit

Permalink
Fix logic revolving around pre-fitted models.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Jul 24, 2024
1 parent 46dd4ea commit 2d775b5
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,29 +1126,43 @@ def _default_scoring() -> Scoring:
@property
def init_args(self) -> dict[str, Any]:
"""Create initiliazation parameters for a new MetaLearner."""
if self._prefitted_nuisance_models:
raise ValueError(
"Cannot recreate MetaLearner if a pre-fitted model was used."
)
return {
"is_classification": self.is_classification,
"n_variants": self.n_variants,
"nuisance_model_factory": {
k: v
for k, v in self.nuisance_model_factory.items()
if k != PROPENSITY_MODEL
if k not in self._prefitted_nuisance_models
},
"treatment_model_factory": self.treatment_model_factory,
"propensity_model_factory": self.nuisance_model_factory.get(
PROPENSITY_MODEL
"propensity_model_factory": (
self.nuisance_model_factory.get(PROPENSITY_MODEL)
if PROPENSITY_MODEL not in self._prefitted_nuisance_models
else None
),
"nuisance_model_params": {
k: v
for k, v in self.nuisance_model_params.items()
if k != PROPENSITY_MODEL
if k not in self._prefitted_nuisance_models
},
"treatment_model_params": self.treatment_model_params,
"propensity_model_params": self.nuisance_model_params.get(PROPENSITY_MODEL),
"propensity_model_params": (
self.nuisance_model_params.get(PROPENSITY_MODEL)
if PROPENSITY_MODEL not in self._prefitted_nuisance_models
else None
),
"fitted_nuisance_models": {
k: v
for k, v in self._nuisance_models.items()
if k in self._prefitted_nuisance_models and k != PROPENSITY_MODEL
},
"fitted_propensity_model": (
self._nuisance_models.get(PROPENSITY_MODEL)
if PROPENSITY_MODEL in self._prefitted_nuisance_models
else None
),
"feature_set": self.feature_set,
"n_folds": self.n_folds,
"random_state": self.random_state,
Expand Down

0 comments on commit 2d775b5

Please sign in to comment.