Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] make record_evaluation callback pickleable #5107

Merged
merged 4 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 41 additions & 32 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,45 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCal
return _LogEvaluationCallback(period=period, show_stdv=show_stdv)


class _RecordEvaluationCallback:
"""Internal record evaluation callable class."""

def __init__(self, eval_result: Dict[str, Dict[str, List[Any]]]) -> None:
self.order = 20
self.before_iteration = False

if not isinstance(eval_result, dict):
raise TypeError('eval_result should be a dictionary')
self.eval_result = eval_result

def _init(self, env: CallbackEnv) -> None:
self.eval_result.clear()
for item in env.evaluation_result_list:
if len(item) == 4: # regular train
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
self.eval_result.setdefault(data_name, collections.OrderedDict())
if len(item) == 4:
self.eval_result[data_name].setdefault(eval_name, [])
else:
self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

def __call__(self, env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
self._init(env)
for item in env.evaluation_result_list:
if len(item) == 4:
data_name, eval_name, result = item[:3]
self.eval_result[data_name][eval_name].append(result)
else:
data_name, eval_name = item[1].split()
res_mean, res_stdv = item[2], item[4]
self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)


def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
"""Create a callback that records the evaluation history into ``eval_result``.

Expand Down Expand Up @@ -126,40 +165,10 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:

Returns
-------
callback : callable
callback : _RecordEvaluationCallback
The callback that records the evaluation history into the passed dictionary.
"""
if not isinstance(eval_result, dict):
raise TypeError('eval_result should be a dictionary')

def _init(env: CallbackEnv) -> None:
eval_result.clear()
for item in env.evaluation_result_list:
if len(item) == 4: # regular train
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
eval_result.setdefault(data_name, collections.OrderedDict())
if len(item) == 4:
eval_result[data_name].setdefault(eval_name, [])
else:
eval_result[data_name].setdefault(f'{eval_name}-mean', [])
eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

def _callback(env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
_init(env)
for item in env.evaluation_result_list:
if len(item) == 4:
data_name, eval_name, result = item[:3]
eval_result[data_name][eval_name].append(result)
else:
data_name, eval_name = item[1].split()
res_mean, res_stdv = item[2], item[4]
eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)
_callback.order = 20 # type: ignore
return _callback
return _RecordEvaluationCallback(eval_result=eval_result)


def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
Expand Down
61 changes: 37 additions & 24 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,51 @@

from .utils import pickle_obj, unpickle_obj

SERIALIZERS = ["pickle", "joblib", "cloudpickle"]

@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path):

def pickle_and_unpickle_object(obj, serializer):
with lgb.basic._TempFile() as tmp_file:
pickle_obj(
obj=obj,
filepath=tmp_file.name,
serializer=serializer
)
obj_from_disk = unpickle_obj(
filepath=tmp_file.name,
serializer=serializer
)
return obj_from_disk


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_early_stopping_callback_is_picklable(serializer):
rounds = 5
callback = lgb.early_stopping(stopping_rounds=rounds)
tmp_file = tmp_path / "early_stopping.pkl"
pickle_obj(
obj=callback,
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
assert callback_from_disk.order == 30
assert callback_from_disk.before_iteration is False
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
assert callback.stopping_rounds == rounds


@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_log_evaluation_callback_is_picklable(serializer, tmp_path):
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer):
periods = 42
callback = lgb.log_evaluation(period=periods)
tmp_file = tmp_path / "log_evaluation.pkl"
pickle_obj(
obj=callback,
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
assert callback_from_disk.order == 10
assert callback_from_disk.before_iteration is False
assert callback.period == callback_from_disk.period
assert callback.period == periods


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_record_evaluation_callback_is_picklable(serializer):
results = {}
callback = lgb.record_evaluation(eval_result=results)
callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
assert callback_from_disk.order == 20
assert callback_from_disk.before_iteration is False
assert callback.eval_result == callback_from_disk.eval_result
assert callback.eval_result is results