diff --git a/metalearners/_utils.py b/metalearners/_utils.py index 00eecfb..0aca691 100644 --- a/metalearners/_utils.py +++ b/metalearners/_utils.py @@ -1,7 +1,6 @@ # # Copyright (c) QuantCo 2024-2024 # # SPDX-License-Identifier: BSD-3-Clause -import operator from collections.abc import Callable from inspect import signature from operator import le, lt @@ -66,14 +65,22 @@ def validate_all_vectors_same_index(*args: Vector) -> None: def validate_number_positive( - value: int | float, name: str, strict: bool = False + value: int | float, name: str, strict: bool = True ) -> None: + """Validates that a number is positive. + + If ``strict = True`` then it validates that the number is strictly positive. + """ if strict: - comparison = operator.lt + if value <= 0: + raise ValueError( + f"{name} was expected to be strictly positive but was {value}." + ) else: - comparison = operator.le - if comparison(value, 0): - raise ValueError(f"{name} was expected to be positive but was {value}.") + if value < 0: + raise ValueError( + f"{name} was expected to be positive or zero but was {value}." + ) def check_propensity_score( diff --git a/metalearners/cross_fit_estimator.py b/metalearners/cross_fit_estimator.py index e26d898..9765aa7 100644 --- a/metalearners/cross_fit_estimator.py +++ b/metalearners/cross_fit_estimator.py @@ -56,7 +56,7 @@ def _validate_data_match_prior_split( ) -> None: """Validate whether the previous test_indices and the passed data are based on the same number of observations.""" - validate_number_positive(n_observations, "n_observations", strict=False) + validate_number_positive(n_observations, "n_observations", strict=True) if test_indices is None: return expected_n_observations = sum(len(x) for x in test_indices) diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index 1d24d89..ea9b2f1 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: BSD-3-Clause import numpy as np +from joblib import Parallel, delayed from typing_extensions import Self from metalearners._typing import Matrix, OosMethod, Vector @@ -22,7 +23,9 @@ VARIANT_OUTCOME_MODEL, MetaLearner, _ConditionalAverageOutcomeMetaLearner, + _fit_cross_fit_estimator_joblib, _ModelSpecifications, + _ParallelJoblibSpecification, ) _EPSILON = 1e-09 @@ -102,27 +105,43 @@ def fit( else: cv_split_indices = None - # TODO: Consider multiprocessing + nuisance_jobs: list[_ParallelJoblibSpecification | None] = [] for treatment_variant in range(self.n_variants): - self.fit_nuisance( - X=index_matrix(X, self._treatment_variants_indices[treatment_variant]), - y=y[self._treatment_variants_indices[treatment_variant]], - model_kind=VARIANT_OUTCOME_MODEL, - model_ord=treatment_variant, + nuisance_jobs.append( + self._nuisance_joblib_specifications( + X=index_matrix( + X, self._treatment_variants_indices[treatment_variant] + ), + y=y[self._treatment_variants_indices[treatment_variant]], + model_kind=VARIANT_OUTCOME_MODEL, + model_ord=treatment_variant, + n_jobs_cross_fitting=n_jobs_cross_fitting, + fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL], + ) + ) + + nuisance_jobs.append( + self._nuisance_joblib_specifications( + X=X, + y=w, + model_kind=PROPENSITY_MODEL, + model_ord=0, n_jobs_cross_fitting=n_jobs_cross_fitting, - fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL], + fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL], + cv=cv_split_indices, ) + ) - self.fit_nuisance( - X=X, - y=w, - model_kind=PROPENSITY_MODEL, - model_ord=0, - n_jobs_cross_fitting=n_jobs_cross_fitting, - fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL], - cv=cv_split_indices, + parallel = Parallel(n_jobs=n_jobs_base_learners) + results = parallel( + delayed(_fit_cross_fit_estimator_joblib)(job) + for job in nuisance_jobs + if job is not None ) + self._assign_joblib_nuisance_results(results) + + treatment_jobs: list[_ParallelJoblibSpecification] = [] for treatment_variant in range(1, self.n_variants): pseudo_outcomes = self._pseudo_outcome( X=X, @@ -131,15 +150,21 @@ def fit( treatment_variant=treatment_variant, ) - self.fit_treatment( - X=X, - y=pseudo_outcomes, - model_kind=TREATMENT_MODEL, - model_ord=treatment_variant - 1, - n_jobs_cross_fitting=n_jobs_cross_fitting, - fit_params=qualified_fit_params[TREATMENT][TREATMENT_MODEL], - cv=cv_split_indices, + treatment_jobs.append( + self._treatment_joblib_specifications( + X=X, + y=pseudo_outcomes, + model_kind=TREATMENT_MODEL, + model_ord=treatment_variant - 1, + n_jobs_cross_fitting=n_jobs_cross_fitting, + fit_params=qualified_fit_params[TREATMENT][TREATMENT_MODEL], + cv=cv_split_indices, + ) ) + results = parallel( + delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs + ) + self._assign_joblib_treatment_results(results) return self def predict( diff --git a/metalearners/rlearner.py b/metalearners/rlearner.py index bf39caa..6a09847 100644 --- a/metalearners/rlearner.py +++ b/metalearners/rlearner.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: BSD-3-Clause import numpy as np +from joblib import Parallel, delayed from sklearn.metrics import log_loss, root_mean_squared_error from typing_extensions import Self @@ -23,7 +24,9 @@ TREATMENT, TREATMENT_MODEL, MetaLearner, + _fit_cross_fit_estimator_joblib, _ModelSpecifications, + _ParallelJoblibSpecification, ) OUTCOME_MODEL = "outcome_model" @@ -175,25 +178,40 @@ def fit( else: cv_split_indices = None - self.fit_nuisance( - X=X, - y=w, - model_kind=PROPENSITY_MODEL, - model_ord=0, - n_jobs_cross_fitting=n_jobs_cross_fitting, - fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL], - cv=cv_split_indices, + nuisance_jobs: list[_ParallelJoblibSpecification | None] = [] + + nuisance_jobs.append( + self._nuisance_joblib_specifications( + X=X, + y=w, + model_kind=PROPENSITY_MODEL, + model_ord=0, + n_jobs_cross_fitting=n_jobs_cross_fitting, + fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL], + cv=cv_split_indices, + ) ) - self.fit_nuisance( - X=X, - y=y, - model_kind=OUTCOME_MODEL, - model_ord=0, - n_jobs_cross_fitting=n_jobs_cross_fitting, - fit_params=qualified_fit_params[NUISANCE][OUTCOME_MODEL], - cv=cv_split_indices, + nuisance_jobs.append( + self._nuisance_joblib_specifications( + X=X, + y=y, + model_kind=OUTCOME_MODEL, + model_ord=0, + n_jobs_cross_fitting=n_jobs_cross_fitting, + fit_params=qualified_fit_params[NUISANCE][OUTCOME_MODEL], + cv=cv_split_indices, + ) + ) + + parallel = Parallel(n_jobs=n_jobs_base_learners) + results = parallel( + delayed(_fit_cross_fit_estimator_joblib)(job) + for job in nuisance_jobs + if job is not None ) + self._assign_joblib_nuisance_results(results) + treatment_jobs: list[_ParallelJoblibSpecification] = [] for treatment_variant in range(1, self.n_variants): is_treatment = w == treatment_variant @@ -213,15 +231,21 @@ def fit( X_filtered = index_matrix(X, mask) - self.fit_treatment( - X=X_filtered, - y=pseudo_outcomes, - model_kind=TREATMENT_MODEL, - model_ord=treatment_variant - 1, - fit_params=qualified_fit_params[TREATMENT][TREATMENT_MODEL] - | {_SAMPLE_WEIGHT: weights}, - n_jobs_cross_fitting=n_jobs_cross_fitting, + treatment_jobs.append( + self._treatment_joblib_specifications( + X=X_filtered, + y=pseudo_outcomes, + model_kind=TREATMENT_MODEL, + model_ord=treatment_variant - 1, + fit_params=qualified_fit_params[TREATMENT][TREATMENT_MODEL] + | {_SAMPLE_WEIGHT: weights}, + n_jobs_cross_fitting=n_jobs_cross_fitting, + ) ) + results = parallel( + delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs + ) + self._assign_joblib_treatment_results(results) return self def predict( diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index 64d059f..729899c 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: BSD-3-Clause import numpy as np +from joblib import Parallel, delayed from typing_extensions import Self from metalearners._typing import Matrix, OosMethod, Vector @@ -20,7 +21,9 @@ VARIANT_OUTCOME_MODEL, MetaLearner, _ConditionalAverageOutcomeMetaLearner, + _fit_cross_fit_estimator_joblib, _ModelSpecifications, + _ParallelJoblibSpecification, ) CONTROL_EFFECT_MODEL = "control_effect_model" @@ -98,51 +101,76 @@ def fit( cv_split_indices = None cvs.append(cv_split_indices) - # TODO: Consider multiprocessing + nuisance_jobs: list[_ParallelJoblibSpecification | None] = [] for treatment_variant in range(self.n_variants): - self.fit_nuisance( - X=index_matrix(X, self._treatment_variants_indices[treatment_variant]), - y=y[self._treatment_variants_indices[treatment_variant]], - model_kind=VARIANT_OUTCOME_MODEL, - model_ord=treatment_variant, + nuisance_jobs.append( + self._nuisance_joblib_specifications( + X=index_matrix( + X, self._treatment_variants_indices[treatment_variant] + ), + y=y[self._treatment_variants_indices[treatment_variant]], + model_kind=VARIANT_OUTCOME_MODEL, + model_ord=treatment_variant, + n_jobs_cross_fitting=n_jobs_cross_fitting, + fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL], + cv=cvs[treatment_variant], + ) + ) + + nuisance_jobs.append( + self._nuisance_joblib_specifications( + X=X, + y=w, + model_kind=PROPENSITY_MODEL, + model_ord=0, n_jobs_cross_fitting=n_jobs_cross_fitting, - fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL], - cv=cvs[treatment_variant], + fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL], ) + ) - self.fit_nuisance( - X=X, - y=w, - model_kind=PROPENSITY_MODEL, - model_ord=0, - n_jobs_cross_fitting=n_jobs_cross_fitting, - fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL], + parallel = Parallel(n_jobs=n_jobs_base_learners) + results = parallel( + delayed(_fit_cross_fit_estimator_joblib)(job) + for job in nuisance_jobs + if job is not None ) + self._assign_joblib_nuisance_results(results) + treatment_jobs: list[_ParallelJoblibSpecification] = [] for treatment_variant in range(1, self.n_variants): imputed_te_control, imputed_te_treatment = self._pseudo_outcome( X, y, w, treatment_variant ) - - self.fit_treatment( - X=index_matrix(X, self._treatment_variants_indices[treatment_variant]), - y=imputed_te_treatment, - model_kind=TREATMENT_EFFECT_MODEL, - model_ord=treatment_variant - 1, - n_jobs_cross_fitting=n_jobs_cross_fitting, - fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL], - cv=cvs[treatment_variant], + treatment_jobs.append( + self._treatment_joblib_specifications( + X=index_matrix( + X, self._treatment_variants_indices[treatment_variant] + ), + y=imputed_te_treatment, + model_kind=TREATMENT_EFFECT_MODEL, + model_ord=treatment_variant - 1, + n_jobs_cross_fitting=n_jobs_cross_fitting, + fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL], + cv=cvs[treatment_variant], + ) ) - self.fit_treatment( - X=index_matrix(X, self._treatment_variants_indices[0]), - y=imputed_te_control, - model_kind=CONTROL_EFFECT_MODEL, - model_ord=treatment_variant - 1, - n_jobs_cross_fitting=n_jobs_cross_fitting, - fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL], - cv=cvs[0], + + treatment_jobs.append( + self._treatment_joblib_specifications( + X=index_matrix(X, self._treatment_variants_indices[0]), + y=imputed_te_control, + model_kind=CONTROL_EFFECT_MODEL, + model_ord=treatment_variant - 1, + n_jobs_cross_fitting=n_jobs_cross_fitting, + fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL], + cv=cvs[0], + ) ) + results = parallel( + delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs + ) + self._assign_joblib_treatment_results(results) return self def predict( diff --git a/tests/conftest.py b/tests/conftest.py index 449ab24..862131b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,8 +75,9 @@ def mindset_data(): return load_mindset_data() -@pytest.fixture(scope="function") -def twins_data(rng): +@pytest.fixture(scope="session") +def twins_data(): + rng = np.random.default_rng(_SEED) ( chosen_df, outcome_column, @@ -94,28 +95,30 @@ def twins_data(rng): ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def n_numericals(): return 25 -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def n_categoricals(): return 5 -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sample_size(): return 100_000 -@pytest.fixture(scope="function") -def numerical_covariates(sample_size, n_numericals, rng): +@pytest.fixture(scope="session") +def numerical_covariates(sample_size, n_numericals): + rng = np.random.default_rng(_SEED) return generate_covariates(sample_size, n_numericals, format="numpy", rng=rng) -@pytest.fixture(scope="function") -def mixed_covariates(sample_size, n_numericals, n_categoricals, rng): +@pytest.fixture(scope="session") +def mixed_covariates(sample_size, n_numericals, n_categoricals): + rng = np.random.default_rng(_SEED) return generate_covariates( sample_size, n_numericals + n_categoricals, @@ -125,52 +128,72 @@ def mixed_covariates(sample_size, n_numericals, n_categoricals, rng): ) -@pytest.fixture(scope="function") +@pytest.fixture(scope="session") def numerical_experiment_dataset_continuous_outcome_binary_treatment_linear_te( - numerical_covariates, rng + sample_size, n_numericals ): - covariates, _, _ = numerical_covariates + rng = np.random.default_rng(_SEED) + covariates, _, _ = generate_covariates( + sample_size, n_numericals, format="numpy", rng=rng + ) return _generate_rct_experiment_data(covariates, False, rng, 0.3, None) -@pytest.fixture(scope="function") +@pytest.fixture(scope="session") def numerical_experiment_dataset_binary_outcome_binary_treatment_linear_te( - numerical_covariates, rng + sample_size, n_numericals ): - covariates, _, _ = numerical_covariates + rng = np.random.default_rng(_SEED) + covariates, _, _ = generate_covariates( + sample_size, n_numericals, format="numpy", rng=rng + ) return _generate_rct_experiment_data(covariates, True, rng, 0.3, None) -@pytest.fixture(scope="function") +@pytest.fixture(scope="session") def mixed_experiment_dataset_continuous_outcome_binary_treatment_linear_te( - mixed_covariates, rng + sample_size, n_numericals, n_categoricals ): - covariates, _, _ = mixed_covariates + rng = np.random.default_rng(_SEED) + covariates, _, _ = generate_covariates( + sample_size, + n_numericals + n_categoricals, + n_categoricals=n_categoricals, + format="pandas", + rng=rng, + ) return _generate_rct_experiment_data(covariates, False, rng, 0.3, None) -@pytest.fixture(scope="function") +@pytest.fixture(scope="session") def numerical_experiment_dataset_continuous_outcome_multi_treatment_linear_te( - numerical_covariates, rng + sample_size, n_numericals ): - covariates, _, _ = numerical_covariates + rng = np.random.default_rng(_SEED) + covariates, _, _ = generate_covariates( + sample_size, n_numericals, format="numpy", rng=rng + ) return _generate_rct_experiment_data( covariates, False, rng, [0.2, 0.1, 0.3, 0.15, 0.25], None ) -@pytest.fixture(scope="function") +@pytest.fixture(scope="session") def numerical_experiment_dataset_continuous_outcome_multi_treatment_constant_te( - numerical_covariates, rng + sample_size, n_numericals ): - covariates, _, _ = numerical_covariates + rng = np.random.default_rng(_SEED) + covariates, _, _ = generate_covariates( + sample_size, n_numericals, format="numpy", rng=rng + ) return _generate_rct_experiment_data( covariates, False, rng, [0.2, 0.1, 0.3, 0.15, 0.25], np.array([-2, 5, 0, 3]) ) -@pytest.fixture -def dummy_dataset(rng): +@pytest.fixture(scope="session") +def dummy_dataset(): + rng = np.random.default_rng(_SEED) sample_size = 100 n_features = 10 X = rng.standard_normal((sample_size, n_features)) @@ -179,8 +202,9 @@ def dummy_dataset(rng): return X, y, w -@pytest.fixture(scope="function") -def feature_importance_dataset(rng): +@pytest.fixture(scope="session") +def feature_importance_dataset(): + rng = np.random.default_rng(_SEED) n_samples = 10000 x0 = rng.normal(10, 1, n_samples) x1 = rng.normal(2, 1, n_samples) diff --git a/tests/test_cross_fit_estimator.py b/tests/test_cross_fit_estimator.py index 8e34b00..bb102c5 100644 --- a/tests/test_cross_fit_estimator.py +++ b/tests/test_cross_fit_estimator.py @@ -223,7 +223,9 @@ def test_crossfitestimator_n_folds_1(rng, sample_size): ) def test_validate_data_match(n_observations, test_indices, success): if n_observations < 1: - with pytest.raises(ValueError, match="was expected to be positive"): + with pytest.raises( + ValueError, match=r"was expected to be (strictly )?positive" + ): _validate_data_match_prior_split(n_observations, test_indices) return if success: diff --git a/tests/test_learner.py b/tests/test_learner.py index e76018e..30dce74 100644 --- a/tests/test_learner.py +++ b/tests/test_learner.py @@ -141,6 +141,7 @@ def test_learner_synthetic( observed_outcomes_train, treatment_train, synchronize_cross_fitting=True, + n_jobs_base_learners=-1, ) # In sample CATEs @@ -236,6 +237,7 @@ def test_learner_synthetic_oos_ate(metalearner, treatment_kind, request): observed_outcomes_train, treatment_train, synchronize_cross_fitting=True, + n_jobs_base_learners=-1, ) for oos_method in _OOS_WHITELIST: cate_estimates = learner.predict( @@ -312,9 +314,8 @@ def test_learner_twins(metalearner, reference_value, twins_data, rng): @pytest.mark.parametrize("n_classes", [2, 5, 10]) @pytest.mark.parametrize("n_variants", [2, 5]) @pytest.mark.parametrize("is_classification", [True, False]) -def test_learner_evaluate( - metalearner, is_classification, rng, sample_size, n_classes, n_variants -): +def test_learner_evaluate(metalearner, is_classification, rng, n_classes, n_variants): + sample_size = 1000 factory = metalearner_factory(metalearner) if n_variants > 2 and not factory._supports_multi_treatment(): pytest.skip() @@ -427,12 +428,14 @@ def test_x_t_conditional_average_outcomes(outcome_kind, is_oos, request): observed_outcomes_train, treatment_train, synchronize_cross_fitting=False, + n_jobs_base_learners=-1, ) xlearner.fit( covariates_train, observed_outcomes_train, treatment_train, synchronize_cross_fitting=False, + n_jobs_base_learners=-1, ) if not is_oos: @@ -617,8 +620,9 @@ def test_conditional_average_outcomes_smoke( @pytest.mark.parametrize("n_classes", [5, 10]) @pytest.mark.parametrize("n_variants", [2, 5]) def test_conditional_average_outcomes_smoke_multi_class( - metalearner_prefix, rng, sample_size, n_classes, n_variants + metalearner_prefix, rng, n_classes, n_variants ): + sample_size = 1000 factory = metalearner_factory(metalearner_prefix) X = rng.standard_normal((sample_size, 10)) @@ -648,8 +652,9 @@ def test_conditional_average_outcomes_smoke_multi_class( @pytest.mark.parametrize("n_variants", [2, 5]) @pytest.mark.parametrize("is_classification", [True, False]) def test_predict_smoke( - metalearner_prefix, is_classification, rng, sample_size, n_classes, n_variants + metalearner_prefix, is_classification, rng, n_classes, n_variants ): + sample_size = 1000 factory = metalearner_factory(metalearner_prefix) if n_variants > 2 and not factory._supports_multi_treatment(): pytest.skip() @@ -707,7 +712,7 @@ def test_model_reusage(outcome_kind, request): n_variants=len(np.unique(treatment)), nuisance_model_params=nuisance_learner_params, ) - tlearner.fit(covariates, observed_outcomes, treatment) + tlearner.fit(covariates, observed_outcomes, treatment, n_jobs_base_learners=-1) xlearner = XLearner( is_classification=is_classification, n_variants=len(np.unique(treatment)), @@ -731,7 +736,7 @@ def test_model_reusage(outcome_kind, request): tlearner_pred_before_refitting = tlearner.predict_conditional_average_outcomes( covariates, False ) - xlearner.fit(covariates, observed_outcomes, treatment) + xlearner.fit(covariates, observed_outcomes, treatment, n_jobs_base_learners=-1) np.testing.assert_allclose( tlearner.predict_conditional_average_outcomes(covariates, False), tlearner_pred_before_refitting, diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index 9de1af3..aff7b48 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -152,7 +152,7 @@ def test_metalearner_init( @pytest.mark.parametrize( "implementation", - [_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner], + [TLearner, SLearner, XLearner, RLearner, DRLearner], ) def test_metalearner_categorical( mixed_experiment_dataset_continuous_outcome_binary_treatment_linear_te, @@ -198,7 +198,7 @@ def test_metalearner_categorical( @pytest.mark.parametrize( "implementation", - [_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner], + [TLearner, SLearner, XLearner, RLearner, DRLearner], ) def test_metalearner_missing_data_smoke( mixed_experiment_dataset_continuous_outcome_binary_treatment_linear_te, @@ -227,7 +227,7 @@ def test_metalearner_missing_data_smoke( @pytest.mark.parametrize( "implementation", - [_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner], + [TLearner, SLearner, XLearner, RLearner, DRLearner], ) def test_metalearner_missing_data_error( numerical_experiment_dataset_continuous_outcome_binary_treatment_linear_te, @@ -258,7 +258,7 @@ def test_metalearner_missing_data_error( @pytest.mark.parametrize( "implementation", - [_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner], + [TLearner, SLearner, XLearner, RLearner, DRLearner], ) def test_metalearner_format_consistent( numerical_experiment_dataset_continuous_outcome_binary_treatment_linear_te, @@ -345,7 +345,7 @@ def test_n_folds(n_folds): @pytest.mark.parametrize( "implementation", - [_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner], + [TLearner, SLearner, XLearner, RLearner, DRLearner], ) def test_metalearner_model_names(implementation): set1 = set(implementation.nuisance_model_specifications().keys()) @@ -702,7 +702,6 @@ def test_fit_params_rlearner_error(dummy_dataset): @pytest.mark.parametrize( "implementation, needs_estimates", [ - (_TestMetaLearner, True), (TLearner, True), (SLearner, True), (XLearner, True), @@ -988,7 +987,7 @@ def test_validate_n_folds_synchronize(n_folds, success): @pytest.mark.parametrize( "implementation", - [TLearner], + [TLearner, XLearner, RLearner, DRLearner], ) def test_n_jobs_base_learners(implementation, rng): n_variants = 5