From bd08187590c8a9cd4262b3bc6c51aa18d4361b56 Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Wed, 30 Mar 2022 03:02:51 +0300 Subject: [PATCH 1/3] make `log_evaluation` callback pickleable --- python-package/lightgbm/callback.py | 27 +++++++++++++++------- tests/python_package_test/test_callback.py | 16 +++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 2fc301b0e509..3a125267c9a8 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -54,7 +54,23 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str: raise ValueError("Wrong metric value") -def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable: +class _LogEvaluationCallback: + """Internal log evaluation callable class.""" + + def __init__(self, period: int = 1, show_stdv: bool = True) -> None: + self.order = 10 + self.before_iteration = False + + self.period = period + self.show_stdv = show_stdv + + def __call__(self, env: CallbackEnv) -> None: + if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0: + result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list]) + _log_info(f'[{env.iteration + 1}]\t{result}') + + +def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback: """Create a callback that logs the evaluation results. By default, standard output resource is used. @@ -74,15 +90,10 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable: Returns ------- - callback : callable + callback : _LogEvaluationCallback The callback that logs the evaluation results every ``period`` boosting iteration(s). """ - def _callback(env: CallbackEnv) -> None: - if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0: - result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list]) - _log_info(f'[{env.iteration + 1}]\t{result}') - _callback.order = 10 # type: ignore - return _callback + return _LogEvaluationCallback(period=period, show_stdv=show_stdv) def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index 0f339aa3a53e..be0cea8dd0e7 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -20,3 +20,19 @@ def test_early_stopping_callback_is_picklable(serializer, tmp_path): serializer=serializer ) assert callback.stopping_rounds == callback_from_disk.stopping_rounds + + +@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) +def test_log_evaluation_callback_is_picklable(serializer, tmp_path): + callback = lgb.log_evaluation(period=42) + tmp_file = tmp_path / "log_evaluation.pkl" + pickle_obj( + obj=callback, + filepath=tmp_file, + serializer=serializer + ) + callback_from_disk = unpickle_obj( + filepath=tmp_file, + serializer=serializer + ) + assert callback.period == callback_from_disk.period From 617d207a0b371c7aed206baff42f18a47a2eef72 Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Wed, 30 Mar 2022 18:12:56 +0300 Subject: [PATCH 2/3] make callback tests stricter --- tests/python_package_test/test_callback.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index be0cea8dd0e7..29609ebcbb22 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -8,7 +8,8 @@ @pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) def test_early_stopping_callback_is_picklable(serializer, tmp_path): - callback = lgb.early_stopping(stopping_rounds=5) + rounds = 5 + callback = lgb.early_stopping(stopping_rounds=rounds) tmp_file = tmp_path / "early_stopping.pkl" pickle_obj( obj=callback, @@ -20,11 +21,13 @@ def test_early_stopping_callback_is_picklable(serializer, tmp_path): serializer=serializer ) assert callback.stopping_rounds == callback_from_disk.stopping_rounds + assert callback.stopping_rounds == rounds @pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) def test_log_evaluation_callback_is_picklable(serializer, tmp_path): - callback = lgb.log_evaluation(period=42) + periods = 42 + callback = lgb.log_evaluation(period=periods) tmp_file = tmp_path / "log_evaluation.pkl" pickle_obj( obj=callback, @@ -36,3 +39,4 @@ def test_log_evaluation_callback_is_picklable(serializer, tmp_path): serializer=serializer ) assert callback.period == callback_from_disk.period + assert callback.period == periods From f4ed73588be0c35521f4396bed5d7d5c0e1c74ef Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Wed, 30 Mar 2022 21:45:13 +0300 Subject: [PATCH 3/3] make `record_evaluation` callback picklable --- python-package/lightgbm/callback.py | 73 ++++++++++++---------- tests/python_package_test/test_callback.py | 61 +++++++++++------- 2 files changed, 78 insertions(+), 56 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 3a125267c9a8..550dbfc6a134 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -96,6 +96,45 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCal return _LogEvaluationCallback(period=period, show_stdv=show_stdv) +class _RecordEvaluationCallback: + """Internal record evaluation callable class.""" + + def __init__(self, eval_result: Dict[str, Dict[str, List[Any]]]) -> None: + self.order = 20 + self.before_iteration = False + + if not isinstance(eval_result, dict): + raise TypeError('eval_result should be a dictionary') + self.eval_result = eval_result + + def _init(self, env: CallbackEnv) -> None: + self.eval_result.clear() + for item in env.evaluation_result_list: + if len(item) == 4: # regular train + data_name, eval_name = item[:2] + else: # cv + data_name, eval_name = item[1].split() + self.eval_result.setdefault(data_name, collections.OrderedDict()) + if len(item) == 4: + self.eval_result[data_name].setdefault(eval_name, []) + else: + self.eval_result[data_name].setdefault(f'{eval_name}-mean', []) + self.eval_result[data_name].setdefault(f'{eval_name}-stdv', []) + + def __call__(self, env: CallbackEnv) -> None: + if env.iteration == env.begin_iteration: + self._init(env) + for item in env.evaluation_result_list: + if len(item) == 4: + data_name, eval_name, result = item[:3] + self.eval_result[data_name][eval_name].append(result) + else: + data_name, eval_name = item[1].split() + res_mean, res_stdv = item[2], item[4] + self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean) + self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv) + + def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: """Create a callback that records the evaluation history into ``eval_result``. @@ -126,40 +165,10 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: Returns ------- - callback : callable + callback : _RecordEvaluationCallback The callback that records the evaluation history into the passed dictionary. """ - if not isinstance(eval_result, dict): - raise TypeError('eval_result should be a dictionary') - - def _init(env: CallbackEnv) -> None: - eval_result.clear() - for item in env.evaluation_result_list: - if len(item) == 4: # regular train - data_name, eval_name = item[:2] - else: # cv - data_name, eval_name = item[1].split() - eval_result.setdefault(data_name, collections.OrderedDict()) - if len(item) == 4: - eval_result[data_name].setdefault(eval_name, []) - else: - eval_result[data_name].setdefault(f'{eval_name}-mean', []) - eval_result[data_name].setdefault(f'{eval_name}-stdv', []) - - def _callback(env: CallbackEnv) -> None: - if env.iteration == env.begin_iteration: - _init(env) - for item in env.evaluation_result_list: - if len(item) == 4: - data_name, eval_name, result = item[:3] - eval_result[data_name][eval_name].append(result) - else: - data_name, eval_name = item[1].split() - res_mean, res_stdv = item[2], item[4] - eval_result[data_name][f'{eval_name}-mean'].append(res_mean) - eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv) - _callback.order = 20 # type: ignore - return _callback + return _RecordEvaluationCallback(eval_result=eval_result) def reset_parameter(**kwargs: Union[list, Callable]) -> Callable: diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index 29609ebcbb22..447ebbc5c9a1 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -5,38 +5,51 @@ from .utils import pickle_obj, unpickle_obj +SERIALIZERS = ["pickle", "joblib", "cloudpickle"] -@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) -def test_early_stopping_callback_is_picklable(serializer, tmp_path): + +def pickle_and_unpickle_object(obj, serializer): + with lgb.basic._TempFile() as tmp_file: + pickle_obj( + obj=obj, + filepath=tmp_file.name, + serializer=serializer + ) + obj_from_disk = unpickle_obj( + filepath=tmp_file.name, + serializer=serializer + ) + return obj_from_disk + + +@pytest.mark.parametrize('serializer', SERIALIZERS) +def test_early_stopping_callback_is_picklable(serializer): rounds = 5 callback = lgb.early_stopping(stopping_rounds=rounds) - tmp_file = tmp_path / "early_stopping.pkl" - pickle_obj( - obj=callback, - filepath=tmp_file, - serializer=serializer - ) - callback_from_disk = unpickle_obj( - filepath=tmp_file, - serializer=serializer - ) + callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) + assert callback_from_disk.order == 30 + assert callback_from_disk.before_iteration is False assert callback.stopping_rounds == callback_from_disk.stopping_rounds assert callback.stopping_rounds == rounds -@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) -def test_log_evaluation_callback_is_picklable(serializer, tmp_path): +@pytest.mark.parametrize('serializer', SERIALIZERS) +def test_log_evaluation_callback_is_picklable(serializer): periods = 42 callback = lgb.log_evaluation(period=periods) - tmp_file = tmp_path / "log_evaluation.pkl" - pickle_obj( - obj=callback, - filepath=tmp_file, - serializer=serializer - ) - callback_from_disk = unpickle_obj( - filepath=tmp_file, - serializer=serializer - ) + callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) + assert callback_from_disk.order == 10 + assert callback_from_disk.before_iteration is False assert callback.period == callback_from_disk.period assert callback.period == periods + + +@pytest.mark.parametrize('serializer', SERIALIZERS) +def test_record_evaluation_callback_is_picklable(serializer): + results = {} + callback = lgb.record_evaluation(eval_result=results) + callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) + assert callback_from_disk.order == 20 + assert callback_from_disk.before_iteration is False + assert callback.eval_result == callback_from_disk.eval_result + assert callback.eval_result is results