Skip to content

Commit

Permalink
Speedup tests
Browse files Browse the repository at this point in the history
Co-authored-by: Kevin Klein <[email protected]>
  • Loading branch information
FrancescMartiEscofetQC and kklein committed Jun 14, 2024
1 parent dfa95d8 commit e8b64e6
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 39 deletions.
80 changes: 52 additions & 28 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def mindset_data():
return load_mindset_data()


@pytest.fixture(scope="function")
def twins_data(rng):
@pytest.fixture(scope="session")
def twins_data():
rng = np.random.default_rng(_SEED)
(
chosen_df,
outcome_column,
Expand All @@ -94,28 +95,30 @@ def twins_data(rng):
)


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def n_numericals():
return 25


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def n_categoricals():
return 5


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def sample_size():
return 100_000


@pytest.fixture(scope="function")
def numerical_covariates(sample_size, n_numericals, rng):
@pytest.fixture(scope="session")
def numerical_covariates(sample_size, n_numericals):
rng = np.random.default_rng(_SEED)
return generate_covariates(sample_size, n_numericals, format="numpy", rng=rng)


@pytest.fixture(scope="function")
def mixed_covariates(sample_size, n_numericals, n_categoricals, rng):
@pytest.fixture(scope="session")
def mixed_covariates(sample_size, n_numericals, n_categoricals):
rng = np.random.default_rng(_SEED)
return generate_covariates(
sample_size,
n_numericals + n_categoricals,
Expand All @@ -125,52 +128,72 @@ def mixed_covariates(sample_size, n_numericals, n_categoricals, rng):
)


@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def numerical_experiment_dataset_continuous_outcome_binary_treatment_linear_te(
numerical_covariates, rng
sample_size, n_numericals
):
covariates, _, _ = numerical_covariates
rng = np.random.default_rng(_SEED)
covariates, _, _ = generate_covariates(
sample_size, n_numericals, format="numpy", rng=rng
)
return _generate_rct_experiment_data(covariates, False, rng, 0.3, None)


@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def numerical_experiment_dataset_binary_outcome_binary_treatment_linear_te(
numerical_covariates, rng
sample_size, n_numericals
):
covariates, _, _ = numerical_covariates
rng = np.random.default_rng(_SEED)
covariates, _, _ = generate_covariates(
sample_size, n_numericals, format="numpy", rng=rng
)
return _generate_rct_experiment_data(covariates, True, rng, 0.3, None)


@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def mixed_experiment_dataset_continuous_outcome_binary_treatment_linear_te(
mixed_covariates, rng
sample_size, n_numericals, n_categoricals
):
covariates, _, _ = mixed_covariates
rng = np.random.default_rng(_SEED)
covariates, _, _ = generate_covariates(
sample_size,
n_numericals + n_categoricals,
n_categoricals=n_categoricals,
format="pandas",
rng=rng,
)
return _generate_rct_experiment_data(covariates, False, rng, 0.3, None)


@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def numerical_experiment_dataset_continuous_outcome_multi_treatment_linear_te(
numerical_covariates, rng
sample_size, n_numericals
):
covariates, _, _ = numerical_covariates
rng = np.random.default_rng(_SEED)
covariates, _, _ = generate_covariates(
sample_size, n_numericals, format="numpy", rng=rng
)
return _generate_rct_experiment_data(
covariates, False, rng, [0.2, 0.1, 0.3, 0.15, 0.25], None
)


@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def numerical_experiment_dataset_continuous_outcome_multi_treatment_constant_te(
numerical_covariates, rng
sample_size, n_numericals
):
covariates, _, _ = numerical_covariates
rng = np.random.default_rng(_SEED)
covariates, _, _ = generate_covariates(
sample_size, n_numericals, format="numpy", rng=rng
)
return _generate_rct_experiment_data(
covariates, False, rng, [0.2, 0.1, 0.3, 0.15, 0.25], np.array([-2, 5, 0, 3])
)


@pytest.fixture
def dummy_dataset(rng):
@pytest.fixture(scope="session")
def dummy_dataset():
rng = np.random.default_rng(_SEED)
sample_size = 100
n_features = 10
X = rng.standard_normal((sample_size, n_features))
Expand All @@ -179,8 +202,9 @@ def dummy_dataset(rng):
return X, y, w


@pytest.fixture(scope="function")
def feature_importance_dataset(rng):
@pytest.fixture(scope="session")
def feature_importance_dataset():
rng = np.random.default_rng(_SEED)
n_samples = 10000
x0 = rng.normal(10, 1, n_samples)
x1 = rng.normal(2, 1, n_samples)
Expand Down
11 changes: 6 additions & 5 deletions tests/test_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,8 @@ def test_learner_twins(metalearner, reference_value, twins_data, rng):
@pytest.mark.parametrize("n_classes", [2, 5, 10])
@pytest.mark.parametrize("n_variants", [2, 5])
@pytest.mark.parametrize("is_classification", [True, False])
def test_learner_evaluate(
metalearner, is_classification, rng, sample_size, n_classes, n_variants
):
def test_learner_evaluate(metalearner, is_classification, rng, n_classes, n_variants):
sample_size = 1000
factory = metalearner_factory(metalearner)
if n_variants > 2 and not factory._supports_multi_treatment():
pytest.skip()
Expand Down Expand Up @@ -617,8 +616,9 @@ def test_conditional_average_outcomes_smoke(
@pytest.mark.parametrize("n_classes", [5, 10])
@pytest.mark.parametrize("n_variants", [2, 5])
def test_conditional_average_outcomes_smoke_multi_class(
metalearner_prefix, rng, sample_size, n_classes, n_variants
metalearner_prefix, rng, n_classes, n_variants
):
sample_size = 1000
factory = metalearner_factory(metalearner_prefix)

X = rng.standard_normal((sample_size, 10))
Expand Down Expand Up @@ -648,8 +648,9 @@ def test_conditional_average_outcomes_smoke_multi_class(
@pytest.mark.parametrize("n_variants", [2, 5])
@pytest.mark.parametrize("is_classification", [True, False])
def test_predict_smoke(
metalearner_prefix, is_classification, rng, sample_size, n_classes, n_variants
metalearner_prefix, is_classification, rng, n_classes, n_variants
):
sample_size = 1000
factory = metalearner_factory(metalearner_prefix)
if n_variants > 2 and not factory._supports_multi_treatment():
pytest.skip()
Expand Down
11 changes: 5 additions & 6 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_metalearner_init(

@pytest.mark.parametrize(
"implementation",
[_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner],
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
def test_metalearner_categorical(
mixed_experiment_dataset_continuous_outcome_binary_treatment_linear_te,
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_metalearner_categorical(

@pytest.mark.parametrize(
"implementation",
[_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner],
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
def test_metalearner_missing_data_smoke(
mixed_experiment_dataset_continuous_outcome_binary_treatment_linear_te,
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_metalearner_missing_data_smoke(

@pytest.mark.parametrize(
"implementation",
[_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner],
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
def test_metalearner_missing_data_error(
numerical_experiment_dataset_continuous_outcome_binary_treatment_linear_te,
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_metalearner_missing_data_error(

@pytest.mark.parametrize(
"implementation",
[_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner],
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
def test_metalearner_format_consistent(
numerical_experiment_dataset_continuous_outcome_binary_treatment_linear_te,
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_n_folds(n_folds):

@pytest.mark.parametrize(
"implementation",
[_TestMetaLearner, TLearner, SLearner, XLearner, RLearner, DRLearner],
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
def test_metalearner_model_names(implementation):
set1 = set(implementation.nuisance_model_specifications().keys())
Expand Down Expand Up @@ -702,7 +702,6 @@ def test_fit_params_rlearner_error(dummy_dataset):
@pytest.mark.parametrize(
"implementation, needs_estimates",
[
(_TestMetaLearner, True),
(TLearner, True),
(SLearner, True),
(XLearner, True),
Expand Down

0 comments on commit e8b64e6

Please sign in to comment.