diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index 4f01943d..a65087cd 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -1125,7 +1125,7 @@ def test_validate_outcome_different_classes(implementation, use_pandas, rng): "implementation", [TLearner, SLearner, XLearner, RLearner, DRLearner], ) -def test_init_args_smoke(implementation): +def test_init_args(implementation): ml = implementation( True, 2, @@ -1133,4 +1133,8 @@ def test_init_args_smoke(implementation): LinearRegression, LogisticRegression, ) - implementation(**ml.init_args) + ml2 = implementation(**ml.init_args) + + assert set(ml.__dict__.keys()) == set(ml2.__dict__.keys()) + for key in ml.__dict__: + assert ml.__dict__[key] == ml2.__dict__[key]