Skip to content

Commit

Permalink
[python-package] remove support for passing 'feature_name' and 'categ…
Browse files Browse the repository at this point in the history
…orical_feature' through train() and cv() (#6706)
  • Loading branch information
jameslamb authored Oct 31, 2024
1 parent dc0ed53 commit 8d5dca2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 87 deletions.
69 changes: 2 additions & 67 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import copy
import json
import warnings
from collections import OrderedDict, defaultdict
from operator import attrgetter
from pathlib import Path
Expand All @@ -15,17 +14,14 @@
from .basic import (
Booster,
Dataset,
LGBMDeprecationWarning,
LightGBMError,
_choose_param_value,
_ConfigAliases,
_InnerPredictor,
_LGBM_BoosterEvalMethodResultType,
_LGBM_BoosterEvalMethodResultWithStandardDeviationType,
_LGBM_CategoricalFeatureConfiguration,
_LGBM_CustomObjectiveFunction,
_LGBM_EvalFunctionResultType,
_LGBM_FeatureNameConfiguration,
_log_warning,
)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
Expand Down Expand Up @@ -54,15 +50,6 @@
]


def _emit_dataset_kwarg_warning(calling_function: str, argname: str) -> None:
msg = (
f"Argument '{argname}' to {calling_function}() is deprecated and will be removed in "
f"a future release. Set '{argname}' when calling lightgbm.Dataset() instead. "
"See https://github.com/microsoft/LightGBM/issues/6435."
)
warnings.warn(msg, category=LGBMDeprecationWarning, stacklevel=2)


def _choose_num_iterations(num_boost_round_kwarg: int, params: Dict[str, Any]) -> Dict[str, Any]:
"""Choose number of boosting rounds.
Expand Down Expand Up @@ -127,8 +114,6 @@ def train(
valid_names: Optional[List[str]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
keep_training_booster: bool = False,
callbacks: Optional[List[Callable]] = None,
) -> Booster:
Expand Down Expand Up @@ -170,21 +155,6 @@ def train(
set the ``metric`` parameter to the string ``"None"`` in ``params``.
init_model : str, pathlib.Path, Booster or None, optional (default=None)
Filename of LightGBM model or Booster instance used for continue training.
feature_name : list of str, or 'auto', optional (default="auto")
**Deprecated.** Set ``feature_name`` on ``train_set`` instead.
Feature names.
If 'auto' and data is pandas DataFrame, data columns names are used.
categorical_feature : list of str or int, or 'auto', optional (default="auto")
**Deprecated.** Set ``categorical_feature`` on ``train_set`` instead.
Categorical features.
If list of int, interpreted as indices.
If list of str, interpreted as feature names (need to specify ``feature_name`` as well).
If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
All values in categorical features will be cast to int32 and thus should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
Floating point numbers in categorical features will be rounded towards 0.
keep_training_booster : bool, optional (default=False)
Whether the returned Booster will be used to keep training.
If False, the returned value will be converted into _InnerPredictor before returning.
Expand Down Expand Up @@ -233,13 +203,6 @@ def train(
f"Item {i} has type '{type(valid_item).__name__}'."
)

# raise deprecation warnings if necessary
# ref: https://github.com/microsoft/LightGBM/issues/6435
if categorical_feature != "auto":
_emit_dataset_kwarg_warning("train", "categorical_feature")
if feature_name != "auto":
_emit_dataset_kwarg_warning("train", "feature_name")

# create predictor first
params = copy.deepcopy(params)
params = _choose_param_value(
Expand Down Expand Up @@ -278,9 +241,7 @@ def train(
else:
init_iteration = 0

train_set._update_params(params)._set_predictor(predictor).set_feature_name(feature_name).set_categorical_feature(
categorical_feature
)
train_set._update_params(params)._set_predictor(predictor)

is_valid_contain_train = False
train_data_name = "training"
Expand Down Expand Up @@ -642,8 +603,6 @@ def cv(
metrics: Optional[Union[str, List[str]]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
fpreproc: Optional[_LGBM_PreprocFunction] = None,
seed: int = 0,
callbacks: Optional[List[Callable]] = None,
Expand Down Expand Up @@ -699,21 +658,6 @@ def cv(
set ``metrics`` to the string ``"None"``.
init_model : str, pathlib.Path, Booster or None, optional (default=None)
Filename of LightGBM model or Booster instance used for continue training.
feature_name : list of str, or 'auto', optional (default="auto")
**Deprecated.** Set ``feature_name`` on ``train_set`` instead.
Feature names.
If 'auto' and data is pandas DataFrame, data columns names are used.
categorical_feature : list of str or int, or 'auto', optional (default="auto")
**Deprecated.** Set ``categorical_feature`` on ``train_set`` instead.
Categorical features.
If list of int, interpreted as indices.
If list of str, interpreted as feature names (need to specify ``feature_name`` as well).
If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
All values in categorical features will be cast to int32 and thus should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
Floating point numbers in categorical features will be rounded towards 0.
fpreproc : callable or None, optional (default=None)
Preprocessing function that takes (dtrain, dtest, params)
and returns transformed versions of those.
Expand Down Expand Up @@ -767,13 +711,6 @@ def cv(
if not isinstance(train_set, Dataset):
raise TypeError(f"cv() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.")

# raise deprecation warnings if necessary
# ref: https://github.com/microsoft/LightGBM/issues/6435
if categorical_feature != "auto":
_emit_dataset_kwarg_warning("cv", "categorical_feature")
if feature_name != "auto":
_emit_dataset_kwarg_warning("cv", "feature_name")

params = copy.deepcopy(params)
params = _choose_param_value(
main_param_name="objective",
Expand Down Expand Up @@ -818,9 +755,7 @@ def cv(
params.pop(metric_alias, None)
params["metric"] = metrics

train_set._update_params(params)._set_predictor(predictor).set_feature_name(feature_name).set_categorical_feature(
categorical_feature
)
train_set._update_params(params)._set_predictor(predictor)

results = defaultdict(list)
cvfolds = _make_n_folds(
Expand Down
51 changes: 31 additions & 20 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,7 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
]
)
y = rng.uniform(size=(100,))
ds = lgb.Dataset(X, y)
ds = lgb.Dataset(X, y, categorical_feature=[1, 2])
params = {
"bagging_fraction": 0.8,
"bagging_freq": 2,
Expand All @@ -1474,7 +1474,7 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
"verbosity": 0,
}
model_file = tmp_path / "model.txt"
orig_bst = lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2])
orig_bst = lgb.train(params, ds, num_boost_round=1)
orig_bst.save_model(model_file)
with model_file.open("rt") as f:
model_contents = f.readlines()
Expand Down Expand Up @@ -1746,16 +1746,18 @@ def test_pandas_categorical(rng_fixed_seed, tmp_path):
gbm0 = lgb.train(params, lgb_train, num_boost_round=10)
pred0 = gbm0.predict(X_test)
assert lgb_train.categorical_feature == "auto"
lgb_train = lgb.Dataset(X, pd.DataFrame(y)) # also test that label can be one-column pd.DataFrame
gbm1 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=[0])
lgb_train = lgb.Dataset(
X, pd.DataFrame(y), categorical_feature=[0]
) # also test that label can be one-column pd.DataFrame
gbm1 = lgb.train(params, lgb_train, num_boost_round=10)
pred1 = gbm1.predict(X_test)
assert lgb_train.categorical_feature == [0]
lgb_train = lgb.Dataset(X, pd.Series(y)) # also test that label can be pd.Series
gbm2 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=["A"])
lgb_train = lgb.Dataset(X, pd.Series(y), categorical_feature=["A"]) # also test that label can be pd.Series
gbm2 = lgb.train(params, lgb_train, num_boost_round=10)
pred2 = gbm2.predict(X_test)
assert lgb_train.categorical_feature == ["A"]
lgb_train = lgb.Dataset(X, y)
gbm3 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=["A", "B", "C", "D"])
lgb_train = lgb.Dataset(X, y, categorical_feature=["A", "B", "C", "D"])
gbm3 = lgb.train(params, lgb_train, num_boost_round=10)
pred3 = gbm3.predict(X_test)
assert lgb_train.categorical_feature == ["A", "B", "C", "D"]
categorical_model_path = tmp_path / "categorical.model"
Expand All @@ -1767,12 +1769,12 @@ def test_pandas_categorical(rng_fixed_seed, tmp_path):
pred5 = gbm4.predict(X_test)
gbm5 = lgb.Booster(model_str=model_str)
pred6 = gbm5.predict(X_test)
lgb_train = lgb.Dataset(X, y)
gbm6 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=["A", "B", "C", "D", "E"])
lgb_train = lgb.Dataset(X, y, categorical_feature=["A", "B", "C", "D", "E"])
gbm6 = lgb.train(params, lgb_train, num_boost_round=10)
pred7 = gbm6.predict(X_test)
assert lgb_train.categorical_feature == ["A", "B", "C", "D", "E"]
lgb_train = lgb.Dataset(X, y)
gbm7 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=[])
lgb_train = lgb.Dataset(X, y, categorical_feature=[])
gbm7 = lgb.train(params, lgb_train, num_boost_round=10)
pred8 = gbm7.predict(X_test)
assert lgb_train.categorical_feature == []
with pytest.raises(AssertionError):
Expand Down Expand Up @@ -3672,12 +3674,11 @@ def test_linear_trees(tmp_path, rng_fixed_seed):
# test with a categorical feature
x[:250, 0] = 0
y[:250] += 10
lgb_train = lgb.Dataset(x, label=y)
lgb_train = lgb.Dataset(x, label=y, categorical_feature=[0])
est = lgb.train(
dict(params, linear_tree=True, subsample=0.8, bagging_freq=1),
lgb_train,
num_boost_round=10,
categorical_feature=[0],
)
# test refit: same results on same data
est2 = est.refit(x, label=y)
Expand All @@ -3700,10 +3701,20 @@ def test_linear_trees(tmp_path, rng_fixed_seed):
# test when num_leaves - 1 < num_features and when num_leaves - 1 > num_features
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=2)
params = {"linear_tree": True, "verbose": -1, "metric": "mse", "seed": 0}
train_data = lgb.Dataset(X_train, label=y_train, params=dict(params, num_leaves=2))
est = lgb.train(params, train_data, num_boost_round=10, categorical_feature=[0])
train_data = lgb.Dataset(X_train, label=y_train, params=dict(params, num_leaves=60))
est = lgb.train(params, train_data, num_boost_round=10, categorical_feature=[0])
train_data = lgb.Dataset(
X_train,
label=y_train,
params=dict(params, num_leaves=2),
categorical_feature=[0],
)
est = lgb.train(params, train_data, num_boost_round=10)
train_data = lgb.Dataset(
X_train,
label=y_train,
params=dict(params, num_leaves=60),
categorical_feature=[0],
)
est = lgb.train(params, train_data, num_boost_round=10)


def test_save_and_load_linear(tmp_path):
Expand All @@ -3714,8 +3725,8 @@ def test_save_and_load_linear(tmp_path):
X_train[: X_train.shape[0] // 2, 0] = 0
y_train[: X_train.shape[0] // 2] = 1
params = {"linear_tree": True}
train_data_1 = lgb.Dataset(X_train, label=y_train, params=params)
est_1 = lgb.train(params, train_data_1, num_boost_round=10, categorical_feature=[0])
train_data_1 = lgb.Dataset(X_train, label=y_train, params=params, categorical_feature=[0])
est_1 = lgb.train(params, train_data_1, num_boost_round=10)
pred_1 = est_1.predict(X_train)

tmp_dataset = str(tmp_path / "temp_dataset.bin")
Expand Down

0 comments on commit 8d5dca2

Please sign in to comment.