Skip to content

Commit

Permalink
[python] make reset_parameter callback pickleable (#5109)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS authored Mar 31, 2022
1 parent 3ed0027 commit 4ae3d13
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
51 changes: 32 additions & 19 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def __call__(self, env: CallbackEnv) -> None:
self.eval_result[data_name][eval_name].append(result)
else:
data_name, eval_name = item[1].split()
res_mean, res_stdv = item[2], item[4]
res_mean = item[2]
res_stdv = item[4]
self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)

Expand Down Expand Up @@ -171,6 +172,34 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
return _RecordEvaluationCallback(eval_result=eval_result)


class _ResetParameterCallback:
"""Internal reset parameter callable class."""

def __init__(self, **kwargs: Union[list, Callable]) -> None:
self.order = 10
self.before_iteration = True

self.kwargs = kwargs

def __call__(self, env: CallbackEnv) -> None:
new_parameters = {}
for key, value in self.kwargs.items():
if isinstance(value, list):
if len(value) != env.end_iteration - env.begin_iteration:
raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
new_param = value[env.iteration - env.begin_iteration]
elif callable(value):
new_param = value(env.iteration - env.begin_iteration)
else:
raise ValueError("Only list and callable values are supported "
"as a mapping from boosting round index to new parameter value.")
if new_param != env.params.get(key, None):
new_parameters[key] = new_param
if new_parameters:
env.model.reset_parameter(new_parameters)
env.params.update(new_parameters)


def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
"""Create a callback that resets the parameter after the first iteration.
Expand All @@ -189,26 +218,10 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
Returns
-------
callback : callable
callback : _ResetParameterCallback
The callback that resets the parameter after the first iteration.
"""
def _callback(env: CallbackEnv) -> None:
new_parameters = {}
for key, value in kwargs.items():
if isinstance(value, list):
if len(value) != env.end_iteration - env.begin_iteration:
raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
new_param = value[env.iteration - env.begin_iteration]
else:
new_param = value(env.iteration - env.begin_iteration)
if new_param != env.params.get(key, None):
new_parameters[key] = new_param
if new_parameters:
env.model.reset_parameter(new_parameters)
env.params.update(new_parameters)
_callback.before_iteration = True # type: ignore
_callback.order = 10 # type: ignore
return _callback
return _ResetParameterCallback(**kwargs)


class _EarlyStoppingCallback:
Expand Down
18 changes: 18 additions & 0 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def pickle_and_unpickle_object(obj, serializer):
return obj_from_disk


def reset_feature_fraction(boosting_round):
return 0.6 if boosting_round < 15 else 0.8


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_early_stopping_callback_is_picklable(serializer):
rounds = 5
Expand Down Expand Up @@ -53,3 +57,17 @@ def test_record_evaluation_callback_is_picklable(serializer):
assert callback_from_disk.before_iteration is False
assert callback.eval_result == callback_from_disk.eval_result
assert callback.eval_result is results


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_reset_parameter_callback_is_picklable(serializer):
params = {
'bagging_fraction': [0.7] * 5 + [0.6] * 5,
'feature_fraction': reset_feature_fraction
}
callback = lgb.reset_parameter(**params)
callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
assert callback_from_disk.order == 10
assert callback_from_disk.before_iteration is True
assert callback.kwargs == callback_from_disk.kwargs
assert callback.kwargs == params

0 comments on commit 4ae3d13

Please sign in to comment.