Skip to content

Commit

Permalink
Fixes for multiple and default metric (dmlc#1239)
Browse files Browse the repository at this point in the history
* fix multiple evaluation metrics

* create DefaultEvalMetric only when really necessary

* py test for dmlc#1239

* make travis happy
  • Loading branch information
khotilov authored and tqchen committed Jun 5, 2016
1 parent 9ef8607 commit 9a48a40
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ class Booster {

inline void SetParam(const std::string& name, const std::string& val) {
auto it = std::find_if(cfg_.begin(), cfg_.end(),
[&name](decltype(*cfg_.begin()) &x) {
[&name, &val](decltype(*cfg_.begin()) &x) {
if (name == "eval_metric") {
return x.first == name && x.second == val;
}
return x.first == name;
});
if (it == cfg_.end()) {
Expand Down
9 changes: 3 additions & 6 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,6 @@ class LearnerImpl : public Learner {
attributes_ = std::map<std::string, std::string>(
attr.begin(), attr.end());
}
if (metrics_.size() == 0) {
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
}
this->base_score_ = mparam.base_score;
gbm_->ResetPredBuffer(pred_buffer_size_);
cfg_["num_class"] = common::ToString(mparam.num_class);
Expand Down Expand Up @@ -307,6 +304,9 @@ class LearnerImpl : public Learner {
std::ostringstream os;
os << '[' << iter << ']'
<< std::setiosflags(std::ios::fixed);
if (metrics_.size() == 0) {
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
}
for (size_t i = 0; i < data_sets.size(); ++i) {
this->PredictRaw(data_sets[i], &preds_);
obj_->EvalTransform(&preds_);
Expand Down Expand Up @@ -445,9 +445,6 @@ class LearnerImpl : public Learner {

// reset the base score
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
if (metrics_.size() == 0) {
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
}

this->base_score_ = mparam.base_score;
gbm_->ResetPredBuffer(pred_buffer_size_);
Expand Down
10 changes: 10 additions & 0 deletions tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def neg_evalerror(preds, dtrain):
if int(preds2[i] > 0.5) != labels[i]) / float(len(preds2))
assert err == err2

def test_multi_eval_metric(self):
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
param = {'max_depth': 2, 'eta': 0.2, 'silent': 1, 'objective': 'binary:logistic'}
param['eval_metric'] = ["auc", "logloss", 'error']
evals_result = {}
bst = xgb.train(param, dtrain, 4, watchlist, evals_result=evals_result)
assert isinstance(bst, xgb.core.Booster)
assert len(evals_result['eval']) == 3
assert set(evals_result['eval'].keys()) == {'auc', 'error', 'logloss'}

def test_fpreproc(self):
param = {'max_depth': 2, 'eta': 1, 'silent': 1,
'objective': 'binary:logistic'}
Expand Down

0 comments on commit 9a48a40

Please sign in to comment.