Skip to content

Commit

Permalink
modify inference behavior of secure vertical from split value to inde…
Browse files Browse the repository at this point in the history
…x for training phase
  • Loading branch information
ZiyueXu77 committed Feb 27, 2024
1 parent 087a8dd commit 5e85438
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 43 deletions.
88 changes: 72 additions & 16 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <limits>
#include <utility>
#include <fstream>

#include "../collective/aggregator.h"
#include "../data/adapter.h"
Expand Down Expand Up @@ -367,7 +368,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
}

template <typename SketchType>
void AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int max_bin,
double AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
// For secure vertical pipeline, we fill the cut values corresponding to empty columns
// with a vector of minimum value
Expand All @@ -388,12 +389,15 @@ void AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int
cut_values.push_back(cpt);
}
}
return cut_values.back();
}
// if empty column, fill the cut values with 0
// if empty column, fill the cut values with NaN
else {
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(0.0);
//cut_values.push_back(0.0);
cut_values.push_back(std::numeric_limits<double>::quiet_NaN());
}
return std::numeric_limits<double>::quiet_NaN();
}
}

Expand Down Expand Up @@ -448,6 +452,7 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
for (size_t fid = 0; fid < reduced.size(); ++fid) {
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
// If vertical and secure mode, we need to sync the max_num_bins aross workers
// to create the same global number of cut point bins for easier future processing
if (info.IsVerticalFederated() && info.IsSecure()) {
collective::Allreduce<collective::Operation::kMax>(&max_num_bins, 1);
}
Expand All @@ -457,17 +462,31 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
} else {
// use special AddCutPoint scheme for secure vertical federated learning
if (info.IsVerticalFederated() && info.IsSecure()) {
AddCutPointSecure<WQSketch>(a, max_num_bins, p_cuts);
double last_value = AddCutPointSecure<WQSketch>(a, max_num_bins, p_cuts);
// push a value that is greater than anything if the feature is not empty
// i.e. if the last value is not NaN
if (!std::isnan(last_value)) {
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
}
else {
// if the feature is empty, push a NaN value
p_cuts->cut_values_.HostVector().push_back(std::numeric_limits<double>::quiet_NaN());
}
}
else {
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts);
// push a value that is greater than anything
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
}
// push a value that is greater than anything
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);

}

// Ensure that every feature gets at least one quantile point
Expand All @@ -477,12 +496,49 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
}

if (info.IsVerticalFederated() && info.IsSecure()) {
// cut values need to be synced across all workers via Allreduce
auto cut_val = p_cuts->cut_values_.HostVector().data();
std::size_t n = p_cuts->cut_values_.HostVector().size();
collective::Allreduce<collective::Operation::kSum>(cut_val, n);
}

/*
// save the cut values and cut pointers to files for examination
if (collective::GetRank() == 0) {
//print the entries to file for debug
std::ofstream file;
file.open("cut_info_0.txt", std::ios_base::app);
file << " Total cut ptr count: " << p_cuts->cut_ptrs_.HostVector().size() << std::endl;
file << " Total cut count: " << p_cuts->cut_values_.HostVector().size() << std::endl;
//iterate through the cut pointers
for (auto i = 0; i < p_cuts->cut_ptrs_.HostVector().size(); i++) {
file << "cut_ptr " << i << ": " << p_cuts->cut_ptrs_.HostVector()[i] << std::endl;
}
//iterate through the cut values
for (auto i = 0; i < p_cuts->cut_values_.HostVector().size(); i++) {
file << "cut_value " << i << ": " << p_cuts->cut_values_.HostVector()[i] << std::endl;
}
file.close();
}
if (collective::GetRank() == 1) {
//print the entries to file for debug
std::ofstream file;
file.open("cut_info_1.txt", std::ios_base::app);
file << " Total cut ptr count: " << p_cuts->cut_ptrs_.HostVector().size() << std::endl;
file << " Total cut count: " << p_cuts->cut_values_.HostVector().size() << std::endl;
//iterate through the cut pointers
for (auto i = 0; i < p_cuts->cut_ptrs_.HostVector().size(); i++) {
file << "cut_ptr " << i << ": " << p_cuts->cut_ptrs_.HostVector()[i] << std::endl;
}
//iterate through the cut values
for (auto i = 0; i < p_cuts->cut_values_.HostVector().size(); i++) {
file << "cut_value " << i << ": " << p_cuts->cut_values_.HostVector()[i] << std::endl;
}
file.close();
}
if (info.IsVerticalFederated() && info.IsSecure()) {
// cut values need to be synced across all workers via Allreduce
auto cut_val = p_cuts->cut_values_.HostVector().data();
std::size_t n = p_cuts->cut_values_.HostVector().size();
collective::Allreduce<collective::Operation::kSum>(cut_val, n);
}
*/

p_cuts->SetCategorical(this->has_categorical_, max_cat);
monitor_.Stop(__func__);
Expand Down
57 changes: 56 additions & 1 deletion src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,45 @@ class LearnerImpl : public LearnerIO {
monitor_.Start("GetGradient");
GetGradient(predt.predictions, train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient");



if(collective::GetRank()==0){
//print the total number of samples
std::cout << "Total number of samples: " << train->Info().labels.Size() << std::endl;
auto i = 0;
// print the first five predictions
for (auto p : predt.predictions.HostVector()) {
std::cout << "Prediction " << i << ": " << p << std::endl;
i++;
if (i == 5) {
break;
}
}

// print the first five labels
std::cout << "Labels: " << std::endl;
i = 0;
while ( i<5 ) {
std::cout << "Label " << i << ": " << train->Info().labels.HostView()(i) << std::endl;
i++;
}

// print the first five gradients
std::cout << "Gradients: " << std::endl;
i = 0;
for (auto p : gpair_.Data()->HostVector()) {
std::cout << "Gradient " << i << ": " << p.GetGrad() << std::endl;
i++;
if (i == 5) {
break;
}
}
}




TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients");

gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get());
Expand Down Expand Up @@ -1333,6 +1372,9 @@ class LearnerImpl : public LearnerIO {
std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = prediction_container_.Cache(m, ctx_.Device());
this->ValidateDMatrix(m.get(), false);
if(collective::GetRank()==0){
std::cout << "data size = " << data_sets[i]->Info().num_row_ << std::endl;
}
this->PredictRaw(m.get(), &predt, false, 0, 0);

auto &out = output_predictions_.Cache(m, ctx_.Device()).predictions;
Expand All @@ -1341,7 +1383,15 @@ class LearnerImpl : public LearnerIO {

obj_->EvalTransform(&out);
for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Evaluate(out, m);

auto metric = ev->Evaluate(out, m);

if(collective::GetRank()==0){
std::cout << "eval result = " << metric << std::endl;
}


os << '\t' << data_names[i] << '-' << ev->Name() << ':' << metric; //ev->Evaluate(out, m);
}
}

Expand Down Expand Up @@ -1446,6 +1496,11 @@ class LearnerImpl : public LearnerIO {
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
this->CheckModelInitialized();
this->ValidateDMatrix(data, false);

if(collective::GetRank()==0){
std::cout << "PredictRaw training ? " << training << std::endl;
}

gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end);
}

Expand Down
42 changes: 28 additions & 14 deletions src/tree/common_row_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class CommonRowPartitioner {

CommonRowPartitioner() = default;
CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid,
bool is_col_split)
: base_rowid{_base_rowid}, is_col_split_{is_col_split} {
bool is_col_split, bool is_secure)
: base_rowid{_base_rowid}, is_col_split_{is_col_split}, is_secure_{is_secure} {
row_set_collection_.Clear();
std::vector<size_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(num_row);
Expand All @@ -106,26 +106,33 @@ class CommonRowPartitioner {

template <typename ExpandEntry>
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions, bool is_index) {
auto const& ptrs = gmat.cut.Ptrs();
auto const& vals = gmat.cut.Values();

for (std::size_t i = 0; i < nodes.size(); ++i) {
bst_node_t const nidx = nodes[i].nid;
bst_feature_t const fidx = tree.SplitIndex(nidx);
float const split_pt = tree.SplitCond(nidx);
std::uint32_t const lower_bound = ptrs[fidx];
std::uint32_t const upper_bound = ptrs[fidx + 1];
bst_bin_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (auto bound = lower_bound; bound < upper_bound; ++bound) {
if (split_pt == vals[bound]) {
split_cond = static_cast<bst_bin_t>(bound);
if (is_index) {
// if the split_pt is already recorded as a bin_id, use it directly
(*split_conditions)[i] = static_cast<int32_t>(split_pt);
}
else {
// otherwise find the bin_id that corresponds to split_pt
std::uint32_t const lower_bound = ptrs[fidx];
std::uint32_t const upper_bound = ptrs[fidx + 1];
bst_bin_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (auto bound = lower_bound; bound < upper_bound; ++bound) {
if (split_pt == vals[bound]) {
split_cond = static_cast<bst_bin_t>(bound);
}
}
(*split_conditions)[i] = split_cond;
}
(*split_conditions)[i] = split_cond;
}
}

Expand Down Expand Up @@ -194,7 +201,13 @@ class CommonRowPartitioner {
std::vector<int32_t> split_conditions;
if (column_matrix.IsInitialized()) {
split_conditions.resize(n_nodes);
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
if (is_secure_) {
// in secure mode, the split index is kept instead of the split value
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions, true);
}
else {
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions, false);
}
}

// 2.1 Create a blocked space of size SUM(samples in each node)
Expand Down Expand Up @@ -294,6 +307,7 @@ class CommonRowPartitioner {
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
common::RowSetCollection row_set_collection_;
bool is_col_split_;
bool is_secure_;
ColumnSplitHelper column_split_helper_;
};

Expand Down
34 changes: 25 additions & 9 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,36 @@ class HistEvaluator {
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
if (!is_secure_) {
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
}
else {
// secure mode: record the best split point, rather than the actual value since it is not accessible
best.Update(loss_chg, fidx, i, d_step == -1, false, left_sum, right_sum);
}

} else {
// backward enumeration: split at left bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) -
parent.root_gain);
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
if (!is_secure_) {
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
else {
// secure mode: record the best split point, rather than the actual value since it is not accessible
if (i != imin) {
i = i - 1;
}
best.Update(loss_chg, fidx, i, d_step == -1, false, right_sum, left_sum);
}
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
}
}
Expand Down Expand Up @@ -387,7 +403,7 @@ class HistEvaluator {
auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);

// print the best split for each feature
// std::cout << "Best split for feature " << fidx << " is " << best->split_value << " with gain " << best->loss_chg << std::endl;
//std::cout << "Current best split at feature " << fidx << " is: " << std::endl << *best << std::endl;


if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
Expand All @@ -408,7 +424,7 @@ class HistEvaluator {

if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
// expand entries accordingly. Update() takes care of selecting the best one.
// Note that under secure vertical setting, only the label owner is able to evaluate the split
// based on the global histogram. The other parties will receive the final best splits
// allgather is capable of performing this (0-gain entries for non-label owners),
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class GloablApproxBuilder {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
p_fmat->Info().IsColumnSplit());
p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure());
n_batches_++;
}

Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class MultiTargetHistBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure());
}

bst_target_t n_targets = p_tree->NumTargets();
Expand Down Expand Up @@ -355,7 +355,7 @@ class HistUpdater {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
fmat->Info().IsColumnSplit());
fmat->Info().IsColumnSplit(), fmat->Info().IsSecure());
}
histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(),
fmat->Info().IsColumnSplit(), fmat->Info().IsSecure(), hist_param_);
Expand Down

0 comments on commit 5e85438

Please sign in to comment.