Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] make log_evaluation callback pickleable #5101

Merged
merged 2 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,23 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
raise ValueError("Wrong metric value")


def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
class _LogEvaluationCallback:
"""Internal log evaluation callable class."""

def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
self.order = 10
self.before_iteration = False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very very minor comment, but could this empty line be removed, to group all assignments to self.* properties?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it would be better for code readability to split properties that are common for all callbacks (order and before_iteration) and specific properties for a particular callback. WDYT?
Refer to

self.order = 30
self.before_iteration = False
self.stopping_rounds = stopping_rounds
self.first_metric_only = first_metric_only
self.verbose = verbose
self.min_delta = min_delta

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see, that makes sense. Ok yeah I'm fine with that too! I don't feel strongly either way.

self.period = period
self.show_stdv = show_stdv

def __call__(self, env: CallbackEnv) -> None:
if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
_log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
"""Create a callback that logs the evaluation results.

By default, standard output resource is used.
Expand All @@ -74,15 +90,10 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:

Returns
-------
callback : callable
callback : _LogEvaluationCallback
The callback that logs the evaluation results every ``period`` boosting iteration(s).
"""
def _callback(env: CallbackEnv) -> None:
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
_log_info(f'[{env.iteration + 1}]\t{result}')
_callback.order = 10 # type: ignore
return _callback
return _LogEvaluationCallback(period=period, show_stdv=show_stdv)


def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
Expand Down
22 changes: 21 additions & 1 deletion tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path):
callback = lgb.early_stopping(stopping_rounds=5)
rounds = 5
callback = lgb.early_stopping(stopping_rounds=rounds)
tmp_file = tmp_path / "early_stopping.pkl"
pickle_obj(
obj=callback,
Expand All @@ -20,3 +21,22 @@ def test_early_stopping_callback_is_picklable(serializer, tmp_path):
serializer=serializer
)
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
assert callback.stopping_rounds == rounds


@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_log_evaluation_callback_is_picklable(serializer, tmp_path):
periods = 42
callback = lgb.log_evaluation(period=periods)
tmp_file = tmp_path / "log_evaluation.pkl"
pickle_obj(
obj=callback,
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
assert callback.period == callback_from_disk.period
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
assert callback.period == periods