Skip to content

Commit

Permalink
identify the HE adding locations
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Feb 23, 2024
1 parent 5d542f8 commit d91be10
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ void ApplyWithLabelsEncrypted(MetaInfo const& info, HostDeviceVector<T>* result,
}
}
// print 1 sample
std::cout << " g[0]: " << result_vector_g[0] << " h[0]: " << result_vector_h[0] << std::endl;
//std::cout << " g[0]: " << result_vector_g[0] << " h[0]: " << result_vector_h[0] << std::endl;
// print max and min
std::cout << "max_g: " << max_g << " min_g: " << min_g << " max_h: " << max_h << " min_h: " << min_h << std::endl;
//std::cout << "max_g: " << max_g << " min_g: " << min_g << " max_h: " << max_h << " min_h: " << min_h << std::endl;
}

result->Resize(size);
Expand Down
67 changes: 62 additions & 5 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,29 @@ class HistogramBuilder {
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection,
common::Span<GradientPair const> gpair_h, bool force_read_by_column) {


if ((collective::GetRank() == 1)) {
std::cout << "Current samples on nodes: " << std::endl;
// print info on all nodes
for (bst_node_t nit = 0; nit < row_set_collection.Size(); ++nit) {
auto size = row_set_collection[nit].Size();
std::cout << "Node " << nit << " has " << size << " rows." << std::endl;
}


for (auto nit = nodes_to_build.begin(); nit != nodes_to_build.end(); ++nit) {
std::cout << "Building local histogram for node ID: " << *nit << " with " << row_set_collection[*nit].Size() << " samples." << std::endl;
}
std::cout << std::endl;

}




// Parallel processing by nodes and data in each node
bool print_once = true;
common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
const auto tid = static_cast<unsigned>(omp_get_thread_num());
bst_node_t const nidx = nodes_to_build[nid_in_set];
Expand All @@ -86,6 +108,19 @@ class HistogramBuilder {
auto rid_set = common::RowSetCollection::Elem(elem.begin + start_of_row_set,
elem.begin + end_of_row_set, nidx);
auto hist = buffer_.GetInitializedHist(tid, nid_in_set);

// print info
//if ((collective::GetRank() == 0) && print_once ) {
//std::cout << "Sample of row set for node " << nidx << ": ";
//std::cout << "Size: " << row_set_collection[nidx].Size() << ", ";
//for (auto i = 0; i < 10; i++) {
// std::cout << rid_set.begin[i] << ", ";
//}
//std::cout << std::endl;
//print_once = false;
//}


if (rid_set.Size() != 0) {
common::BuildHist<any_missing>(gpair_h, rid_set, gidx, hist, force_read_by_column);
}
Expand Down Expand Up @@ -156,6 +191,11 @@ class HistogramBuilder {
if (page_idx == 0) {
// Add the local histogram cache to the parallel buffer before processing the first page.
auto n_nodes = nodes_to_build.size();

if ((collective::GetRank() == 0)) {
std::cout << "Building histogram for " << n_nodes << " nodes" << std::endl;
}

std::vector<common::GHistRow> target_hists(n_nodes);
for (size_t i = 0; i < n_nodes; ++i) {
auto const nidx = nodes_to_build[i];
Expand Down Expand Up @@ -213,27 +253,27 @@ class HistogramBuilder {
std::vector<double> hist_flat;
hist_flat.resize(n);
// iterate through the nodes_to_build
std::cout << "nodes_to_build.size() = " << nodes_to_build.size() << std::endl;
//std::cout << "nodes_to_build.size() = " << nodes_to_build.size() << std::endl;
// front pointer
auto it = reinterpret_cast<double *>(this->hist_[nodes_to_build.front()].data());
auto hist_size = this->hist_[nodes_to_build.front()].size();
std::cout<< "n=" << n << std::endl;
std::cout << "hist_size = " << hist_size << std::endl;
//std::cout<< "n=" << n << std::endl;
//std::cout << "hist_size = " << hist_size << std::endl;
for (size_t i = 0; i < n; i++) {
// get item with iterator
auto item = *it;
hist_flat[i] = item;
it++;
}
std::cout << "hist_flat.size() = " << hist_flat.size() << std::endl;
//std::cout << "hist_flat.size() = " << hist_flat.size() << std::endl;

// Perform AllGather
auto hist_entries = collective::Allgather(hist_flat);

// Update histogram for data owner
if (collective::GetRank() == 0) {
// skip rank 0, as local hist already contains its own entries
std::cout << "hist_entries.size() = " << hist_entries.size() << std::endl;
//std::cout << "hist_entries.size() = " << hist_entries.size() << std::endl;
// reposition iterator to the beginning of the vector
it = reinterpret_cast<double *>(this->hist_[nodes_to_build.front()].data());
for (auto rank_idx = 1; rank_idx < hist_entries.size()/n; rank_idx++) {
Expand Down Expand Up @@ -317,6 +357,10 @@ class MultiHistogramBuilder {
linalg::MatrixView<GradientPair const> gpair, ExpandEntry const &best,
BatchParam const &param, bool force_read_by_column = false) {
auto n_targets = p_tree->NumTargets();


std::cout << "Root n_targets = " << n_targets << std::endl;

CHECK_EQ(gpair.Shape(1), n_targets);
CHECK_EQ(p_fmat->Info().num_row_, gpair.Shape(0));
CHECK_EQ(target_builders_.size(), n_targets);
Expand Down Expand Up @@ -357,6 +401,16 @@ class MultiHistogramBuilder {
std::vector<bst_node_t> nodes_to_sub(valid_candidates.size());
AssignNodes(p_tree, valid_candidates, nodes_to_build, nodes_to_sub);


// print index for nodes_to_build and nodes_to_sub
if (collective::GetRank() == 0) {
for (int i = 0; i < nodes_to_build.size(); i++) {
std::cout<< "Left-Right: nodes_to_build index " << nodes_to_build[i] << "; ";
std::cout<< "nodes_to_sub index " << nodes_to_sub[i] << std::endl;
}
}


// use the first builder for getting number of valid nodes.
target_builders_.front().AddHistRows(p_tree, &nodes_to_build, &nodes_to_sub, true);
CHECK_GE(nodes_to_build.size(), nodes_to_sub.size());
Expand All @@ -373,6 +427,9 @@ class MultiHistogramBuilder {
CHECK_EQ(gpair.Shape(1), p_tree->NumTargets());
for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) {
auto t_gpair = gpair.Slice(linalg::All(), t);
if (collective::GetRank() == 0) {
std::cout<< "Total row count: " << p_fmat->Info().num_row_ << std::endl;
}
CHECK_EQ(t_gpair.Shape(0), p_fmat->Info().num_row_);
this->target_builders_[t].BuildHist(page_idx, space, page,
partitioners[page_idx].Partitions(), nodes_to_build,
Expand Down
12 changes: 11 additions & 1 deletion src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,15 @@ class HistUpdater {
monitor_->Start(__func__);
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0));

this->histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, node,



std::cout<<"InitRoot: --------------------------------------"<<std::endl;




this->histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, node,
HistBatch(param_));

{
Expand Down Expand Up @@ -439,6 +447,8 @@ class HistUpdater {
std::vector<CPUExpandEntry> const &valid_candidates,
linalg::MatrixView<GradientPair const> gpair) {
monitor_->Start(__func__);

std::cout << "BuildHistogram: --------------------------------------" << std::endl;
this->histogram_builder_->BuildHistLeftRight(ctx_, p_fmat, p_tree, partitioner_,
valid_candidates, gpair, HistBatch(param_));
monitor_->Stop(__func__);
Expand Down

0 comments on commit d91be10

Please sign in to comment.