Skip to content

Commit

Permalink
fix the logic for secure vertical inference, each client save a diffe…
Browse files Browse the repository at this point in the history
…rent model
  • Loading branch information
ZiyueXu77 committed Feb 27, 2024
1 parent 5e85438 commit e008818
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 19 deletions.
22 changes: 3 additions & 19 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ class LearnerImpl : public LearnerIO {
monitor_.Stop("GetGradient");



/*
if(collective::GetRank()==0){
//print the total number of samples
std::cout << "Total number of samples: " << train->Info().labels.Size() << std::endl;
Expand Down Expand Up @@ -1320,7 +1320,7 @@ class LearnerImpl : public LearnerIO {
}
}
}

*/



Expand Down Expand Up @@ -1372,9 +1372,6 @@ class LearnerImpl : public LearnerIO {
std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = prediction_container_.Cache(m, ctx_.Device());
this->ValidateDMatrix(m.get(), false);
if(collective::GetRank()==0){
std::cout << "data size = " << data_sets[i]->Info().num_row_ << std::endl;
}
this->PredictRaw(m.get(), &predt, false, 0, 0);

auto &out = output_predictions_.Cache(m, ctx_.Device()).predictions;
Expand All @@ -1383,15 +1380,7 @@ class LearnerImpl : public LearnerIO {

obj_->EvalTransform(&out);
for (auto& ev : metrics_) {

auto metric = ev->Evaluate(out, m);

if(collective::GetRank()==0){
std::cout << "eval result = " << metric << std::endl;
}


os << '\t' << data_names[i] << '-' << ev->Name() << ':' << metric; //ev->Evaluate(out, m);
os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Evaluate(out, m);
}
}

Expand Down Expand Up @@ -1496,11 +1485,6 @@ class LearnerImpl : public LearnerIO {
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
this->CheckModelInitialized();
this->ValidateDMatrix(data, false);

if(collective::GetRank()==0){
std::cout << "PredictRaw training ? " << training << std::endl;
}

gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end);
}

Expand Down
6 changes: 6 additions & 0 deletions src/tree/common_row_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ class CommonRowPartitioner {
if (is_index) {
// if the split_pt is already recorded as a bin_id, use it directly
(*split_conditions)[i] = static_cast<int32_t>(split_pt);
// at this point, each participants received the best split index,
// therefore can recover the split_pt from bin_id, update tree info
auto split_pt_local = vals[split_pt];
// make updates to the tree, replacing the existing index
// with cut value, note that we modified const here, carefully
const_cast<RegTree::Node&>(tree.GetNodes()[nidx]).SetSplit(fidx, split_pt_local);
}
else {
// otherwise find the bin_id that corresponds to split_pt
Expand Down
1 change: 1 addition & 0 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class HistEvaluator {
}
else {
// secure mode: record the best split point, rather than the actual value since it is not accessible
// at this point (active party finding best-split)
best.Update(loss_chg, fidx, i, d_step == -1, false, left_sum, right_sum);
}

Expand Down

0 comments on commit e008818

Please sign in to comment.