Skip to content

Commit

Permalink
Merge pull request dmlc#602 from Far0n/cv
Browse files Browse the repository at this point in the history
early stopping for CV (python) issue dmlc#529
  • Loading branch information
terrytangyuan committed Nov 7, 2015
2 parents 190e58a + 95cc900 commit a3a4439
Showing 1 changed file with 41 additions and 2 deletions.
43 changes: 41 additions & 2 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ def aggcv(rlist, show_stdv=True, show_progress=None, as_pandas=True):


def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
obj=None, feval=None, fpreproc=None, as_pandas=True,
show_progress=None, show_stdv=True, seed=0):
obj=None, feval=None, maximize=False, early_stopping_rounds=None,
fpreproc=None, as_pandas=True, show_progress=None, show_stdv=True, seed=0):
# pylint: disable = invalid-name
"""Cross-validation with given paramaters.
Expand All @@ -313,6 +313,12 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
Custom objective function.
feval : function
Custom evaluation function.
maximize : bool
Whether to maximize feval.
early_stopping_rounds: int
Activates early stopping. CV error needs to decrease at least
every <early_stopping_rounds> round(s) to continue.
Last entry in evaluation history is the one from best iteration.
fpreproc : function
Preprocessing function that takes (dtrain, dtest, param) and returns
transformed versions of those.
Expand All @@ -332,6 +338,28 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
-------
evaluation history : list(string)
"""
if early_stopping_rounds is not None:
if len(metrics) > 1:
raise ValueError('Check your params.'\
'Early stopping works with single eval metric only.')

sys.stderr.write("Will train until cv error hasn't decreased in {} rounds.\n".format(\
early_stopping_rounds))

maximize_score = False
if len(metrics) == 1:
maximize_metrics = ('auc', 'map', 'ndcg')
if any(metrics[0].startswith(x) for x in maximize_metrics):
maximize_score = True
if feval is not None:
maximize_score = maximize

if maximize_score:
best_score = 0.0
else:
best_score = float('inf')

best_score_i = 0
results = []
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc)
for i in range(num_boost_round):
Expand All @@ -342,6 +370,17 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
as_pandas=as_pandas)
results.append(res)

if early_stopping_rounds is not None:
score = res[0]
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
best_score = score
best_score_i = i
elif i - best_score_i >= early_stopping_rounds:
sys.stderr.write("Stopping. Best iteration: {}\n".format(best_score_i))
results = results[:best_score_i+1]
break

if as_pandas:
try:
import pandas as pd
Expand Down

0 comments on commit a3a4439

Please sign in to comment.