Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 17, 2024
1 parent b6c3dd0 commit e8e3b39
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# SPDX-License-Identifier: BSD-3-Clause


from types import GeneratorType

import numpy as np
import pandas as pd
import pytest
from lightgbm import LGBMClassifier, LGBMRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
Expand Down Expand Up @@ -259,3 +262,42 @@ def test_metalearnergridsearch_reuse_propensity_smoke(grid_search_data):
assert gs.results_ is not None
assert gs.results_.shape[0] == 2
assert len(gs.results_.index.names) == 5


@pytest.mark.parametrize(
"store_raw_results, store_results, expected_type_raw_results, expected_type_results",
[
(True, True, list, pd.DataFrame),
(True, False, list, type(None)),
(False, True, type(None), pd.DataFrame),
(False, False, GeneratorType, type(None)),
],
)
def test_metalearnergridsearch_store(
store_raw_results,
store_results,
expected_type_raw_results,
expected_type_results,
grid_search_data,
):
X, _, y, w, X_test, _, y_test, w_test = grid_search_data
n_variants = len(np.unique(w))

metalearner_params = {
"is_classification": False,
"n_variants": n_variants,
"n_folds": 2,
}

gs = MetaLearnerGridSearch(
metalearner_factory=SLearner,
metalearner_params=metalearner_params,
base_learner_grid={"base_model": [LinearRegression, LGBMRegressor]},
param_grid={"base_model": {"LGBMRegressor": {"n_estimators": [1, 2]}}},
store_raw_results=store_raw_results,
store_results=store_results,
)

gs.fit(X, y, w, X_test, y_test, w_test)
assert isinstance(gs.raw_results_, expected_type_raw_results)
assert isinstance(gs.results_, expected_type_results)

0 comments on commit e8e3b39

Please sign in to comment.