diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index cb2e893c9612..ef20b78a40c4 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -2307,6 +2307,15 @@ def test_refit(): assert err_pred > new_err_pred +def test_refit_with_one_tree(): + X, y = load_breast_cancer(return_X_y=True) + lgb_train = lgb.Dataset(X, label=y) + params={"objective": "binary", "num_trees": 1, "verbosity": -1} + model = lgb.train(params, lgb_train, num_boost_round=1) + model_refit = model.refit(X, y) + assert isinstance(model_refit, lgb.Booster) + + def test_refit_dataset_params(rng): # check refit accepts dataset_params X, y = load_breast_cancer(return_X_y=True)