Skip to content

Commit

Permalink
code clean
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Feb 27, 2024
1 parent 1fd1fb0 commit 72159b9
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 82 deletions.
44 changes: 0 additions & 44 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,50 +496,6 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
}


/*
// save the cut values and cut pointers to files for examination
if (collective::GetRank() == 0) {
//print the entries to file for debug
std::ofstream file;
file.open("cut_info_0.txt", std::ios_base::app);
file << " Total cut ptr count: " << p_cuts->cut_ptrs_.HostVector().size() << std::endl;
file << " Total cut count: " << p_cuts->cut_values_.HostVector().size() << std::endl;
//iterate through the cut pointers
for (auto i = 0; i < p_cuts->cut_ptrs_.HostVector().size(); i++) {
file << "cut_ptr " << i << ": " << p_cuts->cut_ptrs_.HostVector()[i] << std::endl;
}
//iterate through the cut values
for (auto i = 0; i < p_cuts->cut_values_.HostVector().size(); i++) {
file << "cut_value " << i << ": " << p_cuts->cut_values_.HostVector()[i] << std::endl;
}
file.close();
}
if (collective::GetRank() == 1) {
//print the entries to file for debug
std::ofstream file;
file.open("cut_info_1.txt", std::ios_base::app);
file << " Total cut ptr count: " << p_cuts->cut_ptrs_.HostVector().size() << std::endl;
file << " Total cut count: " << p_cuts->cut_values_.HostVector().size() << std::endl;
//iterate through the cut pointers
for (auto i = 0; i < p_cuts->cut_ptrs_.HostVector().size(); i++) {
file << "cut_ptr " << i << ": " << p_cuts->cut_ptrs_.HostVector()[i] << std::endl;
}
//iterate through the cut values
for (auto i = 0; i < p_cuts->cut_values_.HostVector().size(); i++) {
file << "cut_value " << i << ": " << p_cuts->cut_values_.HostVector()[i] << std::endl;
}
file.close();
}
if (info.IsVerticalFederated() && info.IsSecure()) {
// cut values need to be synced across all workers via Allreduce
auto cut_val = p_cuts->cut_values_.HostVector().data();
std::size_t n = p_cuts->cut_values_.HostVector().size();
collective::Allreduce<collective::Operation::kSum>(cut_val, n);
}
*/

p_cuts->SetCategorical(this->has_categorical_, max_cat);
monitor_.Stop(__func__);
}
Expand Down
38 changes: 0 additions & 38 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1286,44 +1286,6 @@ class LearnerImpl : public LearnerIO {
GetGradient(predt.predictions, train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient");


/*
if(collective::GetRank()==0){
//print the total number of samples
std::cout << "Total number of samples: " << train->Info().labels.Size() << std::endl;
auto i = 0;
// print the first five predictions
for (auto p : predt.predictions.HostVector()) {
std::cout << "Prediction " << i << ": " << p << std::endl;
i++;
if (i == 5) {
break;
}
}
// print the first five labels
std::cout << "Labels: " << std::endl;
i = 0;
while ( i<5 ) {
std::cout << "Label " << i << ": " << train->Info().labels.HostView()(i) << std::endl;
i++;
}
// print the first five gradients
std::cout << "Gradients: " << std::endl;
i = 0;
for (auto p : gpair_.Data()->HostVector()) {
std::cout << "Gradient " << i << ": " << p.GetGrad() << std::endl;
i++;
if (i == 5) {
break;
}
}
}
*/



TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients");

gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get());
Expand Down

0 comments on commit 72159b9

Please sign in to comment.