From 8b33e776cc53cf8040c6653b7e43e02fd250e97c Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Wed, 30 Mar 2022 21:52:46 +0300 Subject: [PATCH] [python] make `log_evaluation` callback pickleable (#5101) * make `log_evaluation` callback pickleable * make callback tests stricter --- python-package/lightgbm/callback.py | 27 +++++++++++++++------- tests/python_package_test/test_callback.py | 22 +++++++++++++++++- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 2fc301b0e509..3a125267c9a8 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -54,7 +54,23 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str: raise ValueError("Wrong metric value") -def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable: +class _LogEvaluationCallback: + """Internal log evaluation callable class.""" + + def __init__(self, period: int = 1, show_stdv: bool = True) -> None: + self.order = 10 + self.before_iteration = False + + self.period = period + self.show_stdv = show_stdv + + def __call__(self, env: CallbackEnv) -> None: + if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0: + result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list]) + _log_info(f'[{env.iteration + 1}]\t{result}') + + +def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback: """Create a callback that logs the evaluation results. By default, standard output resource is used. @@ -74,15 +90,10 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable: Returns ------- - callback : callable + callback : _LogEvaluationCallback The callback that logs the evaluation results every ``period`` boosting iteration(s). """ - def _callback(env: CallbackEnv) -> None: - if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0: - result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list]) - _log_info(f'[{env.iteration + 1}]\t{result}') - _callback.order = 10 # type: ignore - return _callback + return _LogEvaluationCallback(period=period, show_stdv=show_stdv) def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index 0f339aa3a53e..29609ebcbb22 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -8,7 +8,8 @@ @pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) def test_early_stopping_callback_is_picklable(serializer, tmp_path): - callback = lgb.early_stopping(stopping_rounds=5) + rounds = 5 + callback = lgb.early_stopping(stopping_rounds=rounds) tmp_file = tmp_path / "early_stopping.pkl" pickle_obj( obj=callback, @@ -20,3 +21,22 @@ def test_early_stopping_callback_is_picklable(serializer, tmp_path): serializer=serializer ) assert callback.stopping_rounds == callback_from_disk.stopping_rounds + assert callback.stopping_rounds == rounds + + +@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) +def test_log_evaluation_callback_is_picklable(serializer, tmp_path): + periods = 42 + callback = lgb.log_evaluation(period=periods) + tmp_file = tmp_path / "log_evaluation.pkl" + pickle_obj( + obj=callback, + filepath=tmp_file, + serializer=serializer + ) + callback_from_disk = unpickle_obj( + filepath=tmp_file, + serializer=serializer + ) + assert callback.period == callback_from_disk.period + assert callback.period == periods