Skip to content

Commit

Permalink
fix single machine gbdt
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyu1994 committed Oct 25, 2024
1 parent 4bb4411 commit b56b39e
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 6 deletions.
1 change: 1 addition & 0 deletions include/LightGBM/cuda/cuda_column_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <LightGBM/bin.h>
#include <LightGBM/utils/openmp_wrapper.h>

#include <memory>
#include <vector>

namespace LightGBM {
Expand Down
1 change: 1 addition & 0 deletions include/LightGBM/cuda/cuda_nccl_topology.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <LightGBM/network.h>
#include <LightGBM/utils/common.h>

#include <memory>
#include <set>
#include <string>
#include <vector>
Expand Down
2 changes: 2 additions & 0 deletions src/boosting/cuda/nccl_gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ void NCCLGBDT<GBDT_T>::Init(
const std::vector<const Metric*>& training_metrics) {
GBDT_T::Init(gbdt_config, train_data, objective_function, training_metrics);

this->tree_learner_.reset();

nccl_topology_.reset(new NCCLTopology(this->config_->gpu_device_id, this->config_->num_gpus, this->config_->gpu_device_id_list, train_data->num_data()));

nccl_topology_->InitNCCL();
Expand Down
5 changes: 3 additions & 2 deletions src/boosting/cuda/nccl_gbdt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

#ifdef USE_CUDA

#include <LightGBM/objective_function.h>
#include <LightGBM/network.h>

#include "cuda_score_updater.hpp"
#include "nccl_gbdt_component.hpp"

#include <LightGBM/cuda/cuda_nccl_topology.hpp>
#include <LightGBM/objective_function.h>
#include <LightGBM/network.h>

#include <pthread.h>
#include <memory>
Expand Down
8 changes: 4 additions & 4 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective() &&
!data_sample_strategy_->IsHessianChange(); // for sample strategy with Hessian change, fall back to boosting on CPU

tree_learner_ = nullptr; // std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type,
// config_.get(), boosting_on_gpu_));
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type,
config_.get(), boosting_on_gpu_));

// init tree learner
// tree_learner_->Init(train_data_, is_constant_hessian_);
// tree_learner_->SetForcedSplit(&forced_splits_json_);
tree_learner_->Init(train_data_, is_constant_hessian_);
tree_learner_->SetForcedSplit(&forced_splits_json_);

// push training metrics
training_metrics_.clear();
Expand Down

0 comments on commit b56b39e

Please sign in to comment.