diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 550dbfc6a134..05539b6396ac 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -130,7 +130,8 @@ def __call__(self, env: CallbackEnv) -> None: self.eval_result[data_name][eval_name].append(result) else: data_name, eval_name = item[1].split() - res_mean, res_stdv = item[2], item[4] + res_mean = item[2] + res_stdv = item[4] self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean) self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv) @@ -171,6 +172,34 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: return _RecordEvaluationCallback(eval_result=eval_result) +class _ResetParameterCallback: + """Internal reset parameter callable class.""" + + def __init__(self, **kwargs: Union[list, Callable]) -> None: + self.order = 10 + self.before_iteration = True + + self.kwargs = kwargs + + def __call__(self, env: CallbackEnv) -> None: + new_parameters = {} + for key, value in self.kwargs.items(): + if isinstance(value, list): + if len(value) != env.end_iteration - env.begin_iteration: + raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.") + new_param = value[env.iteration - env.begin_iteration] + elif callable(value): + new_param = value(env.iteration - env.begin_iteration) + else: + raise ValueError("Only list and callable values are supported " + "as a mapping from boosting round index to new parameter value.") + if new_param != env.params.get(key, None): + new_parameters[key] = new_param + if new_parameters: + env.model.reset_parameter(new_parameters) + env.params.update(new_parameters) + + def reset_parameter(**kwargs: Union[list, Callable]) -> Callable: """Create a callback that resets the parameter after the first iteration. @@ -189,26 +218,10 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable: Returns ------- - callback : callable + callback : _ResetParameterCallback The callback that resets the parameter after the first iteration. """ - def _callback(env: CallbackEnv) -> None: - new_parameters = {} - for key, value in kwargs.items(): - if isinstance(value, list): - if len(value) != env.end_iteration - env.begin_iteration: - raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.") - new_param = value[env.iteration - env.begin_iteration] - else: - new_param = value(env.iteration - env.begin_iteration) - if new_param != env.params.get(key, None): - new_parameters[key] = new_param - if new_parameters: - env.model.reset_parameter(new_parameters) - env.params.update(new_parameters) - _callback.before_iteration = True # type: ignore - _callback.order = 10 # type: ignore - return _callback + return _ResetParameterCallback(**kwargs) class _EarlyStoppingCallback: diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index 447ebbc5c9a1..1a101fd6799b 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -22,6 +22,10 @@ def pickle_and_unpickle_object(obj, serializer): return obj_from_disk +def reset_feature_fraction(boosting_round): + return 0.6 if boosting_round < 15 else 0.8 + + @pytest.mark.parametrize('serializer', SERIALIZERS) def test_early_stopping_callback_is_picklable(serializer): rounds = 5 @@ -53,3 +57,17 @@ def test_record_evaluation_callback_is_picklable(serializer): assert callback_from_disk.before_iteration is False assert callback.eval_result == callback_from_disk.eval_result assert callback.eval_result is results + + +@pytest.mark.parametrize('serializer', SERIALIZERS) +def test_reset_parameter_callback_is_picklable(serializer): + params = { + 'bagging_fraction': [0.7] * 5 + [0.6] * 5, + 'feature_fraction': reset_feature_fraction + } + callback = lgb.reset_parameter(**params) + callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) + assert callback_from_disk.order == 10 + assert callback_from_disk.before_iteration is True + assert callback.kwargs == callback_from_disk.kwargs + assert callback.kwargs == params