diff --git a/src/learner.cc b/src/learner.cc index 78c6c15ea35d..d3b3a9bfc607 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1287,7 +1287,7 @@ class LearnerImpl : public LearnerIO { monitor_.Stop("GetGradient"); - +/* if(collective::GetRank()==0){ //print the total number of samples std::cout << "Total number of samples: " << train->Info().labels.Size() << std::endl; @@ -1320,7 +1320,7 @@ class LearnerImpl : public LearnerIO { } } } - +*/ @@ -1372,9 +1372,6 @@ class LearnerImpl : public LearnerIO { std::shared_ptr m = data_sets[i]; auto &predt = prediction_container_.Cache(m, ctx_.Device()); this->ValidateDMatrix(m.get(), false); - if(collective::GetRank()==0){ - std::cout << "data size = " << data_sets[i]->Info().num_row_ << std::endl; - } this->PredictRaw(m.get(), &predt, false, 0, 0); auto &out = output_predictions_.Cache(m, ctx_.Device()).predictions; @@ -1383,15 +1380,7 @@ class LearnerImpl : public LearnerIO { obj_->EvalTransform(&out); for (auto& ev : metrics_) { - - auto metric = ev->Evaluate(out, m); - - if(collective::GetRank()==0){ - std::cout << "eval result = " << metric << std::endl; - } - - - os << '\t' << data_names[i] << '-' << ev->Name() << ':' << metric; //ev->Evaluate(out, m); + os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Evaluate(out, m); } } @@ -1496,11 +1485,6 @@ class LearnerImpl : public LearnerIO { CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration"; this->CheckModelInitialized(); this->ValidateDMatrix(data, false); - - if(collective::GetRank()==0){ - std::cout << "PredictRaw training ? " << training << std::endl; - } - gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end); } diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index 2152155fd9ef..c12c5f9c206b 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -117,6 +117,12 @@ class CommonRowPartitioner { if (is_index) { // if the split_pt is already recorded as a bin_id, use it directly (*split_conditions)[i] = static_cast(split_pt); + // at this point, each participants received the best split index, + // therefore can recover the split_pt from bin_id, update tree info + auto split_pt_local = vals[split_pt]; + // make updates to the tree, replacing the existing index + // with cut value, note that we modified const here, carefully + const_cast(tree.GetNodes()[nidx]).SetSplit(fidx, split_pt_local); } else { // otherwise find the bin_id that corresponds to split_pt diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 3a7179704e33..f505fdd78631 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -270,6 +270,7 @@ class HistEvaluator { } else { // secure mode: record the best split point, rather than the actual value since it is not accessible + // at this point (active party finding best-split) best.Update(loss_chg, fidx, i, d_step == -1, false, left_sum, right_sum); }