Skip to content

Commit

Permalink
add explicit tests for pred_leaf shape
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Dec 15, 2024
1 parent b6bb3c9 commit ba39a6f
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import psutil
import pytest
from scipy.sparse import csr_matrix, isspmatrix_csc, isspmatrix_csr
from sklearn.datasets import load_svmlight_file, make_blobs, make_multilabel_classification
from sklearn.datasets import load_svmlight_file, make_blobs, make_multilabel_classification, make_regression
from sklearn.metrics import average_precision_score, log_loss, mean_absolute_error, mean_squared_error, roc_auc_score
from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split

Expand Down Expand Up @@ -2316,6 +2316,14 @@ def test_refit_with_one_tree():
assert isinstance(model_refit, lgb.Booster)


def test_pred_leaf_output_shape():
X, y = make_regression(n_samples=10_000, n_features=10)
dtrain = lgb.Dataset(X, label=y)
params = {"objective": "regression", "verbosity": -1}
assert lgb.train(params, dtrain, num_boost_round=1).predict(X, pred_leaf=True).shape == (10_000, 1)
assert lgb.train(params, dtrain, num_boost_round=2).predict(X, pred_leaf=True).shape == (10_000, 2)


def test_refit_dataset_params(rng):
# check refit accepts dataset_params
X, y = load_breast_cancer(return_X_y=True)
Expand Down

0 comments on commit ba39a6f

Please sign in to comment.