diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 3a125267c9a8..550dbfc6a134 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -96,6 +96,45 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCal return _LogEvaluationCallback(period=period, show_stdv=show_stdv) +class _RecordEvaluationCallback: + """Internal record evaluation callable class.""" + + def __init__(self, eval_result: Dict[str, Dict[str, List[Any]]]) -> None: + self.order = 20 + self.before_iteration = False + + if not isinstance(eval_result, dict): + raise TypeError('eval_result should be a dictionary') + self.eval_result = eval_result + + def _init(self, env: CallbackEnv) -> None: + self.eval_result.clear() + for item in env.evaluation_result_list: + if len(item) == 4: # regular train + data_name, eval_name = item[:2] + else: # cv + data_name, eval_name = item[1].split() + self.eval_result.setdefault(data_name, collections.OrderedDict()) + if len(item) == 4: + self.eval_result[data_name].setdefault(eval_name, []) + else: + self.eval_result[data_name].setdefault(f'{eval_name}-mean', []) + self.eval_result[data_name].setdefault(f'{eval_name}-stdv', []) + + def __call__(self, env: CallbackEnv) -> None: + if env.iteration == env.begin_iteration: + self._init(env) + for item in env.evaluation_result_list: + if len(item) == 4: + data_name, eval_name, result = item[:3] + self.eval_result[data_name][eval_name].append(result) + else: + data_name, eval_name = item[1].split() + res_mean, res_stdv = item[2], item[4] + self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean) + self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv) + + def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: """Create a callback that records the evaluation history into ``eval_result``. @@ -126,40 +165,10 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: Returns ------- - callback : callable + callback : _RecordEvaluationCallback The callback that records the evaluation history into the passed dictionary. """ - if not isinstance(eval_result, dict): - raise TypeError('eval_result should be a dictionary') - - def _init(env: CallbackEnv) -> None: - eval_result.clear() - for item in env.evaluation_result_list: - if len(item) == 4: # regular train - data_name, eval_name = item[:2] - else: # cv - data_name, eval_name = item[1].split() - eval_result.setdefault(data_name, collections.OrderedDict()) - if len(item) == 4: - eval_result[data_name].setdefault(eval_name, []) - else: - eval_result[data_name].setdefault(f'{eval_name}-mean', []) - eval_result[data_name].setdefault(f'{eval_name}-stdv', []) - - def _callback(env: CallbackEnv) -> None: - if env.iteration == env.begin_iteration: - _init(env) - for item in env.evaluation_result_list: - if len(item) == 4: - data_name, eval_name, result = item[:3] - eval_result[data_name][eval_name].append(result) - else: - data_name, eval_name = item[1].split() - res_mean, res_stdv = item[2], item[4] - eval_result[data_name][f'{eval_name}-mean'].append(res_mean) - eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv) - _callback.order = 20 # type: ignore - return _callback + return _RecordEvaluationCallback(eval_result=eval_result) def reset_parameter(**kwargs: Union[list, Callable]) -> Callable: diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index 29609ebcbb22..447ebbc5c9a1 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -5,38 +5,51 @@ from .utils import pickle_obj, unpickle_obj +SERIALIZERS = ["pickle", "joblib", "cloudpickle"] -@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) -def test_early_stopping_callback_is_picklable(serializer, tmp_path): + +def pickle_and_unpickle_object(obj, serializer): + with lgb.basic._TempFile() as tmp_file: + pickle_obj( + obj=obj, + filepath=tmp_file.name, + serializer=serializer + ) + obj_from_disk = unpickle_obj( + filepath=tmp_file.name, + serializer=serializer + ) + return obj_from_disk + + +@pytest.mark.parametrize('serializer', SERIALIZERS) +def test_early_stopping_callback_is_picklable(serializer): rounds = 5 callback = lgb.early_stopping(stopping_rounds=rounds) - tmp_file = tmp_path / "early_stopping.pkl" - pickle_obj( - obj=callback, - filepath=tmp_file, - serializer=serializer - ) - callback_from_disk = unpickle_obj( - filepath=tmp_file, - serializer=serializer - ) + callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) + assert callback_from_disk.order == 30 + assert callback_from_disk.before_iteration is False assert callback.stopping_rounds == callback_from_disk.stopping_rounds assert callback.stopping_rounds == rounds -@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) -def test_log_evaluation_callback_is_picklable(serializer, tmp_path): +@pytest.mark.parametrize('serializer', SERIALIZERS) +def test_log_evaluation_callback_is_picklable(serializer): periods = 42 callback = lgb.log_evaluation(period=periods) - tmp_file = tmp_path / "log_evaluation.pkl" - pickle_obj( - obj=callback, - filepath=tmp_file, - serializer=serializer - ) - callback_from_disk = unpickle_obj( - filepath=tmp_file, - serializer=serializer - ) + callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) + assert callback_from_disk.order == 10 + assert callback_from_disk.before_iteration is False assert callback.period == callback_from_disk.period assert callback.period == periods + + +@pytest.mark.parametrize('serializer', SERIALIZERS) +def test_record_evaluation_callback_is_picklable(serializer): + results = {} + callback = lgb.record_evaluation(eval_result=results) + callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) + assert callback_from_disk.order == 20 + assert callback_from_disk.before_iteration is False + assert callback.eval_result == callback_from_disk.eval_result + assert callback.eval_result is results