Skip to content

Commit

Permalink
[python] make log_evaluation callback pickleable (#5101)
Browse files Browse the repository at this point in the history
* make `log_evaluation` callback pickleable

* make callback tests stricter
  • Loading branch information
StrikerRUS authored Mar 30, 2022
1 parent 417c732 commit 8b33e77
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
27 changes: 19 additions & 8 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,23 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
raise ValueError("Wrong metric value")


def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
class _LogEvaluationCallback:
"""Internal log evaluation callable class."""

def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
self.order = 10
self.before_iteration = False

self.period = period
self.show_stdv = show_stdv

def __call__(self, env: CallbackEnv) -> None:
if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
_log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
"""Create a callback that logs the evaluation results.
By default, standard output resource is used.
Expand All @@ -74,15 +90,10 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
Returns
-------
callback : callable
callback : _LogEvaluationCallback
The callback that logs the evaluation results every ``period`` boosting iteration(s).
"""
def _callback(env: CallbackEnv) -> None:
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
_log_info(f'[{env.iteration + 1}]\t{result}')
_callback.order = 10 # type: ignore
return _callback
return _LogEvaluationCallback(period=period, show_stdv=show_stdv)


def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
Expand Down
22 changes: 21 additions & 1 deletion tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path):
callback = lgb.early_stopping(stopping_rounds=5)
rounds = 5
callback = lgb.early_stopping(stopping_rounds=rounds)
tmp_file = tmp_path / "early_stopping.pkl"
pickle_obj(
obj=callback,
Expand All @@ -20,3 +21,22 @@ def test_early_stopping_callback_is_picklable(serializer, tmp_path):
serializer=serializer
)
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
assert callback.stopping_rounds == rounds


@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_log_evaluation_callback_is_picklable(serializer, tmp_path):
periods = 42
callback = lgb.log_evaluation(period=periods)
tmp_file = tmp_path / "log_evaluation.pkl"
pickle_obj(
obj=callback,
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
assert callback.period == callback_from_disk.period
assert callback.period == periods

0 comments on commit 8b33e77

Please sign in to comment.