Skip to content

Commit

Permalink
[Python] Refactor scikit-learn API to allow a list of evaluation metr…
Browse files Browse the repository at this point in the history
…ics (microsoft#3254)

* Refactors sklearn API to allow a list of evaluation metrics in the parameter eval_metric of the class (and subclasses of) LGBMModel. Also adds unit tests for this functionality

* Simplify expression to check whether the user passed one or multiple metrics to eval_metric parameter

* Simplify new tests by using custom metrics already defined in the test file

* Update docstring to reflect the fact that the parameter "feval" from the "train" and "cv" functions can also receive a list of callables

* Remove oxford comma from docstrings

Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* Use named-parameters to make sure code is compatible with future versions of scikit-learn

Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* Remove throwaway return value to make code more succinct

Co-authored-by: Nikita Titov <[email protected]>

* Move statement to group together the code related to feval

* Avoid modifying original args as it causes errors in scikit-learn tools

For details see: microsoft#2619

* Consolidate multiple eval-metrics unit-tests into one test

Co-authored-by: German I Ramirez-Espinoza <gire@home>
Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
3 people authored Sep 6, 2020
1 parent 0faf874 commit afc76d2
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 36 deletions.
17 changes: 11 additions & 6 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3111,18 +3111,23 @@ def __inner_eval(self, data_name, data_idx, feval=None):
for i in range_(self.__num_inner_eval):
ret.append((data_name, self.__name_inner_eval[i],
result[i], self.__higher_better_inner_eval[i]))
if callable(feval):
feval = [feval]
if feval is not None:
if data_idx == 0:
cur_data = self.train_set
else:
cur_data = self.valid_sets[data_idx - 1]
feval_ret = feval(self.__inner_predict(data_idx), cur_data)
if isinstance(feval_ret, list):
for eval_name, val, is_higher_better in feval_ret:
for eval_function in feval:
if eval_function is None:
continue
feval_ret = eval_function(self.__inner_predict(data_idx), cur_data)
if isinstance(feval_ret, list):
for eval_name, val, is_higher_better in feval_ret:
ret.append((data_name, eval_name, val, is_higher_better))
else:
eval_name, val, is_higher_better = feval_ret
ret.append((data_name, eval_name, val, is_higher_better))
else:
eval_name, val, is_higher_better = feval_ret
ret.append((data_name, eval_name, val, is_higher_better))
return ret

def __inner_predict(self, data_idx):
Expand Down
8 changes: 4 additions & 4 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def train(params, train_set, num_boost_round=100,
If you want to get i-th row preds in j-th class, the access way is score[j * num_data + i]
and you should group grad and hess in this way as well.
feval : callable or None, optional (default=None)
feval : callable, list of callable functions or None, optional (default=None)
Customized evaluation function.
Should accept two parameters: preds, train_data,
Each evaluation function should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.
preds : list or numpy 1-D array
Expand Down Expand Up @@ -443,9 +443,9 @@ def cv(params, train_set, num_boost_round=100,
If you want to get i-th row preds in j-th class, the access way is score[j * num_data + i]
and you should group grad and hess in this way as well.
feval : callable or None, optional (default=None)
feval : callable, list of callable functions or None, optional (default=None)
Customized evaluation function.
Should accept two parameters: preds, train_data,
Each evaluation function should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.
preds : list or numpy 1-D array
Expand Down
59 changes: 34 additions & 25 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Scikit-learn wrapper interface for LightGBM."""
from __future__ import absolute_import

import copy
import warnings

import numpy as np
Expand Down Expand Up @@ -388,9 +389,10 @@ def fit(self, X, y,
Init score of eval data.
eval_group : list of arrays or None, optional (default=None)
Group data of eval data.
eval_metric : string, list of strings, callable or None, optional (default=None)
eval_metric : string, callable, list or None, optional (default=None)
If string, it should be a built-in evaluation metric to use.
If callable, it should be a custom evaluation metric, see note below for more details.
If list, it can be a list of built-in metrics, a list of custom evaluation metrics, or a mix of both.
In either case, the ``metric`` from the model parameters will be evaluated and used as well.
Default: 'l2' for LGBMRegressor, 'logloss' for LGBMClassifier, 'ndcg' for LGBMRanker.
early_stopping_rounds : int or None, optional (default=None)
Expand Down Expand Up @@ -500,29 +502,36 @@ def fit(self, X, y,
if self._fobj:
params['objective'] = 'None' # objective = nullptr for unknown objective

if callable(eval_metric):
feval = _EvalFunctionWrapper(eval_metric)
else:
feval = None
# register default metric for consistency with callable eval_metric case
original_metric = self._objective if isinstance(self._objective, string_type) else None
if original_metric is None:
# try to deduce from class instance
if isinstance(self, LGBMRegressor):
original_metric = "l2"
elif isinstance(self, LGBMClassifier):
original_metric = "multi_logloss" if self._n_classes > 2 else "binary_logloss"
elif isinstance(self, LGBMRanker):
original_metric = "ndcg"
# overwrite default metric by explicitly set metric
for metric_alias in _ConfigAliases.get("metric"):
if metric_alias in params:
original_metric = params.pop(metric_alias)
# concatenate metric from params (or default if not provided in params) and eval_metric
original_metric = [original_metric] if isinstance(original_metric, (string_type, type(None))) else original_metric
eval_metric = [eval_metric] if isinstance(eval_metric, (string_type, type(None))) else eval_metric
params['metric'] = [e for e in eval_metric if e not in original_metric] + original_metric
params['metric'] = [metric for metric in params['metric'] if metric is not None]
# Do not modify original args in fit function
# Refer to https://github.com/microsoft/LightGBM/pull/2619
eval_metric_list = copy.deepcopy(eval_metric)
if not isinstance(eval_metric_list, list):
eval_metric_list = [eval_metric_list]

# Separate built-in from callable evaluation metrics
eval_metrics_callable = [_EvalFunctionWrapper(f) for f in eval_metric_list if callable(f)]
eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, string_type)]

# register default metric for consistency with callable eval_metric case
original_metric = self._objective if isinstance(self._objective, string_type) else None
if original_metric is None:
# try to deduce from class instance
if isinstance(self, LGBMRegressor):
original_metric = "l2"
elif isinstance(self, LGBMClassifier):
original_metric = "multi_logloss" if self._n_classes > 2 else "binary_logloss"
elif isinstance(self, LGBMRanker):
original_metric = "ndcg"

# overwrite default metric by explicitly set metric
for metric_alias in _ConfigAliases.get("metric"):
if metric_alias in params:
original_metric = params.pop(metric_alias)

# concatenate metric from params (or default if not provided in params) and eval_metric
original_metric = [original_metric] if isinstance(original_metric, (string_type, type(None))) else original_metric
params['metric'] = [e for e in eval_metrics_builtin if e not in original_metric] + original_metric
params['metric'] = [metric for metric in params['metric'] if metric is not None]

if not isinstance(X, (DataFrame, DataTable)):
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
Expand Down Expand Up @@ -595,7 +604,7 @@ def _get_meta_data(collection, name, i):
self._Booster = train(params, train_set,
self.n_estimators, valid_sets=valid_sets, valid_names=eval_names,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, fobj=self._fobj, feval=feval,
evals_result=evals_result, fobj=self._fobj, feval=eval_metrics_callable,
verbose_eval=verbose, feature_name=feature_name,
callbacks=callbacks, init_model=init_model)

Expand Down
45 changes: 45 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,6 +1803,51 @@ def train_booster(params=params_obj_verbose, **kwargs):
self.assertRaises(lgb.basic.LightGBMError, get_cv_result,
params_class_3_verbose, metrics='binary_error', fobj=dummy_obj)

def test_multiple_feval_train(self):
X, y = load_breast_cancer(return_X_y=True)

params = {'verbose': -1, 'objective': 'binary', 'metric': 'binary_logloss'}

X_train, X_validation, y_train, y_validation = train_test_split(X, y, test_size=0.2)

train_dataset = lgb.Dataset(data=X_train, label=y_train, silent=True)
validation_dataset = lgb.Dataset(data=X_validation, label=y_validation, reference=train_dataset, silent=True)
evals_result = {}
lgb.train(
params=params,
train_set=train_dataset,
valid_sets=validation_dataset,
num_boost_round=5,
feval=[constant_metric, decreasing_metric],
evals_result=evals_result)

self.assertEqual(len(evals_result['valid_0']), 3)
self.assertIn('binary_logloss', evals_result['valid_0'])
self.assertIn('error', evals_result['valid_0'])
self.assertIn('decreasing_metric', evals_result['valid_0'])

def test_multiple_feval_cv(self):
X, y = load_breast_cancer(return_X_y=True)

params = {'verbose': -1, 'objective': 'binary', 'metric': 'binary_logloss'}

train_dataset = lgb.Dataset(data=X, label=y, silent=True)

cv_results = lgb.cv(
params=params,
train_set=train_dataset,
num_boost_round=5,
feval=[constant_metric, decreasing_metric])

# Expect three metrics but mean and stdv for each metric
self.assertEqual(len(cv_results), 6)
self.assertIn('binary_logloss-mean', cv_results)
self.assertIn('error-mean', cv_results)
self.assertIn('decreasing_metric-mean', cv_results)
self.assertIn('binary_logloss-stdv', cv_results)
self.assertIn('error-stdv', cv_results)
self.assertIn('decreasing_metric-stdv', cv_results)

@unittest.skipIf(psutil.virtual_memory().available / 1024 / 1024 / 1024 < 3, 'not enough RAM')
def test_model_size(self):
X, y = load_boston(return_X_y=True)
Expand Down
39 changes: 38 additions & 1 deletion tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def test_metrics(self):
# custom metric for custom objective
gbm = lgb.LGBMRegressor(objective=custom_dummy_obj,
**params).fit(eval_metric=constant_metric, **params_fit)
self.assertEqual(len(gbm.evals_result_['training']), 1)
self.assertEqual(len(gbm.evals_result_['training']), 2)
self.assertIn('error', gbm.evals_result_['training'])

# non-default regression metric with custom metric for custom objective
Expand Down Expand Up @@ -922,6 +922,43 @@ def test_metrics(self):
self.assertEqual(len(gbm.evals_result_['training']), 1)
self.assertIn('binary_logloss', gbm.evals_result_['training'])

def test_multiple_eval_metrics(self):

X, y = load_breast_cancer(return_X_y=True)

params = {'n_estimators': 2, 'verbose': -1, 'objective': 'binary', 'metric': 'binary_logloss'}
params_fit = {'X': X, 'y': y, 'eval_set': (X, y), 'verbose': False}

# Verify that can receive a list of metrics, only callable
gbm = lgb.LGBMClassifier(**params).fit(eval_metric=[constant_metric, decreasing_metric], **params_fit)
self.assertEqual(len(gbm.evals_result_['training']), 3)
self.assertIn('error', gbm.evals_result_['training'])
self.assertIn('decreasing_metric', gbm.evals_result_['training'])
self.assertIn('binary_logloss', gbm.evals_result_['training'])

# Verify that can receive a list of custom and built-in metrics
gbm = lgb.LGBMClassifier(**params).fit(eval_metric=[constant_metric, decreasing_metric, 'fair'], **params_fit)
self.assertEqual(len(gbm.evals_result_['training']), 4)
self.assertIn('error', gbm.evals_result_['training'])
self.assertIn('decreasing_metric', gbm.evals_result_['training'])
self.assertIn('binary_logloss', gbm.evals_result_['training'])
self.assertIn('fair', gbm.evals_result_['training'])

# Verify that works as expected when eval_metric is empty
gbm = lgb.LGBMClassifier(**params).fit(eval_metric=[], **params_fit)
self.assertEqual(len(gbm.evals_result_['training']), 1)
self.assertIn('binary_logloss', gbm.evals_result_['training'])

# Verify that can receive a list of metrics, only built-in
gbm = lgb.LGBMClassifier(**params).fit(eval_metric=['fair', 'error'], **params_fit)
self.assertEqual(len(gbm.evals_result_['training']), 3)
self.assertIn('binary_logloss', gbm.evals_result_['training'])

# Verify that eval_metric is robust to receiving a list with None
gbm = lgb.LGBMClassifier(**params).fit(eval_metric=['fair', 'error', None], **params_fit)
self.assertEqual(len(gbm.evals_result_['training']), 3)
self.assertIn('binary_logloss', gbm.evals_result_['training'])

def test_inf_handle(self):
nrows = 100
ncols = 10
Expand Down

0 comments on commit afc76d2

Please sign in to comment.