Skip to content

Commit

Permalink
[python-package] make record_evaluation compatible with cv (fixes #4943
Browse files Browse the repository at this point in the history
…) (#4947)

* make record_evaluation compatible with cv

* test multiple metrics in cv

* lint

* fix cv with train metric. save stdv as well

* always add dataset prefix to cv_agg

* remove unused function
  • Loading branch information
jmoralez authored Feb 15, 2022
1 parent 2d1caf1 commit 9fc348a
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 83 deletions.
23 changes: 19 additions & 4 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,30 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:

def _init(env: CallbackEnv) -> None:
eval_result.clear()
for data_name, eval_name, _, _ in env.evaluation_result_list:
for item in env.evaluation_result_list:
if len(item) == 4: # regular train
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
eval_result.setdefault(data_name, collections.OrderedDict())
eval_result[data_name].setdefault(eval_name, [])
if len(item) == 4:
eval_result[data_name].setdefault(eval_name, [])
else:
eval_result[data_name].setdefault(f'{eval_name}-mean', [])
eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

def _callback(env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
_init(env)
for data_name, eval_name, result, _ in env.evaluation_result_list:
eval_result[data_name][eval_name].append(result)
for item in env.evaluation_result_list:
if len(item) == 4:
data_name, eval_name, result = item[:3]
eval_result[data_name][eval_name].append(result)
else:
data_name, eval_name = item[1].split()
res_mean, res_stdv = item[2], item[4]
eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)
_callback.order = 20 # type: ignore
return _callback

Expand Down
9 changes: 3 additions & 6 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,13 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
return ret


def _agg_cv_result(raw_results, eval_train_metric=False):
def _agg_cv_result(raw_results):
"""Aggregate cross-validation results."""
cvmap = collections.OrderedDict()
metric_type = {}
for one_result in raw_results:
for one_line in one_result:
if eval_train_metric:
key = f"{one_line[0]} {one_line[1]}"
else:
key = one_line[1]
key = f"{one_line[0]} {one_line[1]}"
metric_type[key] = one_line[3]
cvmap.setdefault(key, [])
cvmap[key].append(one_line[2])
Expand Down Expand Up @@ -573,7 +570,7 @@ def cv(params, train_set, num_boost_round=100,
end_iteration=num_boost_round,
evaluation_result_list=None))
cvfolds.update(fobj=fobj)
res = _agg_cv_result(cvfolds.eval_valid(feval), eval_train_metric)
res = _agg_cv_result(cvfolds.eval_valid(feval))
for _, key, mean, _, std in res:
results[f'{key}-mean'].append(mean)
results[f'{key}-stdv'].append(std)
Expand Down
Loading

0 comments on commit 9fc348a

Please sign in to comment.